背景
本文实现一个类似于nginx或gateway的反向代理网关,实现思路是访客通过网络请求反向代理服务,代理服务连接到真实服务,维护访客和真实服务的数据交互。
这个实现和之前的内网穿透项目思路相似,只不过内网穿透是由客户端主动跟代理服务维护连接,这个是代理服务主动和真实服务连接。
本文用MySQL做真实服务的配置功能,只实现了配置IP、端口和请求路径的配置。
实践
项目结构
1、cc-gateway项目:实现反向代理功能
原理分析
反向代理的实现过程主要分两步
1、启动服务端,这时代理服务监听8888端口(默认8888)
2、访客通过访问代理服务8888端口(例如http://127.0.0.1:8888/sso),代理服务接收到请求后解析请求路径得到(/sso),根据这个路径查询数据库配置,如果匹配到(/sso)对应的真实服务的IP和端口,那么代理服务会发起与真实服务的连接,并建立访客和真实服务的数据传输通道。
这两步最终形成了(访客-代理-真实服务)完整的通道。
代码实现
数据库设计
建表
CREATE TABLE `server` (`id` varchar(32) NOT NULL COMMENT '主键ID',`name` varchar(255) DEFAULT NULL COMMENT '服务名称',`code` varchar(255) DEFAULT NULL COMMENT '服务标识',`ip` varchar(255) DEFAULT NULL COMMENT '服务IP',`port` int(11) DEFAULT NULL COMMENT '服务端口',`ip_type` varchar(255) DEFAULT NULL COMMENT '服务IP类型(ipv4,ipv6)',`weight` double DEFAULT NULL COMMENT '服务权重',`status` int(1) DEFAULT NULL COMMENT '服务状态(运行中、掉线)',`able` int(1) DEFAULT NULL COMMENT '操作状态(启用、禁用)',`gray` varchar(255) DEFAULT NULL COMMENT '灰度信息',`sno` int(11) DEFAULT NULL COMMENT '排序',PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='服务信息表';
填点数据
### 数据以自己真实服务的配置修改
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('1', 'im', '/im', '10.0.0.2', 8889, '1', 1, 1, 1, '', 1);
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('2', '/', '/', '10.0.0.2', 8886, '1', 1, 1, 1, '', 1);
INSERT INTO `server`(`id`, `name`, `code`, `ip`, `port`, `ip_type`, `weight`, `status`, `able`, `gray`, `sno`) VALUES ('3', 'sso', '/sso', '10.0.0.3', 8885, '1', 1, 1, 1, '', 1);
cc-gateway项目
pom文件
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><groupId>com.cc</groupId><artifactId>cc-gateway</artifactId><version>1.0-SNAPSHOT</version><name>cc-gateway</name><url>http://maven.apache.org</url><properties><java.home>${env.JAVA_HOME}</java.home><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><java.version>1.8</java.version></properties><dependencies><dependency><groupId>io.netty</groupId><artifactId>netty-all</artifactId><version>4.1.74.Final</version></dependency><!-- mysql驱动 --><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>5.1.38</version></dependency><dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.8</version></dependency></dependencies><build><plugins><plugin><artifactId>maven-compiler-plugin</artifactId><configuration><source>1.8</source><target>1.8</target></configuration></plugin><plugin><artifactId>maven-assembly-plugin</artifactId><version>3.0.0</version><configuration><archive><manifest><mainClass>com.cc.gw.MainApp</mainClass></manifest></archive><descriptorRefs><descriptorRef>jar-with-dependencies</descriptorRef></descriptorRefs></configuration><executions><execution><id>make-assembly</id><phase>package</phase><goals><goal>single</goal></goals></execution></executions></plugin></plugins></build>
</project>
工具类
mysql工具,主要查询配置
package com.cc.gw.util;import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;public class SQLUtil {private static String url = "jdbc:mysql://127.0.0.1:3306/serverdb?allowMultiQueries=true&useUnicode=true&characterEncoding=utf8&useSSL=false";private static String username = "root";private static String password = "123456";/*** 处理查询* @param sqlStr 查询语句* @return*/public static List<Map<String, Object>> query(String sqlStr) {List<Map<String, Object>> list = new ArrayList<Map<String, Object>>();try {Connection con = DriverManager.getConnection(url, username, password);Statement stmt = con.createStatement();ResultSet rs = stmt.executeQuery(sqlStr);if (null != rs) {ResultSetMetaData md = rs.getMetaData(); //获得结果集结构信息,元数据int columnCount = md.getColumnCount(); //获得列数 while (rs.next()) {Map<String, Object> rowData = new HashMap<String, Object>();for (int i = 1; i <= columnCount; i++) {rowData.put(md.getColumnName(i), rs.getObject(i));}list.add(rowData);}try {rs.close();} catch (SQLException e) {e.printStackTrace();}}if (stmt != null) { // 关闭声明 try {stmt.close();} catch (Exception e) {e.printStackTrace();}}if (con != null) { // 关闭连接对象 try {con.close();} catch (Exception e) {e.printStackTrace();}}} catch (Exception se) {System.out.println("数据库连接失败!");se.printStackTrace();}return list;}
}
HTTP协议工具,支持http和websocket协议
package com.cc.gw.util;import java.nio.charset.StandardCharsets;import io.netty.util.internal.StringUtil;public class HttpProtocolUtil {private static final int MIN_REQUEST_LINE_LENGTH = 14; // 最小的请求行长度,例如 "GET / HTTP/1." /*** 获取请求的URI,根据URI去匹配后端服务列表,如果不是合法请求,返回null* @param bytes 请求头字节数组* @return*/public static String getRequestURI(byte[] bytes) {if (null == bytes || bytes.length < MIN_REQUEST_LINE_LENGTH) {return null;}String requestStr = new String(bytes, StandardCharsets.UTF_8);String[] lines = requestStr.split("\r\n");// 读取第一行 String request = lines[0];String[] split = request.split(" ");if (split.length != 3) {return null;}String method = split[0];if (!("GET".equals(method) || "POST".equals(method))) {return null;}String version = split[2];if (!version.startsWith("HTTP/1.")) {return null;}String uri = split[1];if (uri.startsWith("/") && requestStr.toLowerCase().contains("connection: upgrade")) {return "ws://1.1" + uri;}return uri;}/*** 获取请求的路径,根据路径去匹配后端服务列表* @param path 整个请求path* @return*/public static String getContextPath(String path) {if (StringUtil.isNullOrEmpty(path)) {return null;}int endLength = path.length();if (path.contains("?")) {endLength = path.indexOf("?");}String uri = null;if (path.startsWith("ws")) {uri = path.substring(path.indexOf("/", 7), endLength);} else if (path.startsWith("/")) {uri = path.substring(0, endLength);}if (null != uri) {uri = uri.replaceAll("//+", "/");return uri;}return null;}
}
配置类
服务实体信息
package com.cc.gw.domain;/*** 服务信息*/
public class RealServerEntity {// 服务IDprivate String id;// 服务名称private String name;// 服务标识private String code;// 服务IPprivate String ip;// 服务端口private Integer port;// 服务IP类型(ipv4,ipv6)private Integer ipType;// 服务权重private Double weight;// 灰度信息private String gray;// 服务状态(运行中、掉线)private Integer status;// 操作状态(启用、禁用)private Integer able;// 排序private Integer sno;public String getId() {return id;}public void setId(String id) {this.id = id;}public String getName() {return name;}public void setName(String name) {this.name = name;}public String getCode() {return code;}public void setCode(String code) {this.code = code;}public String getIp() {return ip;}public void setIp(String ip) {this.ip = ip;}public Integer getPort() {return port;}public void setPort(Integer port) {this.port = port;}public Integer getIpType() {return ipType;}public void setIpType(Integer ipType) {this.ipType = ipType;}public Double getWeight() {return weight;}public void setWeight(Double weight) {this.weight = weight;}public String getGray() {return gray;}public void setGray(String gray) {this.gray = gray;}public Integer getStatus() {return status;}public void setStatus(Integer status) {this.status = status;}public Integer getAble() {return able;}public void setAble(Integer able) {this.able = able;}public Integer getSno() {return sno;}public void setSno(Integer sno) {this.sno = sno;}@Overridepublic String toString() {return "RealServerEntity [id=" + id + ", name=" + name + ", code=" + code + ", ip=" + ip + ", port=" + port + ", ipType=" + ipType + ", weight=" + weight + ", gray=" + gray + ", status=" + status + ", able=" + able + ", sno=" + sno + "]";}}
常量
package com.cc.gw.config;import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;import com.cc.gw.domain.RealServerEntity;import io.netty.channel.Channel;
import io.netty.util.AttributeKey;public class Constant {public static final int SERVER_PORT = 8888;public static final Map<String, List<RealServerEntity>> ALL_SERVERS = new ConcurrentHashMap<>();/** 绑定channel */public static final AttributeKey<Channel> C = AttributeKey.newInstance("c");/** 绑定协议类型 */public static final AttributeKey<String> T = AttributeKey.newInstance("t");
}
服务配置接口
package com.cc.gw.service;import java.util.List;import com.cc.gw.domain.RealServerEntity;public interface IRealServerService {/*** 获取所有在线服务* @return*/List<RealServerEntity> getAllServers();/*** 根据服务名称获取服务信息* @param serverCode 服务名称* @return*/RealServerEntity getServer(String serverCode);/*** 根据服务名称和元数据获取服务信息* @param serverCode 服务名称* @param meta 元数据* @return*/RealServerEntity getServer(String serverCode, String meta);
}
服务配置接口实现
package com.cc.gw.service.impl;import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;import com.cc.gw.config.Constant;
import com.cc.gw.domain.RealServerEntity;
import com.cc.gw.service.IRealServerService;
import com.cc.gw.util.SQLUtil;import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.TypeReference;
import cn.hutool.json.JSONUtil;public class RealServerServiceImpl implements IRealServerService {@Overridepublic List<RealServerEntity> getAllServers() {String sqlStr = "SELECT s.id, s.`name`, s.`code`, s.ip, s.`port`, s.ip_type, s.weight, s.`status`, s.able, s.gray, s.sno FROM `server` AS s WHERE s.able = 1 ORDER BY s.sno";List<Map<String, Object>> query = SQLUtil.query(sqlStr);if (CollectionUtil.isNotEmpty(query)) {List<RealServerEntity> servers = JSONUtil.toBean(JSONUtil.toJsonStr(query), new TypeReference<List<RealServerEntity>>() {}, true);if (CollectionUtil.isNotEmpty(servers)) {Map<String, List<RealServerEntity>> collect = servers.stream().collect(Collectors.groupingBy(RealServerEntity::getCode));synchronized (Constant.ALL_SERVERS) {Constant.ALL_SERVERS.clear();collect.forEach((k, v) -> {Constant.ALL_SERVERS.put(k, v);});}}return servers;}return null;}@Overridepublic RealServerEntity getServer(String serverCode) {List<RealServerEntity> servers = getAllServers();if (CollectionUtil.isNotEmpty(servers)) {// TODO 负载均衡、灰度等策略return servers.stream().filter(o -> serverCode.equals(o.getCode())).findFirst().orElse(null);}return null;}@Overridepublic RealServerEntity getServer(String serverCode, String meta) {List<RealServerEntity> servers = getAllServers();if (CollectionUtil.isNotEmpty(servers)) {// TODO 负载均衡、灰度等策略return servers.stream().filter(o -> serverCode.equals(o.getCode())).findFirst().orElse(null);}return null;}}
定时器任务,定时查询数据库获取最新的配置信息
package com.cc.gw.config;import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;import com.cc.gw.service.IRealServerService;
import com.cc.gw.service.impl.RealServerServiceImpl;import io.netty.channel.Channel;public class ScheduledTasks {private static final ScheduledExecutorService reconnectExecutor = Executors.newSingleThreadScheduledExecutor();public static Channel refreshServerList() throws Exception {reconnectExecutor.scheduleAtFixedRate(new Runnable() {@Overridepublic void run() {try {IRealServerService realServerService = new RealServerServiceImpl();realServerService.getAllServers();} catch (Exception e) {e.printStackTrace();}}}, 3, 5, TimeUnit.SECONDS);return null;}
}
代理服务类
真实服务处理类
package com.cc.gw.socket;import com.cc.gw.config.Constant;import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;public class RealServerHandler extends SimpleChannelInboundHandler<ByteBuf> {public Channel proxyChannel;public RealServerHandler(Channel proxyChannel) {this.proxyChannel = proxyChannel;}@Overridepublic void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) {byte[] bytes = new byte[buf.readableBytes()];buf.readBytes(bytes);ByteBuf byteBuf = ctx.alloc().buffer(bytes.length);byteBuf.writeBytes(bytes);proxyChannel.writeAndFlush(byteBuf);}@Overridepublic void channelInactive(ChannelHandlerContext ctx) throws Exception {super.channelInactive(ctx);proxyChannel.attr(Constant.C).set(null);}
}
反向代理处理类,实现整个反向代理的主要功能
package com.cc.gw.socket;import java.util.List;
import java.util.stream.Collectors;import com.cc.gw.config.Constant;
import com.cc.gw.domain.RealServerEntity;
import com.cc.gw.service.IRealServerService;
import com.cc.gw.service.impl.RealServerServiceImpl;
import com.cc.gw.util.HttpProtocolUtil;import cn.hutool.core.collection.CollectionUtil;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.internal.StringUtil;public class ReverseProxyHandler extends SimpleChannelInboundHandler<ByteBuf> {static EventLoopGroup eventLoopGroup = new NioEventLoopGroup();@Overridepublic void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {byte[] bytes = new byte[buf.readableBytes()];buf.readBytes(bytes);Channel realChannel = ctx.channel().attr(Constant.C).get();String protocol = ctx.channel().attr(Constant.T).get();if (null != realChannel && realChannel.isActive() && "ws".equals(protocol)) {sendMsg(ctx, bytes, realChannel);return;}String uri = HttpProtocolUtil.getRequestURI(bytes);if (StringUtil.isNullOrEmpty(uri)) {ctx.close();return;}String contextPath = HttpProtocolUtil.getContextPath(uri);if (StringUtil.isNullOrEmpty(contextPath)) {ctx.close();return;}checkServers();List<String> collect = Constant.ALL_SERVERS.keySet().stream().filter(o -> contextPath.startsWith(o)).sorted().collect(Collectors.toList());String serverCode = null;if (CollectionUtil.isNotEmpty(collect)) {serverCode = collect.get(collect.size() - 1);}if (StringUtil.isNullOrEmpty(serverCode)) {ctx.close();return;}if (uri.startsWith("ws")) {// 长连接if (null == realChannel || !realChannel.isActive()) {createSocket(ctx, bytes, serverCode);ctx.channel().attr(Constant.T).set("ws");} else {sendMsg(ctx, bytes, realChannel);}} else {// 短连接// TODO 短时间内可复用链接createSocket(ctx, bytes, serverCode);}}/*** 创建一个socket链接* @param ctx 当前会话* @param bytes 向后台服务传输的数据* @param serverCode 后台服务CODE*/private void createSocket(ChannelHandlerContext ctx, byte[] bytes, String serverCode) {try {List<RealServerEntity> list = Constant.ALL_SERVERS.get(serverCode);RealServerEntity server = list.get(0);Bootstrap bootstrap = new Bootstrap();bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() {@Overridepublic void initChannel(SocketChannel ch) throws Exception {ChannelPipeline pipeline = ch.pipeline();pipeline.addLast(new RealServerHandler(ctx.channel()));}});bootstrap.connect(server.getIp(), server.getPort()).addListener(new ChannelFutureListener() {@Overridepublic void operationComplete(ChannelFuture future) throws Exception {if (future.isSuccess()) {Channel channel = future.channel();ctx.channel().attr(Constant.C).set(channel);sendMsg(ctx, bytes, channel);}}});} catch (Exception e) {e.printStackTrace();}}/*** 如果还未初始化服务列表,进行初始化*/private void checkServers() {if (Constant.ALL_SERVERS.isEmpty()) {IRealServerService realServerService = new RealServerServiceImpl();realServerService.getAllServers();}}private void sendMsg(ChannelHandlerContext ctx, byte[] bytes, Channel channel) {ByteBuf byteBuf = ctx.alloc().buffer(bytes.length);byteBuf.writeBytes(bytes);channel.writeAndFlush(byteBuf);}@Overridepublic void channelInactive(ChannelHandlerContext ctx) throws Exception {super.channelInactive(ctx);try {Channel channel = ctx.channel().attr(Constant.C).get();if (null != channel && channel.isActive()) {channel.close();}ctx.channel().attr(Constant.C).set(null);} catch (Exception e) {e.printStackTrace();}}
}
反向代理服务类
package com.cc.gw.socket;import com.cc.gw.config.Constant;import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;public class ReverseProxySocket {/*** 启动服务代理* * @throws Exception*/public static void startServer() throws Exception {EventLoopGroup bossGroup = new NioEventLoopGroup();EventLoopGroup workerGroup = new NioEventLoopGroup();try {ServerBootstrap b = new ServerBootstrap();b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<SocketChannel>() {@Overridepublic void initChannel(SocketChannel ch) throws Exception {ChannelPipeline pipeline = ch.pipeline();pipeline.addLast(new ReverseProxyHandler());}});ChannelFuture f = b.bind(Constant.SERVER_PORT).sync();f.channel().closeFuture().sync();} finally {workerGroup.shutdownGracefully();bossGroup.shutdownGracefully();}}}
启动类
package com.cc.gw;import com.cc.gw.config.ScheduledTasks;
import com.cc.gw.socket.ReverseProxySocket;public class MainApp {public static void main(String[] args) throws Exception {ScheduledTasks.refreshServerList();ReverseProxySocket.startServer();}
}
使用
启动服务端
java -jar cc-gateway.jar