mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-06-09 01:19:42 +08:00
refactor: optimize backend module
This commit is contained in:
189
domain/virtual_fs/api.py
Normal file
189
domain/virtual_fs/api.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.types import MkdirRequest, MoveRequest
|
||||
|
||||
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
|
||||
|
||||
|
||||
@router.get("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
|
||||
async def get_file(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return await VirtualFSService.serve_file(full_path, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/thumb/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="获取缩略图")
|
||||
async def get_thumb(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
w: int = Query(256, ge=8, le=1024),
|
||||
h: int = Query(256, ge=8, le=1024),
|
||||
fit: str = Query("cover"),
|
||||
):
|
||||
return await VirtualFSService.get_thumbnail(full_path, w, h, fit)
|
||||
|
||||
|
||||
@router.get("/stream/{full_path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="流式读取文件")
|
||||
async def stream_endpoint(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
):
|
||||
return await VirtualFSService.stream_response(full_path, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/temp-link/{full_path:path}")
|
||||
@audit(action=AuditAction.SHARE, description="创建临时链接")
|
||||
async def get_temp_link(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久"),
|
||||
):
|
||||
data = await VirtualFSService.create_temp_link(full_path, expires_in)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/public/{token}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
|
||||
async def access_public_file(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/stat/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="查看文件信息")
|
||||
async def get_file_stat(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
stat = await VirtualFSService.stat(full_path)
|
||||
return success(stat)
|
||||
|
||||
|
||||
@router.post("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="上传文件")
|
||||
async def put_file(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
data = await file.read()
|
||||
result = await VirtualFSService.write_uploaded_file(full_path, data)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/mkdir")
|
||||
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
|
||||
async def api_mkdir(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MkdirRequest,
|
||||
):
|
||||
result = await VirtualFSService.mkdir(body.path)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
|
||||
async def api_move(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.move(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
|
||||
async def api_rename(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.rename(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/copy")
|
||||
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
|
||||
async def api_copy(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.copy(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/upload/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
|
||||
async def upload_stream(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
|
||||
chunk_size: int = Query(1024 * 1024, ge=8 * 1024, le=8 * 1024 * 1024, description="单次读取块大小"),
|
||||
):
|
||||
result = await VirtualFSService.upload_stream_from_upload_file(full_path, file, chunk_size, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="浏览目录")
|
||||
async def browse_fs(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.delete("/{full_path:path}")
|
||||
@audit(action=AuditAction.DELETE, description="删除路径")
|
||||
async def api_delete(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
):
|
||||
result = await VirtualFSService.delete(full_path)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="浏览根目录")
|
||||
async def root_listing(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
|
||||
return success(data)
|
||||
537
domain/virtual_fs/s3_api.py
Normal file
537
domain/virtual_fs/s3_api.py
Normal file
@@ -0,0 +1,537 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import datetime as dt
|
||||
import hashlib
|
||||
import hmac
|
||||
import uuid
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/s3", tags=["s3"])
|
||||
|
||||
|
||||
FALSEY = {"0", "false", "off", "no"}
|
||||
_XML_NS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
||||
|
||||
|
||||
class S3Settings(Dict[str, str]):
|
||||
bucket: str
|
||||
region: str
|
||||
base_path: str
|
||||
access_key: str
|
||||
secret_key: str
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return dt.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
|
||||
def _etag(key: str, size: Optional[int], mtime: Optional[int]) -> str:
|
||||
raw = f"{key}|{size or 0}|{mtime or 0}".encode("utf-8")
|
||||
return '"' + hashlib.md5(raw).hexdigest() + '"'
|
||||
|
||||
|
||||
def _meta_headers() -> Tuple[str, Dict[str, str]]:
|
||||
req_id = uuid.uuid4().hex
|
||||
headers = {
|
||||
"x-amz-request-id": req_id,
|
||||
"x-amz-id-2": uuid.uuid4().hex,
|
||||
"Server": "FoxelS3",
|
||||
}
|
||||
return req_id, headers
|
||||
|
||||
|
||||
def _s3_error(code: str, message: str, resource: str = "", status: int = 400) -> Response:
|
||||
req_id, headers = _meta_headers()
|
||||
xml = (
|
||||
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
|
||||
f"<Error>"
|
||||
f"<Code>{code}</Code>"
|
||||
f"<Message>{message}</Message>"
|
||||
f"<Resource>{resource}</Resource>"
|
||||
f"<RequestId>{req_id}</RequestId>"
|
||||
f"</Error>"
|
||||
)
|
||||
return Response(content=xml, status_code=status, media_type="application/xml", headers=headers)
|
||||
|
||||
|
||||
async def _ensure_enabled() -> Optional[Response]:
|
||||
flag = await ConfigService.get("S3_MAPPING_ENABLED", "1")
|
||||
if str(flag).strip().lower() in FALSEY:
|
||||
return _s3_error("ServiceUnavailable", "S3 mapping disabled", status=503)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_settings() -> Tuple[Optional[S3Settings], Optional[Response]]:
|
||||
bucket = (await ConfigService.get("S3_MAPPING_BUCKET", "foxel")) or "foxel"
|
||||
region = (await ConfigService.get("S3_MAPPING_REGION", "us-east-1")) or "us-east-1"
|
||||
base_path = (await ConfigService.get("S3_MAPPING_BASE_PATH", "/")) or "/"
|
||||
access_key = (await ConfigService.get("S3_MAPPING_ACCESS_KEY")) or ""
|
||||
secret_key = (await ConfigService.get("S3_MAPPING_SECRET_KEY")) or ""
|
||||
if not access_key or not secret_key:
|
||||
return None, _s3_error(
|
||||
"InvalidAccessKeyId",
|
||||
"S3 mapping access key/secret are not configured.",
|
||||
status=403,
|
||||
)
|
||||
settings: S3Settings = {
|
||||
"bucket": bucket,
|
||||
"region": region,
|
||||
"base_path": base_path,
|
||||
"access_key": access_key,
|
||||
"secret_key": secret_key,
|
||||
}
|
||||
return settings, None
|
||||
|
||||
|
||||
def _canonical_uri(path: str) -> str:
|
||||
from urllib.parse import quote
|
||||
|
||||
if not path:
|
||||
return "/"
|
||||
return quote(path, safe="/-_.~")
|
||||
|
||||
|
||||
def _canonical_query(params: Iterable[Tuple[str, str]]) -> str:
|
||||
from urllib.parse import quote
|
||||
|
||||
encoded = []
|
||||
for key, value in params:
|
||||
enc_key = quote(key, safe="-_.~")
|
||||
enc_val = quote(value or "", safe="-_.~")
|
||||
encoded.append((enc_key, enc_val))
|
||||
encoded.sort()
|
||||
return "&".join(f"{k}={v}" for k, v in encoded)
|
||||
|
||||
|
||||
def _normalize_ws(value: str) -> str:
|
||||
return " ".join(value.strip().split())
|
||||
|
||||
|
||||
def _sign(key: bytes, msg: str) -> bytes:
|
||||
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
|
||||
|
||||
|
||||
async def _authorize_sigv4(request: Request, settings: S3Settings) -> Optional[Response]:
|
||||
auth = request.headers.get("authorization")
|
||||
if not auth:
|
||||
return _s3_error("AccessDenied", "Missing Authorization header", status=403)
|
||||
scheme = "AWS4-HMAC-SHA256"
|
||||
if not auth.startswith(scheme + " "):
|
||||
return _s3_error("InvalidRequest", "Signature Version 4 is required", status=400)
|
||||
|
||||
parts: Dict[str, str] = {}
|
||||
for segment in auth[len(scheme) + 1 :].split(","):
|
||||
k, _, v = segment.strip().partition("=")
|
||||
parts[k] = v
|
||||
|
||||
credential = parts.get("Credential")
|
||||
signed_headers = parts.get("SignedHeaders")
|
||||
signature = parts.get("Signature")
|
||||
if not credential or not signed_headers or not signature:
|
||||
return _s3_error("InvalidRequest", "Authorization header is malformed", status=400)
|
||||
|
||||
cred_parts = credential.split("/")
|
||||
if len(cred_parts) != 5 or cred_parts[-1] != "aws4_request":
|
||||
return _s3_error("InvalidRequest", "Credential scope is invalid", status=400)
|
||||
|
||||
access_key, datestamp, region, service, _ = cred_parts
|
||||
if access_key != settings["access_key"]:
|
||||
return _s3_error("InvalidAccessKeyId", "The AWS Access Key Id you provided does not exist in our records.", status=403)
|
||||
if service != "s3":
|
||||
return _s3_error("InvalidRequest", "Only service 's3' is supported", status=400)
|
||||
if region != settings["region"]:
|
||||
return _s3_error("AuthorizationHeaderMalformed", f"Region '{region}' is invalid", status=400)
|
||||
|
||||
amz_date = request.headers.get("x-amz-date")
|
||||
if not amz_date or not amz_date.startswith(datestamp):
|
||||
return _s3_error("AuthorizationHeaderMalformed", "x-amz-date does not match credential scope", status=400)
|
||||
|
||||
payload_hash = request.headers.get("x-amz-content-sha256")
|
||||
if not payload_hash:
|
||||
return _s3_error("AuthorizationHeaderMalformed", "Missing x-amz-content-sha256", status=400)
|
||||
if payload_hash.upper().startswith("STREAMING-AWS4-HMAC-SHA256"):
|
||||
return _s3_error("NotImplemented", "Chunked uploads are not supported", status=400)
|
||||
|
||||
signed_header_names = [h.strip().lower() for h in signed_headers.split(";") if h.strip()]
|
||||
headers = {k.lower(): v for k, v in request.headers.items()}
|
||||
canonical_headers = []
|
||||
for name in signed_header_names:
|
||||
value = headers.get(name)
|
||||
if value is None:
|
||||
return _s3_error("AuthorizationHeaderMalformed", f"Signed header '{name}' missing", status=400)
|
||||
canonical_headers.append(f"{name}:{_normalize_ws(value)}\n")
|
||||
|
||||
canonical_request = "\n".join(
|
||||
[
|
||||
request.method,
|
||||
_canonical_uri(request.url.path),
|
||||
_canonical_query(request.query_params.multi_items()),
|
||||
"".join(canonical_headers),
|
||||
";".join(signed_header_names),
|
||||
payload_hash,
|
||||
]
|
||||
)
|
||||
|
||||
hashed_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
|
||||
scope = "/".join([datestamp, region, "s3", "aws4_request"])
|
||||
string_to_sign = "\n".join([scheme, amz_date, scope, hashed_request])
|
||||
|
||||
k_date = _sign(("AWS4" + settings["secret_key"]).encode("utf-8"), datestamp)
|
||||
k_region = hmac.new(k_date, region.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_service = hmac.new(k_region, b"s3", hashlib.sha256).digest()
|
||||
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
|
||||
expected = hmac.new(k_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
if expected != signature:
|
||||
return _s3_error("SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.", status=403)
|
||||
return None
|
||||
|
||||
|
||||
def _virtual_path(settings: S3Settings, key: str) -> str:
|
||||
key_norm = key.strip("/")
|
||||
base_norm = settings["base_path"].strip("/")
|
||||
segments = [seg for seg in [base_norm, key_norm] if seg]
|
||||
if not segments:
|
||||
return "/"
|
||||
return "/" + "/".join(segments)
|
||||
|
||||
|
||||
def _join_virtual(base: str, name: str) -> str:
|
||||
if not base or base == "/":
|
||||
return "/" + name.strip("/")
|
||||
return base.rstrip("/") + "/" + name.strip("/")
|
||||
|
||||
|
||||
async def _list_dir_all(path: str) -> List[Dict]:
|
||||
items: List[Dict] = []
|
||||
page_num = 1
|
||||
page_size = 1000
|
||||
while True:
|
||||
try:
|
||||
res = await VirtualFSService.list_virtual_dir(path, page_num=page_num, page_size=page_size)
|
||||
except HTTPException as exc: # directory missing
|
||||
if exc.status_code in (400, 404):
|
||||
return []
|
||||
raise
|
||||
chunk = res.get("items", [])
|
||||
items.extend(chunk)
|
||||
total = int(res.get("total", len(items)))
|
||||
if len(items) >= total or not chunk or len(chunk) < page_size:
|
||||
break
|
||||
page_num += 1
|
||||
return items
|
||||
|
||||
|
||||
async def _collect_objects(path: str, key_prefix: str, recursive: bool, collect_prefixes: bool) -> Tuple[List[Tuple[str, Dict]], List[str]]:
|
||||
entries = await _list_dir_all(path)
|
||||
files: List[Tuple[str, Dict]] = []
|
||||
prefixes: List[str] = []
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
if not name:
|
||||
continue
|
||||
if entry.get("is_dir"):
|
||||
dir_key = f"{key_prefix}{name.strip('/')}/"
|
||||
if collect_prefixes:
|
||||
prefixes.append(dir_key)
|
||||
if recursive:
|
||||
sub_path = _join_virtual(path, name)
|
||||
sub_files, _ = await _collect_objects(sub_path, dir_key, True, False)
|
||||
files.extend(sub_files)
|
||||
else:
|
||||
key = f"{key_prefix}{name}"
|
||||
files.append((key, entry))
|
||||
files.sort(key=lambda item: item[0])
|
||||
prefixes.sort()
|
||||
return files, prefixes
|
||||
|
||||
|
||||
def _encode_token(key: str) -> str:
|
||||
raw = base64.urlsafe_b64encode(key.encode("utf-8")).decode("ascii")
|
||||
return raw.rstrip("=")
|
||||
|
||||
|
||||
def _decode_token(token: str) -> Optional[str]:
|
||||
if not token:
|
||||
return None
|
||||
padding = "=" * (-len(token) % 4)
|
||||
try:
|
||||
return base64.urlsafe_b64decode(token + padding).decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _apply_pagination(entries: List[Tuple[str, Dict]], prefixes: List[str], max_keys: int, start_after: Optional[str], continuation_token: Optional[str]) -> Tuple[List[Tuple[str, Dict]], List[str], bool, Optional[str]]:
|
||||
combined = [(key, data, True) for key, data in entries] + [(prefix, None, False) for prefix in prefixes]
|
||||
combined.sort(key=lambda item: item[0])
|
||||
|
||||
start_key = start_after or _decode_token(continuation_token or "")
|
||||
if start_key:
|
||||
combined = [item for item in combined if item[0] > start_key]
|
||||
|
||||
is_truncated = len(combined) > max_keys
|
||||
sliced = combined[:max_keys]
|
||||
next_token = _encode_token(sliced[-1][0]) if is_truncated and sliced else None
|
||||
|
||||
contents = [(key, data) for key, data, is_file in sliced if is_file]
|
||||
next_prefixes = [key for key, _, is_file in sliced if not is_file]
|
||||
return contents, next_prefixes, is_truncated, next_token
|
||||
|
||||
|
||||
def _format_contents(entries: List[Tuple[str, Dict]]) -> str:
|
||||
blocks = []
|
||||
for key, meta in entries:
|
||||
size = int(meta.get("size", 0))
|
||||
mtime = meta.get("mtime")
|
||||
if mtime is not None:
|
||||
try:
|
||||
mtime_val = int(mtime)
|
||||
except Exception:
|
||||
mtime_val = 0
|
||||
else:
|
||||
mtime_val = 0
|
||||
last_modified = dt.datetime.utcfromtimestamp(mtime_val or dt.datetime.utcnow().timestamp()).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
etag = _etag(key, size, mtime_val)
|
||||
blocks.append(
|
||||
f"<Contents><Key>{key}</Key><LastModified>{last_modified}</LastModified><ETag>{etag}</ETag><Size>{size}</Size><StorageClass>STANDARD</StorageClass></Contents>"
|
||||
)
|
||||
return "".join(blocks)
|
||||
|
||||
|
||||
def _format_common_prefixes(prefixes: List[str]) -> str:
|
||||
return "".join(f"<CommonPrefixes><Prefix>{p}</Prefix></CommonPrefixes>" for p in prefixes)
|
||||
|
||||
|
||||
def _resource_path(bucket: str, key: Optional[str] = None) -> str:
|
||||
if key:
|
||||
return f"/s3/{bucket}/{key}"
|
||||
return f"/s3/{bucket}"
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_buckets(request: Request):
|
||||
if (resp := await _ensure_enabled()) is not None:
|
||||
return resp
|
||||
settings, err = await _get_settings()
|
||||
if err:
|
||||
return err
|
||||
assert settings
|
||||
if (auth := await _authorize_sigv4(request, settings)) is not None:
|
||||
return auth
|
||||
req_id, headers = _meta_headers()
|
||||
xml = (
|
||||
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
|
||||
f"<ListAllMyBucketsResult xmlns=\"{_XML_NS}\">"
|
||||
f"<Owner><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Owner>"
|
||||
f"<Buckets><Bucket><Name>{settings['bucket']}</Name><CreationDate>{_now_iso()}</CreationDate></Bucket></Buckets>"
|
||||
f"</ListAllMyBucketsResult>"
|
||||
)
|
||||
headers.update({"Content-Type": "application/xml"})
|
||||
return Response(content=xml, media_type="application/xml", headers=headers)
|
||||
|
||||
|
||||
@router.get("/{bucket}")
|
||||
async def list_objects(request: Request, bucket: str):
|
||||
if (resp := await _ensure_enabled()) is not None:
|
||||
return resp
|
||||
settings, err = await _get_settings()
|
||||
if err:
|
||||
return err
|
||||
assert settings
|
||||
if bucket != settings["bucket"]:
|
||||
return _s3_error("NoSuchBucket", "The specified bucket does not exist.", _resource_path(bucket), status=404)
|
||||
if (auth := await _authorize_sigv4(request, settings)) is not None:
|
||||
return auth
|
||||
|
||||
params = request.query_params
|
||||
if params.get("list-type", "2") != "2":
|
||||
return _s3_error("InvalidArgument", "Only ListObjectsV2 (list-type=2) is supported.", _resource_path(bucket), status=400)
|
||||
|
||||
prefix = (params.get("prefix") or "").lstrip("/")
|
||||
delimiter = params.get("delimiter")
|
||||
recursive = not delimiter
|
||||
max_keys_raw = params.get("max-keys", "1000")
|
||||
try:
|
||||
max_keys = max(1, min(1000, int(max_keys_raw)))
|
||||
except ValueError:
|
||||
max_keys = 1000
|
||||
start_after = (params.get("start-after") or "").lstrip("/") or None
|
||||
continuation = params.get("continuation-token")
|
||||
|
||||
# Exact file match if prefix is non-empty and does not end with '/'
|
||||
files: List[Tuple[str, Dict]] = []
|
||||
prefixes: List[str] = []
|
||||
if prefix and not prefix.endswith("/"):
|
||||
try:
|
||||
info = await VirtualFSService.stat_file(_virtual_path(settings, prefix))
|
||||
if not info.get("is_dir"):
|
||||
files = [(prefix, info)]
|
||||
except HTTPException as exc:
|
||||
if exc.status_code not in (400, 404):
|
||||
raise
|
||||
if files:
|
||||
contents, next_prefixes, is_truncated, next_token = _apply_pagination(files, [], max_keys, start_after, continuation)
|
||||
xml = _build_list_result(bucket, prefix, delimiter, contents, next_prefixes, max_keys, is_truncated, continuation, next_token, start_after)
|
||||
return xml
|
||||
|
||||
dir_prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
||||
virtual_dir = _virtual_path(settings, dir_prefix)
|
||||
files, prefixes = await _collect_objects(virtual_dir, dir_prefix, recursive, bool(delimiter))
|
||||
|
||||
contents, next_prefixes, is_truncated, next_token = _apply_pagination(files, prefixes if delimiter else [], max_keys, start_after, continuation)
|
||||
return _build_list_result(bucket, prefix, delimiter, contents, next_prefixes if delimiter else [], max_keys, is_truncated, continuation, next_token, start_after)
|
||||
|
||||
|
||||
@router.get("/{bucket}/", include_in_schema=False)
|
||||
async def list_objects_with_slash(request: Request, bucket: str):
|
||||
return await list_objects(request, bucket)
|
||||
|
||||
|
||||
def _build_list_result(
|
||||
bucket: str,
|
||||
prefix: str,
|
||||
delimiter: Optional[str],
|
||||
contents: List[Tuple[str, Dict]],
|
||||
prefixes: List[str],
|
||||
max_keys: int,
|
||||
is_truncated: bool,
|
||||
continuation: Optional[str],
|
||||
next_token: Optional[str],
|
||||
start_after: Optional[str],
|
||||
):
|
||||
req_id, headers = _meta_headers()
|
||||
body = [f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", f"<ListBucketResult xmlns=\"{_XML_NS}\">"]
|
||||
body.append(f"<Name>{bucket}</Name>")
|
||||
body.append(f"<Prefix>{prefix}</Prefix>")
|
||||
if delimiter:
|
||||
body.append(f"<Delimiter>{delimiter}</Delimiter>")
|
||||
if continuation:
|
||||
body.append(f"<ContinuationToken>{continuation}</ContinuationToken>")
|
||||
if start_after:
|
||||
body.append(f"<StartAfter>{start_after}</StartAfter>")
|
||||
body.append(f"<MaxKeys>{max_keys}</MaxKeys>")
|
||||
body.append(f"<KeyCount>{len(contents) + len(prefixes)}</KeyCount>")
|
||||
body.append(f"<IsTruncated>{str(is_truncated).lower()}</IsTruncated>")
|
||||
if next_token:
|
||||
body.append(f"<NextContinuationToken>{next_token}</NextContinuationToken>")
|
||||
body.append(_format_contents(contents))
|
||||
if prefixes:
|
||||
body.append(_format_common_prefixes(prefixes))
|
||||
body.append("</ListBucketResult>")
|
||||
xml = "".join(body)
|
||||
headers.update({"Content-Type": "application/xml"})
|
||||
return Response(content=xml, media_type="application/xml", headers=headers)
|
||||
|
||||
|
||||
async def _ensure_bucket_and_auth(request: Request, bucket: str) -> Tuple[Optional[S3Settings], Optional[Response]]:
|
||||
if (resp := await _ensure_enabled()) is not None:
|
||||
return None, resp
|
||||
settings, err = await _get_settings()
|
||||
if err:
|
||||
return None, err
|
||||
assert settings
|
||||
if bucket != settings["bucket"]:
|
||||
return None, _s3_error("NoSuchBucket", "The specified bucket does not exist.", _resource_path(bucket), status=404)
|
||||
if (auth := await _authorize_sigv4(request, settings)) is not None:
|
||||
return None, auth
|
||||
return settings, None
|
||||
|
||||
|
||||
def _object_headers(meta: Dict, key: str) -> Dict[str, str]:
|
||||
size = int(meta.get("size", 0))
|
||||
mtime = meta.get("mtime")
|
||||
if mtime is not None:
|
||||
try:
|
||||
mtime_val = int(mtime)
|
||||
except Exception:
|
||||
mtime_val = 0
|
||||
else:
|
||||
mtime_val = 0
|
||||
last_modified = dt.datetime.utcfromtimestamp(mtime_val or dt.datetime.utcnow().timestamp()).strftime("%a, %d %b %Y %H:%M:%S GMT")
|
||||
headers = {
|
||||
"Content-Length": str(size),
|
||||
"ETag": _etag(key, size, mtime_val),
|
||||
"Last-Modified": last_modified,
|
||||
"Accept-Ranges": "bytes",
|
||||
"x-amz-version-id": "null",
|
||||
}
|
||||
return headers
|
||||
|
||||
|
||||
async def _stat_object(settings: S3Settings, key: str) -> Tuple[Optional[Dict], Optional[Response]]:
|
||||
try:
|
||||
info = await VirtualFSService.stat_file(_virtual_path(settings, key))
|
||||
if info.get("is_dir"):
|
||||
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
|
||||
return info, None
|
||||
except HTTPException as exc:
|
||||
if exc.status_code == 404:
|
||||
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
|
||||
raise
|
||||
|
||||
|
||||
@router.api_route("/{bucket}/{object_path:path}", methods=["GET", "HEAD"])
|
||||
async def object_get_head(request: Request, bucket: str, object_path: str):
|
||||
settings, error = await _ensure_bucket_and_auth(request, bucket)
|
||||
if error:
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
meta, err = await _stat_object(settings, key)
|
||||
if err:
|
||||
return err
|
||||
assert meta
|
||||
_, base_headers = _meta_headers()
|
||||
base_headers.update(_object_headers(meta, key))
|
||||
if request.method == "HEAD":
|
||||
return Response(status_code=200, headers=base_headers)
|
||||
resp = await VirtualFSService.stream_file(_virtual_path(settings, key), request.headers.get("range"))
|
||||
safe_merge_keys = {"ETag", "Last-Modified", "x-amz-version-id", "Accept-Ranges"}
|
||||
for hk, hv in base_headers.items():
|
||||
if hk in safe_merge_keys:
|
||||
resp.headers.setdefault(hk, hv)
|
||||
resp.headers.setdefault("Content-Type", meta.get("mime") or "application/octet-stream")
|
||||
return resp
|
||||
|
||||
|
||||
@router.put("/{bucket}/{object_path:path}")
|
||||
async def put_object(request: Request, bucket: str, object_path: str):
|
||||
settings, error = await _ensure_bucket_and_auth(request, bucket)
|
||||
if error:
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
await VirtualFSService.write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
|
||||
meta, err = await _stat_object(settings, key)
|
||||
if err:
|
||||
return err
|
||||
headers = _object_headers(meta, key)
|
||||
headers.pop("Content-Length", None)
|
||||
headers.pop("Accept-Ranges", None)
|
||||
headers["Content-Length"] = "0"
|
||||
_, extra = _meta_headers()
|
||||
headers.update(extra)
|
||||
return Response(status_code=200, headers=headers)
|
||||
|
||||
|
||||
@router.delete("/{bucket}/{object_path:path}")
|
||||
async def delete_object(request: Request, bucket: str, object_path: str):
|
||||
settings, error = await _ensure_bucket_and_auth(request, bucket)
|
||||
if error:
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
try:
|
||||
await VirtualFSService.delete_path(_virtual_path(settings, key))
|
||||
except HTTPException as exc:
|
||||
if exc.status_code not in (400, 404):
|
||||
raise
|
||||
_, headers = _meta_headers()
|
||||
return Response(status_code=204, headers=headers)
|
||||
26
domain/virtual_fs/search_api.py
Normal file
26
domain/virtual_fs/search_api.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.search_service import VirtualFSSearchService
|
||||
|
||||
router = APIRouter(prefix="/api/fs/search", tags=["search"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search_files(
|
||||
q: str = Query(..., description="搜索查询"),
|
||||
top_k: int = Query(10, description="返回结果数量"),
|
||||
mode: str = Query("vector", description="搜索模式: 'vector' 或 'filename'"),
|
||||
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
|
||||
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
if not q.strip():
|
||||
return {"items": [], "query": q}
|
||||
|
||||
top_k = max(top_k, 1)
|
||||
page = max(page, 1)
|
||||
page_size = max(min(page_size, 100), 1)
|
||||
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
118
domain/virtual_fs/search_service.py
Normal file
118
domain/virtual_fs/search_service.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from domain.virtual_fs.types import SearchResultItem
|
||||
from domain.ai.inference import get_text_embedding
|
||||
from domain.ai.service import VectorDBService
|
||||
|
||||
|
||||
def _normalize_result(raw: Dict[str, Any], source: str, fallback_score: float = 0.0) -> SearchResultItem:
|
||||
entity = dict(raw.get("entity") or {})
|
||||
source_path = entity.get("source_path")
|
||||
stored_path = entity.get("path")
|
||||
path = source_path or stored_path or ""
|
||||
chunk_id_value = entity.get("chunk_id")
|
||||
chunk_id = str(chunk_id_value) if chunk_id_value is not None else None
|
||||
snippet = entity.get("text") or entity.get("description") or entity.get("name")
|
||||
mime = entity.get("mime")
|
||||
start_offset = entity.get("start_offset")
|
||||
end_offset = entity.get("end_offset")
|
||||
raw_score = raw.get("distance")
|
||||
score = float(raw_score) if raw_score is not None else fallback_score
|
||||
|
||||
metadata = {
|
||||
"retrieval_source": source,
|
||||
"raw_distance": raw_score,
|
||||
}
|
||||
if stored_path and stored_path != path:
|
||||
metadata["stored_path"] = stored_path
|
||||
vector_id = entity.get("vector_id")
|
||||
if vector_id:
|
||||
metadata["vector_id"] = vector_id
|
||||
|
||||
return SearchResultItem(
|
||||
id=str(raw.get("id")),
|
||||
path=path,
|
||||
score=score,
|
||||
chunk_id=chunk_id,
|
||||
snippet=snippet,
|
||||
mime=mime,
|
||||
source_type=entity.get("type") or source,
|
||||
start_offset=start_offset,
|
||||
end_offset=end_offset,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def _vector_search(query: str, top_k: int) -> List[SearchResultItem]:
|
||||
vector_db = VectorDBService()
|
||||
try:
|
||||
embedding = await get_text_embedding(query)
|
||||
except Exception:
|
||||
embedding = None
|
||||
if not embedding:
|
||||
return []
|
||||
|
||||
try:
|
||||
raw_results = await vector_db.search_vectors("vector_collection", embedding, max(top_k, 10))
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
results: List[SearchResultItem] = []
|
||||
for bucket in raw_results or []:
|
||||
for record in bucket or []:
|
||||
results.append(_normalize_result(record, "vector"))
|
||||
return results
|
||||
|
||||
|
||||
async def _filename_search(query: str, page: int, page_size: int) -> Tuple[List[SearchResultItem], bool]:
|
||||
vector_db = VectorDBService()
|
||||
limit = max(page * page_size + 1, page_size * (page + 2))
|
||||
limit = min(limit, 2000)
|
||||
try:
|
||||
raw_results = await vector_db.search_by_path("vector_collection", query, limit)
|
||||
except Exception:
|
||||
return [], False
|
||||
|
||||
records = raw_results[0] if raw_results else []
|
||||
deduped: List[SearchResultItem] = []
|
||||
seen_paths: set[str] = set()
|
||||
for record in records or []:
|
||||
item = _normalize_result(record, "filename", fallback_score=1.0)
|
||||
stored_path = item.metadata.get("stored_path") if item.metadata else None
|
||||
key = item.path or stored_path or ""
|
||||
if key in seen_paths:
|
||||
continue
|
||||
seen_paths.add(key)
|
||||
deduped.append(item)
|
||||
|
||||
start = max(page - 1, 0) * page_size
|
||||
end = start + page_size
|
||||
page_items = deduped[start:end]
|
||||
for offset, item in enumerate(page_items):
|
||||
if item.metadata is None:
|
||||
item.metadata = {}
|
||||
item.metadata.setdefault("retrieval_rank", start + offset)
|
||||
has_more = len(deduped) > end
|
||||
return page_items, has_more
|
||||
|
||||
|
||||
class VirtualFSSearchService:
|
||||
@staticmethod
|
||||
async def search(query: str, top_k: int, mode: str, page: int, page_size: int):
|
||||
if mode == "vector":
|
||||
items = (await _vector_search(query, top_k))[:top_k]
|
||||
return {"items": items, "query": query, "mode": mode}
|
||||
if mode == "filename":
|
||||
items, has_more = await _filename_search(query, page, page_size)
|
||||
return {
|
||||
"items": items,
|
||||
"query": query,
|
||||
"mode": mode,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
},
|
||||
}
|
||||
items = (await _vector_search(query, top_k))[:top_k]
|
||||
return {"items": items, "query": query, "mode": mode}
|
||||
1360
domain/virtual_fs/service.py
Normal file
1360
domain/virtual_fs/service.py
Normal file
File diff suppressed because it is too large
Load Diff
330
domain/virtual_fs/thumbnail.py
Normal file
330
domain/virtual_fs/thumbnail.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import hashlib
|
||||
import tempfile
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from fastapi import HTTPException
|
||||
|
||||
ALLOWED_EXT = {"jpg", "jpeg", "png", "webp", "gif", "bmp",
|
||||
"tiff", "arw", "cr2", "cr3", "nef", "rw2", "orf", "pef", "dng"}
|
||||
RAW_EXT = {"arw", "cr2", "cr3", "nef", "rw2", "orf", "pef", "dng"}
|
||||
VIDEO_EXT = {"mp4", "mov", "m4v", "avi", "mkv", "wmv", "flv", "webm", "mpg", "mpeg", "3gp"}
|
||||
MAX_IMAGE_SOURCE_SIZE = 200 * 1024 * 1024
|
||||
VIDEO_RANGE_LIMIT = 16 * 1024 * 1024 # 16MB
|
||||
VIDEO_INITIAL_CHUNK = 4 * 1024 * 1024
|
||||
CACHE_ROOT = Path('data/.thumb_cache')
|
||||
|
||||
|
||||
def is_image_filename(name: str) -> bool:
|
||||
parts = name.rsplit('.', 1)
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
return parts[1].lower() in ALLOWED_EXT
|
||||
|
||||
|
||||
def is_raw_filename(name: str) -> bool:
|
||||
parts = name.rsplit('.', 1)
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
return parts[1].lower() in RAW_EXT
|
||||
|
||||
|
||||
def is_video_filename(name: str) -> bool:
|
||||
parts = name.rsplit('.', 1)
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
return parts[1].lower() in VIDEO_EXT
|
||||
|
||||
|
||||
def _cache_key(adapter_id: int, rel: str, size: int, mtime: int, w: int, h: int, fit: str) -> str:
|
||||
raw = f"{adapter_id}|{rel}|{size}|{mtime}|{w}x{h}|{fit}".encode()
|
||||
return hashlib.sha1(raw).hexdigest()
|
||||
|
||||
|
||||
def _cache_path(key: str) -> Path:
|
||||
sub = Path(key[:2]) / key[2:4]
|
||||
return CACHE_ROOT / sub / f"{key}.webp"
|
||||
|
||||
|
||||
def _ensure_cache_dir(p: Path):
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _image_to_webp(im, w: int, h: int, fit: str) -> Tuple[bytes, str]:
|
||||
from PIL import Image
|
||||
if im.mode not in ("RGB", "RGBA"):
|
||||
im = im.convert("RGBA" if im.mode in ("P", "LA") else "RGB")
|
||||
if fit == 'cover':
|
||||
im_ratio = im.width / im.height
|
||||
target_ratio = w / h
|
||||
if im_ratio > target_ratio:
|
||||
new_h = h
|
||||
new_w = int(h * im_ratio)
|
||||
else:
|
||||
new_w = w
|
||||
new_h = int(w / im_ratio)
|
||||
im = im.resize((new_w, new_h))
|
||||
left = max(0, (im.width - w)//2)
|
||||
top = max(0, (im.height - h)//2)
|
||||
im = im.crop((left, top, left + w, top + h))
|
||||
else:
|
||||
im.thumbnail((w, h))
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, 'WEBP', quality=80)
|
||||
return buf.getvalue(), 'image/webp'
|
||||
|
||||
|
||||
def generate_thumb(data: bytes, w: int, h: int, fit: str, is_raw: bool = False) -> Tuple[bytes, str]:
|
||||
from PIL import Image
|
||||
if is_raw:
|
||||
try:
|
||||
import rawpy
|
||||
with rawpy.imread(io.BytesIO(data)) as raw:
|
||||
try:
|
||||
thumb = raw.extract_thumb()
|
||||
except rawpy.LibRawNoThumbnailError:
|
||||
thumb = None
|
||||
|
||||
if thumb is not None and thumb.format in [rawpy.ThumbFormat.JPEG, rawpy.ThumbFormat.BITMAP]:
|
||||
im = Image.open(io.BytesIO(thumb.data))
|
||||
else:
|
||||
rgb = raw.postprocess(
|
||||
use_camera_wb=False, use_auto_wb=True, output_bps=8)
|
||||
im = Image.fromarray(rgb)
|
||||
except Exception as e:
|
||||
print(f"rawpy processing failed: {e}")
|
||||
raise e
|
||||
|
||||
else:
|
||||
im = Image.open(io.BytesIO(data))
|
||||
|
||||
return _image_to_webp(im, w, h, fit)
|
||||
|
||||
|
||||
async def _collect_response_bytes(response, limit: int) -> bytes:
|
||||
if response is None:
|
||||
return b""
|
||||
|
||||
try:
|
||||
if isinstance(response, (bytes, bytearray)):
|
||||
return bytes(response[:limit])
|
||||
|
||||
body = getattr(response, "body", None)
|
||||
if body is not None:
|
||||
return bytes(body[:limit])
|
||||
|
||||
iterator = getattr(response, "body_iterator", None)
|
||||
if iterator is not None:
|
||||
data = bytearray()
|
||||
async for chunk in iterator:
|
||||
if not chunk:
|
||||
continue
|
||||
need = limit - len(data)
|
||||
if need <= 0:
|
||||
break
|
||||
data.extend(chunk[:need])
|
||||
if len(data) >= limit:
|
||||
break
|
||||
return bytes(data)
|
||||
|
||||
if hasattr(response, "__aiter__"):
|
||||
data = bytearray()
|
||||
async for chunk in response:
|
||||
if not chunk:
|
||||
continue
|
||||
need = limit - len(data)
|
||||
if need <= 0:
|
||||
break
|
||||
data.extend(chunk[:need])
|
||||
if len(data) >= limit:
|
||||
break
|
||||
return bytes(data)
|
||||
finally:
|
||||
close_func = getattr(response, "close", None)
|
||||
if callable(close_func):
|
||||
result = close_func()
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
return b""
|
||||
|
||||
|
||||
async def _read_range_slice(adapter, root: str, rel: str, start: int, end: int) -> bytes:
|
||||
read_range = getattr(adapter, "read_file_range", None)
|
||||
if callable(read_range):
|
||||
try:
|
||||
return await read_range(root, rel, start, end)
|
||||
except TypeError:
|
||||
return await read_range(root, rel, start, end=end)
|
||||
|
||||
stream_impl = getattr(adapter, "stream_file", None)
|
||||
if callable(stream_impl):
|
||||
range_header = f"bytes={start}-{end}"
|
||||
response = await stream_impl(root, rel, range_header)
|
||||
expected = end - start + 1
|
||||
return await _collect_response_bytes(response, expected)
|
||||
|
||||
read_file = getattr(adapter, "read_file", None)
|
||||
if callable(read_file) and start == 0:
|
||||
data = await read_file(root, rel)
|
||||
slice_end = end + 1
|
||||
return data[:slice_end]
|
||||
|
||||
return b""
|
||||
|
||||
|
||||
async def _read_video_prefix(adapter, root: str, rel: str, size: int, limit: int = VIDEO_RANGE_LIMIT) -> bytes:
|
||||
chunk_size = min(VIDEO_INITIAL_CHUNK, limit)
|
||||
offset = 0
|
||||
collected = bytearray()
|
||||
|
||||
while len(collected) < limit:
|
||||
end = offset + chunk_size - 1
|
||||
data = await _read_range_slice(adapter, root, rel, offset, end)
|
||||
if not data:
|
||||
break
|
||||
collected.extend(data)
|
||||
if len(data) < chunk_size:
|
||||
break
|
||||
offset += len(data)
|
||||
remaining = limit - len(collected)
|
||||
if remaining <= 0:
|
||||
break
|
||||
chunk_size = min(chunk_size * 2, remaining)
|
||||
|
||||
if not collected and size <= limit:
|
||||
read_file = getattr(adapter, "read_file", None)
|
||||
if callable(read_file):
|
||||
blob = await read_file(root, rel)
|
||||
if blob:
|
||||
return bytes(blob[:limit])
|
||||
|
||||
return bytes(collected[:limit])
|
||||
|
||||
|
||||
async def _run_ffmpeg_extract_frame(src_path: str, dst_path: str):
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-i", src_path,
|
||||
"-frames:v", "1",
|
||||
dst_path,
|
||||
]
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
raise RuntimeError("未找到 ffmpeg,可执行文件需要在 PATH 中") from e
|
||||
|
||||
stdout, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
message = stderr.decode().strip() or stdout.decode().strip() or "ffmpeg 执行失败"
|
||||
raise RuntimeError(message)
|
||||
|
||||
|
||||
async def _generate_video_thumb(video_bytes: bytes, rel: str, w: int, h: int, fit: str) -> Tuple[bytes, str]:
|
||||
from PIL import Image
|
||||
|
||||
suffix = Path(rel).suffix or ".mp4"
|
||||
src_tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
src_path = src_tmp.name
|
||||
try:
|
||||
src_tmp.write(video_bytes)
|
||||
src_tmp.flush()
|
||||
finally:
|
||||
src_tmp.close()
|
||||
|
||||
dst_tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||
dst_path = dst_tmp.name
|
||||
dst_tmp.close()
|
||||
|
||||
try:
|
||||
await _run_ffmpeg_extract_frame(src_path, dst_path)
|
||||
with Image.open(dst_path) as im:
|
||||
im.load()
|
||||
return _image_to_webp(im, w, h, fit)
|
||||
finally:
|
||||
with suppress(FileNotFoundError):
|
||||
Path(src_path).unlink()
|
||||
with suppress(FileNotFoundError):
|
||||
Path(dst_path).unlink()
|
||||
|
||||
|
||||
async def get_or_create_thumb(adapter, adapter_id: int, root: str, rel: str, w: int, h: int, fit: str = 'cover'):
|
||||
stat = await adapter.stat_file(root, rel)
|
||||
size = int(stat.get('size') or 0)
|
||||
is_video = is_video_filename(rel)
|
||||
if not is_video and size > MAX_IMAGE_SOURCE_SIZE:
|
||||
raise HTTPException(400, detail="Image too large for thumbnail")
|
||||
|
||||
key = _cache_key(adapter_id, rel, size, int(
|
||||
stat.get('mtime', 0)), w, h, fit)
|
||||
path = _cache_path(key)
|
||||
if path.exists():
|
||||
return path.read_bytes(), 'image/webp', key
|
||||
|
||||
_ensure_cache_dir(path)
|
||||
thumb_bytes, mime = None, None
|
||||
|
||||
get_thumb_impl = getattr(adapter, "get_thumbnail", None)
|
||||
if callable(get_thumb_impl):
|
||||
size_str = "large" if w > 400 else "medium" if w > 100 else "small"
|
||||
native_thumb_bytes = await get_thumb_impl(root, rel, size_str)
|
||||
|
||||
if native_thumb_bytes:
|
||||
try:
|
||||
from PIL import Image
|
||||
im = Image.open(io.BytesIO(native_thumb_bytes))
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, 'WEBP', quality=85)
|
||||
thumb_bytes = buf.getvalue()
|
||||
mime = 'image/webp'
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Failed to convert native thumbnail to WebP: {e}, falling back.")
|
||||
thumb_bytes, mime = None, None
|
||||
|
||||
if not thumb_bytes:
|
||||
if is_video:
|
||||
try:
|
||||
video_bytes = await _read_video_prefix(adapter, root, rel, size)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Video prefix read failed: {e}")
|
||||
raise HTTPException(500, detail=f"Video read failed: {e}")
|
||||
|
||||
if not video_bytes:
|
||||
raise HTTPException(500, detail="Unable to read video data for thumbnail")
|
||||
|
||||
try:
|
||||
thumb_bytes, mime = await _generate_video_thumb(video_bytes, rel, w, h, fit)
|
||||
except Exception as e:
|
||||
print(f"Video thumbnail generation failed: {e}")
|
||||
raise HTTPException(
|
||||
500, detail=f"Video thumbnail generation failed: {e}")
|
||||
else:
|
||||
read_data = await adapter.read_file(root, rel)
|
||||
try:
|
||||
thumb_bytes, mime = generate_thumb(
|
||||
read_data, w, h, fit, is_raw=is_raw_filename(rel))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
500, detail=f"Thumbnail generation failed: {e}")
|
||||
|
||||
if thumb_bytes:
|
||||
path.write_bytes(thumb_bytes)
|
||||
return thumb_bytes, mime, key
|
||||
|
||||
raise HTTPException(
|
||||
500, detail="Failed to generate thumbnail by any means")
|
||||
40
domain/virtual_fs/types.py
Normal file
40
domain/virtual_fs/types.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VfsEntry(BaseModel):
|
||||
name: str
|
||||
is_dir: bool
|
||||
size: int
|
||||
mtime: int
|
||||
type: Optional[str] = None
|
||||
has_thumbnail: Optional[bool] = None
|
||||
|
||||
|
||||
class DirListing(BaseModel):
|
||||
path: str
|
||||
entries: List[VfsEntry]
|
||||
pagination: Optional[dict] = None
|
||||
|
||||
|
||||
class SearchResultItem(BaseModel):
|
||||
id: int | str
|
||||
path: str
|
||||
score: float
|
||||
chunk_id: Optional[str] = None
|
||||
snippet: Optional[str] = None
|
||||
mime: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
start_offset: Optional[int] = None
|
||||
end_offset: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class MkdirRequest(BaseModel):
|
||||
path: str
|
||||
|
||||
|
||||
class MoveRequest(BaseModel):
|
||||
src: str
|
||||
dst: str
|
||||
301
domain/virtual_fs/webdav_api.py
Normal file
301
domain/virtual_fs/webdav_api.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from __future__ import annotations
|
||||
import base64
|
||||
import hashlib
|
||||
import mimetypes
|
||||
from email.utils import formatdate
|
||||
from urllib.parse import urlparse, unquote
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, Depends
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from domain.auth.service import AuthService
|
||||
from domain.auth.types import User, UserInDB
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
|
||||
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
|
||||
|
||||
|
||||
async def _ensure_webdav_enabled() -> None:
|
||||
enabled = await ConfigService.get(_WEBDAV_ENABLED_KEY, "1")
|
||||
if str(enabled).strip().lower() in ("0", "false", "off", "no"):
|
||||
raise HTTPException(503, detail="WebDAV mapping disabled")
|
||||
|
||||
|
||||
router = APIRouter(prefix="/webdav", tags=["webdav"])
|
||||
|
||||
|
||||
def _dav_headers(extra: Optional[dict] = None) -> dict:
|
||||
headers = {
|
||||
"DAV": "1",
|
||||
"MS-Author-Via": "DAV",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Allow": ", ".join([
|
||||
"OPTIONS",
|
||||
"PROPFIND",
|
||||
"GET",
|
||||
"HEAD",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"MKCOL",
|
||||
"MOVE",
|
||||
"COPY",
|
||||
]),
|
||||
}
|
||||
if extra:
|
||||
headers.update(extra)
|
||||
return headers
|
||||
|
||||
|
||||
async def _get_basic_user(request: Request) -> User:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth:
|
||||
raise HTTPException(401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
|
||||
scheme, _, param = auth.partition(" ")
|
||||
scheme_lower = scheme.lower()
|
||||
if scheme_lower == "basic":
|
||||
try:
|
||||
decoded = base64.b64decode(param).decode("utf-8")
|
||||
username, _, password = decoded.partition(":")
|
||||
except Exception:
|
||||
raise HTTPException(401, detail="Invalid Basic auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
user_or_false: Optional[UserInDB] = await AuthService.authenticate_user_db(username, password)
|
||||
if not user_or_false:
|
||||
raise HTTPException(401, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
u: UserInDB = user_or_false
|
||||
return User(id=u.id, username=u.username, email=u.email, full_name=u.full_name, disabled=u.disabled)
|
||||
elif scheme_lower == "bearer":
|
||||
if not param:
|
||||
raise HTTPException(401, detail="Invalid Bearer token")
|
||||
return User(id=0, username="bearer", email=None, full_name=None, disabled=False)
|
||||
else:
|
||||
raise HTTPException(401, detail="Unsupported auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
|
||||
|
||||
def _httpdate(ts: int | float) -> str:
|
||||
return formatdate(ts, usegmt=True)
|
||||
|
||||
|
||||
def _etag(path: str, size: int | None, mtime: int | None) -> str:
|
||||
raw = f"{path}|{size or 0}|{mtime or 0}".encode("utf-8")
|
||||
return '"' + hashlib.md5(raw).hexdigest() + '"'
|
||||
|
||||
|
||||
def _href_for(path: str, is_dir: bool) -> str:
|
||||
from urllib.parse import quote
|
||||
p = "/webdav" + (path if path.startswith("/") else "/" + path)
|
||||
if is_dir and not p.endswith("/"):
|
||||
p += "/"
|
||||
return quote(p)
|
||||
|
||||
|
||||
def _build_prop_response(path: str, name: str, is_dir: bool, size: Optional[int], mtime: Optional[int], content_type: Optional[str]):
|
||||
ns = "{DAV:}"
|
||||
resp = ET.Element(ns + "response")
|
||||
href = ET.SubElement(resp, ns + "href")
|
||||
href.text = _href_for(path, is_dir)
|
||||
|
||||
propstat = ET.SubElement(resp, ns + "propstat")
|
||||
prop = ET.SubElement(propstat, ns + "prop")
|
||||
|
||||
displayname = ET.SubElement(prop, ns + "displayname")
|
||||
displayname.text = name
|
||||
|
||||
resourcetype = ET.SubElement(prop, ns + "resourcetype")
|
||||
if is_dir:
|
||||
ET.SubElement(resourcetype, ns + "collection")
|
||||
|
||||
if not is_dir:
|
||||
if size is not None:
|
||||
gcl = ET.SubElement(prop, ns + "getcontentlength")
|
||||
gcl.text = str(size)
|
||||
if content_type:
|
||||
gct = ET.SubElement(prop, ns + "getcontenttype")
|
||||
gct.text = content_type
|
||||
|
||||
if mtime is not None:
|
||||
glm = ET.SubElement(prop, ns + "getlastmodified")
|
||||
glm.text = _httpdate(mtime)
|
||||
|
||||
etag = ET.SubElement(prop, ns + "getetag")
|
||||
etag.text = _etag(path, size, mtime)
|
||||
|
||||
status = ET.SubElement(propstat, ns + "status")
|
||||
status.text = "HTTP/1.1 200 OK"
|
||||
return resp
|
||||
|
||||
|
||||
def _multistatus_xml(responses: list[ET.Element]) -> bytes:
|
||||
ns = "{DAV:}"
|
||||
ms = ET.Element(ns + "multistatus")
|
||||
for r in responses:
|
||||
ms.append(r)
|
||||
return ET.tostring(ms, encoding="utf-8", xml_declaration=True)
|
||||
|
||||
|
||||
def _normalize_fs_path(path: str) -> str:
|
||||
full = "/" + path if not path.startswith("/") else path
|
||||
return unquote(full)
|
||||
|
||||
|
||||
@router.options("/{path:path}")
|
||||
async def options_root(path: str = "", _enabled: None = Depends(_ensure_webdav_enabled)):
|
||||
return Response(status_code=200, headers=_dav_headers())
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["PROPFIND"])
|
||||
async def propfind(
|
||||
request: Request,
|
||||
path: str,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
depth = request.headers.get("Depth", "1").lower()
|
||||
if depth not in ("0", "1", "infinity"):
|
||||
depth = "1"
|
||||
|
||||
responses: list[ET.Element] = []
|
||||
|
||||
# 先获取当前路径信息
|
||||
try:
|
||||
st = await VirtualFSService.stat_file(full_path)
|
||||
is_dir = bool(st.get("is_dir"))
|
||||
name = st.get("name") or full_path.rsplit("/", 1)[-1] or "/"
|
||||
size = None if is_dir else int(st.get("size", 0))
|
||||
mtime = int(st.get("mtime", 0)) if st.get("mtime") is not None else None
|
||||
ctype = None if is_dir else (mimetypes.guess_type(name)[0] or "application/octet-stream")
|
||||
responses.append(_build_prop_response(full_path, name, is_dir, size, mtime, ctype))
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
if depth in ("1", "infinity"):
|
||||
try:
|
||||
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
|
||||
for ent in listing["items"]:
|
||||
is_dir = bool(ent.get("is_dir"))
|
||||
name = ent.get("name")
|
||||
child_path = full_path.rstrip("/") + "/" + name
|
||||
size = None if is_dir else int(ent.get("size", 0))
|
||||
mtime = int(ent.get("mtime", 0)) if ent.get("mtime") is not None else None
|
||||
ctype = None if is_dir else (mimetypes.guess_type(name)[0] or "application/octet-stream")
|
||||
responses.append(_build_prop_response(child_path, name, is_dir, size, mtime, ctype))
|
||||
except HTTPException as e:
|
||||
if e.status_code == 400:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
xml = _multistatus_xml(responses)
|
||||
return Response(content=xml, status_code=207, media_type='application/xml; charset="utf-8"', headers=_dav_headers())
|
||||
|
||||
|
||||
@router.get("/{path:path}")
|
||||
async def dav_get(
|
||||
path: str,
|
||||
request: Request,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
range_header = request.headers.get("Range")
|
||||
return await VirtualFSService.stream_file(full_path, range_header)
|
||||
|
||||
|
||||
@router.head("/{path:path}")
|
||||
async def dav_head(
|
||||
path: str,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
try:
|
||||
st = await VirtualFSService.stat_file(full_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
is_dir = bool(st.get("is_dir"))
|
||||
headers = _dav_headers()
|
||||
if not is_dir:
|
||||
size = int(st.get("size", 0))
|
||||
name = st.get("name") or full_path.rsplit("/", 1)[-1]
|
||||
ctype = mimetypes.guess_type(name)[0] or "application/octet-stream"
|
||||
mtime = int(st.get("mtime", 0)) if st.get("mtime") is not None else None
|
||||
headers.update({
|
||||
"Content-Length": str(size),
|
||||
"Content-Type": ctype,
|
||||
"ETag": _etag(full_path, size, mtime),
|
||||
})
|
||||
return Response(status_code=200, headers=headers)
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["PUT"])
|
||||
async def dav_put(
|
||||
path: str,
|
||||
request: Request,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
async def body_iter():
|
||||
async for chunk in request.stream():
|
||||
if chunk:
|
||||
yield chunk
|
||||
size = await VirtualFSService.write_file_stream(full_path, body_iter(), overwrite=True)
|
||||
return Response(status_code=201, headers=_dav_headers({"Content-Length": "0"}))
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["DELETE"])
|
||||
async def dav_delete(
|
||||
path: str,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
await VirtualFSService.delete_path(full_path)
|
||||
return Response(status_code=204, headers=_dav_headers())
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["MKCOL"])
|
||||
async def dav_mkcol(
|
||||
path: str,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
await VirtualFSService.make_dir(full_path)
|
||||
return Response(status_code=201, headers=_dav_headers())
|
||||
|
||||
|
||||
def _parse_destination(dest: str) -> str:
|
||||
if not dest:
|
||||
raise HTTPException(400, detail="Missing Destination header")
|
||||
p = urlparse(dest)
|
||||
path = p.path if p.scheme else dest
|
||||
if path.startswith("/webdav"):
|
||||
rel = path[len("/webdav"):]
|
||||
else:
|
||||
rel = path
|
||||
return _normalize_fs_path(rel)
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["MOVE"])
|
||||
async def dav_move(path: str, request: Request, user: User = Depends(_get_basic_user)):
|
||||
full_src = _normalize_fs_path(path)
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
|
||||
await VirtualFSService.move_path(full_src, dst, overwrite=overwrite)
|
||||
return Response(status_code=204, headers=_dav_headers())
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["COPY"])
|
||||
async def dav_copy(path: str, request: Request, user: User = Depends(_get_basic_user)):
|
||||
full_src = _normalize_fs_path(path)
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
|
||||
await VirtualFSService.copy_path(full_src, dst, overwrite=overwrite)
|
||||
return Response(status_code=201 if not overwrite else 204, headers=_dav_headers())
|
||||
Reference in New Issue
Block a user