package cc.mrbird.febs.websocket; import cn.hutool.core.util.StrUtil; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import java.io.IOException; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; /** * @date 2020-09-01 **/ public class WsSessionManager { private static final ConcurrentHashMap SESSIONS = new ConcurrentHashMap<>(); public static void add(String key, WebSocketSession session) { SESSIONS.put(key, session); } public static WebSocketSession remove(String key) { return SESSIONS.remove(key); } public static void removeAndClose(String key) { WebSocketSession session = remove(key); if (session != null) { try { // 关闭连接 session.close(); } catch (IOException e) { // todo: 关闭出现异常处理 e.printStackTrace(); } } } public static WebSocketSession get(String key) { // 获得 session return SESSIONS.get(key); } /** * 发送消息 * * @param key 用户手机号 * @param msg 消息 */ public static void sendMsgToOne(String key, String msg) { TextMessage textMessage = new TextMessage(msg); try { if (SESSIONS.containsKey(key)) { SESSIONS.get(key).sendMessage(textMessage); } } catch (IOException e) { e.printStackTrace(); } } /** * 批量发送 * * @param keys 手机号集合, 逗号隔开 * @param msg 消息 */ public static void sendMsgToMany(String keys, String msg) { TextMessage textMessage = new TextMessage(msg); List keyList = StrUtil.splitTrim(keys, ","); for (Map.Entry entry : SESSIONS.entrySet()) { if (keyList.contains(entry.getKey())) { try { entry.getValue().sendMessage(textMessage); } catch (IOException e) { e.printStackTrace(); } } } } }