From 93d5e5e31399baaaef39fb708894d0932f92b009 Mon Sep 17 00:00:00 2001 From: shiyu Date: Wed, 6 May 2026 22:12:35 +0800 Subject: [PATCH] feat: enhance TelegramAdapter with message caching and connection management --- domain/adapters/providers/telegram.py | 268 ++++++++++++++------------ 1 file changed, 149 insertions(+), 119 deletions(-) diff --git a/domain/adapters/providers/telegram.py b/domain/adapters/providers/telegram.py index 843fbec..06a2e31 100644 --- a/domain/adapters/providers/telegram.py +++ b/domain/adapters/providers/telegram.py @@ -1,8 +1,10 @@ from typing import List, Dict, Tuple, AsyncIterator, Optional +import asyncio import base64 import io import os import struct +import time from models import StorageAdapter from telethon import TelegramClient, utils from telethon.crypto import AuthKey @@ -51,6 +53,9 @@ CONFIG_SCHEMA = [ class TelegramAdapter: """Telegram 存储适配器 (使用用户 Session)""" native_video_thumbnail_only = True + _message_cache_ttl = 300 + _message_cache_limit = 200 + _download_chunk_size = 512 * 1024 def __init__(self, record: StorageAdapter): self.record = record @@ -83,6 +88,10 @@ class TelegramAdapter: if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]): raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id") + self._client: TelegramClient | None = None + self._client_lock = asyncio.Lock() + self._message_cache: Dict[int, Tuple[float, object]] = {} + @staticmethod def _parse_legacy_session_string(value: str) -> StringSession: """ @@ -184,6 +193,67 @@ class TelegramAdapter: """创建一个新的 TelegramClient 实例""" return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy) + async def _get_connected_client(self) -> TelegramClient: + async with self._client_lock: + if self._client is None: + self._client = self._get_client() + if not self._client.is_connected(): + await self._client.connect() + return self._client + + async def _disconnect_shared_client(self): + if self._client and self._client.is_connected(): + await self._client.disconnect() + + def _clear_message_cache(self): + self._message_cache.clear() + + async def _get_cached_message(self, message_id: int): + now = time.monotonic() + cached = self._message_cache.get(message_id) + if cached and cached[0] > now: + return cached[1] + + client = await self._get_connected_client() + message = await client.get_messages(self.chat_id, ids=message_id) + if message: + if len(self._message_cache) >= self._message_cache_limit: + oldest_key = min(self._message_cache, key=lambda k: self._message_cache[k][0]) + self._message_cache.pop(oldest_key, None) + self._message_cache[message_id] = (now + self._message_cache_ttl, message) + else: + self._message_cache.pop(message_id, None) + return message + + @staticmethod + def _get_message_media(message): + return message.document or message.video or message.photo + + @staticmethod + def _get_message_file_size(message, media) -> int: + file_meta = message.file + size = file_meta.size if file_meta and file_meta.size is not None else None + if size is None: + if hasattr(media, "size") and media.size is not None: + size = media.size + elif message.photo and getattr(message.photo, "sizes", None): + photo_size = message.photo.sizes[-1] + size = getattr(photo_size, "size", 0) or 0 + else: + size = 0 + return int(size or 0) + + @staticmethod + def _get_message_mime_type(message, media) -> str: + file_meta = message.file + if file_meta and getattr(file_meta, "mime_type", None): + return file_meta.mime_type + if hasattr(media, "mime_type") and media.mime_type: + return media.mime_type + if message.photo: + return "image/jpeg" + return "application/octet-stream" + @staticmethod def _parse_message_id(rel: str) -> int: try: @@ -274,62 +344,57 @@ class TelegramAdapter: async def read_file(self, root: str, rel: str) -> bytes: message_id = self._parse_message_id(rel) - client = self._get_client() - try: - await client.connect() - message = await client.get_messages(self.chat_id, ids=message_id) - if not message or not (message.document or message.video or message.photo): - raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") - - file_bytes = await client.download_media(message, file=bytes) - return file_bytes - finally: - if client.is_connected(): - await client.disconnect() + client = await self._get_connected_client() + message = await self._get_cached_message(message_id) + if not message or not self._get_message_media(message): + raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") + + file_bytes = await client.download_media(message, file=bytes) + return file_bytes async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes: from fastapi import HTTPException message_id = self._parse_message_id(rel) - client = self._get_client() - try: - await client.connect() - message = await client.get_messages(self.chat_id, ids=message_id) - if not message: - raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") + client = await self._get_connected_client() + message = await self._get_cached_message(message_id) + if not message: + raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") - media = message.document or message.video or message.photo - if not media: - raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") + media = self._get_message_media(message) + if not media: + raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") - file_meta = message.file - file_size = file_meta.size if file_meta and file_meta.size is not None else getattr(media, "size", 0) or 0 - if file_size > 0: - if start >= file_size: - raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable") - if end is None or end >= file_size: - end = file_size - 1 - elif end is None: - end = start - - if end < start: + file_size = self._get_message_file_size(message, media) + if file_size > 0: + if start >= file_size: raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable") + if end is None or end >= file_size: + end = file_size - 1 + elif end is None: + end = start - limit = end - start + 1 - data = bytearray() - async for chunk in client.iter_download(media, offset=start): - if not chunk: - continue - need = limit - len(data) - if need <= 0: - break - data.extend(chunk[:need]) - if len(data) >= limit: - break - return bytes(data) - finally: - if client.is_connected(): - await client.disconnect() + if end < start: + raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable") + + limit = end - start + 1 + data = bytearray() + async for chunk in client.iter_download( + media, + offset=start, + request_size=self._download_chunk_size, + chunk_size=self._download_chunk_size, + file_size=file_size or None, + ): + if not chunk: + continue + need = limit - len(data) + if need <= 0: + break + data.extend(chunk[:need]) + if len(data) >= limit: + break + return bytes(data) async def write_file(self, root: str, rel: str, data: bytes): """将字节数据作为文件上传""" @@ -349,6 +414,7 @@ class TelegramAdapter: stored_name = file_meta.name if getattr(message, "id", None) is not None: actual_rel = f"{message.id}_{stored_name}" + self._clear_message_cache() return {"rel": actual_rel, "size": len(data)} finally: if client.is_connected(): @@ -378,6 +444,7 @@ class TelegramAdapter: stored_name = file_meta.name if getattr(message, "id", None) is not None: actual_rel = f"{message.id}_{stored_name}" + self._clear_message_cache() if file_meta and getattr(file_meta, "size", None): size = int(file_meta.size) return {"rel": actual_rel, "size": size} @@ -413,6 +480,7 @@ class TelegramAdapter: stored_name = file_meta.name if getattr(message, "id", None) is not None: actual_rel = f"{message.id}_{stored_name}" + self._clear_message_cache() finally: if os.path.exists(temp_path): @@ -431,10 +499,9 @@ class TelegramAdapter: except (ValueError, IndexError): return None - client = self._get_client() try: - await client.connect() - message = await client.get_messages(self.chat_id, ids=message_id) + client = await self._get_connected_client() + message = await self._get_cached_message(message_id) if not message: return None @@ -454,9 +521,6 @@ class TelegramAdapter: return None except Exception: return None - finally: - if client.is_connected(): - await client.disconnect() async def delete(self, root: str, rel: str): """删除一个文件 (即一条消息)""" @@ -472,9 +536,12 @@ class TelegramAdapter: result = await client.delete_messages(self.chat_id, [message_id]) if not result or not result[0].pts: raise FileNotFoundError(f"在 {self.chat_id} 中删除消息 {message_id} 失败,可能消息不存在或无权限") + self._message_cache.pop(message_id, None) finally: if client.is_connected(): await client.disconnect() + if self._client is client: + self._client = None async def move(self, root: str, src_rel: str, dst_rel: str): raise NotImplementedError("Telegram 适配器不支持移动。") @@ -494,38 +561,17 @@ class TelegramAdapter: except FileNotFoundError: raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}") - client = self._get_client() - try: - await client.connect() - message = await client.get_messages(self.chat_id, ids=message_id) + client = await self._get_connected_client() + message = await self._get_cached_message(message_id) if not message: raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") - media = message.document or message.video or message.photo + media = self._get_message_media(message) if not media: raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") - file_meta = message.file - file_size = file_meta.size if file_meta and file_meta.size is not None else None - if file_size is None: - if hasattr(media, "size") and media.size is not None: - file_size = media.size - elif message.photo and getattr(message.photo, "sizes", None): - photo_size = message.photo.sizes[-1] - file_size = getattr(photo_size, "size", 0) or 0 - else: - file_size = 0 - - mime_type = None - if file_meta and getattr(file_meta, "mime_type", None): - mime_type = file_meta.mime_type - if not mime_type: - if hasattr(media, "mime_type") and media.mime_type: - mime_type = media.mime_type - elif message.photo: - mime_type = "image/jpeg" - else: - mime_type = "application/octet-stream" + file_size = self._get_message_file_size(message, media) + mime_type = self._get_message_mime_type(message, media) start = 0 end = file_size - 1 @@ -538,8 +584,6 @@ class TelegramAdapter: if file_size <= 0: headers["Content-Length"] = "0" - if client.is_connected(): - await client.disconnect() return StreamingResponse(iter(()), status_code=status, headers=headers) if range_header: @@ -562,7 +606,13 @@ class TelegramAdapter: limit = end - start + 1 downloaded = 0 - async for chunk in client.iter_download(media, offset=start): + async for chunk in client.iter_download( + media, + offset=start, + request_size=self._download_chunk_size, + chunk_size=self._download_chunk_size, + file_size=file_size, + ): if downloaded + len(chunk) > limit: yield chunk[:limit - downloaded] break @@ -570,23 +620,18 @@ class TelegramAdapter: downloaded += len(chunk) if downloaded >= limit: break - finally: - if client.is_connected(): - await client.disconnect() + except Exception: + await self._disconnect_shared_client() + raise return StreamingResponse(iterator(), status_code=status, headers=headers) except HTTPException: - if client.is_connected(): - await client.disconnect() raise except FileNotFoundError as e: - if client.is_connected(): - await client.disconnect() raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - if client.is_connected(): - await client.disconnect() + await self._disconnect_shared_client() raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}") async def stat_file(self, root: str, rel: str): @@ -596,36 +641,21 @@ class TelegramAdapter: except (ValueError, IndexError): raise FileNotFoundError(f"无效的文件路径格式: {rel}") - client = self._get_client() - try: - await client.connect() - message = await client.get_messages(self.chat_id, ids=message_id) - media = message.document or message.video or message.photo - if not message or not media: - raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") + message = await self._get_cached_message(message_id) + media = self._get_message_media(message) if message else None + if not message or not media: + raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") - file_meta = message.file - size = file_meta.size if file_meta and file_meta.size is not None else None - if size is None: - if hasattr(media, "size") and media.size is not None: - size = media.size - elif message.photo and getattr(message.photo, "sizes", None): - photo_size = message.photo.sizes[-1] - size = getattr(photo_size, "size", 0) or 0 - else: - size = 0 + size = self._get_message_file_size(message, media) - return { - "name": rel, - "is_dir": False, - "size": size, - "mtime": int(message.date.timestamp()), - "type": "file", - "has_thumbnail": self._message_has_thumbnail(message), - } - finally: - if client.is_connected(): - await client.disconnect() + return { + "name": rel, + "is_dir": False, + "size": size, + "mtime": int(message.date.timestamp()), + "type": "file", + "has_thumbnail": self._message_has_thumbnail(message), + } def ADAPTER_FACTORY(rec: StorageAdapter) -> TelegramAdapter: return TelegramAdapter(rec)