Skip to content

FastAPI WebSocket实战:从实时聊天室到生产级部署

WebSocket是实时Web应用的基石。当你的应用需要推送通知、实时聊天或协作编辑时,WebSocket提供了比HTTP轮询更高效的双向通信方案。本文将带你从WebSocket基础到生产级部署,全面掌握FastAPI中的WebSocket开发。

引言:为什么需要WebSocket?

2023年,全球实时通信市场规模已达200亿美元。我曾在参与一个在线协作平台项目时,亲眼见证了从HTTP轮询切换到WebSocket带来的变化:

  • API请求量从每分钟10万次降至5000次
  • 服务器负载降低60%
  • 消息延迟从3秒降至50毫秒

WebSocket协议解决了HTTP在实时通信中的根本问题:持久连接和双向通信。相比HTTP的请求-响应模式,WebSocket允许服务器主动推送数据,极大地提升了实时应用的性能和用户体验。

1. WebSocket协议基础

HTTP vs WebSocket:握手过程对比

python
# HTTP请求-响应(每次都需要建立连接)
import aiohttp

async def http_polling():
    """HTTP轮询示例(低效但简单)"""
    async with aiohttp.ClientSession() as session:
        while True:
            # 每次请求都需要建立新连接
            async with session.get('http://api.example.com/messages') as resp:
                messages = await resp.json()
                process_messages(messages)
            await asyncio.sleep(1)  # 延迟和浪费资源

# WebSocket(一次连接,双向通信)
async def websocket_client():
    """WebSocket客户端示例"""
    async with websockets.connect('ws://api.example.com/ws') as websocket:
        # 一次握手,持久连接
        await websocket.send('Hello Server!')

        # 可以随时接收服务器推送
        async for message in websocket:
            process_message(message)

        # 也可以随时发送消息
        await websocket.send('Another message')

WebSocket握手协议详解

python
# WebSocket握手过程(RFC 6455)
"""
客户端请求(HTTP升级):
GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13
Sec-WebSocket-Protocol: chat, superchat

服务器响应:
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
"""

# 手动实现WebSocket握手验证
import base64
import hashlib

def generate_websocket_accept(key: str) -> str:
    """生成WebSocket Accept头"""
    GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
    accept_key = key + GUID
    accept_key = hashlib.sha1(accept_key.encode()).digest()
    return base64.b64encode(accept_key).decode()

# 验证示例
client_key = "dGhlIHNhbXBsZSBub25jZQ=="
server_accept = generate_websocket_accept(client_key)
print(f"Sec-WebSocket-Accept: {server_accept}")
# 输出: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=

WebSocket帧结构

python
# WebSocket数据帧结构(简化版)
"""
 0                   1                   2                   3
 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len |    Extended payload length    |
|I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
|N|V|V|V|       |S|             |   (if payload len==126/127)   |
| |1|2|3|       |K|             |                               |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|     Extended payload length continued, if payload len == 127  |
+ - - - - - - - - - - - - - - - +-------------------------------+
|                               |Masking-key, if MASK set to 1  |
+-------------------------------+-------------------------------+
| Masking-key (continued)       |          Payload Data         |
+-------------------------------- - - - - - - - - - - - - - - - +
:                     Payload Data continued ...                :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|                     Payload Data continued ...                |
+---------------------------------------------------------------+
"""

# WebSocket操作码(opcode)枚举
class WebSocketOpcode:
    """WebSocket帧操作码"""
    CONTINUATION = 0x0  # 延续帧
    TEXT = 0x1          # 文本帧
    BINARY = 0x2        # 二进制帧
    CLOSE = 0x8         # 关闭连接
    PING = 0x9          # Ping帧
    PONG = 0xA          # Pong帧

# WebSocket关闭状态码
class WebSocketCloseCode:
    """WebSocket关闭状态码"""
    NORMAL_CLOSURE = 1000      # 正常关闭
    GOING_AWAY = 1001          # 服务端/客户端正在离开
    PROTOCOL_ERROR = 1002      # 协议错误
    UNSUPPORTED_DATA = 1003    # 不支持的数据类型
    NO_STATUS_RCVD = 1005      # 未收到状态码
    ABNORMAL_CLOSURE = 1006    # 异常关闭
    INVALID_PAYLOAD = 1007     # 无效负载数据
    POLICY_VIOLATION = 1008    # 策略违规
    MESSAGE_TOO_BIG = 1009     # 消息太大
    MANDATORY_EXT = 1010       # 需要扩展
    INTERNAL_ERROR = 1011      # 服务器内部错误
    SERVICE_RESTART = 1012     # 服务重启
    TRY_AGAIN_LATER = 1013     # 稍后重试
    TLS_HANDSHAKE = 1015       # TLS握手失败

2. FastAPI WebSocket端点

基础WebSocket端点

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import uvicorn

app = FastAPI()

# 简单HTML页面用于测试
html = """
<!DOCTYPE html>
<html>
    <head>
        <title>WebSocket Test</title>
    </head>
    <body>
        <h1>WebSocket Test</h1>
        <form action="" onsubmit="sendMessage(event)">
            <input type="text" id="messageText" autocomplete="off"/>
            <button>Send</button>
        </form>
        <ul id='messages'>
        </ul>
        <script>
            var ws = new WebSocket("ws://localhost:8000/ws");
            ws.onmessage = function(event) {
                var messages = document.getElementById('messages')
                var message = document.createElement('li')
                var content = document.createTextNode(event.data)
                message.appendChild(content)
                messages.appendChild(message)
            };
            function sendMessage(event) {
                var input = document.getElementById("messageText")
                ws.send(input.value)
                input.value = ''
                event.preventDefault()
            }
        </script>
    </body>
</html>
"""

@app.get("/")
async def get():
    return HTMLResponse(html)

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """基础WebSocket端点"""

    # 1. 接受WebSocket连接
    await websocket.accept()

    try:
        # 2. 持续监听消息
        while True:
            # 接收文本消息
            data = await websocket.receive_text()

            # 处理消息
            response = f"Echo: {data}"

            # 发送响应
            await websocket.send_text(response)

    except WebSocketDisconnect:
        # 3. 客户端断开连接
        print("Client disconnected")

    except Exception as e:
        # 4. 其他异常处理
        print(f"WebSocket error: {e}")
        await websocket.close(code=1011)  # 内部错误

# 运行服务器
if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

高级WebSocket端点

python
from fastapi import FastAPI, WebSocket, Query, Cookie, status
from fastapi.exceptions import WebSocketException
from typing import Optional
import json

app = FastAPI()

@app.websocket("/ws/{client_id}")
async def websocket_advanced(
    websocket: WebSocket,
    client_id: int,
    token: Optional[str] = Query(None),
    session_id: Optional[str] = Cookie(None)
):
    """高级WebSocket端点,支持认证和参数"""

    # 1. 认证检查
    if not token or not verify_token(token):
        raise WebSocketException(
            code=status.WS_1008_POLICY_VIOLATION,
            reason="Authentication required"
        )

    # 2. 获取用户信息
    user = await get_user_from_token(token)
    if not user:
        raise WebSocketException(
            code=status.WS_1008_POLICY_VIOLATION,
            reason="Invalid user"
        )

    # 3. 设置WebSocket子协议(可选)
    subprotocols = websocket.headers.get("sec-websocket-protocol", "")
    if "chat.v1" in subprotocols:
        await websocket.accept(subprotocol="chat.v1")
    else:
        await websocket.accept()

    # 4. 设置连接元数据
    websocket.state.user_id = user.id
    websocket.state.client_id = client_id
    websocket.state.session_id = session_id

    try:
        # 5. 处理不同类型的消息
        while True:
            # 等待消息
            message = await websocket.receive()

            if message["type"] == "websocket.receive":
                # 判断消息类型
                if "text" in message:
                    await handle_text_message(websocket, message["text"])
                elif "bytes" in message:
                    await handle_binary_message(websocket, message["bytes"])

            elif message["type"] == "websocket.disconnect":
                break

    except WebSocketDisconnect as e:
        print(f"Client {client_id} disconnected: {e.code}")

    except Exception as e:
        print(f"WebSocket error for client {client_id}: {e}")
        await websocket.close(code=1011, reason=str(e))

async def handle_text_message(websocket: WebSocket, text: str):
    """处理文本消息"""
    try:
        # 解析JSON消息
        data = json.loads(text)
        message_type = data.get("type")

        if message_type == "chat":
            await handle_chat_message(websocket, data)
        elif message_type == "ping":
            await handle_ping(websocket, data)
        elif message_type == "subscribe":
            await handle_subscribe(websocket, data)
        else:
            await websocket.send_json({
                "type": "error",
                "message": f"Unknown message type: {message_type}"
            })

    except json.JSONDecodeError:
        await websocket.send_text(f"Echo: {text}")

async def handle_binary_message(websocket: WebSocket, binary: bytes):
    """处理二进制消息"""
    # 例如:处理文件上传、图片等
    await websocket.send_bytes(binary[::-1])  # 反转字节并返回

async def handle_chat_message(websocket: WebSocket, data: dict):
    """处理聊天消息"""
    message = data.get("message", "")
    user_id = websocket.state.user_id

    response = {
        "type": "chat",
        "user_id": user_id,
        "message": message,
        "timestamp": get_timestamp()
    }

    await websocket.send_json(response)

