Vanson's Eternal Blog

Python中的网络编程

Python pool.png
Published on
/21 mins read/---

Network

TCP

服务端

  • 创建 socket (socket())
  • 绑定地址 (bind())
  • 开始监听 (listen())
  • 接受连接 (accept())
  • 接收/发送数据 (recv()/send())
  • 关闭连接 (close())
import asyncio
 
class AsyncTCPServer:
    def __init__(self, host='0.0.0.0', port=8888):
        self.host = host
        self.port = port
        self.server = None
        self.clients = set()  # 保存所有客户端连接
 
    async def handle_client(self, reader, writer):
        """处理客户端连接"""
        # 获取客户端地址
        client_addr = writer.get_extra_info('peername')
        print(f"New connection from {client_addr}")
        
        # 将客户端连接添加到集合
        self.clients.add(writer)
        
        try:
            while True:
                # 异步读取数据 (最多1024字节)
                data = await reader.read(1024)
                if not data:  # 客户端断开连接
                    break
                
                # 处理数据 (这里简单回显)
                message = data.decode().strip()
                print(f"Received from {client_addr}: {message}")
                
                # 发送响应
                response = f"Async Echo: {message}"
                writer.write(response.encode())
                await writer.drain()  # 等待数据发送完成
                
        except ConnectionError as e:
            print(f"Connection error with {client_addr}: {e}")
        except Exception as e:
            print(f"Error with client {client_addr}: {e}")
        finally:
            # 清理客户端连接
            print(f"Closing connection with {client_addr}")
            self.clients.remove(writer)
            writer.close()
            await writer.wait_closed()
 
    async def start(self):
        """启动异步TCP服务器"""
        # 创建异步TCP服务器
        self.server = await asyncio.start_server(
            self.handle_client,
            self.host,
            self.port
        )
        
        print(f"Async TCP Server started on {self.host}:{self.port}")
        
        # 保持服务器运行
        async with self.server:
            await self.server.serve_forever()
 
    async def stop(self):
        """停止服务器"""
        if self.server:
            # 关闭所有客户端连接
            for writer in self.clients:
                writer.close()
                await writer.wait_closed()
            
            # 关闭服务器
            self.server.close()
            await self.server.wait_closed()
            print("Async TCP Server stopped")
 
async def main_tcp_server():
    server = AsyncTCPServer()
    try:
        await server.start()
    except asyncio.CancelledError:
        await server.stop()
 
if __name__ == '__main__':
    try:
        asyncio.run(main_tcp_server())
    except KeyboardInterrupt:
        print("Server shutdown by user")

客户端

  • 创建 socket (socket())
  • 连接服务器 (connect())
  • 发送/接收数据 (send()/recv())
  • 关闭连接 (close())
import asyncio
 
class AsyncTCPClient:
    def __init__(self, host='127.0.0.1', port=8888):
        self.host = host
        self.port = port
        self.reader = None
        self.writer = None
        self.running = False
 
    async def connect(self):
        """连接到服务器"""
        try:
            # 建立异步TCP连接
            self.reader, self.writer = await asyncio.open_connection(
                self.host,
                self.port
            )
            self.running = True
            print(f"Connected to server at {self.host}:{self.port}")
            
            # 启动接收任务
            receive_task = asyncio.create_task(self.receive_messages())
            
            # 启动发送任务
            await self.send_messages()
            
            # 等待接收任务完成
            await receive_task
            
        except Exception as e:
            print(f"Connection error: {e}")
        finally:
            await self.disconnect()
 
    async def receive_messages(self):
        """异步接收服务器消息"""
        while self.running:
            try:
                # 异步读取数据
                data = await self.reader.read(1024)
                if not data:  # 服务器关闭连接
                    print("Server closed the connection")
                    self.running = False
                    break
                    
                print(f"Received: {data.decode()}")
                
            except ConnectionError:
                print("Connection lost with server")
                self.running = False
                break
            except Exception as e:
                print(f"Receive error: {e}")
                self.running = False
                break
 
    async def send_messages(self):
        """异步发送消息到服务器"""
        try:
            while self.running:
                # 异步获取用户输入
                message = await asyncio.get_event_loop().run_in_executor(
                    None,  # 使用默认线程池执行器
                    input,  # 执行input函数
                    "Enter message (or 'quit' to exit): "  # input的参数
                )
                
                if message.lower() == 'quit':
                    self.running = False
                    break
                
                try:
                    # 发送消息
                    self.writer.write(message.encode())
                    await self.writer.drain()  # 等待数据发送完成
                except ConnectionError:
                    print("Failed to send message: connection lost")
                    self.running = False
                    break
                except Exception as e:
                    print(f"Send error: {e}")
                    self.running = False
                    break
                    
        except asyncio.CancelledError:
            pass  # 任务被取消时的正常退出
        except KeyboardInterrupt:
            self.running = False
 
    async def disconnect(self):
        """断开连接"""
        self.running = False
        if self.writer:
            # 关闭连接
            self.writer.close()
            await self.writer.wait_closed()
        print("Disconnected from server")
 
