websocket服务端增加定时发送心跳机制
@ServerEndpoint ( value = "/websocket/{uuid}" )
@Component
public class DevMessageHandleController {
private static final Logger logger = LoggerFactory . getLogger ( DevMessageHandleController . class ) ;
public static CopyOnWriteArraySet < DevMessageHandleController > webSocketSet = new CopyOnWriteArraySet < > ( ) ;
private static ConcurrentHashMap < String , DevMessageHandleController > webSocketMap = new ConcurrentHashMap < > ( ) ;
private Session session;
private String uuid;
private AtomicInteger heartbeatAttempts;
@OnOpen
public void onOpen ( @PathParam ( "uuid" ) String uuid, Session session) {
logger. info ( "uuid: {}, sessionId: {}" , uuid, session. getId ( ) ) ;
try {
if ( webSocketMap. containsKey ( uuid) ) {
webSocketMap. get ( uuid) . session. close ( ) ;
webSocketSet. remove ( webSocketMap. get ( uuid) ) ;
}
this . session = session;
this . uuid = uuid;
heartbeatAttempts = new AtomicInteger ( 0 ) ;
webSocketSet. add ( this ) ;
webSocketMap. put ( uuid, this ) ;
} catch ( Exception e) {
logger. error ( "onOpen error:" + e. getMessage ( ) ) ;
}
}
@OnClose
public void onClose ( @PathParam ( "uuid" ) String uuid, Session session) {
logger. info ( "会话关闭" ) ;
webSocketSet. remove ( this ) ;
webSocketMap. remove ( uuid) ;
}
@OnMessage
public void onMessage ( String message, Session session) {
logger. info ( "Message from client: " + message) ;
if ( "pong" . equals ( message) ) {
this . heartbeatAttempts. set ( 0 ) ;
System . out. println ( "Received pong from: " + session. getId ( ) ) ;
}
}
@OnError
public void onError ( Session session, Throwable error) {
logger. error ( "发生错误 session:" + session. getId ( ) + ",error:" + error) ;
try {
session. close ( ) ;
webSocketSet. remove ( this ) ;
webSocketMap. remove ( this . uuid) ;
} catch ( IOException e) {
logger. error ( "onError error:" + e. getMessage ( ) ) ;
}
}
public void sendMessage ( Session session, String msg) {
logger. info ( "发送消息" ) ;
try {
if ( session. isOpen ( ) ) {
session. getAsyncRemote ( ) . sendText ( msg) ;
} else {
session. close ( ) ;
webSocketSet. remove ( this ) ;
webSocketMap. remove ( this . uuid) ;
}
} catch ( IOException e) {
e. printStackTrace ( ) ;
}
}
public static CopyOnWriteArraySet < DevMessageHandleController > getWebSocketSet ( ) {
return webSocketSet;
}
public static void setWebSocketSet ( CopyOnWriteArraySet < DevMessageHandleController > webSocketSet) {
DevMessageHandleController . webSocketSet = webSocketSet;
}
public static ConcurrentHashMap < String , DevMessageHandleController > getWebSocketMap ( ) {
return webSocketMap;
}
public static void setWebSocketMap ( ConcurrentHashMap < String , DevMessageHandleController > webSocketMap) {
DevMessageHandleController . webSocketMap = webSocketMap;
}
public Session getSession ( ) {
return session;
}
public void setSession ( Session session) {
this . session = session;
}
public String getUuid ( ) {
return uuid;
}
public void setUuid ( String uuid) {
this . uuid = uuid;
}
public AtomicInteger getHeartbeatAttempts ( ) {
return heartbeatAttempts;
}
public void setHeartbeatAttempts ( AtomicInteger heartbeatAttempts) {
this . heartbeatAttempts = heartbeatAttempts;
}
}
每间隔10s向客户端发送一次心跳
private static final int MAX_HEARTBEAT_ATTEMPTS = 3 ;
@Scheduled ( fixedDelay = 10000 )
public void sendHeartBeat ( ) {
CopyOnWriteArraySet < DevMessageHandleController > webSocketSet;
try {
webSocketSet = DevMessageHandleController . getWebSocketSet ( ) ;
logger. info ( "连接数量:" + webSocketSet. size ( ) ) ;
if ( webSocketSet. size ( ) == 0 ) {
return ;
}
logger. info ( "定时发送心跳" ) ;
webSocketSet. forEach ( obj -> {
Session session = obj. getSession ( ) ;
logger. info ( "sessionId:" + session. getId ( ) + " 心跳ping发送次数:" + obj. getHeartbeatAttempts ( ) . get ( ) ) ;
if ( obj. getHeartbeatAttempts ( ) . get ( ) >= MAX_HEARTBEAT_ATTEMPTS ) {
try {
session. close ( ) ;
} catch ( IOException e) {
e. printStackTrace ( ) ;
logger. error ( "session close error:" + e. getMessage ( ) ) ;
}
} else {
obj. getHeartbeatAttempts ( ) . incrementAndGet ( ) ;
if ( session. isOpen ( ) ) {
session. getAsyncRemote ( ) . sendText ( "ping" ) ;
}
}
} ) ;
} catch ( Exception e) {
logger. error ( "发送心跳 error:" + e. getMessage ( ) ) ;
}
}