async def handle_ping(websocket: WebSocket, data: dict):
    """处理Ping消息"""
    await websocket.send_json({
        "type": "pong",
        "timestamp": get_timestamp()
    })

async def handle_subscribe(websocket: WebSocket, data: dict):
    """处理订阅消息"""
    channel = data.get("channel")
    websocket.state.subscribed_channels.add(channel)

    await websocket.send_json({
        "type": "subscribed",
        "channel": channel,
        "timestamp": get_timestamp()
    })

# 辅助函数
async def verify_token(token: str) -> bool:
    """验证令牌"""
    # 实际项目中应该实现真正的验证逻辑
    return token == "valid_token"

async def get_user_from_token(token: str):
    """从令牌获取用户信息"""
    # 实际项目中应该查询数据库
    return {"id": 1, "name": "Test User"}

def get_timestamp() -> int:
    """获取时间戳"""
    import time
    return int(time.time() * 1000)

WebSocket异常处理

python
from fastapi import WebSocket, WebSocketException
from starlette import status

class WebSocketErrorHandler:
    """WebSocket异常处理器"""

    @staticmethod
    async def handle_websocket(websocket: WebSocket):
        """包装WebSocket处理,提供统一的异常处理"""
        try:
            await websocket.accept()

            async for message in websocket.iter_text():
                try:
                    await WebSocketErrorHandler.process_message(websocket, message)
                except ProcessingError as e:
                    await websocket.send_json({
                        "error": str(e),
                        "code": "PROCESSING_ERROR"
                    })
                except ValidationError as e:
                    await websocket.send_json({
                        "error": str(e),
                        "code": "VALIDATION_ERROR"
                    })

        except WebSocketException as e:
            print(f"WebSocket异常: {e}")
            # 记录异常但不需要重新抛出

        except Exception as e:
            print(f"未预期的异常: {e}")
            # 发送关闭帧
            await websocket.close(
                code=status.WS_1011_INTERNAL_ERROR,
                reason="Internal server error"
            )

    @staticmethod
    async def process_message(websocket: WebSocket, message: str):
        """处理消息,可能抛出各种异常"""
        import json

        # 解析JSON
        data = json.loads(message)

        # 验证消息格式
        if "type" not in data:
            raise ValidationError("消息必须包含type字段")

        # 处理消息
        if data["type"] == "dangerous":
            raise ProcessingError("危险操作被拒绝")

        # 正常处理
        await websocket.send_text(f"Processed: {data}")

# 自定义异常
class ProcessingError(Exception):
    pass

class ValidationError(Exception):
    pass

# 使用异常处理器
@app.websocket("/ws/safe")
async def safe_websocket_endpoint(websocket: WebSocket):
    """安全的WebSocket端点"""
    handler = WebSocketErrorHandler()
    await handler.handle_websocket(websocket)

3. 实时聊天室实现

完整聊天室后端

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from typing import Dict, List, Set
import asyncio
import json
import uuid
from datetime import datetime
from dataclasses import dataclass, asdict
from enum import Enum

app = FastAPI()

# 数据模型
class MessageType(Enum):
    TEXT = "text"
    IMAGE = "image"
    FILE = "file"
    SYSTEM = "system"

@dataclass
class ChatMessage:
    id: str
    type: MessageType
    content: str
    sender_id: str
    sender_name: str
    room_id: str
    timestamp: datetime
    reply_to: str = None

    def to_dict(self):
        data = asdict(self)
        data["type"] = self.type.value
        data["timestamp"] = self.timestamp.isoformat()
        return data

@dataclass
class User:
    id: str
    username: str
    websocket: WebSocket
    current_room: str = None
    joined_at: datetime = None

    def to_dict(self):
        return {
            "id": self.id,
            "username": self.username,
            "current_room": self.current_room,
            "joined_at": self.joined_at.isoformat() if self.joined_at else None
        }

class ChatRoomManager:
    """聊天室管理器"""

    def __init__(self):
        # 用户管理
        self.users: Dict[str, User] = {}  # user_id -> User
        self.username_to_id: Dict[str, str] = {}  # username -> user_id

        # 房间管理
        self.rooms: Dict[str, Set[str]] = {}  # room_id -> set of user_ids
        self.room_messages: Dict[str, List[ChatMessage]] = {}  # room_id -> messages

        # 消息历史限制
        self.max_messages_per_room = 1000

        # 创建默认房间
        self.create_room("general", "General Chat")
        self.create_room("random", "Random Talk")

    def create_room(self, room_id: str, room_name: str = None):
        """创建新房间"""
        if room_id not in self.rooms:
            self.rooms[room_id] = set()
            self.room_messages[room_id] = []
            self.room_names[room_id] = room_name or room_id.capitalize()

    async def register_user(self, websocket: WebSocket, username: str) -> User:
        """注册新用户"""

        # 检查用户名是否已存在
        if username in self.username_to_id:
            # 用户名已存在,生成唯一用户名
            username = f"{username}_{len(self.users)}"

        # 创建用户
        user_id = str(uuid.uuid4())
        user = User(
            id=user_id,
            username=username,
            websocket=websocket,
            joined_at=datetime.now()
        )

        # 存储用户
        self.users[user_id] = user
        self.username_to_id[username] = user_id

        # 发送欢迎消息
        await self.send_system_message(
            user_id,
            f"欢迎 {username} 加入聊天室!"
        )

        # 广播用户上线
        await self.broadcast_user_status(user_id, "online")

        return user

    async def handle_message(self, user_id: str, message_data: dict):
        """处理用户消息"""

        user = self.users.get(user_id)
        if not user or not user.current_room:
            return

        room_id = user.current_room

        # 创建消息
        message = ChatMessage(
            id=str(uuid.uuid4()),
            type=MessageType(message_data.get("type", "text")),
            content=message_data["content"],
            sender_id=user_id,
            sender_name=user.username,
            room_id=room_id,
            timestamp=datetime.now(),
            reply_to=message_data.get("reply_to")
        )

        # 保存消息
        self.room_messages[room_id].append(message)

        # 限制消息历史
        if len(self.room_messages[room_id]) > self.max_messages_per_room:
            self.room_messages[room_id] = self.room_messages[room_id][-self.max_messages_per_room:]

        # 广播消息
        await self.broadcast_to_room(room_id, {
            "type": "message",
            "message": message.to_dict()
        }, exclude_user_id=user_id)

        # 发送确认给发送者
        await user.websocket.send_json({
            "type": "message_sent",
            "message_id": message.id,
            "timestamp": message.timestamp.isoformat()
        })

    async def join_room(self, user_id: str, room_id: str):
        """用户加入房间"""

        user = self.users.get(user_id)
        if not user:
            return

        # 离开当前房间
        if user.current_room:
            await self.leave_room(user_id, user.current_room)

        # 加入新房间
        if room_id not in self.rooms:
            self.create_room(room_id)

        self.rooms[room_id].add(user_id)
        user.current_room = room_id

        # 发送房间信息
        await user.websocket.send_json({
            "type": "room_joined",
            "room_id": room_id,
            "room_name": self.room_names.get(room_id, room_id),
            "user_count": len(self.rooms[room_id]),
            "recent_messages": [
                msg.to_dict()
                for msg in self.room_messages[room_id][-50:]  # 最近50条消息
            ]
        })

        # 广播用户加入
        await self.send_system_message(
            room_id,
            f"{user.username} 加入了房间"
        )

    async def leave_room(self, user_id: str, room_id: str):
        """用户离开房间"""

        if room_id in self.rooms and user_id in self.rooms[room_id]:
            self.rooms[room_id].remove(user_id)

            user = self.users.get(user_id)
            if user:
                user.current_room = None

            # 广播用户离开
            await self.send_system_message(
                room_id,
                f"{user.username if user else '某用户'} 离开了房间"
            )

    async def send_system_message(self, target: str, content: str):
        """发送系统消息"""

        message = ChatMessage(
            id=str(uuid.uuid4()),
            type=MessageType.SYSTEM,
            content=content,
            sender_id="system",
            sender_name="System",
            room_id=target if target in self.rooms else None,
            timestamp=datetime.now()
        )

        if target in self.rooms:
            # 发送给房间
            self.room_messages[target].append(message)
            await self.broadcast_to_room(target, {
                "type": "system_message",
                "message": message.to_dict()
            })
        elif target in self.users:
            # 发送给用户
            user = self.users[target]
            await user.websocket.send_json({
                "type": "system_message",
                "message": message.to_dict()
            })

    async def broadcast_to_room(
        self,
        room_id: str,
        data: dict,
        exclude_user_id: str = None
    ):
        """广播消息到房间"""

        if room_id not in self.rooms:
            return

        tasks = []
        for user_id in self.rooms[room_id]:
            if user_id == exclude_user_id:
                continue

            user = self.users.get(user_id)
            if user and user.websocket:
                tasks.append(
                    user.websocket.send_json(data)
                )

        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)

    async def broadcast_user_status(self, user_id: str, status: str):
        """广播用户状态变化"""

        user = self.users.get(user_id)
        if not user:
            return

        # 广播给所有房间的用户
        for room_id, user_ids in self.rooms.items():
            if user_id in user_ids:
                await self.broadcast_to_room(room_id, {
                    "type": "user_status",
                    "user_id": user_id,
                    "username": user.username,
                    "status": status,
                    "timestamp": datetime.now().isoformat()
                }, exclude_user_id=user_id)

    async def get_online_users(self, room_id: str = None) -> List[dict]:
        """获取在线用户"""

        if room_id:
            # 房间内的在线用户
            user_ids = self.rooms.get(room_id, set())
        else:
            # 所有在线用户
            user_ids = set(self.users.keys())

        return [
            self.users[uid].to_dict()
            for uid in user_ids
            if uid in self.users
        ]

    async def disconnect_user(self, user_id: str):
        """用户断开连接"""

        user = self.users.pop(user_id, None)
        if user:
            # 从用户名映射中移除
            if user.username in self.username_to_id:
                del self.username_to_id[user.username]

            # 离开所有房间
            for room_id in list(self.rooms.keys()):
                if user_id in self.rooms[room_id]:
                    self.rooms[room_id].remove(user_id)

                    # 广播用户离线
                    await self.send_system_message(
                        room_id,
                        f"{user.username} 已离线"
                    )

            # 关闭WebSocket连接
            try:
                await user.websocket.close()
            except:
                pass