async def main_tcp_client():
    client = AsyncTCPClient()
    await client.connect()
 
if __name__ == '__main__':
    try:
        asyncio.run(main_tcp_client())
    except KeyboardInterrupt:
        print("Client shutdown by user")

最佳实践

处理粘包问题

TCP 是流式协议,可能会发生粘包问题。解决方法:

  • 固定长度消息
  • 特殊分隔符 (如 \n)
  • 消息头 + 消息体 (包含长度信息)
import struct
 
def send_msg(sock, msg):
    # 前4字节为消息长度 (网络字节序)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)
 
def recv_msg(sock):
    # 读取消息长度
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # 读取消息数据
    return recvall(sock, msglen)
 
def recvall(sock, n):
    data = bytearray()
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data.extend(packet)
    return data
 

UDP

服务端

  • 创建 socket (socket())
  • 绑定地址 (bind())
  • 接收/发送数据 (recvfrom()/sendto())
  • 关闭 socket (close())
import asyncio
 
class AsyncUDPServer:
    def __init__(self, host='0.0.0.0', port=8888):
        self.host = host
        self.port = port
        self.transport = None
        self.protocol = None
 
    class ServerProtocol(asyncio.DatagramProtocol):
        """自定义UDP协议处理器"""
        def __init__(self, server_instance):
            self.server = server_instance
            super().__init__()
 
        def connection_made(self, transport):
            """连接建立时调用"""
            print("UDP Server ready to receive")
            self.transport = transport
 
        def datagram_received(self, data, addr):
            """收到数据时调用"""
            try:
                message = data.decode().strip()
                print(f"Received from {addr}: {message}")
                
                # 发送响应
                response = f"Async UDP Echo: {message}"
                self.transport.sendto(response.encode(), addr)
                
            except Exception as e:
                print(f"Error processing datagram from {addr}: {e}")
 
        def error_received(self, exc):
            """接收错误时调用"""
            print(f"UDP Error received: {exc}")
 
        def connection_lost(self, exc):
            """连接丢失时调用"""
            print("UDP Connection closed")
            if exc:
                print(f"Error: {exc}")
 
    async def start(self):
        """启动异步UDP服务器"""
        # 获取事件循环
        loop = asyncio.get_running_loop()
        
        # 创建UDP端点
        self.transport, self.protocol = await loop.create_datagram_endpoint(
            lambda: self.ServerProtocol(self),  # 协议工厂函数
            local_addr=(self.host, self.port)   # 绑定地址
        )
        
        print(f"Async UDP Server started on {self.host}:{self.port}")
        
        # 保持服务器运行
        try:
            while True:
                await asyncio.sleep(3600)  # 每1小时唤醒一次检查
        except asyncio.CancelledError:
            self.stop()
 
    def stop(self):
        """停止服务器"""
        if self.transport:
            self.transport.close()
        print("Async UDP Server stopped")
 
async def main_udp_server():
    server = AsyncUDPServer()
    try:
        await server.start()
    except KeyboardInterrupt:
        server.stop()
 
