from typing import List, Dict, Tuple, AsyncIterator, Optional import asyncio import base64 import io import os import struct import time from models import StorageAdapter from api.response import cursor_page from telethon import TelegramClient, errors, utils from telethon.crypto import AuthKey from telethon.sessions import StringSession from telethon.tl import types import socks class _NamedFile: def __init__(self, file_obj, name: str): self._file = file_obj self.name = name def read(self, *args, **kwargs): return self._file.read(*args, **kwargs) def seek(self, *args, **kwargs): return self._file.seek(*args, **kwargs) def tell(self): return self._file.tell() def seekable(self): return self._file.seekable() def close(self): return self._file.close() def __getattr__(self, name): return getattr(self._file, name) # 适配器类型标识 ADAPTER_TYPE = "telegram" # 适配器配置项定义 CONFIG_SCHEMA = [ {"key": "api_id", "label": "API ID", "type": "string", "required": True, "help_text": "从 my.telegram.org 获取"}, {"key": "api_hash", "label": "API Hash", "type": "password", "required": True, "help_text": "从 my.telegram.org 获取"}, {"key": "session_string", "label": "Session String", "type": "password", "required": True, "help_text": "通过 generate_session.py 生成"}, {"key": "chat_id", "label": "Chat ID", "type": "string", "required": True, "placeholder": "频道/群组的ID或用户名, 例如: -100123456789 或 'channel_username'"}, {"key": "proxy_protocol", "label": "代理协议", "type": "string", "required": False, "placeholder": "例如: socks5, http"}, {"key": "proxy_host", "label": "代理主机", "type": "string", "required": False, "placeholder": "例如: 127.0.0.1"}, {"key": "proxy_port", "label": "代理端口", "type": "number", "required": False, "placeholder": "例如: 1080"}, ] class TelegramAdapter: """Telegram 存储适配器 (使用用户 Session)""" native_video_thumbnail_only = True _message_cache_ttl = 300 _message_cache_limit = 200 _download_chunk_size = 512 * 1024 def __init__(self, record: StorageAdapter): self.record = record cfg = record.config self.api_id = int(cfg.get("api_id")) self.api_hash = cfg.get("api_hash") self.session_string = cfg.get("session_string") self.chat_id_str = cfg.get("chat_id") # 代理设置 self.proxy_protocol = cfg.get("proxy_protocol") self.proxy_host = cfg.get("proxy_host") self.proxy_port = cfg.get("proxy_port") self.proxy = None if self.proxy_protocol and self.proxy_host and self.proxy_port: proto_map = { "socks5": socks.SOCKS5, "http": socks.HTTP, } proxy_type = proto_map.get(self.proxy_protocol.lower()) if proxy_type: self.proxy = (proxy_type, self.proxy_host, int(self.proxy_port)) try: self.chat_id = int(self.chat_id_str) except (ValueError, TypeError): self.chat_id = self.chat_id_str if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]): raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id") self._client: TelegramClient | None = None self._client_lock = asyncio.Lock() self._download_lock = asyncio.Lock() self._active_stream_message_id: int | None = None self._message_cache: Dict[int, Tuple[float, object]] = {} @staticmethod def _parse_legacy_session_string(value: str) -> StringSession: """ 兼容旧版 session_string 格式: - version(1B char) + base64(data) - data: dc_id(1B) + ip_len(2B) + ip(ASCII, ip_len bytes) + port(2B) + auth_key(256B) """ s = (value or "").strip() if not s: raise ValueError("session_string 为空") body = s[1:] if s.startswith("1") else s raw = base64.urlsafe_b64decode(body) if len(raw) < 1 + 2 + 2 + 256: raise ValueError("legacy session 数据长度不足") dc_id = raw[0] ip_len = struct.unpack(">H", raw[1:3])[0] expected_len = 1 + 2 + ip_len + 2 + 256 if len(raw) != expected_len: raise ValueError("legacy session 数据长度不匹配") ip_start = 3 ip_end = ip_start + ip_len ip = raw[ip_start:ip_end].decode("utf-8") port = struct.unpack(">H", raw[ip_end : ip_end + 2])[0] key = raw[ip_end + 2 : ip_end + 2 + 256] sess = StringSession() sess.set_dc(dc_id, ip, port) sess.auth_key = AuthKey(key) return sess @staticmethod def _pick_photo_thumb(thumbs: list | None): if not thumbs: return None cached = [] downloadable = [] for t in thumbs: if isinstance(t, (types.PhotoCachedSize, types.PhotoStrippedSize)): cached.append(t) elif isinstance(t, (types.PhotoSize, types.PhotoSizeProgressive)): if not isinstance(t, types.PhotoSizeEmpty): downloadable.append(t) if downloadable: def _sz(x): if isinstance(x, types.PhotoSizeProgressive): return max(x.sizes or [0]) return int(getattr(x, "size", 0) or 0) downloadable.sort(key=_sz) return downloadable[-1] if cached: cached.sort(key=lambda x: len(getattr(x, "bytes", b"") or b"")) return cached[-1] return None @staticmethod def _get_message_thumbs(message) -> list: doc = message.document or message.video if doc and getattr(doc, "thumbs", None): return list(doc.thumbs or []) if message.photo and getattr(message.photo, "sizes", None): return list(message.photo.sizes or []) return [] @classmethod def _message_has_thumbnail(cls, message) -> bool: return cls._pick_photo_thumb(cls._get_message_thumbs(message)) is not None def _build_session(self) -> StringSession: s = (self.session_string or "").strip() if not s: raise ValueError("Telegram 适配器 session_string 为空") try: return StringSession(s) except Exception: pass # 少数工具可能去掉了 version 前缀,这里做一次兼容 if not s.startswith("1"): try: return StringSession("1" + s) except Exception: pass try: return self._parse_legacy_session_string(s) except Exception as exc: raise ValueError("Telegram session_string 无效,请使用 Telethon StringSession 重新生成") from exc def _get_client(self) -> TelegramClient: """创建一个新的 TelegramClient 实例""" return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy) async def _get_connected_client(self) -> TelegramClient: async with self._client_lock: if self._client is None: self._client = self._get_client() if not self._client.is_connected(): await self._client.connect() return self._client async def _disconnect_shared_client(self): if self._client and self._client.is_connected(): await self._client.disconnect() def _clear_message_cache(self): self._message_cache.clear() async def _get_cached_message(self, message_id: int): now = time.monotonic() cached = self._message_cache.get(message_id) if cached and cached[0] > now: return cached[1] client = await self._get_connected_client() message = await client.get_messages(self.chat_id, ids=message_id) if message: if len(self._message_cache) >= self._message_cache_limit: oldest_key = min(self._message_cache, key=lambda k: self._message_cache[k][0]) self._message_cache.pop(oldest_key, None) self._message_cache[message_id] = (now + self._message_cache_ttl, message) else: self._message_cache.pop(message_id, None) return message @staticmethod def _get_message_media(message): return message.document or message.video or message.photo @staticmethod def _flood_wait_http_exception(exc: errors.FloodWaitError): from fastapi import HTTPException seconds = int(getattr(exc, "seconds", 0) or 0) if seconds > 0: return HTTPException( status_code=429, detail=f"Telegram 请求过于频繁,请等待 {seconds} 秒后重试", headers={"Retry-After": str(seconds)}, ) return HTTPException(status_code=429, detail="Telegram 请求过于频繁,请稍后重试") @staticmethod def _get_message_file_size(message, media) -> int: file_meta = message.file size = file_meta.size if file_meta and file_meta.size is not None else None if size is None: if hasattr(media, "size") and media.size is not None: size = media.size elif message.photo and getattr(message.photo, "sizes", None): photo_size = message.photo.sizes[-1] size = getattr(photo_size, "size", 0) or 0 else: size = 0 return int(size or 0) @staticmethod def _get_message_mime_type(message, media) -> str: file_meta = message.file if file_meta and getattr(file_meta, "mime_type", None): return file_meta.mime_type if hasattr(media, "mime_type") and media.mime_type: return media.mime_type if message.photo: return "image/jpeg" return "application/octet-stream" @staticmethod def _parse_message_id(rel: str) -> int: try: message_id_str, _ = rel.split('_', 1) return int(message_id_str) except (ValueError, IndexError): raise FileNotFoundError(f"无效的文件路径格式: {rel}") def get_effective_root(self, sub_path: str | None) -> str: return "" async def list_dir( self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc", cursor: str | None = None, ): if rel: return cursor_page([], page_size, cursor=cursor) client = self._get_client() entries = [] next_cursor = None try: await client.connect() offset_id = int(cursor) if cursor else 0 batch_limit = min(max(page_size, 50), 200) while len(entries) < page_size: messages = await client.get_messages(self.chat_id, limit=batch_limit, offset_id=offset_id) if not messages: next_cursor = None break offset_id = messages[-1].id next_cursor = str(offset_id) for message in messages: if not message: continue media = message.document or message.video or message.photo if not media: continue file_meta = message.file if not file_meta: continue filename = file_meta.name if not filename: if message.text and '.' in message.text and len(message.text) < 256 and '\n' not in message.text: filename = message.text else: filename = f"unknown_{message.id}" size = file_meta.size if size is None: # 兼容缺失 size 的情况 if hasattr(media, "size") and media.size is not None: size = media.size elif message.photo and getattr(message.photo, "sizes", None): photo_size = message.photo.sizes[-1] size = getattr(photo_size, "size", 0) or 0 else: size = 0 entries.append({ "name": f"{message.id}_{filename}", "is_dir": False, "size": size, "mtime": int(message.date.timestamp()), "type": "file", "has_thumbnail": False, }) if len(entries) >= page_size: break finally: if client.is_connected(): await client.disconnect() return cursor_page(entries, page_size, cursor=cursor, next_cursor=next_cursor) async def read_file(self, root: str, rel: str) -> bytes: message_id = self._parse_message_id(rel) client = await self._get_connected_client() message = await self._get_cached_message(message_id) if not message or not self._get_message_media(message): raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") try: async with self._download_lock: file_bytes = await client.download_media(message, file=bytes) return file_bytes except errors.FloodWaitError as exc: await self._disconnect_shared_client() raise self._flood_wait_http_exception(exc) async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes: from fastapi import HTTPException message_id = self._parse_message_id(rel) client = await self._get_connected_client() message = await self._get_cached_message(message_id) if not message: raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") media = self._get_message_media(message) if not media: raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") file_size = self._get_message_file_size(message, media) if file_size > 0: if start >= file_size: raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable") if end is None or end >= file_size: end = file_size - 1 elif end is None: end = start if end < start: raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable") limit = end - start + 1 data = bytearray() try: async with self._download_lock: async for chunk in client.iter_download( media, offset=start, request_size=self._download_chunk_size, chunk_size=self._download_chunk_size, file_size=file_size or None, ): if not chunk: continue need = limit - len(data) if need <= 0: break data.extend(chunk[:need]) if len(data) >= limit: break return bytes(data) except errors.FloodWaitError as exc: await self._disconnect_shared_client() raise self._flood_wait_http_exception(exc) async def write_file(self, root: str, rel: str, data: bytes): """将字节数据作为文件上传""" client = self._get_client() file_like = io.BytesIO(data) file_like.name = os.path.basename(rel) or "file" try: await client.connect() sent = await client.send_file(self.chat_id, file_like, caption=file_like.name) message = sent[0] if isinstance(sent, list) and sent else sent actual_rel = rel if message: stored_name = file_like.name file_meta = getattr(message, "file", None) if file_meta and getattr(file_meta, "name", None): stored_name = file_meta.name if getattr(message, "id", None) is not None: actual_rel = f"{message.id}_{stored_name}" self._clear_message_cache() return {"rel": actual_rel, "size": len(data)} finally: if client.is_connected(): await client.disconnect() async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None): client = self._get_client() name = filename or os.path.basename(rel) or "file" file_like = _NamedFile(file_obj, name) try: await client.connect() sent = await client.send_file( self.chat_id, file_like, caption=file_like.name, file_size=file_size, mime_type=content_type, ) message = sent[0] if isinstance(sent, list) and sent else sent actual_rel = rel size = file_size or 0 if message: stored_name = file_like.name file_meta = getattr(message, "file", None) if file_meta and getattr(file_meta, "name", None): stored_name = file_meta.name if getattr(message, "id", None) is not None: actual_rel = f"{message.id}_{stored_name}" self._clear_message_cache() if file_meta and getattr(file_meta, "size", None): size = int(file_meta.size) return {"rel": actual_rel, "size": size} finally: if client.is_connected(): await client.disconnect() async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]): """以流式方式上传文件""" client = self._get_client() filename = os.path.basename(rel) or "file" import tempfile suffix = os.path.splitext(filename)[1] with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf: temp_path = tf.name total_size = 0 try: with open(temp_path, "wb") as f: async for chunk in data_iter: if chunk: f.write(chunk) total_size += len(chunk) await client.connect() sent = await client.send_file(self.chat_id, temp_path, caption=filename) message = sent[0] if isinstance(sent, list) and sent else sent actual_rel = rel if message: stored_name = filename file_meta = getattr(message, "file", None) if file_meta and getattr(file_meta, "name", None): stored_name = file_meta.name if getattr(message, "id", None) is not None: actual_rel = f"{message.id}_{stored_name}" self._clear_message_cache() finally: if os.path.exists(temp_path): os.remove(temp_path) if client.is_connected(): await client.disconnect() return {"rel": actual_rel, "size": total_size} async def mkdir(self, root: str, rel: str): raise NotImplementedError("Telegram 适配器不支持创建目录。") async def get_thumbnail(self, root: str, rel: str, size: str = "medium"): return None async def delete(self, root: str, rel: str): """删除一个文件 (即一条消息)""" try: message_id_str, _ = rel.split('_', 1) message_id = int(message_id_str) except (ValueError, IndexError): raise FileNotFoundError(f"无效的文件路径格式,无法解析消息ID: {rel}") client = self._get_client() try: await client.connect() result = await client.delete_messages(self.chat_id, [message_id]) if not result or not result[0].pts: raise FileNotFoundError(f"在 {self.chat_id} 中删除消息 {message_id} 失败,可能消息不存在或无权限") self._message_cache.pop(message_id, None) finally: if client.is_connected(): await client.disconnect() if self._client is client: self._client = None async def move(self, root: str, src_rel: str, dst_rel: str): raise NotImplementedError("Telegram 适配器不支持移动。") async def rename(self, root: str, src_rel: str, dst_rel: str): raise NotImplementedError("Telegram 适配器不支持重命名。") async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False): raise NotImplementedError("Telegram 适配器不支持复制。") async def stream_file(self, root: str, rel: str, range_header: str | None): from fastapi.responses import StreamingResponse from fastapi import HTTPException try: message_id = self._parse_message_id(rel) except FileNotFoundError: raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}") try: client = await self._get_connected_client() message = await self._get_cached_message(message_id) if not message: raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") media = self._get_message_media(message) if not media: raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") file_size = self._get_message_file_size(message, media) mime_type = self._get_message_mime_type(message, media) start = 0 end = file_size - 1 status = 200 headers = { "Accept-Ranges": "bytes", "Content-Type": mime_type, } if file_size <= 0: headers["Content-Length"] = "0" return StreamingResponse(iter(()), status_code=status, headers=headers) if range_header: try: range_val = range_header.strip().partition("=")[2] s, _, e = range_val.partition("-") start = int(s) if s else 0 end = int(e) if e else file_size - 1 if start >= file_size or end >= file_size or start > end: raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable") status = 206 headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" except ValueError: raise HTTPException(status_code=400, detail="Invalid Range header") headers["Content-Length"] = str(end - start + 1) self._active_stream_message_id = message_id async def iterator(): downloaded = 0 try: limit = end - start + 1 if self._active_stream_message_id != message_id: return async with self._download_lock: async for chunk in client.iter_download( media, offset=start, request_size=self._download_chunk_size, chunk_size=self._download_chunk_size, file_size=file_size, ): if self._active_stream_message_id != message_id: return if not chunk: continue remaining = limit - downloaded if remaining <= 0: break data = chunk[:remaining] downloaded += len(data) yield data if downloaded >= limit: break except errors.FloodWaitError as exc: await self._disconnect_shared_client() if downloaded == 0: raise self._flood_wait_http_exception(exc) seconds = int(getattr(exc, "seconds", 0) or 0) print(f"Telegram streaming stopped by FloodWait after partial response, wait={seconds}s") return except Exception: await self._disconnect_shared_client() raise agen = iterator() try: first_chunk = await agen.__anext__() except StopAsyncIteration: first_chunk = b"" except HTTPException: raise async def response_iterator(): try: if first_chunk: yield first_chunk async for chunk in agen: yield chunk finally: await agen.aclose() return StreamingResponse(response_iterator(), status_code=status, headers=headers) except HTTPException: raise except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: await self._disconnect_shared_client() raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}") async def stat_file(self, root: str, rel: str): try: message_id_str, filename = rel.split('_', 1) message_id = int(message_id_str) except (ValueError, IndexError): raise FileNotFoundError(f"无效的文件路径格式: {rel}") message = await self._get_cached_message(message_id) media = self._get_message_media(message) if message else None if not message or not media: raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件") size = self._get_message_file_size(message, media) return { "name": rel, "is_dir": False, "size": size, "mtime": int(message.date.timestamp()), "type": "file", "has_thumbnail": False, } def ADAPTER_FACTORY(rec: StorageAdapter) -> TelegramAdapter: return TelegramAdapter(rec)