# 创建聊天室管理器实例
chat_manager = ChatRoomManager()

@app.websocket("/ws/chat")
async def chat_websocket(websocket: WebSocket):
    """聊天室WebSocket端点"""

    await websocket.accept()
    user = None

    try:
        # 1. 用户注册
        init_data = await websocket.receive_json()
        username = init_data.get("username", f"User_{uuid.uuid4().hex[:8]}")

        user = await chat_manager.register_user(websocket, username)

        # 2. 发送初始数据
        await websocket.send_json({
            "type": "init",
            "user_id": user.id,
            "username": user.username,
            "available_rooms": list(chat_manager.rooms.keys()),
            "online_users": await chat_manager.get_online_users()
        })

        # 3. 加入默认房间
        await chat_manager.join_room(user.id, "general")

        # 4. 消息处理循环
        while True:
            data = await websocket.receive_json()
            message_type = data.get("type")

            if message_type == "chat_message":
                await chat_manager.handle_message(user.id, data)

            elif message_type == "join_room":
                room_id = data.get("room_id")
                if room_id:
                    await chat_manager.join_room(user.id, room_id)

            elif message_type == "leave_room":
                room_id = data.get("room_id") or user.current_room
                if room_id:
                    await chat_manager.leave_room(user.id, room_id)

            elif message_type == "list_users":
                room_id = data.get("room_id")
                users = await chat_manager.get_online_users(room_id)
                await websocket.send_json({
                    "type": "user_list",
                    "users": users
                })

            elif message_type == "typing":
                # 广播用户正在输入
                if user.current_room:
                    await chat_manager.broadcast_to_room(
                        user.current_room,
                        {
                            "type": "user_typing",
                            "user_id": user.id,
                            "username": user.username,
                            "is_typing": data.get("is_typing", True)
                        },
                        exclude_user_id=user.id
                    )

            elif message_type == "ping":
                await websocket.send_json({
                    "type": "pong",
                    "timestamp": datetime.now().isoformat()
                })

    except WebSocketDisconnect:
        print(f"用户断开连接: {user.username if user else 'Unknown'}")

    except Exception as e:
        print(f"聊天室错误: {e}")

    finally:
        # 清理用户
        if user:
            await chat_manager.disconnect_user(user.id)

# 提供前端页面
@app.get("/chat")
async def chat_page():
    """聊天室前端页面"""
    html = """
    <!DOCTYPE html>
    <html>
    <head>
        <title>FastAPI 聊天室</title>
        <style>
            /* 样式代码... */
        </style>
    </head>
    <body>
        <div id="app">
            <!-- 聊天室UI代码... -->
        </div>
        <script>
            // JavaScript代码...
        </script>
    </body>
    </html>
    """
    return HTMLResponse(html)

4. 连接管理与状态维护

连接池管理器

python
from typing import Dict, Set, Optional
import asyncio
from datetime import datetime, timedelta
import logging

logger = logging.getLogger(__name__)

class ConnectionPool:
    """WebSocket连接池管理器"""

    def __init__(self, max_connections: int = 10000):
        self.max_connections = max_connections
        self.connections: Dict[str, WebSocket] = {}  # connection_id -> websocket
        self.user_connections: Dict[str, Set[str]] = {}  # user_id -> set of connection_ids
        self.connection_users: Dict[str, str] = {}  # connection_id -> user_id

        # 连接元数据
        self.connection_metadata: Dict[str, dict] = {}  # connection_id -> metadata

        # 心跳跟踪
        self.last_heartbeat: Dict[str, datetime] = {}  # connection_id -> last_heartbeat

        # 连接统计
        self.connection_stats = {
            "total_connections": 0,
            "active_connections": 0,
            "max_concurrent": 0,
            "total_disconnections": 0,
            "connection_errors": 0,
        }

    async def add_connection(
        self,
        websocket: WebSocket,
        user_id: Optional[str] = None,
        metadata: Optional[dict] = None
    ) -> str:
        """添加新连接"""

        # 检查连接限制
        if len(self.connections) >= self.max_connections:
            raise ConnectionLimitError(f"达到最大连接数限制: {self.max_connections}")

        # 生成连接ID
        connection_id = self._generate_connection_id()

        # 存储连接
        self.connections[connection_id] = websocket
        self.connection_metadata[connection_id] = metadata or {}
        self.last_heartbeat[connection_id] = datetime.now()

        # 关联用户
        if user_id:
            if user_id not in self.user_connections:
                self.user_connections[user_id] = set()
            self.user_connections[user_id].add(connection_id)
            self.connection_users[connection_id] = user_id

        # 更新统计
        self.connection_stats["total_connections"] += 1
        self.connection_stats["active_connections"] = len(self.connections)
        self.connection_stats["max_concurrent"] = max(
            self.connection_stats["max_concurrent"],
            self.connection_stats["active_connections"]
        )

        logger.info(f"新连接: {connection_id}, 用户: {user_id}")
        return connection_id

    async def remove_connection(self, connection_id: str):
        """移除连接"""

        if connection_id not in self.connections:
            return

        # 获取用户ID
        user_id = self.connection_users.get(connection_id)

        # 清理连接
        del self.connections[connection_id]

        if connection_id in self.connection_metadata:
            del self.connection_metadata[connection_id]

        if connection_id in self.last_heartbeat:
            del self.last_heartbeat[connection_id]

        # 清理用户关联
        if user_id and user_id in self.user_connections:
            self.user_connections[user_id].discard(connection_id)
            if not self.user_connections[user_id]:
                del self.user_connections[user_id]

        if connection_id in self.connection_users:
            del self.connection_users[connection_id]

        # 更新统计
        self.connection_stats["active_connections"] = len(self.connections)
        self.connection_stats["total_disconnections"] += 1

        logger.info(f"连接断开: {connection_id}, 用户: {user_id}")

    async def update_heartbeat(self, connection_id: str):
        """更新心跳时间"""
        if connection_id in self.connections:
            self.last_heartbeat[connection_id] = datetime.now()

    async def check_connection_health(self, timeout_seconds: int = 60) -> Set[str]:
        """检查连接健康状态,返回超时的连接ID"""

        now = datetime.now()
        timeout_connections = set()

        for connection_id, last_beat in self.last_heartbeat.items():
            if (now - last_beat).total_seconds() > timeout_seconds:
                timeout_connections.add(connection_id)

        return timeout_connections

    async def cleanup_dead_connections(self, timeout_seconds: int = 60):
        """清理死连接"""

        dead_connections = await self.check_connection_health(timeout_seconds)

        for connection_id in dead_connections:
            logger.warning(f"清理死连接: {connection_id}")
            await self.remove_connection(connection_id)

        return len(dead_connections)

    async def send_to_connection(self, connection_id: str, message: dict) -> bool:
        """发送消息到指定连接"""

        if connection_id not in self.connections:
            return False

        websocket = self.connections[connection_id]

        try:
            import json
            await websocket.send_json(message)
            return True

        except Exception as e:
            logger.error(f"发送消息到连接 {connection_id} 失败: {e}")
            await self.remove_connection(connection_id)
            return False

    async def send_to_user(self, user_id: str, message: dict) -> int:
        """发送消息到用户的所有连接"""

        if user_id not in self.user_connections:
            return 0

        sent_count = 0
        for connection_id in list(self.user_connections[user_id]):
            success = await self.send_to_connection(connection_id, message)
            if success:
                sent_count += 1

        return sent_count

    async def broadcast(self, message: dict, exclude_connection_ids: Set[str] = None) -> int:
        """广播消息到所有连接"""

        sent_count = 0
        exclude_connection_ids = exclude_connection_ids or set()

        for connection_id in list(self.connections.keys()):
            if connection_id in exclude_connection_ids:
                continue

            success = await self.send_to_connection(connection_id, message)
            if success:
                sent_count += 1

        return sent_count

    def get_connection_info(self, connection_id: str) -> Optional[dict]:
        """获取连接信息"""

        if connection_id not in self.connections:
            return None

        user_id = self.connection_users.get(connection_id)
        metadata = self.connection_metadata.get(connection_id, {})
        last_heartbeat = self.last_heartbeat.get(connection_id)

        return {
            "connection_id": connection_id,
            "user_id": user_id,
            "metadata": metadata,
            "last_heartbeat": last_heartbeat.isoformat() if last_heartbeat else None,
            "connected_at": metadata.get("connected_at"),
            "user_agent": metadata.get("user_agent"),
            "ip_address": metadata.get("ip_address"),
        }

    def get_user_connections(self, user_id: str) -> List[dict]:
        """获取用户的所有连接信息"""

        if user_id not in self.user_connections:
            return []

        connections = []
        for connection_id in self.user_connections[user_id]:
            info = self.get_connection_info(connection_id)
            if info:
                connections.append(info)

        return connections

    def get_stats(self) -> dict:
        """获取统计信息"""

        now = datetime.now()
        active_count = len(self.connections)

        # 计算平均连接时长
        total_duration = 0
        for metadata in self.connection_metadata.values():
            connected_at = metadata.get("connected_at")
            if connected_at:
                try:
                    connected_time = datetime.fromisoformat(connected_at)
                    total_duration += (now - connected_time).total_seconds()
                except:
                    pass

        avg_duration = total_duration / max(active_count, 1)

        stats = self.connection_stats.copy()
        stats.update({
            "timestamp": now.isoformat(),
            "average_connection_duration": avg_duration,
            "unique_users": len(self.user_connections),
            "heartbeat_timeouts": len(await self.check_connection_health(60)),
        })

        return stats

    def _generate_connection_id(self) -> str:
        """生成连接ID"""
        import uuid
        return f"conn_{uuid.uuid4().hex}"