if __name__ == '__main__':
    try:
        asyncio.run(main_udp_server())
    except KeyboardInterrupt:
        print("Server shutdown by user")

客户端

  • 创建 socket (socket())
  • 发送/接收数据 (sendto()/recvfrom())
  • 关闭 socket (close())
import asyncio
 
class AsyncUDPClient:
    def __init__(self, host='127.0.0.1', port=8888):
        self.host = host
        self.port = port
        self.transport = None
        self.protocol = None
        self.running = False
 
    class ClientProtocol(asyncio.DatagramProtocol):
        """自定义UDP客户端协议"""
        def __init__(self, client_instance):
            self.client = client_instance
            self.loop = asyncio.get_event_loop()
            super().__init__()
 
        def connection_made(self, transport):
            """连接建立时调用"""
            self.transport = transport
            print("UDP Client ready")
 
        def datagram_received(self, data, addr):
            """收到数据时调用"""
            try:
                message = data.decode()
                print(f"Received from {addr}: {message}")
            except Exception as e:
                print(f"Error processing datagram: {e}")
 
        def error_received(self, exc):
            """接收错误时调用"""
            print(f"UDP Error received: {exc}")
 
        def connection_lost(self, exc):
            """连接丢失时调用"""
            print("UDP Connection closed")
            if exc:
                print(f"Error: {exc}")
            self.client.running = False
 
    async def connect(self):
        """初始化异步UDP客户端"""
        # 获取事件循环
        loop = asyncio.get_running_loop()
        
        # 创建UDP传输
        self.transport, self.protocol = await loop.create_datagram_endpoint(
            lambda: self.ClientProtocol(self),  # 协议工厂函数
            remote_addr=(self.host, self.port)  # 目标地址
        )
        
        self.running = True
        print(f"Async UDP Client ready to send to {self.host}:{self.port}")
        
        # 启动发送任务
        await self.send_messages()
 
    async def send_messages(self):
        """异步发送消息到服务器"""
        try:
            while self.running:
                # 异步获取用户输入
                message = await asyncio.get_event_loop().run_in_executor(
                    None,  # 使用默认线程池执行器
                    input,  # 执行input函数
                    "Enter message (or 'quit' to exit): "  # input的参数
                )
                
                if message.lower() == 'quit':
                    self.running = False
                    break
                
                try:
                    # 发送消息
                    self.transport.sendto(message.encode())
                except Exception as e:
                    print(f"Send error: {e}")
                    self.running = False
                    break
                    
        except asyncio.CancelledError:
            pass  # 任务被取消时的正常退出
        except KeyboardInterrupt:
            self.running = False
        finally:
            await self.disconnect()
 
    async def disconnect(self):
        """断开连接"""
        self.running = False
        if self.transport:
            self.transport.close()
        print("Async UDP Client stopped")
 
async def main_udp_client():
    client = AsyncUDPClient()
    await client.connect()
 
if __name__ == '__main__':
    try:
        asyncio.run(main_udp_client())
    except KeyboardInterrupt:
        print("Client shutdown by user")

Websocket

服务端

  • 客户端管理:使用 set() 存储所有连接,实现广播功能
  • 消息处理:支持 JSON 格式消息,区分不同类型(聊天、系统消息等)
  • 心跳检测:内置 ping/pong 机制保持连接
  • 异常处理:妥善处理连接断开等异常情况
import asyncio
import websockets
from datetime import datetime
import json
 
