websocket服务端增加定时发送心跳机制
@ServerEndpoint ( value = "/websocket/{uuid}" )
@Component
public class DevMessageHandleController { private static final Logger logger = LoggerFactory . getLogger ( DevMessageHandleController . class ) ; public static CopyOnWriteArraySet < DevMessageHandleController > webSocketSet = new CopyOnWriteArraySet < > ( ) ; private static ConcurrentHashMap < String , DevMessageHandleController > webSocketMap = new ConcurrentHashMap < > ( ) ; private Session session; private String uuid; private AtomicInteger heartbeatAttempts; @OnOpen public void onOpen ( @PathParam ( "uuid" ) String uuid, Session session) { logger. info ( "uuid: {}, sessionId: {}" , uuid, session. getId ( ) ) ; try { if ( webSocketMap. containsKey ( uuid) ) { webSocketMap. get ( uuid) . session. close ( ) ; webSocketSet. remove ( webSocketMap. get ( uuid) ) ; } this . session = session; this . uuid = uuid; heartbeatAttempts = new AtomicInteger ( 0 ) ; webSocketSet. add ( this ) ; webSocketMap. put ( uuid, this ) ; } catch ( Exception e) { logger. error ( "onOpen error:" + e. getMessage ( ) ) ; } } @OnClose public void onClose ( @PathParam ( "uuid" ) String uuid, Session session) { logger. info ( "会话关闭" ) ; webSocketSet. remove ( this ) ; webSocketMap. remove ( uuid) ; } @OnMessage public void onMessage ( String message, Session session) { logger. info ( "Message from client: " + message) ; if ( "pong" . equals ( message) ) { this . heartbeatAttempts. set ( 0 ) ; System . out. println ( "Received pong from: " + session. getId ( ) ) ; } } @OnError public void onError ( Session session, Throwable error) { logger. error ( "发生错误 session:" + session. getId ( ) + ",error:" + error) ; try { session. close ( ) ; webSocketSet. remove ( this ) ; webSocketMap. remove ( this . uuid) ; } catch ( IOException e) { logger. error ( "onError error:" + e. getMessage ( ) ) ; } } public void sendMessage ( Session session, String msg) { logger. info ( "发送消息" ) ; try { if ( session. isOpen ( ) ) { session. getAsyncRemote ( ) . sendText ( msg) ; } else { session. close ( ) ; webSocketSet. remove ( this ) ; webSocketMap. remove ( this . uuid) ; } } catch ( IOException e) { e. printStackTrace ( ) ; } } public static CopyOnWriteArraySet < DevMessageHandleController > getWebSocketSet ( ) { return webSocketSet; } public static void setWebSocketSet ( CopyOnWriteArraySet < DevMessageHandleController > webSocketSet) { DevMessageHandleController . webSocketSet = webSocketSet; } public static ConcurrentHashMap < String , DevMessageHandleController > getWebSocketMap ( ) { return webSocketMap; } public static void setWebSocketMap ( ConcurrentHashMap < String , DevMessageHandleController > webSocketMap) { DevMessageHandleController . webSocketMap = webSocketMap; } public Session getSession ( ) { return session; } public void setSession ( Session session) { this . session = session; } public String getUuid ( ) { return uuid; } public void setUuid ( String uuid) { this . uuid = uuid; } public AtomicInteger getHeartbeatAttempts ( ) { return heartbeatAttempts; } public void setHeartbeatAttempts ( AtomicInteger heartbeatAttempts) { this . heartbeatAttempts = heartbeatAttempts; }
}
每间隔10s向客户端发送一次心跳
private static final int MAX_HEARTBEAT_ATTEMPTS = 3 ; @Scheduled ( fixedDelay = 10000 ) public void sendHeartBeat ( ) { CopyOnWriteArraySet < DevMessageHandleController > webSocketSet; try { webSocketSet = DevMessageHandleController . getWebSocketSet ( ) ; logger. info ( "连接数量:" + webSocketSet. size ( ) ) ; if ( webSocketSet. size ( ) == 0 ) { return ; } logger. info ( "定时发送心跳" ) ; webSocketSet. forEach ( obj -> { Session session = obj. getSession ( ) ; logger. info ( "sessionId:" + session. getId ( ) + " 心跳ping发送次数:" + obj. getHeartbeatAttempts ( ) . get ( ) ) ; if ( obj. getHeartbeatAttempts ( ) . get ( ) >= MAX_HEARTBEAT_ATTEMPTS ) { try { session. close ( ) ; } catch ( IOException e) { e. printStackTrace ( ) ; logger. error ( "session close error:" + e. getMessage ( ) ) ; } } else { obj. getHeartbeatAttempts ( ) . incrementAndGet ( ) ; if ( session. isOpen ( ) ) { session. getAsyncRemote ( ) . sendText ( "ping" ) ; } } } ) ; } catch ( Exception e) { logger. error ( "发送心跳 error:" + e. getMessage ( ) ) ; } }