class ConnectionLimitError(Exception):
    """连接限制异常"""
    pass

# 使用连接池
connection_pool = ConnectionPool(max_connections=5000)

# 健康检查任务
async def connection_health_check():
    """定期检查连接健康状态"""
    while True:
        try:
            cleaned = await connection_pool.cleanup_dead_connections(120)  # 2分钟超时
            if cleaned > 0:
                logger.info(f"清理了 {cleaned} 个死连接")

            # 记录统计信息
            stats = connection_pool.get_stats()
            logger.debug(f"连接池统计: {stats}")

        except Exception as e:
            logger.error(f"连接健康检查失败: {e}")

        await asyncio.sleep(30)  # 每30秒检查一次

# 启动健康检查
@app.on_event("startup")
async def startup_connection_health_check():
    asyncio.create_task(connection_health_check())

会话状态管理

python
from typing import Dict, Any, Optional
import pickle
import hashlib
from datetime import datetime, timedelta

class WebSocketSessionManager:
    """WebSocket会话管理器"""

    def __init__(self, redis_client = None):
        self.redis = redis_client
        self.local_sessions: Dict[str, dict] = {}
        self.session_timeout = timedelta(hours=24)

    async def create_session(
        self,
        connection_id: str,
        user_id: Optional[str] = None,
        initial_data: Optional[dict] = None
    ) -> str:
        """创建新会话"""

        session_id = self._generate_session_id(connection_id, user_id)
        session_data = {
            "session_id": session_id,
            "connection_id": connection_id,
            "user_id": user_id,
            "created_at": datetime.now().isoformat(),
            "last_accessed": datetime.now().isoformat(),
            "data": initial_data or {},
            "active": True,
        }

        if self.redis:
            # 存储到Redis
            await self.redis.setex(
                f"websocket_session:{session_id}",
                int(self.session_timeout.total_seconds()),
                pickle.dumps(session_data)
            )
        else:
            # 存储到内存
            self.local_sessions[session_id] = session_data

        return session_id

    async def get_session(self, session_id: str) -> Optional[dict]:
        """获取会话数据"""

        session_data = None

        if self.redis:
            # 从Redis获取
            data = await self.redis.get(f"websocket_session:{session_id}")
            if data:
                session_data = pickle.loads(data)
        else:
            # 从内存获取
            session_data = self.local_sessions.get(session_id)

        if session_data:
            # 更新访问时间
            session_data["last_accessed"] = datetime.now().isoformat()
            await self._save_session(session_id, session_data)

        return session_data

    async def update_session(self, session_id: str, updates: dict) -> bool:
        """更新会话数据"""

        session_data = await self.get_session(session_id)
        if not session_data:
            return False

        # 更新数据
        if "data" in updates:
            session_data["data"].update(updates["data"])
        else:
            session_data.update(updates)

        # 保存
        await self._save_session(session_id, session_data)
        return True

    async def delete_session(self, session_id: str) -> bool:
        """删除会话"""

        if self.redis:
            deleted = await self.redis.delete(f"websocket_session:{session_id}")
            return deleted > 0
        else:
            if session_id in self.local_sessions:
                del self.local_sessions[session_id]
                return True
            return False

    async def invalidate_user_sessions(self, user_id: str, keep_current: str = None):
        """使用户的所有会话失效(除当前会话)"""

        # 这里需要实现查找用户所有会话的逻辑
        # 实际项目中可能需要维护 user_id -> [session_ids] 的映射

        if self.redis:
            # 使用Redis扫描匹配的key
            pattern = f"websocket_session:*"
            cursor = 0
            deleted_count = 0

            while True:
                cursor, keys = await self.redis.scan(
                    cursor=cursor,
                    match=pattern,
                    count=100
                )

                for key in keys:
                    data = await self.redis.get(key)
                    if data:
                        session = pickle.loads(data)
                        if session["user_id"] == user_id:
                            if keep_current and session["session_id"] == keep_current:
                                continue
                            await self.redis.delete(key)
                            deleted_count += 1

                if cursor == 0:
                    break

            return deleted_count

        else:
            # 内存版本
            deleted_count = 0
            for session_id, session in list(self.local_sessions.items()):
                if session["user_id"] == user_id:
                    if keep_current and session_id == keep_current:
                        continue
                    del self.local_sessions[session_id]
                    deleted_count += 1

            return deleted_count

    async def cleanup_expired_sessions(self) -> int:
        """清理过期会话"""

        if self.redis:
            # Redis会自动过期
            return 0

        else:
            # 内存版本需要手动清理
            now = datetime.now()
            expired_count = 0

            for session_id, session in list(self.local_sessions.items()):
                last_accessed = datetime.fromisoformat(session["last_accessed"])
                if now - last_accessed > self.session_timeout:
                    del self.local_sessions[session_id]
                    expired_count += 1

            return expired_count

    async def _save_session(self, session_id: str, session_data: dict):
        """保存会话数据"""

        if self.redis:
            await self.redis.setex(
                f"websocket_session:{session_id}",
                int(self.session_timeout.total_seconds()),
                pickle.dumps(session_data)
            )
        else:
            self.local_sessions[session_id] = session_data

    def _generate_session_id(self, connection_id: str, user_id: Optional[str]) -> str:
        """生成会话ID"""

        components = [
            connection_id,
            user_id or "anonymous",
            datetime.now().isoformat(),
            hashlib.md5(str(id(self)).encode()).hexdigest()[:8]
        ]

        raw_id = "|".join(components)
        return hashlib.sha256(raw_id.encode()).hexdigest()[:32]

# 使用会话管理器
session_manager = WebSocketSessionManager()

@app.websocket("/ws/with-session")
async def websocket_with_session(websocket: WebSocket):
    """带会话管理的WebSocket端点"""

    await websocket.accept()

    # 提取用户信息(从查询参数、头或cookie)
    user_id = websocket.query_params.get("user_id")
    connection_id = None

    try:
        # 创建会话
        session_id = await session_manager.create_session(
            connection_id="pending",  # 稍后更新
            user_id=user_id,
            initial_data={
                "user_agent": websocket.headers.get("user-agent"),
                "ip_address": websocket.client.host if websocket.client else None,
                "connected_at": datetime.now().isoformat(),
            }
        )

        # 存储到WebSocket状态
        websocket.state.session_id = session_id

        # 主循环
        while True:
            data = await websocket.receive_json()

            # 更新会话访问时间
            await session_manager.update_session(session_id, {
                "last_accessed": datetime.now().isoformat()
            })

            # 处理消息
            if data.get("type") == "update_session":
                await session_manager.update_session(session_id, {
                    "data": data.get("data", {})
                })

            elif data.get("type") == "get_session":
                session = await session_manager.get_session(session_id)
                await websocket.send_json({
                    "type": "session_data",
                    "session": session
                })

    except WebSocketDisconnect:
        # 清理会话
        if hasattr(websocket.state, "session_id"):
            await session_manager.delete_session(websocket.state.session_id)

    except Exception as e:
        logger.error(f"WebSocket会话错误: {e}")

5. 广播消息与房间概念

房间广播系统

python
from typing import Dict, Set, List, Optional
import asyncio
from collections import defaultdict
import json