class WebSocketServer:
    def __init__(self, host='0.0.0.0', port=8765):
        """
        初始化WebSocket服务器
        
        Args:
            host: 监听地址,默认0.0.0.0
            port: 监听端口,默认8765
        """
        self.host = host
        self.port = port
        self.server = None
        self.clients = set()  # 存储所有连接的客户端
        self.message_count = 0  # 消息计数器
 
    async def register(self, websocket):
        """注册新客户端连接"""
        self.clients.add(websocket)
        print(f"New client connected. Total clients: {len(self.clients)}")
        
        # 发送欢迎消息
        welcome_msg = {
            "type": "system",
            "content": "Welcome to the WebSocket server!",
            "timestamp": datetime.now().isoformat(),
            "client_count": len(self.clients)
        }
        await websocket.send(json.dumps(welcome_msg))
 
    async def unregister(self, websocket):
        """注销客户端连接"""
        if websocket in self.clients:
            self.clients.remove(websocket)
            print(f"Client disconnected. Remaining clients: {len(self.clients)}")
            
            # 通知其他客户端
            if len(self.clients) > 0:
                leave_msg = {
                    "type": "system",
                    "content": "A client has left the chat",
                    "timestamp": datetime.now().isoformat(),
                    "client_count": len(self.clients)
                }
                await self.broadcast(json.dumps(leave_msg))
 
    async def broadcast(self, message):
        """广播消息给所有客户端"""
        if self.clients:
            await asyncio.wait([client.send(message) for client in self.clients])
 
    async def handle_client(self, websocket, path):
        """处理单个客户端连接"""
        await self.register(websocket)
        try:
            async for message in websocket:
                # 收到消息时的处理
                self.message_count += 1
                print(f"Received message #{self.message_count}: {message}")
                
                # 解析消息 (假设是JSON格式)
                try:
                    msg_data = json.loads(message)
                    msg_type = msg_data.get("type", "unknown")
                    
                    # 处理不同类型的消息
                    if msg_type == "chat":
                        # 构建响应消息
                        response = {
                            "type": "chat",
                            "content": msg_data.get("content", ""),
                            "sender": "user",  # 实际应用中可以用用户ID
                            "timestamp": datetime.now().isoformat(),
                            "message_id": self.message_count
                        }
                        
                        # 广播消息给所有客户端
                        await self.broadcast(json.dumps(response))
                        
                    elif msg_type == "ping":
                        # 响应ping消息
                        await websocket.send(json.dumps({
                            "type": "pong",
                            "timestamp": datetime.now().isoformat()
                        }))
                    
                except json.JSONDecodeError:
                    # 非JSON消息处理
                    response = {
                        "type": "echo",
                        "content": message,
                        "timestamp": datetime.now().isoformat()
                    }
                    await websocket.send(json.dumps(response))
                    
        except websockets.exceptions.ConnectionClosed:
            print("Client disconnected unexpectedly")
        finally:
            await self.unregister(websocket)
 
    async def start(self):
        """启动WebSocket服务器"""
        print(f"Starting WebSocket server on ws://{self.host}:{self.port}")
        self.server = await websockets.serve(
            self.handle_client,
            self.host,
            self.port
        )
        
        # 保持服务器运行
        await self.server.wait_closed()
 
    async def stop(self):
        """停止服务器"""
        if self.server:
            self.server.close()
            await self.server.wait_closed()
            print("WebSocket server stopped")
 
async def main_server():
    server = WebSocketServer()
    try:
        await server.start()
    except asyncio.CancelledError:
        await server.stop()
 
if __name__ == '__main__':
    try:
        asyncio.run(main_server())
    except KeyboardInterrupt:
        print("Server shutdown by user")

客户端

  • 连接管理:自动重连、心跳检测
  • 消息收发:支持 JSON 和纯文本消息
  • 用户界面:异步处理用户输入不阻塞消息接收
  • 格式化显示:不同类型消息不同显示格式
import asyncio
import websockets
import json
from datetime import datetime
import argparse
 
