From e7eafdee973770fdd9aceb04adbc2badf5b6e6b2 Mon Sep 17 00:00:00 2001 From: shiyu Date: Sun, 11 Jan 2026 14:08:52 +0800 Subject: [PATCH] feat: add session locking mechanism in Telegram adapter and improve SPA fallback handling --- domain/adapters/providers/telegram.py | 29 +++++++++++++++--- main.py | 43 +++++++++++++++++---------- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/domain/adapters/providers/telegram.py b/domain/adapters/providers/telegram.py index 308f9f2..e56e079 100644 --- a/domain/adapters/providers/telegram.py +++ b/domain/adapters/providers/telegram.py @@ -1,4 +1,5 @@ from typing import List, Dict, Tuple, AsyncIterator +import asyncio import base64 import io import os @@ -10,6 +11,16 @@ from telethon.sessions import StringSession from telethon.tl import types import socks +_SESSION_LOCKS: Dict[str, asyncio.Lock] = {} + + +def _get_session_lock(session_string: str) -> asyncio.Lock: + lock = _SESSION_LOCKS.get(session_string) + if lock is None: + lock = asyncio.Lock() + _SESSION_LOCKS[session_string] = lock + return lock + # 适配器类型标识 ADAPTER_TYPE = "telegram" @@ -359,6 +370,8 @@ class TelegramAdapter: raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}") client = self._get_client() + lock = _get_session_lock(self.session_string) + await lock.acquire() try: await client.connect() @@ -396,7 +409,6 @@ class TelegramAdapter: headers = { "Accept-Ranges": "bytes", "Content-Type": mime_type, - "Content-Length": str(file_size), } if range_header: @@ -408,7 +420,6 @@ class TelegramAdapter: if start >= file_size or end >= file_size or start > end: raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable") status = 206 - headers["Content-Length"] = str(end - start + 1) headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" except ValueError: raise HTTPException(status_code=400, detail="Invalid Range header") @@ -427,18 +438,28 @@ class TelegramAdapter: if downloaded >= limit: break finally: - if client.is_connected(): - await client.disconnect() + try: + if client.is_connected(): + await client.disconnect() + finally: + lock.release() return StreamingResponse(iterator(), status_code=status, headers=headers) + except HTTPException: + if client.is_connected(): + await client.disconnect() + lock.release() + raise except FileNotFoundError as e: if client.is_connected(): await client.disconnect() + lock.release() raise HTTPException(status_code=404, detail=str(e)) except Exception as e: if client.is_connected(): await client.disconnect() + lock.release() raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}") async def stat_file(self, root: str, rel: str): diff --git a/main.py b/main.py index 662c786..c582ad9 100644 --- a/main.py +++ b/main.py @@ -9,8 +9,9 @@ from api.routers import include_routers from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException from middleware.exception_handler import ( global_exception_handler, http_exception_handler, @@ -26,27 +27,38 @@ load_dotenv() class SPAStaticFiles(StaticFiles): async def get_response(self, path, scope): - response = await super().get_response(path, scope) - if response.status_code == 404: - return await super().get_response("index.html", scope) + try: + response = await super().get_response(path, scope) + except StarletteHTTPException as exc: + if exc.status_code != 404: + raise + if self._should_spa_fallback(scope): + return FileResponse(INDEX_FILE) + raise + + if response.status_code == 404 and self._should_spa_fallback(scope): + return FileResponse(INDEX_FILE) return response + @staticmethod + def _should_spa_fallback(scope) -> bool: + return ( + scope.get("method") == "GET" + and _request_accepts_html(scope) + and not (scope.get("path") or "").startswith(SPA_EXCLUDE_PREFIXES) + and INDEX_FILE.exists() + ) + INDEX_FILE = Path("web/dist/index.html") SPA_EXCLUDE_PREFIXES = ("/api", "/docs", "/openapi.json", "/webdav", "/s3") -async def spa_fallback_middleware(request: Request, call_next): - response = await call_next(request) - if ( - response.status_code == 404 - and request.method == "GET" - and "text/html" in request.headers.get("accept", "") - and not request.url.path.startswith(SPA_EXCLUDE_PREFIXES) - and INDEX_FILE.exists() - ): - return FileResponse(INDEX_FILE) - return response +def _request_accepts_html(scope) -> bool: + for k, v in scope.get("headers") or []: + if k == b"accept": + return "text/html" in v.decode("latin-1") + return False @asynccontextmanager @@ -78,7 +90,6 @@ def create_app() -> FastAPI: description="A highly extensible private cloud storage solution for individuals and teams", lifespan=lifespan, ) - app.middleware("http")(spa_fallback_middleware) include_routers(app) app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler)