mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-06-26 01:31:42 +08:00
feat: add session locking mechanism in Telegram adapter and improve SPA fallback handling
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
@@ -10,6 +11,16 @@ from telethon.sessions import StringSession
|
||||
from telethon.tl import types
|
||||
import socks
|
||||
|
||||
_SESSION_LOCKS: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _get_session_lock(session_string: str) -> asyncio.Lock:
|
||||
lock = _SESSION_LOCKS.get(session_string)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_SESSION_LOCKS[session_string] = lock
|
||||
return lock
|
||||
|
||||
# 适配器类型标识
|
||||
ADAPTER_TYPE = "telegram"
|
||||
|
||||
@@ -359,6 +370,8 @@ class TelegramAdapter:
|
||||
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
lock = _get_session_lock(self.session_string)
|
||||
await lock.acquire()
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
@@ -396,7 +409,6 @@ class TelegramAdapter:
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": mime_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
if range_header:
|
||||
@@ -408,7 +420,6 @@ class TelegramAdapter:
|
||||
if start >= file_size or end >= file_size or start > end:
|
||||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid Range header")
|
||||
@@ -427,18 +438,28 @@ class TelegramAdapter:
|
||||
if downloaded >= limit:
|
||||
break
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
try:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers)
|
||||
|
||||
except HTTPException:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
|
||||
43
main.py
43
main.py
@@ -9,8 +9,9 @@ from api.routers import include_routers
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from middleware.exception_handler import (
|
||||
global_exception_handler,
|
||||
http_exception_handler,
|
||||
@@ -26,27 +27,38 @@ load_dotenv()
|
||||
|
||||
class SPAStaticFiles(StaticFiles):
|
||||
async def get_response(self, path, scope):
|
||||
response = await super().get_response(path, scope)
|
||||
if response.status_code == 404:
|
||||
return await super().get_response("index.html", scope)
|
||||
try:
|
||||
response = await super().get_response(path, scope)
|
||||
except StarletteHTTPException as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
if self._should_spa_fallback(scope):
|
||||
return FileResponse(INDEX_FILE)
|
||||
raise
|
||||
|
||||
if response.status_code == 404 and self._should_spa_fallback(scope):
|
||||
return FileResponse(INDEX_FILE)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _should_spa_fallback(scope) -> bool:
|
||||
return (
|
||||
scope.get("method") == "GET"
|
||||
and _request_accepts_html(scope)
|
||||
and not (scope.get("path") or "").startswith(SPA_EXCLUDE_PREFIXES)
|
||||
and INDEX_FILE.exists()
|
||||
)
|
||||
|
||||
|
||||
INDEX_FILE = Path("web/dist/index.html")
|
||||
SPA_EXCLUDE_PREFIXES = ("/api", "/docs", "/openapi.json", "/webdav", "/s3")
|
||||
|
||||
|
||||
async def spa_fallback_middleware(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
if (
|
||||
response.status_code == 404
|
||||
and request.method == "GET"
|
||||
and "text/html" in request.headers.get("accept", "")
|
||||
and not request.url.path.startswith(SPA_EXCLUDE_PREFIXES)
|
||||
and INDEX_FILE.exists()
|
||||
):
|
||||
return FileResponse(INDEX_FILE)
|
||||
return response
|
||||
def _request_accepts_html(scope) -> bool:
|
||||
for k, v in scope.get("headers") or []:
|
||||
if k == b"accept":
|
||||
return "text/html" in v.decode("latin-1")
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -78,7 +90,6 @@ def create_app() -> FastAPI:
|
||||
description="A highly extensible private cloud storage solution for individuals and teams",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.middleware("http")(spa_fallback_middleware)
|
||||
include_routers(app)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
Reference in New Issue
Block a user