class WebSocketClient:
    def __init__(self, server_uri='ws://localhost:8765', client_name='Anonymous'):
        """
        初始化WebSocket客户端
        
        Args:
            server_uri: 服务器地址,默认ws://localhost:8765
            client_name: 客户端名称,默认Anonymous
        """
        self.server_uri = server_uri
        self.client_name = client_name
        self.websocket = None
        self.running = False
 
    async def connect(self):
        """连接到WebSocket服务器"""
        print(f"Connecting to {self.server_uri}...")
        try:
            self.websocket = await websockets.connect(
                self.server_uri,
                ping_interval=20,  # 每20秒发送一次ping
                ping_timeout=5     # 等待pong响应超时5秒
            )
            self.running = True
            print("Connected to server!")
            
            # 发送连接消息
            await self.send_message({
                "type": "system",
                "content": f"{self.client_name} joined the chat",
                "client_name": self.client_name
            })
            
            # 启动接收消息任务
            receive_task = asyncio.create_task(self.receive_messages())
            
            # 启动用户输入处理
            await self.handle_user_input()
            
            # 等待接收任务完成
            await receive_task
            
        except Exception as e:
            print(f"Connection error: {e}")
        finally:
            await self.disconnect()
 
    async def send_message(self, message):
        """发送消息到服务器"""
        if isinstance(message, dict):
            message = json.dumps(message)
            
        if self.websocket and self.running:
            try:
                await self.websocket.send(message)
            except websockets.exceptions.ConnectionClosed:
                print("Connection closed while sending message")
                self.running = False
 
    async def receive_messages(self):
        """接收服务器消息"""
        try:
            async for message in self.websocket:
                try:
                    # 尝试解析JSON消息
                    msg_data = json.loads(message)
                    self.display_message(msg_data)
                except json.JSONDecodeError:
                    # 显示原始消息
                    print(f"\n[Server]: {message}")
                    
        except websockets.exceptions.ConnectionClosed:
            print("Connection with server closed")
            self.running = False
        except Exception as e:
            print(f"Error receiving message: {e}")
            self.running = False
 
    def display_message(self, msg_data):
        """格式化显示消息"""
        msg_type = msg_data.get("type", "unknown")
        timestamp = msg_data.get("timestamp", "")
        content = msg_data.get("content", "")
        
        if timestamp:
            try:
                dt = datetime.fromisoformat(timestamp)
                timestamp = dt.strftime("%H:%M:%S")
            except ValueError:
                pass
                
        if msg_type == "chat":
            sender = msg_data.get("sender", "unknown")
            print(f"\n[{timestamp}] {sender}: {content}")
        elif msg_type == "system":
            print(f"\n[System @ {timestamp}]: {content}")
        else:
            print(f"\n[Unknown message type]: {msg_data}")
 
    async def handle_user_input(self):
        """处理用户输入"""
        try:
            while self.running:
                # 异步获取用户输入
                message = await asyncio.get_event_loop().run_in_executor(
                    None,  # 使用默认线程池
                    input,  # 执行input函数
                    "> "    # input的提示符
                )
                
                if not self.running:
                    break
                    
                if message.lower() in ('exit', 'quit'):
                    self.running = False
                    break
                
                # 发送聊天消息
                chat_msg = {
                    "type": "chat",
                    "content": message,
                    "sender": self.client_name,
                    "timestamp": datetime.now().isoformat()
                }
                await self.send_message(chat_msg)
                
        except asyncio.CancelledError:
            pass  # 正常退出
        except KeyboardInterrupt:
            self.running = False
        except Exception as e:
            print(f"Input error: {e}")
            self.running = False
 
    async def disconnect(self):
        """断开连接"""
        self.running = False
        if self.websocket:
            # 发送离开消息
            try:
                await self.send_message({
                    "type": "system",
                    "content": f"{self.client_name} left the chat",
                    "client_name": self.client_name
                })
            except:
                pass
                
            # 关闭连接
            await self.websocket.close()
            print("Disconnected from server")
 
async def main_client():
    # 解析命令行参数
    parser = argparse.ArgumentParser(description='WebSocket Client')
    parser.add_argument('--uri', default='ws://localhost:8765', help='WebSocket server URI')
    parser.add_argument('--name', default='Anonymous', help='Client name')
    args = parser.parse_args()
    
    client = WebSocketClient(server_uri=args.uri, client_name=args.name)
    await client.connect()
 
if __name__ == '__main__':
    try:
        asyncio.run(main_client())
    except KeyboardInterrupt:
        print("Client shutdown by user")

SSE

SSE (Server-Sent Events) 是一种服务器向客户端推送数据的轻量级协议,基于 HTTP 协议。

主要特点:

  • 单向通信(服务器→客户端)
  • 基于 HTTP,兼容性好
  • 自动重连机制
  • 简单易用
  • 支持事件ID和自定义事件类型