class RoomBroadcastSystem:
    """房间广播系统"""

    def __init__(self):
        # 房间成员:room_id -> set of connection_ids
        self.room_members: Dict[str, Set[str]] = defaultdict(set)

        # 连接的房间:connection_id -> set of room_ids
        self.connection_rooms: Dict[str, Set[str]] = defaultdict(set)

        # 房间元数据
        self.room_metadata: Dict[str, dict] = {}

        # 消息队列(用于处理大量消息时的流量控制)
        self.message_queues: Dict[str, asyncio.Queue] = {}

        # 广播工作者
        self.broadcast_workers: Dict[str, asyncio.Task] = {}

        # 统计
        self.stats = {
            "total_rooms": 0,
            "active_rooms": 0,
            "total_broadcasts": 0,
            "total_messages_sent": 0,
        }

    async def create_room(
        self,
        room_id: str,
        metadata: Optional[dict] = None,
        max_members: int = 1000
    ) -> bool:
        """创建房间"""

        if room_id in self.room_members:
            return False

        self.room_members[room_id] = set()
        self.room_metadata[room_id] = {
            "room_id": room_id,
            "created_at": datetime.now().isoformat(),
            "max_members": max_members,
            "is_public": metadata.get("is_public", True) if metadata else True,
            "owner": metadata.get("owner"),
            "password": metadata.get("password"),  # 实际项目中应该哈希存储
            **(metadata or {})
        }

        # 创建消息队列和工作者
        self.message_queues[room_id] = asyncio.Queue(maxsize=1000)
        self.broadcast_workers[room_id] = asyncio.create_task(
            self._room_broadcast_worker(room_id)
        )

        self.stats["total_rooms"] += 1
        self.stats["active_rooms"] += 1

        return True

    async def join_room(
        self,
        connection_id: str,
        room_id: str,
        password: Optional[str] = None
    ) -> tuple[bool, str]:
        """加入房间"""

        if room_id not in self.room_members:
            return False, "房间不存在"

        room_meta = self.room_metadata[room_id]

        # 检查密码
        if room_meta.get("password") and room_meta["password"] != password:
            return False, "密码错误"

        # 检查人数限制
        if len(self.room_members[room_id]) >= room_meta["max_members"]:
            return False, "房间已满"

        # 加入房间
        self.room_members[room_id].add(connection_id)
        self.connection_rooms[connection_id].add(room_id)

        # 发送加入通知
        await self.broadcast_to_room(room_id, {
            "type": "user_joined",
            "room_id": room_id,
            "connection_id": connection_id,
            "timestamp": datetime.now().isoformat(),
            "member_count": len(self.room_members[room_id])
        }, exclude_connection_id=connection_id)

        return True, "加入成功"

    async def leave_room(self, connection_id: str, room_id: str):
        """离开房间"""

        if room_id in self.room_members and connection_id in self.room_members[room_id]:
            self.room_members[room_id].remove(connection_id)

            if connection_id in self.connection_rooms:
                self.connection_rooms[connection_id].discard(room_id)

            # 发送离开通知
            await self.broadcast_to_room(room_id, {
                "type": "user_left",
                "room_id": room_id,
                "connection_id": connection_id,
                "timestamp": datetime.now().isoformat(),
                "member_count": len(self.room_members[room_id])
            })

            # 如果房间为空,清理
            if not self.room_members[room_id]:
                await self.delete_room(room_id)

    async def delete_room(self, room_id: str):
        """删除房间"""

        if room_id not in self.room_members:
            return

        # 通知所有成员房间关闭
        await self.broadcast_to_room(room_id, {
            "type": "room_closed",
            "room_id": room_id,
            "reason": "房间已被删除",
            "timestamp": datetime.now().isoformat()
        })

        # 移除所有成员
        for connection_id in list(self.room_members[room_id]):
            if connection_id in self.connection_rooms:
                self.connection_rooms[connection_id].discard(room_id)

        # 清理数据结构
        del self.room_members[room_id]
        if room_id in self.room_metadata:
            del self.room_metadata[room_id]

        # 停止广播工作者
        if room_id in self.broadcast_workers:
            self.broadcast_workers[room_id].cancel()
            del self.broadcast_workers[room_id]

        if room_id in self.message_queues:
            del self.message_queues[room_id]

        self.stats["active_rooms"] -= 1

    async def broadcast_to_room(
        self,
        room_id: str,
        message: dict,
        exclude_connection_id: Optional[str] = None
    ) -> int:
        """广播消息到房间(异步队列)"""

        if room_id not in self.message_queues:
            return 0

        try:
            # 将消息放入队列
            broadcast_data = {
                "message": message,
                "exclude_connection_id": exclude_connection_id,
                "timestamp": datetime.now().isoformat()
            }

            self.message_queues[room_id].put_nowait(broadcast_data)

            self.stats["total_broadcasts"] += 1
            return len(self.room_members.get(room_id, set()))

        except asyncio.QueueFull:
            logger.warning(f"房间 {room_id} 的消息队列已满")
            return 0

    async def _room_broadcast_worker(self, room_id: str):
        """房间广播工作者"""

        queue = self.message_queues[room_id]

        while True:
            try:
                # 从队列获取消息
                broadcast_data = await queue.get()
                message = broadcast_data["message"]
                exclude_connection_id = broadcast_data["exclude_connection_id"]

                # 获取房间成员
                members = self.room_members.get(room_id, set())

                # 发送消息
                sent_count = 0
                tasks = []

                for connection_id in members:
                    if connection_id == exclude_connection_id:
                        continue

                    # 获取连接(需要从连接池获取)
                    websocket = connection_pool.connections.get(connection_id)
                    if websocket:
                        task = asyncio.create_task(
                            self._send_to_websocket(websocket, message)
                        )
                        tasks.append((connection_id, task))

                # 等待所有发送完成
                for connection_id, task in tasks:
                    try:
                        success = await task
                        if success:
                            sent_count += 1
                    except Exception as e:
                        logger.error(f"向连接 {connection_id} 发送消息失败: {e}")

                self.stats["total_messages_sent"] += sent_count

                # 标记任务完成
                queue.task_done()

            except asyncio.CancelledError:
                break

            except Exception as e:
                logger.error(f"房间广播工作者错误 (房间: {room_id}): {e}")
                await asyncio.sleep(1)  # 避免快速失败循环

    async def _send_to_websocket(self, websocket: WebSocket, message: dict) -> bool:
        """发送消息到WebSocket"""

        try:
            await websocket.send_json(message)
            return True

        except Exception as e:
            logger.debug(f"发送消息失败: {e}")
            return False

    async def direct_broadcast_to_room(
        self,
        room_id: str,
        message: dict,
        exclude_connection_id: Optional[str] = None
    ) -> int:
        """直接广播到房间(同步,用于重要消息)"""

        if room_id not in self.room_members:
            return 0

        members = self.room_members[room_id]
        sent_count = 0

        for connection_id in members:
            if connection_id == exclude_connection_id:
                continue

            websocket = connection_pool.connections.get(connection_id)
            if websocket:
                try:
                    await websocket.send_json(message)
                    sent_count += 1
                except Exception as e:
                    logger.debug(f"直接广播失败: {e}")

        return sent_count

    def get_room_info(self, room_id: str) -> Optional[dict]:
        """获取房间信息"""

        if room_id not in self.room_members:
            return None

        return {
            "room_id": room_id,
            "member_count": len(self.room_members[room_id]),
            "metadata": self.room_metadata.get(room_id, {}),
            "queue_size": self.message_queues[room_id].qsize() if room_id in self.message_queues else 0,
        }

    def get_user_rooms(self, connection_id: str) -> List[dict]:
        """获取用户加入的房间"""

        rooms = []
        for room_id in self.connection_rooms.get(connection_id, set()):
            room_info = self.get_room_info(room_id)
            if room_info:
                rooms.append(room_info)

        return rooms

    def get_all_rooms(self, include_private: bool = False) -> List[dict]:
        """获取所有房间"""

        rooms = []
        for room_id in self.room_members:
            room_meta = self.room_metadata.get(room_id, {})

            if not include_private and not room_meta.get("is_public", True):
                continue

            rooms.append(self.get_room_info(room_id))

        return rooms

    async def cleanup_empty_rooms(self, min_age_seconds: int = 300) -> int:
        """清理空房间"""

        now = datetime.now()
        cleaned_count = 0

        for room_id in list(self.room_members.keys()):
            if not self.room_members[room_id]:
                room_meta = self.room_metadata.get(room_id, {})
                created_at = datetime.fromisoformat(room_meta.get("created_at", now.isoformat()))

                # 只清理创建时间超过min_age_seconds的空房间
                if (now - created_at).total_seconds() > min_age_seconds:
                    await self.delete_room(room_id)
                    cleaned_count += 1

        return cleaned_count

# 使用广播系统
broadcast_system = RoomBroadcastSystem()

