mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-10 17:43:35 +08:00
691 lines
27 KiB
Python
691 lines
27 KiB
Python
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
||
import asyncio
|
||
import base64
|
||
import io
|
||
import os
|
||
import struct
|
||
import time
|
||
from models import StorageAdapter
|
||
from api.response import cursor_page
|
||
from telethon import TelegramClient, errors, utils
|
||
from telethon.crypto import AuthKey
|
||
from telethon.sessions import StringSession
|
||
from telethon.tl import types
|
||
import socks
|
||
|
||
|
||
class _NamedFile:
|
||
def __init__(self, file_obj, name: str):
|
||
self._file = file_obj
|
||
self.name = name
|
||
|
||
def read(self, *args, **kwargs):
|
||
return self._file.read(*args, **kwargs)
|
||
|
||
def seek(self, *args, **kwargs):
|
||
return self._file.seek(*args, **kwargs)
|
||
|
||
def tell(self):
|
||
return self._file.tell()
|
||
|
||
def seekable(self):
|
||
return self._file.seekable()
|
||
|
||
def close(self):
|
||
return self._file.close()
|
||
|
||
def __getattr__(self, name):
|
||
return getattr(self._file, name)
|
||
|
||
# 适配器类型标识
|
||
ADAPTER_TYPE = "telegram"
|
||
|
||
# 适配器配置项定义
|
||
CONFIG_SCHEMA = [
|
||
{"key": "api_id", "label": "API ID", "type": "string", "required": True, "help_text": "从 my.telegram.org 获取"},
|
||
{"key": "api_hash", "label": "API Hash", "type": "password", "required": True, "help_text": "从 my.telegram.org 获取"},
|
||
{"key": "session_string", "label": "Session String", "type": "password", "required": True, "help_text": "通过 generate_session.py 生成"},
|
||
{"key": "chat_id", "label": "Chat ID", "type": "string", "required": True, "placeholder": "频道/群组的ID或用户名, 例如: -100123456789 或 'channel_username'"},
|
||
{"key": "proxy_protocol", "label": "代理协议", "type": "string", "required": False, "placeholder": "例如: socks5, http"},
|
||
{"key": "proxy_host", "label": "代理主机", "type": "string", "required": False, "placeholder": "例如: 127.0.0.1"},
|
||
{"key": "proxy_port", "label": "代理端口", "type": "number", "required": False, "placeholder": "例如: 1080"},
|
||
]
|
||
|
||
class TelegramAdapter:
|
||
"""Telegram 存储适配器 (使用用户 Session)"""
|
||
native_video_thumbnail_only = True
|
||
_message_cache_ttl = 300
|
||
_message_cache_limit = 200
|
||
_download_chunk_size = 512 * 1024
|
||
|
||
def __init__(self, record: StorageAdapter):
|
||
self.record = record
|
||
cfg = record.config
|
||
self.api_id = int(cfg.get("api_id"))
|
||
self.api_hash = cfg.get("api_hash")
|
||
self.session_string = cfg.get("session_string")
|
||
self.chat_id_str = cfg.get("chat_id")
|
||
|
||
# 代理设置
|
||
self.proxy_protocol = cfg.get("proxy_protocol")
|
||
self.proxy_host = cfg.get("proxy_host")
|
||
self.proxy_port = cfg.get("proxy_port")
|
||
|
||
self.proxy = None
|
||
if self.proxy_protocol and self.proxy_host and self.proxy_port:
|
||
proto_map = {
|
||
"socks5": socks.SOCKS5,
|
||
"http": socks.HTTP,
|
||
}
|
||
proxy_type = proto_map.get(self.proxy_protocol.lower())
|
||
if proxy_type:
|
||
self.proxy = (proxy_type, self.proxy_host, int(self.proxy_port))
|
||
|
||
try:
|
||
self.chat_id = int(self.chat_id_str)
|
||
except (ValueError, TypeError):
|
||
self.chat_id = self.chat_id_str
|
||
|
||
if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]):
|
||
raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id")
|
||
|
||
self._client: TelegramClient | None = None
|
||
self._client_lock = asyncio.Lock()
|
||
self._download_lock = asyncio.Lock()
|
||
self._active_stream_message_id: int | None = None
|
||
self._message_cache: Dict[int, Tuple[float, object]] = {}
|
||
|
||
@staticmethod
|
||
def _parse_legacy_session_string(value: str) -> StringSession:
|
||
"""
|
||
兼容旧版 session_string 格式:
|
||
- version(1B char) + base64(data)
|
||
- data: dc_id(1B) + ip_len(2B) + ip(ASCII, ip_len bytes) + port(2B) + auth_key(256B)
|
||
"""
|
||
s = (value or "").strip()
|
||
if not s:
|
||
raise ValueError("session_string 为空")
|
||
|
||
body = s[1:] if s.startswith("1") else s
|
||
raw = base64.urlsafe_b64decode(body)
|
||
if len(raw) < 1 + 2 + 2 + 256:
|
||
raise ValueError("legacy session 数据长度不足")
|
||
|
||
dc_id = raw[0]
|
||
ip_len = struct.unpack(">H", raw[1:3])[0]
|
||
expected_len = 1 + 2 + ip_len + 2 + 256
|
||
if len(raw) != expected_len:
|
||
raise ValueError("legacy session 数据长度不匹配")
|
||
|
||
ip_start = 3
|
||
ip_end = ip_start + ip_len
|
||
ip = raw[ip_start:ip_end].decode("utf-8")
|
||
port = struct.unpack(">H", raw[ip_end : ip_end + 2])[0]
|
||
key = raw[ip_end + 2 : ip_end + 2 + 256]
|
||
|
||
sess = StringSession()
|
||
sess.set_dc(dc_id, ip, port)
|
||
sess.auth_key = AuthKey(key)
|
||
return sess
|
||
|
||
@staticmethod
|
||
def _pick_photo_thumb(thumbs: list | None):
|
||
if not thumbs:
|
||
return None
|
||
|
||
cached = []
|
||
downloadable = []
|
||
for t in thumbs:
|
||
if isinstance(t, (types.PhotoCachedSize, types.PhotoStrippedSize)):
|
||
cached.append(t)
|
||
elif isinstance(t, (types.PhotoSize, types.PhotoSizeProgressive)):
|
||
if not isinstance(t, types.PhotoSizeEmpty):
|
||
downloadable.append(t)
|
||
|
||
if downloadable:
|
||
def _sz(x):
|
||
if isinstance(x, types.PhotoSizeProgressive):
|
||
return max(x.sizes or [0])
|
||
return int(getattr(x, "size", 0) or 0)
|
||
|
||
downloadable.sort(key=_sz)
|
||
return downloadable[-1]
|
||
|
||
if cached:
|
||
cached.sort(key=lambda x: len(getattr(x, "bytes", b"") or b""))
|
||
return cached[-1]
|
||
|
||
return None
|
||
|
||
@staticmethod
|
||
def _get_message_thumbs(message) -> list:
|
||
doc = message.document or message.video
|
||
if doc and getattr(doc, "thumbs", None):
|
||
return list(doc.thumbs or [])
|
||
if message.photo and getattr(message.photo, "sizes", None):
|
||
return list(message.photo.sizes or [])
|
||
return []
|
||
|
||
@classmethod
|
||
def _message_has_thumbnail(cls, message) -> bool:
|
||
return cls._pick_photo_thumb(cls._get_message_thumbs(message)) is not None
|
||
|
||
def _build_session(self) -> StringSession:
|
||
s = (self.session_string or "").strip()
|
||
if not s:
|
||
raise ValueError("Telegram 适配器 session_string 为空")
|
||
|
||
try:
|
||
return StringSession(s)
|
||
except Exception:
|
||
pass
|
||
|
||
# 少数工具可能去掉了 version 前缀,这里做一次兼容
|
||
if not s.startswith("1"):
|
||
try:
|
||
return StringSession("1" + s)
|
||
except Exception:
|
||
pass
|
||
|
||
try:
|
||
return self._parse_legacy_session_string(s)
|
||
except Exception as exc:
|
||
raise ValueError("Telegram session_string 无效,请使用 Telethon StringSession 重新生成") from exc
|
||
|
||
def _get_client(self) -> TelegramClient:
|
||
"""创建一个新的 TelegramClient 实例"""
|
||
return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy)
|
||
|
||
async def _get_connected_client(self) -> TelegramClient:
|
||
async with self._client_lock:
|
||
if self._client is None:
|
||
self._client = self._get_client()
|
||
if not self._client.is_connected():
|
||
await self._client.connect()
|
||
return self._client
|
||
|
||
async def _disconnect_shared_client(self):
|
||
if self._client and self._client.is_connected():
|
||
await self._client.disconnect()
|
||
|
||
def _clear_message_cache(self):
|
||
self._message_cache.clear()
|
||
|
||
async def _get_cached_message(self, message_id: int):
|
||
now = time.monotonic()
|
||
cached = self._message_cache.get(message_id)
|
||
if cached and cached[0] > now:
|
||
return cached[1]
|
||
|
||
client = await self._get_connected_client()
|
||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||
if message:
|
||
if len(self._message_cache) >= self._message_cache_limit:
|
||
oldest_key = min(self._message_cache, key=lambda k: self._message_cache[k][0])
|
||
self._message_cache.pop(oldest_key, None)
|
||
self._message_cache[message_id] = (now + self._message_cache_ttl, message)
|
||
else:
|
||
self._message_cache.pop(message_id, None)
|
||
return message
|
||
|
||
@staticmethod
|
||
def _get_message_media(message):
|
||
return message.document or message.video or message.photo
|
||
|
||
@staticmethod
|
||
def _flood_wait_http_exception(exc: errors.FloodWaitError):
|
||
from fastapi import HTTPException
|
||
|
||
seconds = int(getattr(exc, "seconds", 0) or 0)
|
||
if seconds > 0:
|
||
return HTTPException(
|
||
status_code=429,
|
||
detail=f"Telegram 请求过于频繁,请等待 {seconds} 秒后重试",
|
||
headers={"Retry-After": str(seconds)},
|
||
)
|
||
return HTTPException(status_code=429, detail="Telegram 请求过于频繁,请稍后重试")
|
||
|
||
@staticmethod
|
||
def _get_message_file_size(message, media) -> int:
|
||
file_meta = message.file
|
||
size = file_meta.size if file_meta and file_meta.size is not None else None
|
||
if size is None:
|
||
if hasattr(media, "size") and media.size is not None:
|
||
size = media.size
|
||
elif message.photo and getattr(message.photo, "sizes", None):
|
||
photo_size = message.photo.sizes[-1]
|
||
size = getattr(photo_size, "size", 0) or 0
|
||
else:
|
||
size = 0
|
||
return int(size or 0)
|
||
|
||
@staticmethod
|
||
def _get_message_mime_type(message, media) -> str:
|
||
file_meta = message.file
|
||
if file_meta and getattr(file_meta, "mime_type", None):
|
||
return file_meta.mime_type
|
||
if hasattr(media, "mime_type") and media.mime_type:
|
||
return media.mime_type
|
||
if message.photo:
|
||
return "image/jpeg"
|
||
return "application/octet-stream"
|
||
|
||
@staticmethod
|
||
def _parse_message_id(rel: str) -> int:
|
||
try:
|
||
message_id_str, _ = rel.split('_', 1)
|
||
return int(message_id_str)
|
||
except (ValueError, IndexError):
|
||
raise FileNotFoundError(f"无效的文件路径格式: {rel}")
|
||
|
||
def get_effective_root(self, sub_path: str | None) -> str:
|
||
return ""
|
||
|
||
async def list_dir(
|
||
self,
|
||
root: str,
|
||
rel: str,
|
||
page_num: int = 1,
|
||
page_size: int = 50,
|
||
sort_by: str = "name",
|
||
sort_order: str = "asc",
|
||
cursor: str | None = None,
|
||
):
|
||
if rel:
|
||
return cursor_page([], page_size, cursor=cursor)
|
||
|
||
client = self._get_client()
|
||
entries = []
|
||
next_cursor = None
|
||
try:
|
||
await client.connect()
|
||
offset_id = int(cursor) if cursor else 0
|
||
batch_limit = min(max(page_size, 50), 200)
|
||
while len(entries) < page_size:
|
||
messages = await client.get_messages(self.chat_id, limit=batch_limit, offset_id=offset_id)
|
||
if not messages:
|
||
next_cursor = None
|
||
break
|
||
|
||
offset_id = messages[-1].id
|
||
next_cursor = str(offset_id)
|
||
for message in messages:
|
||
if not message:
|
||
continue
|
||
|
||
media = message.document or message.video or message.photo
|
||
if not media:
|
||
continue
|
||
|
||
file_meta = message.file
|
||
if not file_meta:
|
||
continue
|
||
|
||
filename = file_meta.name
|
||
if not filename:
|
||
if message.text and '.' in message.text and len(message.text) < 256 and '\n' not in message.text:
|
||
filename = message.text
|
||
else:
|
||
filename = f"unknown_{message.id}"
|
||
|
||
size = file_meta.size
|
||
if size is None:
|
||
# 兼容缺失 size 的情况
|
||
if hasattr(media, "size") and media.size is not None:
|
||
size = media.size
|
||
elif message.photo and getattr(message.photo, "sizes", None):
|
||
photo_size = message.photo.sizes[-1]
|
||
size = getattr(photo_size, "size", 0) or 0
|
||
else:
|
||
size = 0
|
||
|
||
entries.append({
|
||
"name": f"{message.id}_{filename}",
|
||
"is_dir": False,
|
||
"size": size,
|
||
"mtime": int(message.date.timestamp()),
|
||
"type": "file",
|
||
"has_thumbnail": False,
|
||
})
|
||
if len(entries) >= page_size:
|
||
break
|
||
finally:
|
||
if client.is_connected():
|
||
await client.disconnect()
|
||
|
||
return cursor_page(entries, page_size, cursor=cursor, next_cursor=next_cursor)
|
||
|
||
async def read_file(self, root: str, rel: str) -> bytes:
|
||
message_id = self._parse_message_id(rel)
|
||
|
||
client = await self._get_connected_client()
|
||
message = await self._get_cached_message(message_id)
|
||
if not message or not self._get_message_media(message):
|
||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||
|
||
try:
|
||
async with self._download_lock:
|
||
file_bytes = await client.download_media(message, file=bytes)
|
||
return file_bytes
|
||
except errors.FloodWaitError as exc:
|
||
await self._disconnect_shared_client()
|
||
raise self._flood_wait_http_exception(exc)
|
||
|
||
async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes:
|
||
from fastapi import HTTPException
|
||
|
||
message_id = self._parse_message_id(rel)
|
||
client = await self._get_connected_client()
|
||
message = await self._get_cached_message(message_id)
|
||
if not message:
|
||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||
|
||
media = self._get_message_media(message)
|
||
if not media:
|
||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||
|
||
file_size = self._get_message_file_size(message, media)
|
||
if file_size > 0:
|
||
if start >= file_size:
|
||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||
if end is None or end >= file_size:
|
||
end = file_size - 1
|
||
elif end is None:
|
||
end = start
|
||
|
||
if end < start:
|
||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||
|
||
limit = end - start + 1
|
||
data = bytearray()
|
||
try:
|
||
async with self._download_lock:
|
||
async for chunk in client.iter_download(
|
||
media,
|
||
offset=start,
|
||
request_size=self._download_chunk_size,
|
||
chunk_size=self._download_chunk_size,
|
||
file_size=file_size or None,
|
||
):
|
||
if not chunk:
|
||
continue
|
||
need = limit - len(data)
|
||
if need <= 0:
|
||
break
|
||
data.extend(chunk[:need])
|
||
if len(data) >= limit:
|
||
break
|
||
return bytes(data)
|
||
except errors.FloodWaitError as exc:
|
||
await self._disconnect_shared_client()
|
||
raise self._flood_wait_http_exception(exc)
|
||
|
||
async def write_file(self, root: str, rel: str, data: bytes):
|
||
"""将字节数据作为文件上传"""
|
||
client = self._get_client()
|
||
file_like = io.BytesIO(data)
|
||
file_like.name = os.path.basename(rel) or "file"
|
||
|
||
try:
|
||
await client.connect()
|
||
sent = await client.send_file(self.chat_id, file_like, caption=file_like.name)
|
||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||
actual_rel = rel
|
||
if message:
|
||
stored_name = file_like.name
|
||
file_meta = getattr(message, "file", None)
|
||
if file_meta and getattr(file_meta, "name", None):
|
||
stored_name = file_meta.name
|
||
if getattr(message, "id", None) is not None:
|
||
actual_rel = f"{message.id}_{stored_name}"
|
||
self._clear_message_cache()
|
||
return {"rel": actual_rel, "size": len(data)}
|
||
finally:
|
||
if client.is_connected():
|
||
await client.disconnect()
|
||
|
||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||
client = self._get_client()
|
||
name = filename or os.path.basename(rel) or "file"
|
||
file_like = _NamedFile(file_obj, name)
|
||
|
||
try:
|
||
await client.connect()
|
||
sent = await client.send_file(
|
||
self.chat_id,
|
||
file_like,
|
||
caption=file_like.name,
|
||
file_size=file_size,
|
||
mime_type=content_type,
|
||
)
|
||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||
actual_rel = rel
|
||
size = file_size or 0
|
||
if message:
|
||
stored_name = file_like.name
|
||
file_meta = getattr(message, "file", None)
|
||
if file_meta and getattr(file_meta, "name", None):
|
||
stored_name = file_meta.name
|
||
if getattr(message, "id", None) is not None:
|
||
actual_rel = f"{message.id}_{stored_name}"
|
||
self._clear_message_cache()
|
||
if file_meta and getattr(file_meta, "size", None):
|
||
size = int(file_meta.size)
|
||
return {"rel": actual_rel, "size": size}
|
||
finally:
|
||
if client.is_connected():
|
||
await client.disconnect()
|
||
|
||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||
"""以流式方式上传文件"""
|
||
client = self._get_client()
|
||
filename = os.path.basename(rel) or "file"
|
||
import tempfile
|
||
suffix = os.path.splitext(filename)[1]
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||
temp_path = tf.name
|
||
|
||
total_size = 0
|
||
try:
|
||
with open(temp_path, "wb") as f:
|
||
async for chunk in data_iter:
|
||
if chunk:
|
||
f.write(chunk)
|
||
total_size += len(chunk)
|
||
|
||
await client.connect()
|
||
sent = await client.send_file(self.chat_id, temp_path, caption=filename)
|
||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||
actual_rel = rel
|
||
if message:
|
||
stored_name = filename
|
||
file_meta = getattr(message, "file", None)
|
||
if file_meta and getattr(file_meta, "name", None):
|
||
stored_name = file_meta.name
|
||
if getattr(message, "id", None) is not None:
|
||
actual_rel = f"{message.id}_{stored_name}"
|
||
self._clear_message_cache()
|
||
|
||
finally:
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
if client.is_connected():
|
||
await client.disconnect()
|
||
return {"rel": actual_rel, "size": total_size}
|
||
|
||
async def mkdir(self, root: str, rel: str):
|
||
raise NotImplementedError("Telegram 适配器不支持创建目录。")
|
||
|
||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||
return None
|
||
|
||
async def delete(self, root: str, rel: str):
|
||
"""删除一个文件 (即一条消息)"""
|
||
try:
|
||
message_id_str, _ = rel.split('_', 1)
|
||
message_id = int(message_id_str)
|
||
except (ValueError, IndexError):
|
||
raise FileNotFoundError(f"无效的文件路径格式,无法解析消息ID: {rel}")
|
||
|
||
client = self._get_client()
|
||
try:
|
||
await client.connect()
|
||
result = await client.delete_messages(self.chat_id, [message_id])
|
||
if not result or not result[0].pts:
|
||
raise FileNotFoundError(f"在 {self.chat_id} 中删除消息 {message_id} 失败,可能消息不存在或无权限")
|
||
self._message_cache.pop(message_id, None)
|
||
finally:
|
||
if client.is_connected():
|
||
await client.disconnect()
|
||
if self._client is client:
|
||
self._client = None
|
||
|
||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||
raise NotImplementedError("Telegram 适配器不支持移动。")
|
||
|
||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||
raise NotImplementedError("Telegram 适配器不支持重命名。")
|
||
|
||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||
raise NotImplementedError("Telegram 适配器不支持复制。")
|
||
|
||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi import HTTPException
|
||
|
||
try:
|
||
message_id = self._parse_message_id(rel)
|
||
except FileNotFoundError:
|
||
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
|
||
|
||
try:
|
||
client = await self._get_connected_client()
|
||
message = await self._get_cached_message(message_id)
|
||
if not message:
|
||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||
media = self._get_message_media(message)
|
||
if not media:
|
||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||
|
||
file_size = self._get_message_file_size(message, media)
|
||
mime_type = self._get_message_mime_type(message, media)
|
||
|
||
start = 0
|
||
end = file_size - 1
|
||
status = 200
|
||
|
||
headers = {
|
||
"Accept-Ranges": "bytes",
|
||
"Content-Type": mime_type,
|
||
}
|
||
|
||
if file_size <= 0:
|
||
headers["Content-Length"] = "0"
|
||
return StreamingResponse(iter(()), status_code=status, headers=headers)
|
||
|
||
if range_header:
|
||
try:
|
||
range_val = range_header.strip().partition("=")[2]
|
||
s, _, e = range_val.partition("-")
|
||
start = int(s) if s else 0
|
||
end = int(e) if e else file_size - 1
|
||
if start >= file_size or end >= file_size or start > end:
|
||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||
status = 206
|
||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||
except ValueError:
|
||
raise HTTPException(status_code=400, detail="Invalid Range header")
|
||
|
||
headers["Content-Length"] = str(end - start + 1)
|
||
self._active_stream_message_id = message_id
|
||
|
||
async def iterator():
|
||
downloaded = 0
|
||
try:
|
||
limit = end - start + 1
|
||
if self._active_stream_message_id != message_id:
|
||
return
|
||
async with self._download_lock:
|
||
async for chunk in client.iter_download(
|
||
media,
|
||
offset=start,
|
||
request_size=self._download_chunk_size,
|
||
chunk_size=self._download_chunk_size,
|
||
file_size=file_size,
|
||
):
|
||
if self._active_stream_message_id != message_id:
|
||
return
|
||
if not chunk:
|
||
continue
|
||
remaining = limit - downloaded
|
||
if remaining <= 0:
|
||
break
|
||
data = chunk[:remaining]
|
||
downloaded += len(data)
|
||
yield data
|
||
if downloaded >= limit:
|
||
break
|
||
except errors.FloodWaitError as exc:
|
||
await self._disconnect_shared_client()
|
||
if downloaded == 0:
|
||
raise self._flood_wait_http_exception(exc)
|
||
seconds = int(getattr(exc, "seconds", 0) or 0)
|
||
print(f"Telegram streaming stopped by FloodWait after partial response, wait={seconds}s")
|
||
return
|
||
except Exception:
|
||
await self._disconnect_shared_client()
|
||
raise
|
||
|
||
agen = iterator()
|
||
try:
|
||
first_chunk = await agen.__anext__()
|
||
except StopAsyncIteration:
|
||
first_chunk = b""
|
||
except HTTPException:
|
||
raise
|
||
|
||
async def response_iterator():
|
||
try:
|
||
if first_chunk:
|
||
yield first_chunk
|
||
async for chunk in agen:
|
||
yield chunk
|
||
finally:
|
||
await agen.aclose()
|
||
|
||
return StreamingResponse(response_iterator(), status_code=status, headers=headers)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except FileNotFoundError as e:
|
||
raise HTTPException(status_code=404, detail=str(e))
|
||
except Exception as e:
|
||
await self._disconnect_shared_client()
|
||
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
|
||
|
||
async def stat_file(self, root: str, rel: str):
|
||
try:
|
||
message_id_str, filename = rel.split('_', 1)
|
||
message_id = int(message_id_str)
|
||
except (ValueError, IndexError):
|
||
raise FileNotFoundError(f"无效的文件路径格式: {rel}")
|
||
|
||
message = await self._get_cached_message(message_id)
|
||
media = self._get_message_media(message) if message else None
|
||
if not message or not media:
|
||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||
|
||
size = self._get_message_file_size(message, media)
|
||
|
||
return {
|
||
"name": rel,
|
||
"is_dir": False,
|
||
"size": size,
|
||
"mtime": int(message.date.timestamp()),
|
||
"type": "file",
|
||
"has_thumbnail": False,
|
||
}
|
||
|
||
def ADAPTER_FACTORY(rec: StorageAdapter) -> TelegramAdapter:
|
||
return TelegramAdapter(rec)
|