服务端

  • 多频道支持:客户端可以订阅不同频道
  • 自动重连:客户端断开后自动尝试重新连接
  • 心跳机制:防止连接超时
  • 消息ID跟踪:支持 Last-Event-ID 头部
  • 广播API:提供HTTP接口主动推送消息
  • 统计信息:提供客户端连接数等统计
import asyncio
from aiohttp import web
import json
from datetime import datetime
from typing import Dict, Set, Optional
import signal
import logging
 
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("SSE-Server")
 
class SSEServer:
    def __init__(self, host: str = '0.0.0.0', port: int = 8080):
        """
        初始化SSE服务器
        
        Args:
            host: 监听地址
            port: 监听端口
        """
        self.host = host
        self.port = port
        self.app = web.Application()
        self.runner: Optional[web.AppRunner] = None
        self.site: Optional[web.TCPSite] = None
        
        # 存储所有连接的客户端
        self.clients: Dict[str, Set[web.StreamResponse]] = {}
        
        # 设置路由
        self.app.router.add_route('GET', '/events', self.sse_handler)
        self.app.router.add_route('POST', '/broadcast', self.broadcast_handler)
        self.app.router.add_route('GET', '/stats', self.stats_handler)
        
        # 初始化消息ID计数器
        self.message_id = 0
        
        # 配置关闭信号处理
        self.shutdown_event = asyncio.Event()
 
    async def sse_handler(self, request: web.Request) -> web.StreamResponse:
        """处理SSE客户端连接"""
        # 检查客户端是否支持SSE
        if 'text/event-stream' not in request.headers.get('Accept', ''):
            return web.Response(status=406, text="Server-Sent Events not supported")
        
        # 获取客户端信息
        client_id = request.headers.get('X-Client-ID', str(id(request)))
        channel = request.query.get('channel', 'default')
        
        logger.info(f"New SSE connection - Client: {client_id}, Channel: {channel}")
        
        # 创建SSE响应
        response = web.StreamResponse(
            status=200,
            headers={
                'Content-Type': 'text/event-stream',
                'Cache-Control': 'no-cache',
                'Connection': 'keep-alive',
            }
        )
        await response.prepare(request)
        
        # 将客户端添加到指定频道
        if channel not in self.clients:
            self.clients[channel] = set()
        self.clients[channel].add(response)
        
        try:
            # 发送初始消息
            await self._send_sse_event(
                response,
                event="init",
                data={"message": "Connected to SSE server", "client_id": client_id},
                id=self.message_id
            )
            self.message_id += 1
            
            # 保持连接打开
            while not self.shutdown_event.is_set():
                await asyncio.sleep(1)  # 保持心跳
                
                # 发送心跳消息防止连接超时
                await self._send_sse_event(
                    response,
                    event="heartbeat",
                    data={"timestamp": datetime.now().isoformat()},
                    id=self.message_id
                )
                self.message_id += 1
                
        except (asyncio.CancelledError, ConnectionError):
            logger.info(f"Client {client_id} disconnected")
        finally:
            # 清理客户端连接
            self.clients[channel].discard(response)
            if not self.clients[channel]:  # 如果频道没有客户端则删除
                del self.clients[channel]
            return response
 
    async def broadcast_handler(self, request: web.Request) -> web.Response:
        """处理广播消息请求"""
        try:
            data = await request.json()
            channel = data.get('channel', 'default')
            message = data.get('message', '')
            event_type = data.get('event', 'message')
            
            # 广播消息到指定频道的所有客户端
            await self.broadcast(
                channel=channel,
                event=event_type,
                data={"message": message, "timestamp": datetime.now().isoformat()},
                id=self.message_id
            )
            self.message_id += 1
            
            return web.json_response({"status": "success", "clients": len(self.clients.get(channel, []))})
        except Exception as e:
            logger.error(f"Broadcast error: {e}")
            return web.json_response({"status": "error", "message": str(e)}, status=400)
 
    async def stats_handler(self, request: web.Request) -> web.Response:
        """返回服务器统计信息"""
        stats = {
            "total_clients": sum(len(clients) for clients in self.clients.values()),
            "channels": {
                channel: len(clients) for channel, clients in self.clients.items()
            },
            "message_count": self.message_id,
            "status": "running"
        }
        return web.json_response(stats)
 
    async def broadcast(self, channel: str, event: str, data: Dict, id: Optional[int] = None):
        """广播消息到指定频道的所有客户端"""
        if channel not in self.clients or not self.clients[channel]:
            return
        
        logger.info(f"Broadcasting to {len(self.clients[channel])} clients in channel '{channel}'")
        
        # 为每个客户端发送消息
        tasks = []
        for response in self.clients[channel]:
            tasks.append(
                self._send_sse_event(response, event, data, id)
            )
        
        await asyncio.gather(*tasks, return_exceptions=True)
 
    async def _send_sse_event(self, response: web.StreamResponse, event: str, data: Dict, id: Optional[int] = None):
        """发送SSE格式的消息"""
        try:
            # SSE消息格式
            message = []
            if id is not None:
                message.append(f"id: {id}")
            message.append(f"event: {event}")
            message.append(f"data: {json.dumps(data)}")
            message.append("\n")  # 消息结束符
            
            # 发送消息
            await response.write("\n".join(message).encode('utf-8'))
        except ConnectionResetError:
            logger.warning("Client disconnected while sending message")
            raise
        except Exception as e:
            logger.error(f"Error sending SSE event: {e}")
            raise
 
    async def start(self):
        """启动SSE服务器"""
        self.runner = web.AppRunner(self.app)
        await self.runner.setup()
        
        self.site = web.TCPSite(self.runner, self.host, self.port)
        await self.site.start()
        
        logger.info(f"SSE Server started at http://{self.host}:{self.port}")
        
        # 处理关闭信号
        loop = asyncio.get_running_loop()
        loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(self.stop()))
        loop.add_signal_handler(signal.SIGTERM, lambda: asyncio.create_task(self.stop()))
        
        # 等待关闭信号
        await self.shutdown_event.wait()
 
    async def stop(self):
        """停止SSE服务器"""
        logger.info("Shutting down SSE server...")
        self.shutdown_event.set()
        
        # 关闭所有客户端连接
        for channel in list(self.clients.keys()):
            for response in list(self.clients[channel]):
                await response.write_eof()
                self.clients[channel].remove(response)
        
        # 停止服务器
        if self.site:
            await self.site.stop()
        if self.runner:
            await self.runner.cleanup()
        
        logger.info("SSE Server stopped")
 