# 房间管理端点
@app.websocket("/ws/room/{room_id}")
async def room_websocket(websocket: WebSocket, room_id: str):
    """房间WebSocket端点"""

    await websocket.accept()

    # 获取连接ID
    connection_id = None
    password = websocket.query_params.get("password")

    try:
        # 创建或加入房间
        if room_id not in broadcast_system.room_members:
            await broadcast_system.create_room(room_id, {
                "creator": "system",
                "created_at": datetime.now().isoformat()
            })

        # 从连接池获取连接ID
        # 这里需要实现连接池的集成

        # 加入房间
        success, message = await broadcast_system.join_room(
            connection_id, room_id, password
        )

        if not success:
            await websocket.send_json({
                "type": "error",
                "message": message
            })
            await websocket.close(code=1008)  # 策略违规
            return

        # 发送欢迎消息
        await websocket.send_json({
            "type": "room_joined",
            "room_id": room_id,
            "message": message,
            "member_count": len(broadcast_system.room_members[room_id]),
            "timestamp": datetime.now().isoformat()
        })

        # 主循环
        while True:
            data = await websocket.receive_json()

            if data.get("type") == "chat_message":
                # 广播聊天消息
                await broadcast_system.broadcast_to_room(room_id, {
                    "type": "chat_message",
                    "room_id": room_id,
                    "sender": connection_id,  # 实际项目中应该是用户ID
                    "message": data["message"],
                    "timestamp": datetime.now().isoformat()
                }, exclude_connection_id=connection_id)

            elif data.get("type") == "room_info":
                # 获取房间信息
                room_info = broadcast_system.get_room_info(room_id)
                await websocket.send_json({
                    "type": "room_info",
                    "room": room_info
                })

    except WebSocketDisconnect:
        # 离开房间
        if connection_id:
            await broadcast_system.leave_room(connection_id, room_id)

    except Exception as e:
        logger.error(f"房间WebSocket错误: {e}")

6. 心跳检测与重连机制

完整的心跳检测系统

python
import asyncio
from datetime import datetime, timedelta
from typing import Dict, Optional
import logging

logger = logging.getLogger(__name__)

class HeartbeatManager:
    """心跳管理器"""

    def __init__(self, timeout_seconds: int = 30, interval_seconds: int = 10):
        self.timeout_seconds = timeout_seconds
        self.interval_seconds = interval_seconds

        # 心跳跟踪
        self.last_heartbeat: Dict[str, datetime] = {}  # connection_id -> last_heartbeat
        self.heartbeat_tasks: Dict[str, asyncio.Task] = {}  # connection_id -> heartbeat_task

        # 统计
        self.stats = {
            "total_heartbeats": 0,
            "timeout_count": 0,
            "reconnect_count": 0,
        }

    async def start_monitoring(self, connection_id: str, websocket: WebSocket):
        """开始监控连接的心跳"""

        # 记录初始心跳
        self.last_heartbeat[connection_id] = datetime.now()

        # 启动心跳任务
        self.heartbeat_tasks[connection_id] = asyncio.create_task(
            self._heartbeat_monitor(connection_id, websocket)
        )

        # 启动ping任务
        asyncio.create_task(
            self._send_ping_messages(connection_id, websocket)
        )

    async def stop_monitoring(self, connection_id: str):
        """停止监控连接"""

        if connection_id in self.heartbeat_tasks:
            self.heartbeat_tasks[connection_id].cancel()
            del self.heartbeat_tasks[connection_id]

        if connection_id in self.last_heartbeat:
            del self.last_heartbeat[connection_id]

    async def record_heartbeat(self, connection_id: str):
        """记录心跳"""

        self.last_heartbeat[connection_id] = datetime.now()
        self.stats["total_heartbeats"] += 1

    async def _heartbeat_monitor(self, connection_id: str, websocket: WebSocket):
        """心跳监控器"""

        try:
            while True:
                await asyncio.sleep(self.timeout_seconds)

                # 检查最后心跳时间
                last_beat = self.last_heartbeat.get(connection_id)
                if not last_beat:
                    break

                now = datetime.now()
                time_since_last_beat = (now - last_beat).total_seconds()

                if time_since_last_beat > self.timeout_seconds:
                    logger.warning(f"心跳超时: {connection_id}, 最后心跳: {last_beat}")
                    self.stats["timeout_count"] += 1

                    # 尝试发送ping确认
                    try:
                        await websocket.send_json({
                            "type": "ping",
                            "timestamp": now.isoformat()
                        })

                        # 等待pong响应
                        try:
                            response = await asyncio.wait_for(
                                websocket.receive_json(),
                                timeout=5
                            )

                            if response.get("type") == "pong":
                                # 收到pong,更新心跳
                                await self.record_heartbeat(connection_id)
                                continue

                        except asyncio.TimeoutError:
                            logger.error(f"Ping无响应: {connection_id}")

                    except Exception as e:
                        logger.error(f"发送ping失败: {e}")

                    # 关闭连接
                    try:
                        await websocket.close(code=1000, reason="Heartbeat timeout")
                    except:
                        pass

                    break

        except asyncio.CancelledError:
            pass

        except Exception as e:
            logger.error(f"心跳监控器错误: {e}")

    async def _send_ping_messages(self, connection_id: str, websocket: WebSocket):
        """发送ping消息"""

        try:
            while True:
                await asyncio.sleep(self.interval_seconds)

                # 发送ping
                try:
                    await websocket.send_json({
                        "type": "ping",
                        "timestamp": datetime.now().isoformat(),
                        "sequence": self.stats["total_heartbeats"]
                    })
                except Exception as e:
                    logger.debug(f"发送ping失败: {e}")
                    break

        except asyncio.CancelledError:
            pass

    def get_connection_health(self, connection_id: str) -> Dict:
        """获取连接健康状态"""

        last_beat = self.last_heartbeat.get(connection_id)

        if not last_beat:
            return {
                "connection_id": connection_id,
                "status": "unknown",
                "monitoring": False
            }

        now = datetime.now()
        time_since_last_beat = (now - last_beat).total_seconds()

        status = "healthy"
        if time_since_last_beat > self.timeout_seconds:
            status = "timeout"
        elif time_since_last_beat > self.timeout_seconds * 0.8:
            status = "warning"

        return {
            "connection_id": connection_id,
            "status": status,
            "last_heartbeat": last_beat.isoformat(),
            "seconds_since_last_heartbeat": time_since_last_beat,
            "monitoring": connection_id in self.heartbeat_tasks,
            "timeout_threshold": self.timeout_seconds
        }

class ReconnectionManager:
    """重连管理器"""

    def __init__(self, max_attempts: int = 5, base_delay: float = 1.0):
        self.max_attempts = max_attempts
        self.base_delay = base_delay

        # 重连跟踪
        self.reconnection_attempts: Dict[str, int] = {}  # connection_id -> attempts
        self.reconnection_timers: Dict[str, asyncio.Task] = {}

        # 回调函数
        self.on_reconnect_callbacks = []

        # 统计
        self.stats = {
            "total_reconnection_attempts": 0,
            "successful_reconnections": 0,
            "failed_reconnections": 0,
        }

    async def schedule_reconnection(
        self,
        connection_id: str,
        user_id: Optional[str] = None,
        metadata: Optional[dict] = None
    ):
        """调度重连"""

        if connection_id in self.reconnection_timers:
            # 已经有一个重连任务
            return

        attempts = self.reconnection_attempts.get(connection_id, 0)

        if attempts >= self.max_attempts:
            logger.info(f"达到最大重连次数: {connection_id}")
            self.stats["failed_reconnections"] += 1
            return

        # 计算延迟(指数退避)
        delay = self.base_delay * (2 ** attempts)

        logger.info(f"调度重连: {connection_id}, 尝试 {attempts + 1}/{self.max_attempts}, 延迟 {delay}秒")

        # 创建重连任务
        self.reconnection_attempts[connection_id] = attempts + 1
        self.reconnection_timers[connection_id] = asyncio.create_task(
            self._reconnect_after_delay(connection_id, user_id, metadata, delay)
        )

    async def _reconnect_after_delay(
        self,
        connection_id: str,
        user_id: Optional[str],
        metadata: Optional[dict],
        delay: float
    ):
        """延迟后重连"""

        try:
            await asyncio.sleep(delay)

            # 尝试重连
            success = await self._attempt_reconnection(connection_id, user_id, metadata)

            if success:
                # 重连成功,重置尝试计数
                if connection_id in self.reconnection_attempts:
                    del self.reconnection_attempts[connection_id]
                self.stats["successful_reconnections"] += 1

                # 调用回调
                for callback in self.on_reconnect_callbacks:
                    try:
                        await callback(connection_id, user_id, metadata)
                    except Exception as e:
                        logger.error(f"重连回调错误: {e}")

            else:
                # 重连失败,可能再次调度
                if connection_id in self.reconnection_attempts:
                    attempts = self.reconnection_attempts[connection_id]
                    if attempts < self.max_attempts:
                        # 再次调度
                        await self.schedule_reconnection(connection_id, user_id, metadata)
                    else:
                        logger.warning(f"重连最终失败: {connection_id}")
                        self.stats["failed_reconnections"] += 1

        except asyncio.CancelledError:
            pass

        finally:
            if connection_id in self.reconnection_timers:
                del self.reconnection_timers[connection_id]

    async def _attempt_reconnection(
        self,
        connection_id: str,
        user_id: Optional[str],
        metadata: Optional[dict]
    ) -> bool:
        """尝试重连"""

        # 这里应该实现实际的重连逻辑
        # 例如:重新建立WebSocket连接,恢复会话状态等

        logger.info(f"尝试重连: {connection_id}")

        # 模拟重连
        import random
        success = random.random() > 0.3  # 70%成功率

        if success:
            logger.info(f"重连成功: {connection_id}")
        else:
            logger.info(f"重连失败: {connection_id}")

        return success

    def cancel_reconnection(self, connection_id: str):
        """取消重连"""

        if connection_id in self.reconnection_timers:
            self.reconnection_timers[connection_id].cancel()
            del self.reconnection_timers[connection_id]

        if connection_id in self.reconnection_attempts:
            del self.reconnection_attempts[connection_id]

    def add_reconnect_callback(self, callback):
        """添加重连回调"""
        self.on_reconnect_callbacks.append(callback)

