feat: add session locking mechanism in Telegram adapter and improve SPA fallback handling

This commit is contained in:
shiyu
2026-01-11 14:08:52 +08:00
parent 051b49d3f6
commit e7eafdee97
2 changed files with 52 additions and 20 deletions

View File

@@ -1,4 +1,5 @@
from typing import List, Dict, Tuple, AsyncIterator
import asyncio
import base64
import io
import os
@@ -10,6 +11,16 @@ from telethon.sessions import StringSession
from telethon.tl import types
import socks
_SESSION_LOCKS: Dict[str, asyncio.Lock] = {}
def _get_session_lock(session_string: str) -> asyncio.Lock:
lock = _SESSION_LOCKS.get(session_string)
if lock is None:
lock = asyncio.Lock()
_SESSION_LOCKS[session_string] = lock
return lock
# 适配器类型标识
ADAPTER_TYPE = "telegram"
@@ -359,6 +370,8 @@ class TelegramAdapter:
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
client = self._get_client()
lock = _get_session_lock(self.session_string)
await lock.acquire()
try:
await client.connect()
@@ -396,7 +409,6 @@ class TelegramAdapter:
headers = {
"Accept-Ranges": "bytes",
"Content-Type": mime_type,
"Content-Length": str(file_size),
}
if range_header:
@@ -408,7 +420,6 @@ class TelegramAdapter:
if start >= file_size or end >= file_size or start > end:
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
status = 206
headers["Content-Length"] = str(end - start + 1)
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
except ValueError:
raise HTTPException(status_code=400, detail="Invalid Range header")
@@ -427,18 +438,28 @@ class TelegramAdapter:
if downloaded >= limit:
break
finally:
if client.is_connected():
await client.disconnect()
try:
if client.is_connected():
await client.disconnect()
finally:
lock.release()
return StreamingResponse(iterator(), status_code=status, headers=headers)
except HTTPException:
if client.is_connected():
await client.disconnect()
lock.release()
raise
except FileNotFoundError as e:
if client.is_connected():
await client.disconnect()
lock.release()
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
if client.is_connected():
await client.disconnect()
lock.release()
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
async def stat_file(self, root: str, rel: str):

43
main.py
View File

@@ -9,8 +9,9 @@ from api.routers import include_routers
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from middleware.exception_handler import (
global_exception_handler,
http_exception_handler,
@@ -26,27 +27,38 @@ load_dotenv()
class SPAStaticFiles(StaticFiles):
async def get_response(self, path, scope):
response = await super().get_response(path, scope)
if response.status_code == 404:
return await super().get_response("index.html", scope)
try:
response = await super().get_response(path, scope)
except StarletteHTTPException as exc:
if exc.status_code != 404:
raise
if self._should_spa_fallback(scope):
return FileResponse(INDEX_FILE)
raise
if response.status_code == 404 and self._should_spa_fallback(scope):
return FileResponse(INDEX_FILE)
return response
@staticmethod
def _should_spa_fallback(scope) -> bool:
return (
scope.get("method") == "GET"
and _request_accepts_html(scope)
and not (scope.get("path") or "").startswith(SPA_EXCLUDE_PREFIXES)
and INDEX_FILE.exists()
)
INDEX_FILE = Path("web/dist/index.html")
SPA_EXCLUDE_PREFIXES = ("/api", "/docs", "/openapi.json", "/webdav", "/s3")
async def spa_fallback_middleware(request: Request, call_next):
response = await call_next(request)
if (
response.status_code == 404
and request.method == "GET"
and "text/html" in request.headers.get("accept", "")
and not request.url.path.startswith(SPA_EXCLUDE_PREFIXES)
and INDEX_FILE.exists()
):
return FileResponse(INDEX_FILE)
return response
def _request_accepts_html(scope) -> bool:
for k, v in scope.get("headers") or []:
if k == b"accept":
return "text/html" in v.decode("latin-1")
return False
@asynccontextmanager
@@ -78,7 +90,6 @@ def create_app() -> FastAPI:
description="A highly extensible private cloud storage solution for individuals and teams",
lifespan=lifespan,
)
app.middleware("http")(spa_fallback_middleware)
include_routers(app)
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)