async def main():
    server = SSEServer()
    await server.start()
 
if __name__ == '__main__':
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        pass

客户端

  • 事件处理:不同事件类型有不同的处理器
  • 指数退避重连:网络问题自动重连
  • 消息恢复:支持 Last-Event-ID 恢复中断的消息
  • 交互式界面:可扩展为命令行工具
  • 完善日志:详细记录连接状态
import asyncio
import aiohttp
import json
from datetime import datetime
import logging
import signal
 
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("SSE-Client")
 
class SSEClient:
    def __init__(self, server_url: str = 'http://localhost:8080/events', 
                 client_id: str = None, 
                 channel: str = 'default'):
        """
        初始化SSE客户端
        
        Args:
            server_url: SSE服务器URL
            client_id: 客户端ID
            channel: 订阅的频道
        """
        self.server_url = server_url
        self.client_id = client_id or f"client-{datetime.now().timestamp()}"
        self.channel = channel
        self.session = None
        self.running = False
        self.last_event_id = None
        self.reconnect_delay = 1  # 初始重连延迟(秒)
        self.max_reconnect_delay = 30  # 最大重连延迟(秒)
        self.message_count = 0
 
    async def connect(self):
        """连接到SSE服务器"""
        headers = {
            'Accept': 'text/event-stream',
            'X-Client-ID': self.client_id
        }
        
        params = {
            'channel': self.channel
        }
        
        if self.last_event_id:
            headers['Last-Event-ID'] = str(self.last_event_id)
        
        self.session = aiohttp.ClientSession()
        self.running = True
        
        while self.running:
            try:
                logger.info(f"Connecting to SSE server at {self.server_url}...")
                
                async with self.session.get(
                    self.server_url,
                    headers=headers,
                    params=params
                ) as response:
                    if response.status != 200:
                        logger.error(f"Connection failed: HTTP {response.status}")
                        await self._handle_reconnect()
                        continue
                    
                    logger.info("Connected to SSE server")
                    self.reconnect_delay = 1  # 重置重连延迟
                    
                    # 处理SSE流
                    async for line in response.content:
                        if not self.running:
                            break
                            
                        await self._process_sse_line(line)
                        
            except (aiohttp.ClientError, asyncio.TimeoutError) as e:
                logger.error(f"Connection error: {e}")
                await self._handle_reconnect()
            except Exception as e:
                logger.error(f"Unexpected error: {e}")
                await self._handle_reconnect()
        
        await self.disconnect()
 
    async def _process_sse_line(self, line: bytes):
        """处理SSE数据行"""
        line = line.decode('utf-8').strip()
        if not line:
            return
            
        # 解析SSE事件
        if line.startswith('id:'):
            self.last_event_id = line[3:].strip()
        elif line.startswith('event:'):
            self.current_event = line[6:].strip()
        elif line.startswith('data:'):
            data = line[5:].strip()
            await self._handle_sse_event(data)
 
    async def _handle_sse_event(self, data: str):
        """处理SSE事件数据"""
        self.message_count += 1
        
        try:
            event_data = json.loads(data)
            timestamp = event_data.get('timestamp', '')
            
            # 格式化时间显示
            if timestamp:
                try:
                    dt = datetime.fromisoformat(timestamp)
                    timestamp = dt.strftime("%Y-%m-%d %H:%M:%S")
                except ValueError:
                    pass
                    
            # 根据不同事件类型处理
            event_type = getattr(self, f"event_{self.current_event}", self.event_default)
            await event_type(event_data, timestamp)
            
        except json.JSONDecodeError:
            logger.info(f"Received raw message: {data}")
        except Exception as e:
            logger.error(f"Error processing event: {e}")
 
    async def event_default(self, data: Dict, timestamp: str):
        """默认事件处理器"""
        logger.info(f"[{timestamp}] {self.current_event}: {json.dumps(data)}")
 
    async def event_init(self, data: Dict, timestamp: str):
        """初始化事件处理器"""
        logger.info(f"Server connection initialized. Client ID: {data.get('client_id')}")
 
    async def event_heartbeat(self, data: Dict, timestamp: str):
        """心跳事件处理器"""
        logger.debug(f"Heartbeat received at {timestamp}")
 
    async def event_message(self, data: Dict, timestamp: str):
        """消息事件处理器"""
        logger.info(f"[{timestamp}] Message: {data.get('message', '')}")
 
    async def _handle_reconnect(self):
        """处理重新连接逻辑"""
        if not self.running:
            return
            
        logger.info(f"Attempting to reconnect in {self.reconnect_delay} seconds...")
        await asyncio.sleep(self.reconnect_delay)
        
        # 指数退避算法
        self.reconnect_delay = min(self.reconnect_delay * 2, self.max_reconnect_delay)
 
    async def disconnect(self):
        """断开连接"""
        self.running = False
        if self.session:
            await self.session.close()
        logger.info("Disconnected from SSE server")
 
    async def run_interactive(self):
        """运行交互式客户端"""
        # 设置信号处理
        loop = asyncio.get_running_loop()
        loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(self.disconnect()))
        
        # 启动连接任务
        connect_task = asyncio.create_task(self.connect())
        
        # 在这里可以添加其他交互逻辑
        try:
            await connect_task
        except asyncio.CancelledError:
            pass
        finally:
            await self.disconnect()
 
async def main():
    # 示例:可以自定义客户端ID和频道
    client = SSEClient(
        server_url='http://localhost:8080/events',
        client_id='my-client-1',
        channel='notifications'
    )
    
    await client.run_interactive()
 
if __name__ == '__main__':
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        logger.info("Client shutdown by user")