# 集成心跳和重连
heartbeat_manager = HeartbeatManager(timeout_seconds=30, interval_seconds=10)
reconnection_manager = ReconnectionManager(max_attempts=5, base_delay=1.0)

@app.websocket("/ws/reliable")
async def reliable_websocket(
    websocket: WebSocket,
    client_id: str,
    token: Optional[str] = None
):
    """可靠的WebSocket端点(带心跳和重连)"""

    await websocket.accept()

    # 验证
    if token and not await verify_token(token):
        await websocket.close(code=1008, reason="Invalid token")
        return

    # 注册连接
    connection_id = f"{client_id}_{int(datetime.now().timestamp())}"

    try:
        # 发送连接确认
        await websocket.send_json({
            "type": "connected",
            "connection_id": connection_id,
            "timestamp": datetime.now().isoformat(),
            "heartbeat_interval": heartbeat_manager.interval_seconds
        })

        # 开始心跳监控
        await heartbeat_manager.start_monitoring(connection_id, websocket)

        # 主消息循环
        while True:
            try:
                # 设置接收超时
                data = await asyncio.wait_for(
                    websocket.receive_json(),
                    timeout=heartbeat_manager.timeout_seconds * 2
                )

                # 处理消息
                message_type = data.get("type")

                if message_type == "ping":
                    # 心跳响应
                    await heartbeat_manager.record_heartbeat(connection_id)
                    await websocket.send_json({
                        "type": "pong",
                        "timestamp": datetime.now().isoformat(),
                        "original_timestamp": data.get("timestamp")
                    })

                elif message_type == "reconnect_request":
                    # 客户端请求重连
                    await websocket.send_json({
                        "type": "reconnect_ack",
                        "timestamp": datetime.now().isoformat()
                    })
                    # 可以在这里执行优雅的重连逻辑

                else:
                    # 业务消息
                    await handle_business_message(websocket, data)

            except asyncio.TimeoutError:
                # 接收超时,检查心跳
                health = heartbeat_manager.get_connection_health(connection_id)
                if health["status"] == "timeout":
                    logger.warning(f"连接 {connection_id} 心跳超时")
                    break

            except WebSocketDisconnect:
                break

    except WebSocketDisconnect as e:
        logger.info(f"客户端断开连接: {connection_id}, 代码: {e.code}")

        # 如果异常断开,调度重连
        if e.code not in [1000, 1001]:  # 正常关闭和离开
            await reconnection_manager.schedule_reconnection(
                connection_id,
                user_id=client_id,
                metadata={"disconnect_code": e.code}
            )

    except Exception as e:
        logger.error(f"WebSocket错误: {connection_id}, 错误: {e}")

        # 调度重连
        await reconnection_manager.schedule_reconnection(
            connection_id,
            user_id=client_id,
            metadata={"error": str(e)}
        )

    finally:
        # 清理
        await heartbeat_manager.stop_monitoring(connection_id)

        # 发送断开通知(如果可能)
        try:
            await websocket.close(code=1000, reason="Connection ended")
        except:
            pass

async def handle_business_message(websocket: WebSocket, data: dict):
    """处理业务消息"""
    # 业务逻辑...
    await websocket.send_json({
        "type": "ack",
        "message_id": data.get("message_id"),
        "timestamp": datetime.now().isoformat()
    })

7. 生产环境部署要点

Nginx配置

nginx
# /etc/nginx/sites-available/websocket-app
upstream websocket_backend {
    # 负载均衡
    server 127.0.0.1:8000;
    server 127.0.0.1:8001;
    server 127.0.0.1:8002;

    # 使用least_conn进行WebSocket连接均衡
    least_conn;

    # 保持连接(重要!)
    keepalive 32;
}

server {
    listen 80;
    server_name websocket.example.com;

    # 重定向到HTTPS
    return 301 https://$server_name$request_uri;
}

server {
    listen 443 ssl http2;
    server_name websocket.example.com;

    # SSL配置
    ssl_certificate /etc/letsencrypt/live/websocket.example.com/fullchain.pem;
    ssl_certificate_key /etc/letsencrypt/live/websocket.example.com/privkey.pem;

    ssl_protocols TLSv1.2 TLSv1.3;
    ssl_ciphers ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512;
    ssl_prefer_server_ciphers off;

    # WebSocket支持
    location /ws/ {
        proxy_pass http://websocket_backend;

        # WebSocket必需的头
        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";

        # 超时设置
        proxy_read_timeout 86400s;  # 24小时
        proxy_send_timeout 86400s;

        # 缓冲设置
        proxy_buffering off;
        proxy_buffer_size 4k;

        # 保持连接
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;

        # 连接数限制
        limit_conn websocket_conn 1000;
        limit_conn_status 429;
    }

    # 静态文件
    location /static/ {
        alias /var/www/websocket-app/static/;
        expires 1y;
        add_header Cache-Control "public, immutable";
    }

    # API端点
    location /api/ {
        proxy_pass http://websocket_backend;
        proxy_http_version 1.1;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;
    }

    # 健康检查
    location /health {
        proxy_pass http://websocket_backend;
        access_log off;
    }

    # 限制连接区域
    limit_conn_zone $binary_remote_addr zone=websocket_conn:10m;
}

Docker部署配置

dockerfile
# Dockerfile
FROM python:3.11-slim

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

# 设置工作目录
WORKDIR /app

# 复制依赖文件
COPY requirements.txt .

# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY . .

# 创建非root用户
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser

# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
    CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"

# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", \
     "--workers", "4", "--loop", "uvloop", "--http", "httptools"]
yaml
# docker-compose.yml
version: "3.8"

services:
  websocket-app:
    build: .
    ports:
      - "8000:8000"
      - "8001:8000"
      - "8002:8000"
    environment:
      - REDIS_URL=redis://redis:6379/0
      - DATABASE_URL=postgresql://user:password@postgres/websocket_db
      - SECRET_KEY=${SECRET_KEY}
      - ENVIRONMENT=production
    depends_on:
      - redis
      - postgres
    deploy:
      replicas: 3
      update_config:
        parallelism: 1
        delay: 10s
      restart_policy:
        condition: on-failure
        delay: 5s
        max_attempts: 3
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 10s
      retries: 3
      start_period: 40s

  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"
    command: redis-server --appendonly yes --maxmemory 512mb --maxmemory-policy allkeys-lru
    volumes:
      - redis-data:/data
    deploy:
      resources:
        limits:
          memory: 1G
    healthcheck:
      test: ["CMD", "redis-cli", "ping"]
      interval: 10s
      timeout: 5s
      retries: 3

  postgres:
    image: postgres:15-alpine
    environment:
      - POSTGRES_USER=user
      - POSTGRES_PASSWORD=password
      - POSTGRES_DB=websocket_db
    volumes:
      - postgres-data:/var/lib/postgresql/data
    ports:
      - "5432:5432"
    healthcheck:
      test: ["CMD-SHELL", "pg_isready -U user"]
      interval: 10s
      timeout: 5s
      retries: 5

  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
      - "443:443"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf:ro
      - ./ssl:/etc/nginx/ssl:ro
    depends_on:
      - websocket-app
    deploy:
      resources:
        limits:
          memory: 512M

volumes:
  redis-data:
  postgres-data:

监控和日志配置

python
# logging_config.py
import logging
import sys
from logging.handlers import RotatingFileHandler
import json
from datetime import datetime

class JSONFormatter(logging.Formatter):
    """JSON日志格式化器"""

    def format(self, record):
        log_data = {
            "timestamp": datetime.now().isoformat(),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
            "module": record.module,
            "function": record.funcName,
            "line": record.lineno,
        }

        # 添加异常信息
        if record.exc_info:
            log_data["exception"] = self.formatException(record.exc_info)

        # 添加额外字段
        if hasattr(record, "extra"):
            log_data.update(record.extra)

        return json.dumps(log_data, ensure_ascii=False)

class WebSocketLogger:
    """WebSocket专用日志器"""

    def __init__(self):
        self.logger = logging.getLogger("websocket")
        self.logger.setLevel(logging.INFO)

        # 控制台处理器
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(JSONFormatter())
        self.logger.addHandler(console_handler)

        # 文件处理器
        file_handler = RotatingFileHandler(
            "/var/log/websocket/app.log",
            maxBytes=10485760,  # 10MB
            backupCount=10
        )
        file_handler.setFormatter(JSONFormatter())
        self.logger.addHandler(file_handler)

    def log_connection(self, connection_id: str, event: str, details: dict = None):
        """记录连接事件"""
        self.logger.info(f"Connection {event}", extra={
            "connection_id": connection_id,
            "event": event,
            "details": details or {},
            "type": "connection"
        })

    def log_message(self, connection_id: str, direction: str, message: dict):
        """记录消息事件"""
        self.logger.debug(f"Message {direction}", extra={
            "connection_id": connection_id,
            "direction": direction,
            "message_type": message.get("type"),
            "message_size": len(str(message)),
            "type": "message"
        })

    def log_error(self, connection_id: str, error: Exception, context: dict = None):
        """记录错误事件"""
        self.logger.error(f"WebSocket error", extra={
            "connection_id": connection_id,
            "error": str(error),
            "error_type": type(error).__name__,
            "context": context or {},
            "type": "error"
        })

