实现websocket的方式
1.springboot中有两种方式实现websocket,一种是基于原生的基于注解的websocket,另一种是基于spring封装后的WebSocketHandler
基于原生注解实现websocket
1)先引入websocket的starter坐标
<dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency>
2)编写websocket的Endpoint端点类
@ServerEndpoint(value = "/ws/{token}")
@Component
public class WebsocketHandler2 {private final static Logger log = LoggerFactory.getLogger(WebsocketHandler2.class);private static final Set<Session> SESSIONS = new ConcurrentSkipListSet<>(Comparator.comparing(Session::getId));private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();@OnOpenpublic void onOpen(Session session, @PathParam("token") String token, EndpointConfig config) throws IOException {
// session.addMessageHandler(new PongHandler());ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);String queryString = session.getQueryString();futures.put(session.getId(), future);session.setMaxIdleTimeout(6 * 1000);SESSIONS.add(session);log.info("open connect sessionId={}, token={}, queryParam={}", session.getId(), token, queryString);String s = String.format("ws client(id=%s) has connected", session.getId());session.getBasicRemote().sendText(s);}static class PongHandler implements MessageHandler.Whole<PongMessage> {@Overridepublic void onMessage(PongMessage message) {ByteBuffer data = message.getApplicationData();String s = new String(data.array(), StandardCharsets.UTF_8);log.info("receive pong msg=> {}", s);}}@OnClosepublic void onClose(Session session, CloseReason reason) {log.info("session(id={}) close ,closeCode={},closeParse={}", session.getId(), reason.getCloseCode(), reason.getReasonPhrase());SESSIONS.remove(session);ScheduledFuture<?> future = futures.get(session.getId());if (future != null) {future.cancel(true);}}@OnMessagepublic void onMessage(String message, Session session) throws IOException {log.info("receive client(id={}) msg=>{}", session.getId(), message);String s = String.format("reply your(id=%s) msg=>【%s】", session.getId(), message);session.getBasicRemote().sendText(s);}@OnMessagepublic void onPong(PongMessage message, Session session) throws IOException {ByteBuffer data = message.getApplicationData();String s = new String(data.array(), StandardCharsets.UTF_8);log.info("receive client(id={}) pong msg=> {}", session.getId(), s);}@OnErrorpublic void onError(Session session, Throwable error) {log.error("Session(id={}) error occur ", session.getId(), error);}private void sendPing(Session session) {if (session.isOpen()) {String replyContent = String.format("Hello,client(id=%s)", session.getId());try {session.getBasicRemote().sendPing(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));} catch (IOException e) {log.error("ping client(id={}) error", session.getId(), e);}return;}SESSIONS.remove(session);ScheduledFuture<?> future = futures.remove(session.getId());if (future != null) {future.cancel(true);}}
}
注解说明
@ServerEndpoint
标记这个是一个服务端的端点类
@OnOpen
标记此方法是建立websocket连接时的回调方法
@OnMessage
标记此方法是接收到客户端消息时的回调方法
@OnClose
标记此方法是断开websocke连接时的回调方法
@OnError
标记此方法是websocke发生异常时的回调方法
@PathParam
可以获取@ServerEndpoint
注解中绑定的路径模板参数
方法参数说明
1) onOpen方法参数
onOpen的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnOpenArgs
可以看到
public Object[] getOnOpenArgs(Map<String,String> pathParameters,Session session, EndpointConfig config) throws DecodeException {return buildArgs(onOpenParams, pathParameters, session, config, null,null);}
因此可以看出@OnOpen
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)当前endpoint的配置详情EndpointConfig参数
2) onClose方法参数
onClose的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnCloseArgs
可以看到
public Object[] getOnCloseArgs(Map<String,String> pathParameters,Session session, CloseReason closeReason) throws DecodeException {return buildArgs(onCloseParams, pathParameters, session, null, null,closeReason);}
因此可以看出@OnClose
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)当前连接关闭的原因CloseReason参数
3) onError方法参数
onError的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnErrorArgs
可以看到
public Object[] getOnErrorArgs(Map<String,String> pathParameters,Session session, Throwable throwable) throws DecodeException {return buildArgs(onErrorParams, pathParameters, session, null,throwable, null);}
因此可以看出@OnError
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)发生异常的异常对象Throwable参数
4) onMessage方法参数
onMessage的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping.MessageHandlerInfo#getMessageHandlers
可以看到
public Set<MessageHandler> getMessageHandlers(Object pojo,Map<String,String> pathParameters, Session session,EndpointConfig config) {Object[] params = new Object[m.getParameterTypes().length];for (Map.Entry<Integer,PojoPathParam> entry :indexPathParams.entrySet()) {PojoPathParam pathParam = entry.getValue();String valueString = pathParameters.get(pathParam.getName());Object value = null;try {value = Util.coerceToType(pathParam.getType(), valueString);} catch (Exception e) {DecodeException de = new DecodeException(valueString,sm.getString("pojoMethodMapping.decodePathParamFail",valueString, pathParam.getType()), e);params = new Object[] { de };break;}params[entry.getKey().intValue()] = value;}Set<MessageHandler> results = new HashSet<>(2);if (indexBoolean == -1) {// Basicif (indexString != -1 || indexPrimitive != -1) {MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,session, config, null, params, indexPayload, false,indexSession, maxMessageSize);results.add(mh);} else if (indexReader != -1) {MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,session, config, null, params, indexReader, true,indexSession, maxMessageSize);results.add(mh);} else if (indexByteArray != -1) {MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,m, session, config, null, params, indexByteArray,true, indexSession, false, maxMessageSize);results.add(mh);} else if (indexByteBuffer != -1) {MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,m, session, config, null, params, indexByteBuffer,false, indexSession, false, maxMessageSize);results.add(mh);} else if (indexInputStream != -1) {MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,m, session, config, null, params, indexInputStream,true, indexSession, true, maxMessageSize);results.add(mh);} else if (decoderMatch != null && decoderMatch.hasMatches()) {if (decoderMatch.getBinaryDecoders().size() > 0) {MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, m, session, config,decoderMatch.getBinaryDecoders(), params,indexPayload, true, indexSession, true,maxMessageSize);results.add(mh);}if (decoderMatch.getTextDecoders().size() > 0) {MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, session, config,decoderMatch.getTextDecoders(), params,indexPayload, true, indexSession, maxMessageSize);results.add(mh);}} else {MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m,session, params, indexPong, false, indexSession);results.add(mh);}} else {// ASyncif (indexString != -1) {MessageHandler mh = new PojoMessageHandlerPartialText(pojo,m, session, params, indexString, false,indexBoolean, indexSession, maxMessageSize);results.add(mh);} else if (indexByteArray != -1) {MessageHandler mh = new PojoMessageHandlerPartialBinary(pojo, m, session, params, indexByteArray, true,indexBoolean, indexSession, maxMessageSize);results.add(mh);} else {MessageHandler mh = new PojoMessageHandlerPartialBinary(pojo, m, session, params, indexByteBuffer, false,indexBoolean, indexSession, maxMessageSize);results.add(mh);}}return results;}}
因此可以看出@OnMessage
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)当数据是分块传输时,表示当前消息时是否是最后一块数据的boolean Boolean参数
(4)字符输入流Reader参数
(5)二进制输入流InputStream参数
(6)原始的ByteBuffer参数
(7)字节数组byte[]参数
(8)字符串string参数
(9) Pong响应PongMessage参数
注意:接收数据报文的参数(4)~(5),只能使用其中的一个,否则可能导致IO异常(IO流只能读取一次)
ping和pong
上面的代码中我额外给websocket会话增加了一个PongMessage
的处理方法onPong
,它的作用是接收客户端的pong回执消息。只有在服务端向客户端发送Ping
请求时,服务端才能接收到Pong响应。这里的ping和pong就是类型于其他系统中的心跳机制,用来检测客户端、服务端双方是否还在线,如果超过了限定时间没有收到ping
和 pong
消息,服务端就会主动断开连接。
因此我在建立websocke连接的时候给当前回话设置了最大空闲时间(超过这个时间没有数据报文传输,此连接就会自动断开),同时绑定了一个定时任务,这个定时任务会定时发送ping消息来保活。
这里的onPong
方法不是必须的,没有它能保活,onPong
只是用来得到一个ping结果的通知。
3)注册暴露端点
@Configuration
@EnableWebSocket
public class WebsocketConfig {@Beanpublic ServerEndpointExporter serverEndpointExporter(){ServerEndpointExporter serverEndpointExporter = new ServerEndpointExporter();//WebsocketHandler2如果是一个spring bean(即有@Component),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测有@ServerEndpoint注解的bean//WebsocketHandler2如果只是一个包含@ServerEndpoint注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型
// serverEndpointExporter.setAnnotatedEndpointClasses(WebsocketHandler2.class );return serverEndpointExporter;}
}
配置类添加@EnableWebSocket
,启用spring websocket功能.
另外还需配置一个Bean ServerEndpointExporter
;如果Endpoint类是一个spring bean(即有@Component
),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测含有@ServerEndpoint
注解的Bean;如果Endpoint类只是一个包含@ServerEndpoint
注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型。
注意:即使Endpoint类是spring bean ,WebsocketContainer也会再创建并使用这个类的一个新实例,也就是说这个Endpoint中不能使用spring相关的功能,典型的就是不能使用@Autowire
等注解自动注入Bean。其原因是websocket的默认端点配置org.apache.tomcat.websocket.server.DefaultServerEndpointConfigurator
获取endpoint实例的逻辑是反射调用构造方法去创建一个新对象
public class DefaultServerEndpointConfiguratorextends ServerEndpointConfig.Configurator {@Overridepublic <T> T getEndpointInstance(Class<T> clazz)throws InstantiationException {try {return clazz.getConstructor().newInstance();} catch (InstantiationException e) {throw e;} catch (ReflectiveOperationException e) {InstantiationException ie = new InstantiationException();ie.initCause(e);throw ie;}}
当然你可以通过注入静态属性的方式来绕过这个限制。
理论上说也可在@ServerEndpoint
注解的configurator属性指定为spring的org.springframework.web.socket.server.standard.SpringConfigurator
也可以自动注入Bean依赖.
@ServerEndpoint(value = "/echo", configurator = SpringConfigurator.class)public class EchoEndpoint {// ...}
SpringConfigurator它重写了获取Endpoint实例的方法逻辑getEndpointInstance
,它是直接到spring容器中去取这个bean,而不是创建一个新实例.
但实际在spring boot项目中,上面的getEndpointInstance
方法获取到的WebApplicationContext
是null
,也就没法从spring容器中获取这个Endpoint bean
基于spring WebSocketHandler实现websocket
了解WebSocketHandler
提前引入前面提到的websocket的starter 依赖
WebSocketHandler接口定义了5个方法,
afterConnectionEstablished
:建立连接后的回调方法
handleMessage
:接收到客户端消息后的回调方法
handleTransportError
: 数据传输异常时的回调方法
afterConnectionClosed
: 连接关闭后的回调方法
supportsPartialMessages
: 是否支持数据分块传输(最后一个分块传输,isLast是true)
它有两个主要的子类, 一个是处理纯文本数据的TextWebSocketHandler
,另一个是处理二进制数据的BinaryWebSocketHandler
。
我们实现websocket一般是继承这两个类,并重写相应的方法。一般都需要重写afterConnectionEstablished
handleTransportError
handleTransportError
afterConnectionClosed
这三个方法,除此之外,处理文本还要重写接收客户端消息后的回调方法handleTextMessage
,处理二进制数据需要重写接收客户端消息后的回调方法handleBinaryMessage
。如果有需要得到ping结果回调,还可以重写handlePongMessage
方法
代码
@Component
public class WebsocketHandler1 extends TextWebSocketHandler {private final Logger log = LoggerFactory.getLogger(getClass());private static final Set<WebSocketSession> sessions = new ConcurrentSkipListSet<>(Comparator.comparing(WebSocketSession::getId));private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();@Overridepublic void afterConnectionEstablished(WebSocketSession session) throws Exception {@SuppressWarnings("unchecked")AbstractWebSocketSession<Session> standardSession = (AbstractWebSocketSession) session;Session nativeSession = standardSession.getNativeSession();nativeSession.setMaxIdleTimeout(1000*4);ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);futures.put(session.getId(), future);log.info("open connect sessionId={}", session.getId());sessions.add(session);TextMessage msg = new TextMessage(String.format("ws client(id=%s) has connected", session.getId()));session.sendMessage(msg);}@Overrideprotected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {log.info("receive client(id={}) msg=>{}", session.getId(), message.getPayload());TextMessage msg = new TextMessage(String.format("reply your(id=%s) msg=>%s", session.getId(), message.getPayload()));session.sendMessage(msg);}@Overrideprotected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {ByteBuffer payload = message.getPayload();String s = new String(payload.array());log.info("receive client(id={}) pong msg=>{}", session.getId(),s);}@Overridepublic void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {log.error("client(id={}) error occur ", session.getId(), exception);}@Overridepublic void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {log.info("close ,status={}", status);sessions.remove(session);ScheduledFuture<?> future = futures.get(session.getId());if (future != null) {future.cancel(true);}}@Overridepublic boolean supportsPartialMessages() {return true;}private void sendPing(WebSocketSession session) {if (session.isOpen()) {String replyContent = String.format("Hello,client(id=%s)", session.getId());PingMessage msg = new PingMessage(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));try {session.sendMessage(msg);} catch (IOException e) {log.error("ping client(id={}) error", session.getId(), e);}}}
}
springboot内置的WebSocketHandler
前端html代码
<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>WebSocket Chat</title><style>body {font-family: Arial, sans-serif;}#chat-box {width: 100%;height: 300px;border: 1px solid #ccc;overflow-y: auto;padding: 10px;margin-bottom: 10px;background-color: #f9f9f9;white-space: pre-wrap;}#input-box {width: calc(50% - 90px);padding: 10px;margin-right: 10px;display: flex;justify-content: center}.btn {padding: 10px;}#btn-container {margin: 1px;display: flex;justify-content: center;gap: 5px;}#input-container {margin: 1px;display: flex;justify-content: center;gap: 5px;}</style>
</head>
<body><div id="chat-box"></div>
<div id="input-container"><input type="text" id="input-box" placeholder="Enter your message"/>
</div><div id="btn-container"><button id="connect-button" class="btn">Connect</button><button id="close-button" class="btn">Close</button><button id="clear-button" class="btn">Clear</button><button id="send-button" class="btn">Send</button>
</div><script>const chatBox = document.getElementById('chat-box');const inputBox = document.getElementById('input-box');const sendButton = document.getElementById('send-button');const connectBtn = document.getElementById('connect-button');const closeBtn = document.getElementById('close-button');const clearBtn = document.getElementById('clear-button');let ws = null;sendButton.addEventListener('click', () => {if (ws === null) {alert("no connect")return;}const message = inputBox.value;if (message) {ws.send(message);chatBox.innerHTML += 'You: ' + message + '\n';chatBox.scrollTop = chatBox.scrollHeight;inputBox.value = '';}});clearBtn.addEventListener('click', () => {chatBox.innerHTML = '';});closeBtn.addEventListener('click', () => {if (ws === null) {alert("no connect")return;}console.log("prepare close ws");ws.close(1000, 'Normal closure');});connectBtn.addEventListener('click', () => {if (ws !== null) {alert("already connected!")return;}let curWs = new WebSocket('ws://localhost:7001/ws/Hews2df?id=323&color=red');curWs.onopen = event => {ws = curWs;console.log('Connected to WebSocket server, event=>%s', JSON.stringify(event));};curWs.onmessage = event => {const message = event.data;chatBox.innerHTML += 'Server: ' + message + '\n';chatBox.scrollTop = chatBox.scrollHeight;};curWs.onclose = event => {ws = null;console.log('Disconnected from WebSocket server, close code=%s,close reason=%s', event.code, event.reason);};curWs.onerror = event => {console.log("error occur, event=>%s", JSON.stringify(event))};});</script></body>
</html>
演示效果