netty使用redis发布订阅实现消息推送
场景
项目中需要给用户推送消息:
接口
@RestController
public class PushApi {@Autowiredprivate PushService pushService;/*** 消息推送* @param query* @return*/@PostMapping("/push/message")public String push(@RequestBody MessagePushConfigDto query){pushService.push(query);return "success";}}@Component
@Slf4j
public class PushService {@Autowiredprivate StringRedisTemplate redisTemplate;@Autowiredprivate MessageService messageService;public void push(MessagePushConfigDto query) {String messageNo = UUID.randomUUID().toString();if (query.getType()== Constants.MSG_TYPE_ALL){doPushGroup(query, messageNo);}else {doPushToUser(query, messageNo);}}private void doPushGroup(MessagePushConfigDto query, String messageNo) {MessageDto dto = new MessageDto();dto.setModule(query.getModule());dto.setType(query.getType());dto.setMessageNo(messageNo);dto.setContent(query.getContent());//转发至其他节点redisTemplate.convertAndSend(Constants.TOPIC_MODULE, JSON.toJSONString(dto));}private void doPushToUser(MessagePushConfigDto query, String messageNo) {for (String identityNo : query.getIdentityList()) {MessageDto dto = new MessageDto();dto.setModule(query.getModule());dto.setType(query.getType());dto.setMessageNo(messageNo);dto.setContent(query.getContent());dto.setIdentityNo(identityNo);String key = MessageFormat.format(Constants.USER_KEY, query.getModule(),identityNo);String nodeIp = redisTemplate.opsForValue().get(key);if (StrUtil.isBlank(nodeIp)){log.info("no user found: {}-{}",identityNo, key);return;}if (NodeConfig.node.equals(nodeIp)){log.info("send from local: {}", identityNo);messageService.sendToUser(dto.getMessageNo(),dto.getModule(),dto.getIdentityNo(),dto.getContent());}else {//转发至其他节点redisTemplate.convertAndSend(Constants.TOPIC_USER, JSON.toJSONString(dto));}}}
}
实体
//发送的消息
@Data
public class MessageDto {private String module;/*** 1、指定用户* 2、全部*/private Integer type;private String messageNo;private String content;private String identityNo;}//消息配置
@Data
public class MessagePushConfigDto {private String module;/*** 1、指定用户* 2、全部*/private Integer type;private String content;private List<String> identityList;}//常量
public interface Constants {int MSG_TYPE_ALL = 1;int MSG_TYPE_SINGLE = 0;String TOPIC_MODULE = "topic:module";String TOPIC_USER = "topic:module:user";String USER_KEY = "socket:module:{0}:userId:{1}";
}
MessageService 发送消息接口
public interface MessageService {/*** 发送组* @param messageNo* @param module* @param content*/void sendToGroup(String messageNo, String module, String content);/*** 单用户发送* @param messageNo* @param module* @param identityNo* @param content*/void sendToUser(String messageNo, String module, String identityNo, String content);
}public class MessageServiceImpl implements MessageService {private SessionRegistry sessionRegistry;public MessageServiceImpl(SessionRegistry sessionRegistry) {this.sessionRegistry = sessionRegistry;}@Overridepublic void sendToGroup(String messageNo, String module, String content) {SessionGroup sessionGroup = sessionRegistry.retrieveGroup(module);if (!Objects.isNull(sessionGroup)){sessionGroup.sendGroup(content);}}@Overridepublic void sendToUser(String messageNo, String module, String identityNo, String content) {WssSession wssSession = sessionRegistry.retrieveSession(module, identityNo);if (!Objects.isNull(wssSession)){wssSession.send(content);}}
}
SessionService
操作 session 服务,并设置 用户到redis
public interface SessionService<WS extends WssSession<C>,C> {/*** 添加session* @param session*/void addSession(WS session);/*** 删除session* @param session*/void removeSession(WS session);}
public abstract class AbstractSessionService<SR extends SessionRegistry<WS, C>, WS extends WssSession<C>, C>implements SessionService<WS, C> {@Getterprivate SR sessionRegistry;public AbstractSessionService(SR sessionRegistry) {this.sessionRegistry = sessionRegistry;}
}public class SessionServiceImpl<SR extends SessionRegistry<WS, C>, WS extends WssSession<C>, C>extends AbstractSessionService<SR, WS, C> {private StringRedisTemplate redisTemplate;public SessionServiceImpl(SR sessionRegistry, StringRedisTemplate redisTemplate) {super(sessionRegistry);this.redisTemplate = redisTemplate;}@Overridepublic void addSession(WS session) {getSessionRegistry().addSession(session);String key = MessageFormat.format(Constants.USER_KEY, session.getModule(), session.getIdentityNo());redisTemplate.opsForValue().set(key, NodeConfig.node);}@Overridepublic void removeSession(WS session) {getSessionRegistry().removeSession(session);String key = MessageFormat.format(Constants.USER_KEY, session.getModule(), session.getIdentityNo());redisTemplate.delete(key);}}
websocket 实现
定义session接口相关
public interface WssSession<C> {/*** 模块* @return*/String getModule();/*** 用户唯一标识* @return*/String getIdentityNo();/*** 通信渠道* @return*/C getChannel();/*** 发送消息* @param message*/void send(String message);}public interface SessionGroup <T extends WssSession<C>, C>{/*** add session* @param session*/void addSession(T session);/*** remove session* @param session*/void removeSession(T session);/*** 发送组数据* @param message*/void sendGroup(String message);/*** 根据唯一标识查询session* @param identityNo* @return*/T getSession(String identityNo);}
public interface SessionRegistry<T extends WssSession<C>, C> {/*** 添加 session** @param session*/void addSession(T session);/*** 移除 session** @param session*/void removeSession(T session);/*** 查询 SessionGroup* @param module* @return*/SessionGroup<T, C> retrieveGroup(String module);/*** 查询 session* @param module* @param identityNo* @return*/T retrieveSession(String module, String identityNo);}public abstract class AbstractSession<C> implements WssSession<C>{private String module;private String identityNo;private C channel;public AbstractSession(String module, String identityNo, C channel) {this.module = module;this.identityNo = identityNo;this.channel = channel;}@Overridepublic String getModule() {return module;}@Overridepublic String getIdentityNo() {return identityNo;}@Overridepublic C getChannel() {return channel;}
}public abstract class AbstractSessionRegistry<T extends WssSession<C>, C> implements SessionRegistry<T, C> {private Map<String, SessionGroup<T, C>> map = new ConcurrentHashMap<>();@Overridepublic void addSession(T session) {SessionGroup<T, C> sessionGroup = map.computeIfAbsent(session.getModule(), key -> newSessionGroup());sessionGroup.addSession(session);}protected abstract SessionGroup<T, C> newSessionGroup();@Overridepublic void removeSession(T session) {SessionGroup<T, C> sessionGroup = map.get(session.getModule());sessionGroup.removeSession(session);}@Overridepublic SessionGroup<T, C> retrieveGroup(String module) {return map.get(module);}@Overridepublic T retrieveSession(String module, String identityNo) {SessionGroup<T, C> sessionGroup = map.get(module);if (sessionGroup != null) {return (T) sessionGroup.getSession(identityNo);}return null;}
}
使用 netty 容器
@Slf4j
@Component
public class NettyServer {private NioEventLoopGroup boss;private NioEventLoopGroup worker;@Value("${namespace:/ns}")private String namespace;@Autowiredprivate SessionService sessionService;@PostConstructpublic void start() {try {boss = new NioEventLoopGroup(1);worker = new NioEventLoopGroup();ServerBootstrap serverBootstrap = new ServerBootstrap();serverBootstrap.group(boss, worker).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<SocketChannel>() {@Overrideprotected void initChannel(SocketChannel ch) throws Exception {ChannelPipeline pipeline = ch.pipeline();pipeline.addLast(new IdleStateHandler(0, 0, 60));pipeline.addLast(new HeartBeatInboundHandler());pipeline.addLast(new HttpServerCodec());pipeline.addLast(new HttpObjectAggregator(64 * 1024));pipeline.addLast(new ChunkedWriteHandler());pipeline.addLast(new HttpRequestInboundHandler(namespace));pipeline.addLast(new WebSocketServerProtocolHandler(namespace, true));pipeline.addLast(new WebSocketHandShakeHandler(sessionService));}});int port = 9999;serverBootstrap.bind(port).addListener((ChannelFutureListener) future -> {if (future.isSuccess()) {log.info("server start at port successfully: {}", port);} else {log.info("server start at port error: {}", port);}}).sync();} catch (InterruptedException e) {log.error("start error", e);close();}}@PreDestroypublic void destroy() {close();}private void close() {log.info("websocket server close..");if (boss != null) {boss.shutdownGracefully();}if (worker != null) {worker.shutdownGracefully();}}}public class NettySessionGroup implements SessionGroup<NWssSession,Channel> {private ChannelGroup group = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);//Map<identityNo,channel>private Map<String, NWssSession> map = new ConcurrentHashMap<>();@Overridepublic void addSession(NWssSession session) {group.add(session.getChannel());map.put(session.getIdentityNo(), session);}@Overridepublic void removeSession(NWssSession session) {group.remove(session.getChannel());map.remove(session.getIdentityNo());}@Overridepublic void sendGroup(String message){group.writeAndFlush(new TextWebSocketFrame(message));}@Overridepublic NWssSession getSession(String identityNo) {return map.get(identityNo);}}public class NettySessionRegistry extends AbstractSessionRegistry<NWssSession, Channel> {@Overrideprotected SessionGroup<NWssSession, Channel> newSessionGroup() {return new NettySessionGroup();}
}public class NWssSession extends AbstractSession<Channel> {public NWssSession(String module, String identityNo, Channel channel) {super(module, identityNo, channel);}@Overridepublic void send(String message) {getChannel().writeAndFlush(new TextWebSocketFrame(message));}
}public class NettyUtil {//参数-module<->user-codepublic static AttributeKey<String> G_U = AttributeKey.valueOf("GU");//参数-uripublic static AttributeKey<String> P = AttributeKey.valueOf("P");/*** 设置上下文参数** @param channel* @param attributeKey* @param data* @param <T>*/public static <T> void setAttr(Channel channel, AttributeKey<T> attributeKey, T data) {Attribute<T> attr = channel.attr(attributeKey);if (attr != null) {attr.set(data);}}/*** 获取上下文参数** @param channel* @param attributeKey* @param <T>* @return*/public static <T> T getAttr(Channel channel, AttributeKey<T> attributeKey) {return channel.attr(attributeKey).get();}/*** 根据 渠道获取 session** @param channel* @return*/public static NWssSession getSession(Channel channel) {String attr = channel.attr(G_U).get();if (StrUtil.isNotBlank(attr)) {String[] split = attr.split(",");String groupId = split[0];String username = split[1];return new NWssSession(groupId, username, channel);}return null;}public static void writeForbiddenRepose(ChannelHandlerContext ctx) {String res = "FORBIDDEN";FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN, Unpooled.wrappedBuffer(res.getBytes(StandardCharsets.UTF_8)));response.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain");response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());ctx.writeAndFlush(response);ctx.close();}}public interface WebSocketListener {void handShakeSuccessful(ChannelHandlerContext ctx, String uri);void handShakeFailed(ChannelHandlerContext ctx,String uri);
}//解析 request uri参数
@Slf4j
public class DefaultWebSocketListener implements WebSocketListener {private static final String G = "module";private static final String U = "userCode";@Overridepublic void handShakeSuccessful(ChannelHandlerContext ctx, String uri) {QueryStringDecoder decoderQuery = new QueryStringDecoder(uri);Map<String, List<String>> params = decoderQuery.parameters();String groupId = getParameter(G, params);String userCode = getParameter(U, params);if (StrUtil.isBlank(groupId) || StrUtil.isBlank(userCode)) {log.info("module or userCode is null: {}", uri);NettyUtil.writeForbiddenRepose(ctx);return;}//传递参数NettyUtil.setAttr(ctx.channel(), NettyUtil.G_U, groupId.concat(",").concat(userCode));}@Overridepublic void handShakeFailed(ChannelHandlerContext ctx, String uri) {log.info("handShakeFailed failed,close channel");ctx.close();}private String getParameter(String key, Map<String, List<String>> params) {if (CollectionUtils.isEmpty(params)) {return null;}List<String> value = params.get(key);if (CollectionUtils.isEmpty(value)) {return null;}return value.get(0);}}
netty handler
//心跳
@Slf4j
public class HeartBeatInboundHandler extends ChannelInboundHandlerAdapter {@Overridepublic void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {if (evt instanceof IdleStateEvent ise){if (ise.state()== IdleState.ALL_IDLE){//关闭连接log.info("HeartBeatInboundHandler heart beat close");ctx.channel().close();return;}}super.userEventTriggered(ctx,evt);}}/*** @Date: 2024/7/17 13:06* 处理 http 协议 的请求参数并传递*/
@Slf4j
public class HttpRequestInboundHandler extends ChannelInboundHandlerAdapter {private String namespace;public HttpRequestInboundHandler(String namespace) {this.namespace = namespace;}@Overridepublic void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {if (msg instanceof FullHttpRequest request) {//ws://localhost:8080/n/ws?groupId=xx&username=tomString requestUri = request.uri();String decode = URLDecoder.decode(requestUri, StandardCharsets.UTF_8);log.info("raw request url: {}", decode);URI uri = new URI(requestUri);if (!uri.getPath().startsWith(namespace)) {NettyUtil.writeForbiddenRepose(ctx);return;}// TODO: 2024/7/17 校验token// 比如从 header中获取token// 构建自定义WebSocket握手处理器, 也可以使用 netty自带 WebSocketServerProtocolHandler//shakeHandsIfNecessary(ctx, request, requestUri);//去掉参数 ===> ws://localhost:8080/n/ws//传递参数NettyUtil.setAttr(ctx.channel(), NettyUtil.P, requestUri);request.setUri(namespace);ctx.pipeline().remove(this);ctx.fireChannelRead(request);}}
/*private void shakeHandsIfNecessary(ChannelHandlerContext ctx, FullHttpRequest request, String requestUri) {WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(request), null, true);WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(request);if (handshaker == null) {// 如果不支持WebSocket版本,返回HTTP 405错误WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());} else {ChannelPipeline pipeline = ctx.channel().pipeline();handshaker.handshake(ctx.channel(), request).addListener((ChannelFutureListener) future -> {if (future.isSuccess()) {//握手成功 WebSocketListener listenerlistener.handShakeSuccessful(ctx, requestUri);} else {//握手失败listener.handShakeFailed(ctx, requestUri);}});}}private String getWebSocketLocation(FullHttpRequest req) {return "ws://" + req.headers().get(HttpHeaderNames.HOST) + prefix;}*/
}@Slf4j
public class WebSocketBizHandler extends SimpleChannelInboundHandler<WebSocketFrame> {private SessionService sessionService;public WebSocketBizHandler(SessionService sessionService){this.sessionService = sessionService;}@Overridepublic void handlerAdded(ChannelHandlerContext ctx) throws Exception {log.info("handlerAdded");NWssSession session = NettyUtil.getSession(ctx.channel());if (session == null) {log.info("session is null: {}", ctx.channel().id());NettyUtil.writeForbiddenRepose(ctx);return;}sessionService.addSession(session);}@Overridepublic void handlerRemoved(ChannelHandlerContext ctx) throws Exception {log.info("handlerRemoved");NWssSession session = NettyUtil.getSession(ctx.channel());if (session == null) {log.info("session is null: {}", ctx.channel().id());return;}sessionService.removeSession(session);}@Overrideprotected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception {if (msg instanceof TextWebSocketFrame) {} else if (msg instanceof BinaryWebSocketFrame) {} else if (msg instanceof PingWebSocketFrame) {} else if (msg instanceof PongWebSocketFrame) {} else if (msg instanceof CloseWebSocketFrame) {if (ctx.channel().isActive()) {ctx.close();}}ctx.writeAndFlush(new TextWebSocketFrame("默认回复"));}@Overridepublic void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {//处理最后的业务异常log.info("WebSocketBizHandler error: ", cause);}
}//处理websocket协议握手
@Slf4j
public class WebSocketHandShakeHandler extends ChannelInboundHandlerAdapter {private SessionService sessionService;private WebSocketListener webSocketListener = new DefaultWebSocketListener();public WebSocketHandShakeHandler(SessionService sessionService) {this.sessionService = sessionService;}@Overridepublic void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {log.info("WebSocketHandShakeHandler shake-hands success");// 在此处获取URL、Headers等信息并做校验,通过throw异常来中断链接。String uri = NettyUtil.getAttr(ctx.channel(), NettyUtil.P);if (StrUtil.isBlank(uri)) {log.info("request uri is null");NettyUtil.writeForbiddenRepose(ctx);return;}webSocketListener.handShakeSuccessful(ctx, uri);ChannelPipeline pipeline = ctx.channel().pipeline();pipeline.addLast(new WebSocketBizHandler(sessionService));pipeline.remove(this);return;}super.userEventTriggered(ctx, evt);}@Overridepublic void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {if (cause instanceof WebSocketHandshakeException) {//只处理 websocket 握手相关异常FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.BAD_REQUEST,Unpooled.wrappedBuffer(cause.getMessage().getBytes()));ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);return;}super.exceptionCaught(ctx,cause);}}
配置
@Component
public class NodeConfig {public static String node;@PostConstructpublic void init() {String localhostStr = NetUtil.getLocalhostStr();NodeConfig.node = localhostStr;Assert.notNull(NodeConfig.node, "local ip is null");}
}@Slf4j
@Configuration
public class RedisPublishConfig {@Beanpublic RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory, MessageListener messageListener) {RedisMessageListenerContainer container = new RedisMessageListenerContainer();container.setConnectionFactory(connectionFactory);List<PatternTopic> topicList = new ArrayList<>();topicList.add(new PatternTopic(Constants.TOPIC_USER));topicList.add(new PatternTopic(Constants.TOPIC_MODULE));container.addMessageListener(messageListener, topicList);log.info("RedisMessageListenerContainer listen topic: {}", Constants.TOPIC_USER);return container;}}@Slf4j
@Component
public class RedisPublisherListener implements MessageListener {@Autowiredprivate RedisPublisherConsumer messageService;@Overridepublic void onMessage(Message message, byte[] pattern) {try {String topic = new String(pattern);String msg = new String(message.getBody(), "utf-8");log.info("recv topic:{}, msg: {}", topic, msg);messageService.consume(topic, msg);} catch (UnsupportedEncodingException e) {log.error("recv msg error: {}", new String(pattern), e);}}
}@Configuration
public class WebSocketConfig {@Beanpublic NettySessionRegistry sessionRegistry() {return new NettySessionRegistry();}@Beanpublic SessionService<NWssSession, Channel> sessionService(StringRedisTemplate redisTemplate) {return new SessionServiceImpl<>(sessionRegistry(), redisTemplate);}@Beanpublic MessageService messageService() {return new MessageServiceImpl(sessionRegistry());}}
good luck!