# 指标收集
class WebSocketMetrics:
    """WebSocket指标收集器"""

    def __init__(self, statsd_client=None):
        self.statsd = statsd_client
        self.metrics = {}

    def increment(self, metric: str, value: int = 1, tags: dict = None):
        """增加计数器"""
        if self.statsd:
            self.statsd.increment(metric, value, tags=tags)
        else:
            key = self._metric_key(metric, tags)
            self.metrics[key] = self.metrics.get(key, 0) + value

    def gauge(self, metric: str, value: float, tags: dict = None):
        """设置仪表值"""
        if self.statsd:
            self.statsd.gauge(metric, value, tags=tags)
        else:
            key = self._metric_key(metric, tags)
            self.metrics[key] = value

    def timer(self, metric: str, value: float, tags: dict = None):
        """记录计时器"""
        if self.statsd:
            self.statsd.timer(metric, value, tags=tags)
        else:
            key = self._metric_key(metric, tags)
            self.metrics[key] = value

    def _metric_key(self, metric: str, tags: dict) -> str:
        """生成指标键"""
        if not tags:
            return metric

        tag_str = ",".join(f"{k}={v}" for k, v in sorted(tags.items()))
        return f"{metric}[{tag_str}]"

    def get_metrics(self) -> dict:
        """获取所有指标"""
        return self.metrics.copy()

# 使用示例
websocket_logger = WebSocketLogger()
websocket_metrics = WebSocketMetrics()

@app.websocket("/ws/monitored")
async def monitored_websocket(websocket: WebSocket, client_id: str):
    """被监控的WebSocket端点"""

    await websocket.accept()

    connection_id = f"monitored_{client_id}"

    # 记录连接开始
    websocket_logger.log_connection(connection_id, "connected", {
        "client_id": client_id,
        "user_agent": websocket.headers.get("user-agent")
    })

    websocket_metrics.increment("websocket.connections.total")
    websocket_metrics.gauge("websocket.connections.active",
                           len(connection_pool.connections))

    try:
        while True:
            start_time = datetime.now()

            try:
                data = await websocket.receive_json()

                # 记录接收消息
                websocket_logger.log_message(connection_id, "received", data)
                websocket_metrics.increment("websocket.messages.received", tags={
                    "type": data.get("type", "unknown")
                })

                # 处理消息
                response = await process_message(data)

                # 发送响应
                await websocket.send_json(response)

                # 记录发送消息
                websocket_logger.log_message(connection_id, "sent", response)
                websocket_metrics.increment("websocket.messages.sent", tags={
                    "type": response.get("type", "unknown")
                })

                # 记录处理时间
                process_time = (datetime.now() - start_time).total_seconds()
                websocket_metrics.timer("websocket.messages.processing_time", process_time)

            except asyncio.TimeoutError:
                websocket_metrics.increment("websocket.errors.timeout")
                raise

            except json.JSONDecodeError:
                websocket_metrics.increment("websocket.errors.invalid_json")
                await websocket.send_json({
                    "type": "error",
                    "message": "Invalid JSON"
                })

    except WebSocketDisconnect as e:
        websocket_logger.log_connection(connection_id, "disconnected", {
            "code": e.code,
            "reason": e.reason
        })

        websocket_metrics.increment("websocket.disconnections", tags={
            "code": str(e.code)
        })
        websocket_metrics.gauge("websocket.connections.active",
                               len(connection_pool.connections))

    except Exception as e:
        websocket_logger.log_error(connection_id, e, {
            "client_id": client_id
        })
        websocket_metrics.increment("websocket.errors.unexpected")

    finally:
        websocket_metrics.gauge("websocket.connections.active",
                               len(connection_pool.connections))

性能优化配置

python
# performance_config.py
from fastapi import FastAPI
import uvicorn
import asyncio
import uvloop

# 使用uvloop提高性能(Linux only)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

app = FastAPI()

# WebSocket配置
WEBSOCKET_CONFIG = {
    # 连接限制
    "max_connections": 10000,
    "max_connections_per_ip": 100,

    # 消息大小限制
    "max_message_size": 16 * 1024 * 1024,  # 16MB
    "max_queue_size": 1000,

    # 超时设置
    "ping_interval": 30,      # 秒
    "ping_timeout": 60,       # 秒
    "close_timeout": 5,       # 秒

    # 压缩
    "permessage_deflate": True,
    "compression_threshold": 1024,  # 字节

    # 缓冲
    "receive_buffer_size": 4096,
    "send_buffer_size": 4096,
}

# Uvicorn配置
UVICORN_CONFIG = {
    "host": "0.0.0.0",
    "port": 8000,
    "workers": 4,  # 根据CPU核心数调整
    "loop": "uvloop",
    "http": "httptools",

    # 连接限制
    "limit_concurrency": 10000,
    "limit_max_requests": 1000,  # 每个worker处理1000个请求后重启

    # 超时
    "timeout_keep_alive": 5,
    "timeout_graceful_shutdown": 10,

    # 日志
    "log_level": "info",
    "access_log": True,

    # 性能
    "reload": False,  # 生产环境关闭
    "use_colors": False,
}

# 启动脚本
if __name__ == "__main__":
    uvicorn.run(
        "main:app",
        **UVICORN_CONFIG
    )

# 内存优化
class MemoryOptimizedConnectionManager:
    """内存优化的连接管理器"""

    def __init__(self):
        # 使用弱引用避免内存泄漏
        import weakref
        self.connections = weakref.WeakValueDictionary()

        # 使用数组存储连接ID而不是对象
        self.connection_ids = []
        self.connection_data = {}  # connection_id -> minimal data

        # 定期清理
        self.cleanup_interval = 300  # 5分钟

    async def add_connection(self, websocket: WebSocket, user_id: str):
        """添加连接(内存友好)"""
        connection_id = self._generate_id()

        # 只存储必要数据
        self.connection_data[connection_id] = {
            "id": connection_id,
            "user_id": user_id,
            "created_at": datetime.now().timestamp(),
            "last_activity": datetime.now().timestamp(),
        }

        # 存储弱引用
        self.connections[connection_id] = websocket
        self.connection_ids.append(connection_id)

        # 限制连接数
        if len(self.connection_ids) > 10000:
            await self._cleanup_oldest()

        return connection_id

    async def _cleanup_oldest(self, max_age: int = 3600):
        """清理最老的连接"""
        now = datetime.now().timestamp()
        cutoff = now - max_age

        to_remove = []
        for conn_id in self.connection_ids:
            data = self.connection_data.get(conn_id)
            if data and data["last_activity"] < cutoff:
                to_remove.append(conn_id)

        for conn_id in to_remove:
            await self.remove_connection(conn_id)

    def _generate_id(self) -> str:
        """生成连接ID"""
        import uuid
        return str(uuid.uuid4())

总结

WebSocket为FastAPI应用提供了强大的实时通信能力。通过本章的学习,你应该能够:

关键要点

  1. 协议理解:深入理解WebSocket握手过程和数据帧结构
  2. 连接管理:实现高效的连接池和会话管理系统
  3. 房间广播:构建支持房间概念的广播系统
  4. 可靠性:实现心跳检测和自动重连机制
  5. 生产部署:掌握Nginx配置、Docker部署和性能优化

最佳实践

  1. 安全性

    • 使用WSS(WebSocket Secure)
    • 验证连接令牌
    • 限制连接数和消息频率
  2. 可靠性

    • 实现心跳检测
    • 支持自动重连
    • 优雅处理断开连接
  3. 性能

    • 使用连接池管理
    • 实现消息队列
    • 监控连接健康状态
  4. 可扩展性

    • 支持水平扩展
    • 使用Redis共享状态
    • 实现负载均衡

监控指标

python
# 关键监控指标
ESSENTIAL_METRICS = [
    "websocket.connections.active",     # 活跃连接数
    "websocket.connections.total",      # 总连接数
    "websocket.messages.received",      # 接收消息数
    "websocket.messages.sent",          # 发送消息数
    "websocket.errors.total",           # 错误总数
    "websocket.ping.latency",           # Ping延迟
    "websocket.connections.reconnect",  # 重连次数
]

故障排除

  1. 连接不稳定

    • 检查网络延迟
    • 调整心跳间隔
    • 检查防火墙设置
  2. 内存泄漏

    • 使用弱引用
    • 定期清理死连接
    • 监控内存使用
  3. 性能瓶颈

    • 优化消息序列化
    • 使用消息压缩
    • 增加工作进程

扩展学习


最后提醒:WebSocket应用的生产部署需要仔细规划。务必进行充分的压力测试,监控关键指标,并准备好应对连接风暴的预案。记住,实时系统的可靠性直接影响用户体验,每一个连接都值得认真对待。

WebSocket就像一座桥梁,连接着服务器和客户端。好的桥梁需要稳固的基础(协议理解)、坚实的结构(代码实现)和持续的维护(监控运维)。投资时间在这些方面,你的实时应用将能够承受真实世界的考验。

热爱生活,喜好美食,追求未来!