refactor: optimize backend module

This commit is contained in:
shiyu
2025-12-08 17:46:45 +08:00
parent cf8d10f71c
commit 8f515aaaf4
124 changed files with 6884 additions and 6390 deletions

189
domain/virtual_fs/api.py Normal file
View File

@@ -0,0 +1,189 @@
from typing import Annotated
from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.virtual_fs.service import VirtualFSService
from domain.virtual_fs.types import MkdirRequest, MoveRequest
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
@router.get("/file/{full_path:path}")
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
async def get_file(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
return await VirtualFSService.serve_file(full_path, request.headers.get("Range"))
@router.get("/thumb/{full_path:path}")
@audit(action=AuditAction.READ, description="获取缩略图")
async def get_thumb(
full_path: str,
request: Request,
w: int = Query(256, ge=8, le=1024),
h: int = Query(256, ge=8, le=1024),
fit: str = Query("cover"),
):
return await VirtualFSService.get_thumbnail(full_path, w, h, fit)
@router.get("/stream/{full_path:path}")
@audit(action=AuditAction.DOWNLOAD, description="流式读取文件")
async def stream_endpoint(
full_path: str,
request: Request,
):
return await VirtualFSService.stream_response(full_path, request.headers.get("Range"))
@router.get("/temp-link/{full_path:path}")
@audit(action=AuditAction.SHARE, description="创建临时链接")
async def get_temp_link(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久"),
):
data = await VirtualFSService.create_temp_link(full_path, expires_in)
return success(data)
@router.get("/public/{token}")
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
async def access_public_file(
token: str,
request: Request,
):
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
@router.get("/stat/{full_path:path}")
@audit(action=AuditAction.READ, description="查看文件信息")
async def get_file_stat(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
stat = await VirtualFSService.stat(full_path)
return success(stat)
@router.post("/file/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="上传文件")
async def put_file(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
):
data = await file.read()
result = await VirtualFSService.write_uploaded_file(full_path, data)
return success(result)
@router.post("/mkdir")
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
async def api_mkdir(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MkdirRequest,
):
result = await VirtualFSService.mkdir(body.path)
return success(result)
@router.post("/move")
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
async def api_move(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
result = await VirtualFSService.move(body.src, body.dst, overwrite)
return success(result)
@router.post("/rename")
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
async def api_rename(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
result = await VirtualFSService.rename(body.src, body.dst, overwrite)
return success(result)
@router.post("/copy")
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
async def api_copy(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
):
result = await VirtualFSService.copy(body.src, body.dst, overwrite)
return success(result)
@router.post("/upload/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
async def upload_stream(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
chunk_size: int = Query(1024 * 1024, ge=8 * 1024, le=8 * 1024 * 1024, description="单次读取块大小"),
):
result = await VirtualFSService.upload_stream_from_upload_file(full_path, file, chunk_size, overwrite)
return success(result)
@router.get("/{full_path:path}")
@audit(action=AuditAction.READ, description="浏览目录")
async def browse_fs(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
return success(data)
@router.delete("/{full_path:path}")
@audit(action=AuditAction.DELETE, description="删除路径")
async def api_delete(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
):
result = await VirtualFSService.delete(full_path)
return success(result)
@router.get("/")
@audit(action=AuditAction.READ, description="浏览根目录")
async def root_listing(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
return success(data)

537
domain/virtual_fs/s3_api.py Normal file
View File

@@ -0,0 +1,537 @@
from __future__ import annotations
import base64
import datetime as dt
import hashlib
import hmac
import uuid
from typing import Dict, Iterable, List, Optional, Tuple
from fastapi import APIRouter, Request, Response
from fastapi import HTTPException
from domain.config.service import ConfigService
from domain.virtual_fs.service import VirtualFSService
router = APIRouter(prefix="/s3", tags=["s3"])
FALSEY = {"0", "false", "off", "no"}
_XML_NS = "http://s3.amazonaws.com/doc/2006-03-01/"
class S3Settings(Dict[str, str]):
bucket: str
region: str
base_path: str
access_key: str
secret_key: str
def _now_iso() -> str:
return dt.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.000Z")
def _etag(key: str, size: Optional[int], mtime: Optional[int]) -> str:
raw = f"{key}|{size or 0}|{mtime or 0}".encode("utf-8")
return '"' + hashlib.md5(raw).hexdigest() + '"'
def _meta_headers() -> Tuple[str, Dict[str, str]]:
req_id = uuid.uuid4().hex
headers = {
"x-amz-request-id": req_id,
"x-amz-id-2": uuid.uuid4().hex,
"Server": "FoxelS3",
}
return req_id, headers
def _s3_error(code: str, message: str, resource: str = "", status: int = 400) -> Response:
req_id, headers = _meta_headers()
xml = (
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
f"<Error>"
f"<Code>{code}</Code>"
f"<Message>{message}</Message>"
f"<Resource>{resource}</Resource>"
f"<RequestId>{req_id}</RequestId>"
f"</Error>"
)
return Response(content=xml, status_code=status, media_type="application/xml", headers=headers)
async def _ensure_enabled() -> Optional[Response]:
flag = await ConfigService.get("S3_MAPPING_ENABLED", "1")
if str(flag).strip().lower() in FALSEY:
return _s3_error("ServiceUnavailable", "S3 mapping disabled", status=503)
return None
async def _get_settings() -> Tuple[Optional[S3Settings], Optional[Response]]:
bucket = (await ConfigService.get("S3_MAPPING_BUCKET", "foxel")) or "foxel"
region = (await ConfigService.get("S3_MAPPING_REGION", "us-east-1")) or "us-east-1"
base_path = (await ConfigService.get("S3_MAPPING_BASE_PATH", "/")) or "/"
access_key = (await ConfigService.get("S3_MAPPING_ACCESS_KEY")) or ""
secret_key = (await ConfigService.get("S3_MAPPING_SECRET_KEY")) or ""
if not access_key or not secret_key:
return None, _s3_error(
"InvalidAccessKeyId",
"S3 mapping access key/secret are not configured.",
status=403,
)
settings: S3Settings = {
"bucket": bucket,
"region": region,
"base_path": base_path,
"access_key": access_key,
"secret_key": secret_key,
}
return settings, None
def _canonical_uri(path: str) -> str:
from urllib.parse import quote
if not path:
return "/"
return quote(path, safe="/-_.~")
def _canonical_query(params: Iterable[Tuple[str, str]]) -> str:
from urllib.parse import quote
encoded = []
for key, value in params:
enc_key = quote(key, safe="-_.~")
enc_val = quote(value or "", safe="-_.~")
encoded.append((enc_key, enc_val))
encoded.sort()
return "&".join(f"{k}={v}" for k, v in encoded)
def _normalize_ws(value: str) -> str:
return " ".join(value.strip().split())
def _sign(key: bytes, msg: str) -> bytes:
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
async def _authorize_sigv4(request: Request, settings: S3Settings) -> Optional[Response]:
auth = request.headers.get("authorization")
if not auth:
return _s3_error("AccessDenied", "Missing Authorization header", status=403)
scheme = "AWS4-HMAC-SHA256"
if not auth.startswith(scheme + " "):
return _s3_error("InvalidRequest", "Signature Version 4 is required", status=400)
parts: Dict[str, str] = {}
for segment in auth[len(scheme) + 1 :].split(","):
k, _, v = segment.strip().partition("=")
parts[k] = v
credential = parts.get("Credential")
signed_headers = parts.get("SignedHeaders")
signature = parts.get("Signature")
if not credential or not signed_headers or not signature:
return _s3_error("InvalidRequest", "Authorization header is malformed", status=400)
cred_parts = credential.split("/")
if len(cred_parts) != 5 or cred_parts[-1] != "aws4_request":
return _s3_error("InvalidRequest", "Credential scope is invalid", status=400)
access_key, datestamp, region, service, _ = cred_parts
if access_key != settings["access_key"]:
return _s3_error("InvalidAccessKeyId", "The AWS Access Key Id you provided does not exist in our records.", status=403)
if service != "s3":
return _s3_error("InvalidRequest", "Only service 's3' is supported", status=400)
if region != settings["region"]:
return _s3_error("AuthorizationHeaderMalformed", f"Region '{region}' is invalid", status=400)
amz_date = request.headers.get("x-amz-date")
if not amz_date or not amz_date.startswith(datestamp):
return _s3_error("AuthorizationHeaderMalformed", "x-amz-date does not match credential scope", status=400)
payload_hash = request.headers.get("x-amz-content-sha256")
if not payload_hash:
return _s3_error("AuthorizationHeaderMalformed", "Missing x-amz-content-sha256", status=400)
if payload_hash.upper().startswith("STREAMING-AWS4-HMAC-SHA256"):
return _s3_error("NotImplemented", "Chunked uploads are not supported", status=400)
signed_header_names = [h.strip().lower() for h in signed_headers.split(";") if h.strip()]
headers = {k.lower(): v for k, v in request.headers.items()}
canonical_headers = []
for name in signed_header_names:
value = headers.get(name)
if value is None:
return _s3_error("AuthorizationHeaderMalformed", f"Signed header '{name}' missing", status=400)
canonical_headers.append(f"{name}:{_normalize_ws(value)}\n")
canonical_request = "\n".join(
[
request.method,
_canonical_uri(request.url.path),
_canonical_query(request.query_params.multi_items()),
"".join(canonical_headers),
";".join(signed_header_names),
payload_hash,
]
)
hashed_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
scope = "/".join([datestamp, region, "s3", "aws4_request"])
string_to_sign = "\n".join([scheme, amz_date, scope, hashed_request])
k_date = _sign(("AWS4" + settings["secret_key"]).encode("utf-8"), datestamp)
k_region = hmac.new(k_date, region.encode("utf-8"), hashlib.sha256).digest()
k_service = hmac.new(k_region, b"s3", hashlib.sha256).digest()
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
expected = hmac.new(k_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
if expected != signature:
return _s3_error("SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.", status=403)
return None
def _virtual_path(settings: S3Settings, key: str) -> str:
key_norm = key.strip("/")
base_norm = settings["base_path"].strip("/")
segments = [seg for seg in [base_norm, key_norm] if seg]
if not segments:
return "/"
return "/" + "/".join(segments)
def _join_virtual(base: str, name: str) -> str:
if not base or base == "/":
return "/" + name.strip("/")
return base.rstrip("/") + "/" + name.strip("/")
async def _list_dir_all(path: str) -> List[Dict]:
items: List[Dict] = []
page_num = 1
page_size = 1000
while True:
try:
res = await VirtualFSService.list_virtual_dir(path, page_num=page_num, page_size=page_size)
except HTTPException as exc: # directory missing
if exc.status_code in (400, 404):
return []
raise
chunk = res.get("items", [])
items.extend(chunk)
total = int(res.get("total", len(items)))
if len(items) >= total or not chunk or len(chunk) < page_size:
break
page_num += 1
return items
async def _collect_objects(path: str, key_prefix: str, recursive: bool, collect_prefixes: bool) -> Tuple[List[Tuple[str, Dict]], List[str]]:
entries = await _list_dir_all(path)
files: List[Tuple[str, Dict]] = []
prefixes: List[str] = []
for entry in entries:
name = entry.get("name")
if not name:
continue
if entry.get("is_dir"):
dir_key = f"{key_prefix}{name.strip('/')}/"
if collect_prefixes:
prefixes.append(dir_key)
if recursive:
sub_path = _join_virtual(path, name)
sub_files, _ = await _collect_objects(sub_path, dir_key, True, False)
files.extend(sub_files)
else:
key = f"{key_prefix}{name}"
files.append((key, entry))
files.sort(key=lambda item: item[0])
prefixes.sort()
return files, prefixes
def _encode_token(key: str) -> str:
raw = base64.urlsafe_b64encode(key.encode("utf-8")).decode("ascii")
return raw.rstrip("=")
def _decode_token(token: str) -> Optional[str]:
if not token:
return None
padding = "=" * (-len(token) % 4)
try:
return base64.urlsafe_b64decode(token + padding).decode("utf-8")
except Exception:
return None
def _apply_pagination(entries: List[Tuple[str, Dict]], prefixes: List[str], max_keys: int, start_after: Optional[str], continuation_token: Optional[str]) -> Tuple[List[Tuple[str, Dict]], List[str], bool, Optional[str]]:
combined = [(key, data, True) for key, data in entries] + [(prefix, None, False) for prefix in prefixes]
combined.sort(key=lambda item: item[0])
start_key = start_after or _decode_token(continuation_token or "")
if start_key:
combined = [item for item in combined if item[0] > start_key]
is_truncated = len(combined) > max_keys
sliced = combined[:max_keys]
next_token = _encode_token(sliced[-1][0]) if is_truncated and sliced else None
contents = [(key, data) for key, data, is_file in sliced if is_file]
next_prefixes = [key for key, _, is_file in sliced if not is_file]
return contents, next_prefixes, is_truncated, next_token
def _format_contents(entries: List[Tuple[str, Dict]]) -> str:
blocks = []
for key, meta in entries:
size = int(meta.get("size", 0))
mtime = meta.get("mtime")
if mtime is not None:
try:
mtime_val = int(mtime)
except Exception:
mtime_val = 0
else:
mtime_val = 0
last_modified = dt.datetime.utcfromtimestamp(mtime_val or dt.datetime.utcnow().timestamp()).strftime("%Y-%m-%dT%H:%M:%S.000Z")
etag = _etag(key, size, mtime_val)
blocks.append(
f"<Contents><Key>{key}</Key><LastModified>{last_modified}</LastModified><ETag>{etag}</ETag><Size>{size}</Size><StorageClass>STANDARD</StorageClass></Contents>"
)
return "".join(blocks)
def _format_common_prefixes(prefixes: List[str]) -> str:
return "".join(f"<CommonPrefixes><Prefix>{p}</Prefix></CommonPrefixes>" for p in prefixes)
def _resource_path(bucket: str, key: Optional[str] = None) -> str:
if key:
return f"/s3/{bucket}/{key}"
return f"/s3/{bucket}"
@router.get("")
async def list_buckets(request: Request):
if (resp := await _ensure_enabled()) is not None:
return resp
settings, err = await _get_settings()
if err:
return err
assert settings
if (auth := await _authorize_sigv4(request, settings)) is not None:
return auth
req_id, headers = _meta_headers()
xml = (
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
f"<ListAllMyBucketsResult xmlns=\"{_XML_NS}\">"
f"<Owner><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Owner>"
f"<Buckets><Bucket><Name>{settings['bucket']}</Name><CreationDate>{_now_iso()}</CreationDate></Bucket></Buckets>"
f"</ListAllMyBucketsResult>"
)
headers.update({"Content-Type": "application/xml"})
return Response(content=xml, media_type="application/xml", headers=headers)
@router.get("/{bucket}")
async def list_objects(request: Request, bucket: str):
if (resp := await _ensure_enabled()) is not None:
return resp
settings, err = await _get_settings()
if err:
return err
assert settings
if bucket != settings["bucket"]:
return _s3_error("NoSuchBucket", "The specified bucket does not exist.", _resource_path(bucket), status=404)
if (auth := await _authorize_sigv4(request, settings)) is not None:
return auth
params = request.query_params
if params.get("list-type", "2") != "2":
return _s3_error("InvalidArgument", "Only ListObjectsV2 (list-type=2) is supported.", _resource_path(bucket), status=400)
prefix = (params.get("prefix") or "").lstrip("/")
delimiter = params.get("delimiter")
recursive = not delimiter
max_keys_raw = params.get("max-keys", "1000")
try:
max_keys = max(1, min(1000, int(max_keys_raw)))
except ValueError:
max_keys = 1000
start_after = (params.get("start-after") or "").lstrip("/") or None
continuation = params.get("continuation-token")
# Exact file match if prefix is non-empty and does not end with '/'
files: List[Tuple[str, Dict]] = []
prefixes: List[str] = []
if prefix and not prefix.endswith("/"):
try:
info = await VirtualFSService.stat_file(_virtual_path(settings, prefix))
if not info.get("is_dir"):
files = [(prefix, info)]
except HTTPException as exc:
if exc.status_code not in (400, 404):
raise
if files:
contents, next_prefixes, is_truncated, next_token = _apply_pagination(files, [], max_keys, start_after, continuation)
xml = _build_list_result(bucket, prefix, delimiter, contents, next_prefixes, max_keys, is_truncated, continuation, next_token, start_after)
return xml
dir_prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
virtual_dir = _virtual_path(settings, dir_prefix)
files, prefixes = await _collect_objects(virtual_dir, dir_prefix, recursive, bool(delimiter))
contents, next_prefixes, is_truncated, next_token = _apply_pagination(files, prefixes if delimiter else [], max_keys, start_after, continuation)
return _build_list_result(bucket, prefix, delimiter, contents, next_prefixes if delimiter else [], max_keys, is_truncated, continuation, next_token, start_after)
@router.get("/{bucket}/", include_in_schema=False)
async def list_objects_with_slash(request: Request, bucket: str):
return await list_objects(request, bucket)
def _build_list_result(
bucket: str,
prefix: str,
delimiter: Optional[str],
contents: List[Tuple[str, Dict]],
prefixes: List[str],
max_keys: int,
is_truncated: bool,
continuation: Optional[str],
next_token: Optional[str],
start_after: Optional[str],
):
req_id, headers = _meta_headers()
body = [f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", f"<ListBucketResult xmlns=\"{_XML_NS}\">"]
body.append(f"<Name>{bucket}</Name>")
body.append(f"<Prefix>{prefix}</Prefix>")
if delimiter:
body.append(f"<Delimiter>{delimiter}</Delimiter>")
if continuation:
body.append(f"<ContinuationToken>{continuation}</ContinuationToken>")
if start_after:
body.append(f"<StartAfter>{start_after}</StartAfter>")
body.append(f"<MaxKeys>{max_keys}</MaxKeys>")
body.append(f"<KeyCount>{len(contents) + len(prefixes)}</KeyCount>")
body.append(f"<IsTruncated>{str(is_truncated).lower()}</IsTruncated>")
if next_token:
body.append(f"<NextContinuationToken>{next_token}</NextContinuationToken>")
body.append(_format_contents(contents))
if prefixes:
body.append(_format_common_prefixes(prefixes))
body.append("</ListBucketResult>")
xml = "".join(body)
headers.update({"Content-Type": "application/xml"})
return Response(content=xml, media_type="application/xml", headers=headers)
async def _ensure_bucket_and_auth(request: Request, bucket: str) -> Tuple[Optional[S3Settings], Optional[Response]]:
if (resp := await _ensure_enabled()) is not None:
return None, resp
settings, err = await _get_settings()
if err:
return None, err
assert settings
if bucket != settings["bucket"]:
return None, _s3_error("NoSuchBucket", "The specified bucket does not exist.", _resource_path(bucket), status=404)
if (auth := await _authorize_sigv4(request, settings)) is not None:
return None, auth
return settings, None
def _object_headers(meta: Dict, key: str) -> Dict[str, str]:
size = int(meta.get("size", 0))
mtime = meta.get("mtime")
if mtime is not None:
try:
mtime_val = int(mtime)
except Exception:
mtime_val = 0
else:
mtime_val = 0
last_modified = dt.datetime.utcfromtimestamp(mtime_val or dt.datetime.utcnow().timestamp()).strftime("%a, %d %b %Y %H:%M:%S GMT")
headers = {
"Content-Length": str(size),
"ETag": _etag(key, size, mtime_val),
"Last-Modified": last_modified,
"Accept-Ranges": "bytes",
"x-amz-version-id": "null",
}
return headers
async def _stat_object(settings: S3Settings, key: str) -> Tuple[Optional[Dict], Optional[Response]]:
try:
info = await VirtualFSService.stat_file(_virtual_path(settings, key))
if info.get("is_dir"):
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
return info, None
except HTTPException as exc:
if exc.status_code == 404:
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
raise
@router.api_route("/{bucket}/{object_path:path}", methods=["GET", "HEAD"])
async def object_get_head(request: Request, bucket: str, object_path: str):
settings, error = await _ensure_bucket_and_auth(request, bucket)
if error:
return error
assert settings
key = object_path.lstrip("/")
meta, err = await _stat_object(settings, key)
if err:
return err
assert meta
_, base_headers = _meta_headers()
base_headers.update(_object_headers(meta, key))
if request.method == "HEAD":
return Response(status_code=200, headers=base_headers)
resp = await VirtualFSService.stream_file(_virtual_path(settings, key), request.headers.get("range"))
safe_merge_keys = {"ETag", "Last-Modified", "x-amz-version-id", "Accept-Ranges"}
for hk, hv in base_headers.items():
if hk in safe_merge_keys:
resp.headers.setdefault(hk, hv)
resp.headers.setdefault("Content-Type", meta.get("mime") or "application/octet-stream")
return resp
@router.put("/{bucket}/{object_path:path}")
async def put_object(request: Request, bucket: str, object_path: str):
settings, error = await _ensure_bucket_and_auth(request, bucket)
if error:
return error
assert settings
key = object_path.lstrip("/")
await VirtualFSService.write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
meta, err = await _stat_object(settings, key)
if err:
return err
headers = _object_headers(meta, key)
headers.pop("Content-Length", None)
headers.pop("Accept-Ranges", None)
headers["Content-Length"] = "0"
_, extra = _meta_headers()
headers.update(extra)
return Response(status_code=200, headers=headers)
@router.delete("/{bucket}/{object_path:path}")
async def delete_object(request: Request, bucket: str, object_path: str):
settings, error = await _ensure_bucket_and_auth(request, bucket)
if error:
return error
assert settings
key = object_path.lstrip("/")
try:
await VirtualFSService.delete_path(_virtual_path(settings, key))
except HTTPException as exc:
if exc.status_code not in (400, 404):
raise
_, headers = _meta_headers()
return Response(status_code=204, headers=headers)

View File

@@ -0,0 +1,26 @@
from fastapi import APIRouter, Depends, Query
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.virtual_fs.search_service import VirtualFSSearchService
router = APIRouter(prefix="/api/fs/search", tags=["search"])
@router.get("")
async def search_files(
q: str = Query(..., description="搜索查询"),
top_k: int = Query(10, description="返回结果数量"),
mode: str = Query("vector", description="搜索模式: 'vector''filename'"),
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
user: User = Depends(get_current_active_user),
):
if not q.strip():
return {"items": [], "query": q}
top_k = max(top_k, 1)
page = max(page, 1)
page_size = max(min(page_size, 100), 1)
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)

View File

@@ -0,0 +1,118 @@
from typing import Any, Dict, List, Tuple
from domain.virtual_fs.types import SearchResultItem
from domain.ai.inference import get_text_embedding
from domain.ai.service import VectorDBService
def _normalize_result(raw: Dict[str, Any], source: str, fallback_score: float = 0.0) -> SearchResultItem:
entity = dict(raw.get("entity") or {})
source_path = entity.get("source_path")
stored_path = entity.get("path")
path = source_path or stored_path or ""
chunk_id_value = entity.get("chunk_id")
chunk_id = str(chunk_id_value) if chunk_id_value is not None else None
snippet = entity.get("text") or entity.get("description") or entity.get("name")
mime = entity.get("mime")
start_offset = entity.get("start_offset")
end_offset = entity.get("end_offset")
raw_score = raw.get("distance")
score = float(raw_score) if raw_score is not None else fallback_score
metadata = {
"retrieval_source": source,
"raw_distance": raw_score,
}
if stored_path and stored_path != path:
metadata["stored_path"] = stored_path
vector_id = entity.get("vector_id")
if vector_id:
metadata["vector_id"] = vector_id
return SearchResultItem(
id=str(raw.get("id")),
path=path,
score=score,
chunk_id=chunk_id,
snippet=snippet,
mime=mime,
source_type=entity.get("type") or source,
start_offset=start_offset,
end_offset=end_offset,
metadata=metadata,
)
async def _vector_search(query: str, top_k: int) -> List[SearchResultItem]:
vector_db = VectorDBService()
try:
embedding = await get_text_embedding(query)
except Exception:
embedding = None
if not embedding:
return []
try:
raw_results = await vector_db.search_vectors("vector_collection", embedding, max(top_k, 10))
except Exception:
return []
results: List[SearchResultItem] = []
for bucket in raw_results or []:
for record in bucket or []:
results.append(_normalize_result(record, "vector"))
return results
async def _filename_search(query: str, page: int, page_size: int) -> Tuple[List[SearchResultItem], bool]:
vector_db = VectorDBService()
limit = max(page * page_size + 1, page_size * (page + 2))
limit = min(limit, 2000)
try:
raw_results = await vector_db.search_by_path("vector_collection", query, limit)
except Exception:
return [], False
records = raw_results[0] if raw_results else []
deduped: List[SearchResultItem] = []
seen_paths: set[str] = set()
for record in records or []:
item = _normalize_result(record, "filename", fallback_score=1.0)
stored_path = item.metadata.get("stored_path") if item.metadata else None
key = item.path or stored_path or ""
if key in seen_paths:
continue
seen_paths.add(key)
deduped.append(item)
start = max(page - 1, 0) * page_size
end = start + page_size
page_items = deduped[start:end]
for offset, item in enumerate(page_items):
if item.metadata is None:
item.metadata = {}
item.metadata.setdefault("retrieval_rank", start + offset)
has_more = len(deduped) > end
return page_items, has_more
class VirtualFSSearchService:
@staticmethod
async def search(query: str, top_k: int, mode: str, page: int, page_size: int):
if mode == "vector":
items = (await _vector_search(query, top_k))[:top_k]
return {"items": items, "query": query, "mode": mode}
if mode == "filename":
items, has_more = await _filename_search(query, page, page_size)
return {
"items": items,
"query": query,
"mode": mode,
"pagination": {
"page": page,
"page_size": page_size,
"has_more": has_more,
},
}
items = (await _vector_search(query, top_k))[:top_k]
return {"items": items, "query": query, "mode": mode}

1360
domain/virtual_fs/service.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,330 @@
from __future__ import annotations
import asyncio
import inspect
import io
import hashlib
import tempfile
from contextlib import suppress
from pathlib import Path
from typing import Tuple
from fastapi import HTTPException
ALLOWED_EXT = {"jpg", "jpeg", "png", "webp", "gif", "bmp",
"tiff", "arw", "cr2", "cr3", "nef", "rw2", "orf", "pef", "dng"}
RAW_EXT = {"arw", "cr2", "cr3", "nef", "rw2", "orf", "pef", "dng"}
VIDEO_EXT = {"mp4", "mov", "m4v", "avi", "mkv", "wmv", "flv", "webm", "mpg", "mpeg", "3gp"}
MAX_IMAGE_SOURCE_SIZE = 200 * 1024 * 1024
VIDEO_RANGE_LIMIT = 16 * 1024 * 1024 # 16MB
VIDEO_INITIAL_CHUNK = 4 * 1024 * 1024
CACHE_ROOT = Path('data/.thumb_cache')
def is_image_filename(name: str) -> bool:
parts = name.rsplit('.', 1)
if len(parts) < 2:
return False
return parts[1].lower() in ALLOWED_EXT
def is_raw_filename(name: str) -> bool:
parts = name.rsplit('.', 1)
if len(parts) < 2:
return False
return parts[1].lower() in RAW_EXT
def is_video_filename(name: str) -> bool:
parts = name.rsplit('.', 1)
if len(parts) < 2:
return False
return parts[1].lower() in VIDEO_EXT
def _cache_key(adapter_id: int, rel: str, size: int, mtime: int, w: int, h: int, fit: str) -> str:
raw = f"{adapter_id}|{rel}|{size}|{mtime}|{w}x{h}|{fit}".encode()
return hashlib.sha1(raw).hexdigest()
def _cache_path(key: str) -> Path:
sub = Path(key[:2]) / key[2:4]
return CACHE_ROOT / sub / f"{key}.webp"
def _ensure_cache_dir(p: Path):
p.parent.mkdir(parents=True, exist_ok=True)
def _image_to_webp(im, w: int, h: int, fit: str) -> Tuple[bytes, str]:
from PIL import Image
if im.mode not in ("RGB", "RGBA"):
im = im.convert("RGBA" if im.mode in ("P", "LA") else "RGB")
if fit == 'cover':
im_ratio = im.width / im.height
target_ratio = w / h
if im_ratio > target_ratio:
new_h = h
new_w = int(h * im_ratio)
else:
new_w = w
new_h = int(w / im_ratio)
im = im.resize((new_w, new_h))
left = max(0, (im.width - w)//2)
top = max(0, (im.height - h)//2)
im = im.crop((left, top, left + w, top + h))
else:
im.thumbnail((w, h))
buf = io.BytesIO()
im.save(buf, 'WEBP', quality=80)
return buf.getvalue(), 'image/webp'
def generate_thumb(data: bytes, w: int, h: int, fit: str, is_raw: bool = False) -> Tuple[bytes, str]:
from PIL import Image
if is_raw:
try:
import rawpy
with rawpy.imread(io.BytesIO(data)) as raw:
try:
thumb = raw.extract_thumb()
except rawpy.LibRawNoThumbnailError:
thumb = None
if thumb is not None and thumb.format in [rawpy.ThumbFormat.JPEG, rawpy.ThumbFormat.BITMAP]:
im = Image.open(io.BytesIO(thumb.data))
else:
rgb = raw.postprocess(
use_camera_wb=False, use_auto_wb=True, output_bps=8)
im = Image.fromarray(rgb)
except Exception as e:
print(f"rawpy processing failed: {e}")
raise e
else:
im = Image.open(io.BytesIO(data))
return _image_to_webp(im, w, h, fit)
async def _collect_response_bytes(response, limit: int) -> bytes:
if response is None:
return b""
try:
if isinstance(response, (bytes, bytearray)):
return bytes(response[:limit])
body = getattr(response, "body", None)
if body is not None:
return bytes(body[:limit])
iterator = getattr(response, "body_iterator", None)
if iterator is not None:
data = bytearray()
async for chunk in iterator:
if not chunk:
continue
need = limit - len(data)
if need <= 0:
break
data.extend(chunk[:need])
if len(data) >= limit:
break
return bytes(data)
if hasattr(response, "__aiter__"):
data = bytearray()
async for chunk in response:
if not chunk:
continue
need = limit - len(data)
if need <= 0:
break
data.extend(chunk[:need])
if len(data) >= limit:
break
return bytes(data)
finally:
close_func = getattr(response, "close", None)
if callable(close_func):
result = close_func()
if inspect.isawaitable(result):
await result
return b""
async def _read_range_slice(adapter, root: str, rel: str, start: int, end: int) -> bytes:
read_range = getattr(adapter, "read_file_range", None)
if callable(read_range):
try:
return await read_range(root, rel, start, end)
except TypeError:
return await read_range(root, rel, start, end=end)
stream_impl = getattr(adapter, "stream_file", None)
if callable(stream_impl):
range_header = f"bytes={start}-{end}"
response = await stream_impl(root, rel, range_header)
expected = end - start + 1
return await _collect_response_bytes(response, expected)
read_file = getattr(adapter, "read_file", None)
if callable(read_file) and start == 0:
data = await read_file(root, rel)
slice_end = end + 1
return data[:slice_end]
return b""
async def _read_video_prefix(adapter, root: str, rel: str, size: int, limit: int = VIDEO_RANGE_LIMIT) -> bytes:
chunk_size = min(VIDEO_INITIAL_CHUNK, limit)
offset = 0
collected = bytearray()
while len(collected) < limit:
end = offset + chunk_size - 1
data = await _read_range_slice(adapter, root, rel, offset, end)
if not data:
break
collected.extend(data)
if len(data) < chunk_size:
break
offset += len(data)
remaining = limit - len(collected)
if remaining <= 0:
break
chunk_size = min(chunk_size * 2, remaining)
if not collected and size <= limit:
read_file = getattr(adapter, "read_file", None)
if callable(read_file):
blob = await read_file(root, rel)
if blob:
return bytes(blob[:limit])
return bytes(collected[:limit])
async def _run_ffmpeg_extract_frame(src_path: str, dst_path: str):
cmd = [
"ffmpeg",
"-y",
"-hide_banner",
"-loglevel", "error",
"-i", src_path,
"-frames:v", "1",
dst_path,
]
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except FileNotFoundError as e:
raise RuntimeError("未找到 ffmpeg可执行文件需要在 PATH 中") from e
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
message = stderr.decode().strip() or stdout.decode().strip() or "ffmpeg 执行失败"
raise RuntimeError(message)
async def _generate_video_thumb(video_bytes: bytes, rel: str, w: int, h: int, fit: str) -> Tuple[bytes, str]:
from PIL import Image
suffix = Path(rel).suffix or ".mp4"
src_tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
src_path = src_tmp.name
try:
src_tmp.write(video_bytes)
src_tmp.flush()
finally:
src_tmp.close()
dst_tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
dst_path = dst_tmp.name
dst_tmp.close()
try:
await _run_ffmpeg_extract_frame(src_path, dst_path)
with Image.open(dst_path) as im:
im.load()
return _image_to_webp(im, w, h, fit)
finally:
with suppress(FileNotFoundError):
Path(src_path).unlink()
with suppress(FileNotFoundError):
Path(dst_path).unlink()
async def get_or_create_thumb(adapter, adapter_id: int, root: str, rel: str, w: int, h: int, fit: str = 'cover'):
stat = await adapter.stat_file(root, rel)
size = int(stat.get('size') or 0)
is_video = is_video_filename(rel)
if not is_video and size > MAX_IMAGE_SOURCE_SIZE:
raise HTTPException(400, detail="Image too large for thumbnail")
key = _cache_key(adapter_id, rel, size, int(
stat.get('mtime', 0)), w, h, fit)
path = _cache_path(key)
if path.exists():
return path.read_bytes(), 'image/webp', key
_ensure_cache_dir(path)
thumb_bytes, mime = None, None
get_thumb_impl = getattr(adapter, "get_thumbnail", None)
if callable(get_thumb_impl):
size_str = "large" if w > 400 else "medium" if w > 100 else "small"
native_thumb_bytes = await get_thumb_impl(root, rel, size_str)
if native_thumb_bytes:
try:
from PIL import Image
im = Image.open(io.BytesIO(native_thumb_bytes))
buf = io.BytesIO()
im.save(buf, 'WEBP', quality=85)
thumb_bytes = buf.getvalue()
mime = 'image/webp'
except Exception as e:
print(
f"Failed to convert native thumbnail to WebP: {e}, falling back.")
thumb_bytes, mime = None, None
if not thumb_bytes:
if is_video:
try:
video_bytes = await _read_video_prefix(adapter, root, rel, size)
except HTTPException:
raise
except Exception as e:
print(f"Video prefix read failed: {e}")
raise HTTPException(500, detail=f"Video read failed: {e}")
if not video_bytes:
raise HTTPException(500, detail="Unable to read video data for thumbnail")
try:
thumb_bytes, mime = await _generate_video_thumb(video_bytes, rel, w, h, fit)
except Exception as e:
print(f"Video thumbnail generation failed: {e}")
raise HTTPException(
500, detail=f"Video thumbnail generation failed: {e}")
else:
read_data = await adapter.read_file(root, rel)
try:
thumb_bytes, mime = generate_thumb(
read_data, w, h, fit, is_raw=is_raw_filename(rel))
except Exception as e:
print(e)
raise HTTPException(
500, detail=f"Thumbnail generation failed: {e}")
if thumb_bytes:
path.write_bytes(thumb_bytes)
return thumb_bytes, mime, key
raise HTTPException(
500, detail="Failed to generate thumbnail by any means")

View File

@@ -0,0 +1,40 @@
from typing import List, Optional
from pydantic import BaseModel
class VfsEntry(BaseModel):
name: str
is_dir: bool
size: int
mtime: int
type: Optional[str] = None
has_thumbnail: Optional[bool] = None
class DirListing(BaseModel):
path: str
entries: List[VfsEntry]
pagination: Optional[dict] = None
class SearchResultItem(BaseModel):
id: int | str
path: str
score: float
chunk_id: Optional[str] = None
snippet: Optional[str] = None
mime: Optional[str] = None
source_type: Optional[str] = None
start_offset: Optional[int] = None
end_offset: Optional[int] = None
metadata: Optional[dict] = None
class MkdirRequest(BaseModel):
path: str
class MoveRequest(BaseModel):
src: str
dst: str

View File

@@ -0,0 +1,301 @@
from __future__ import annotations
import base64
import hashlib
import mimetypes
from email.utils import formatdate
from urllib.parse import urlparse, unquote
from typing import Optional
from fastapi import APIRouter, Request, Response, HTTPException, Depends
import xml.etree.ElementTree as ET
from domain.auth.service import AuthService
from domain.auth.types import User, UserInDB
from domain.virtual_fs.service import VirtualFSService
from domain.config.service import ConfigService
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
async def _ensure_webdav_enabled() -> None:
enabled = await ConfigService.get(_WEBDAV_ENABLED_KEY, "1")
if str(enabled).strip().lower() in ("0", "false", "off", "no"):
raise HTTPException(503, detail="WebDAV mapping disabled")
router = APIRouter(prefix="/webdav", tags=["webdav"])
def _dav_headers(extra: Optional[dict] = None) -> dict:
headers = {
"DAV": "1",
"MS-Author-Via": "DAV",
"Accept-Ranges": "bytes",
"Allow": ", ".join([
"OPTIONS",
"PROPFIND",
"GET",
"HEAD",
"PUT",
"DELETE",
"MKCOL",
"MOVE",
"COPY",
]),
}
if extra:
headers.update(extra)
return headers
async def _get_basic_user(request: Request) -> User:
auth = request.headers.get("Authorization", "")
if not auth:
raise HTTPException(401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic realm=webdav"})
scheme, _, param = auth.partition(" ")
scheme_lower = scheme.lower()
if scheme_lower == "basic":
try:
decoded = base64.b64decode(param).decode("utf-8")
username, _, password = decoded.partition(":")
except Exception:
raise HTTPException(401, detail="Invalid Basic auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
user_or_false: Optional[UserInDB] = await AuthService.authenticate_user_db(username, password)
if not user_or_false:
raise HTTPException(401, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic realm=webdav"})
u: UserInDB = user_or_false
return User(id=u.id, username=u.username, email=u.email, full_name=u.full_name, disabled=u.disabled)
elif scheme_lower == "bearer":
if not param:
raise HTTPException(401, detail="Invalid Bearer token")
return User(id=0, username="bearer", email=None, full_name=None, disabled=False)
else:
raise HTTPException(401, detail="Unsupported auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
def _httpdate(ts: int | float) -> str:
return formatdate(ts, usegmt=True)
def _etag(path: str, size: int | None, mtime: int | None) -> str:
raw = f"{path}|{size or 0}|{mtime or 0}".encode("utf-8")
return '"' + hashlib.md5(raw).hexdigest() + '"'
def _href_for(path: str, is_dir: bool) -> str:
from urllib.parse import quote
p = "/webdav" + (path if path.startswith("/") else "/" + path)
if is_dir and not p.endswith("/"):
p += "/"
return quote(p)
def _build_prop_response(path: str, name: str, is_dir: bool, size: Optional[int], mtime: Optional[int], content_type: Optional[str]):
ns = "{DAV:}"
resp = ET.Element(ns + "response")
href = ET.SubElement(resp, ns + "href")
href.text = _href_for(path, is_dir)
propstat = ET.SubElement(resp, ns + "propstat")
prop = ET.SubElement(propstat, ns + "prop")
displayname = ET.SubElement(prop, ns + "displayname")
displayname.text = name
resourcetype = ET.SubElement(prop, ns + "resourcetype")
if is_dir:
ET.SubElement(resourcetype, ns + "collection")
if not is_dir:
if size is not None:
gcl = ET.SubElement(prop, ns + "getcontentlength")
gcl.text = str(size)
if content_type:
gct = ET.SubElement(prop, ns + "getcontenttype")
gct.text = content_type
if mtime is not None:
glm = ET.SubElement(prop, ns + "getlastmodified")
glm.text = _httpdate(mtime)
etag = ET.SubElement(prop, ns + "getetag")
etag.text = _etag(path, size, mtime)
status = ET.SubElement(propstat, ns + "status")
status.text = "HTTP/1.1 200 OK"
return resp
def _multistatus_xml(responses: list[ET.Element]) -> bytes:
ns = "{DAV:}"
ms = ET.Element(ns + "multistatus")
for r in responses:
ms.append(r)
return ET.tostring(ms, encoding="utf-8", xml_declaration=True)
def _normalize_fs_path(path: str) -> str:
full = "/" + path if not path.startswith("/") else path
return unquote(full)
@router.options("/{path:path}")
async def options_root(path: str = "", _enabled: None = Depends(_ensure_webdav_enabled)):
return Response(status_code=200, headers=_dav_headers())
@router.api_route("/{path:path}", methods=["PROPFIND"])
async def propfind(
request: Request,
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
depth = request.headers.get("Depth", "1").lower()
if depth not in ("0", "1", "infinity"):
depth = "1"
responses: list[ET.Element] = []
# 先获取当前路径信息
try:
st = await VirtualFSService.stat_file(full_path)
is_dir = bool(st.get("is_dir"))
name = st.get("name") or full_path.rsplit("/", 1)[-1] or "/"
size = None if is_dir else int(st.get("size", 0))
mtime = int(st.get("mtime", 0)) if st.get("mtime") is not None else None
ctype = None if is_dir else (mimetypes.guess_type(name)[0] or "application/octet-stream")
responses.append(_build_prop_response(full_path, name, is_dir, size, mtime, ctype))
except FileNotFoundError:
raise HTTPException(404, detail="Not found")
if depth in ("1", "infinity"):
try:
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
for ent in listing["items"]:
is_dir = bool(ent.get("is_dir"))
name = ent.get("name")
child_path = full_path.rstrip("/") + "/" + name
size = None if is_dir else int(ent.get("size", 0))
mtime = int(ent.get("mtime", 0)) if ent.get("mtime") is not None else None
ctype = None if is_dir else (mimetypes.guess_type(name)[0] or "application/octet-stream")
responses.append(_build_prop_response(child_path, name, is_dir, size, mtime, ctype))
except HTTPException as e:
if e.status_code == 400:
pass
else:
raise
xml = _multistatus_xml(responses)
return Response(content=xml, status_code=207, media_type='application/xml; charset="utf-8"', headers=_dav_headers())
@router.get("/{path:path}")
async def dav_get(
path: str,
request: Request,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
range_header = request.headers.get("Range")
return await VirtualFSService.stream_file(full_path, range_header)
@router.head("/{path:path}")
async def dav_head(
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
try:
st = await VirtualFSService.stat_file(full_path)
except FileNotFoundError:
raise HTTPException(404, detail="Not found")
is_dir = bool(st.get("is_dir"))
headers = _dav_headers()
if not is_dir:
size = int(st.get("size", 0))
name = st.get("name") or full_path.rsplit("/", 1)[-1]
ctype = mimetypes.guess_type(name)[0] or "application/octet-stream"
mtime = int(st.get("mtime", 0)) if st.get("mtime") is not None else None
headers.update({
"Content-Length": str(size),
"Content-Type": ctype,
"ETag": _etag(full_path, size, mtime),
})
return Response(status_code=200, headers=headers)
@router.api_route("/{path:path}", methods=["PUT"])
async def dav_put(
path: str,
request: Request,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
async def body_iter():
async for chunk in request.stream():
if chunk:
yield chunk
size = await VirtualFSService.write_file_stream(full_path, body_iter(), overwrite=True)
return Response(status_code=201, headers=_dav_headers({"Content-Length": "0"}))
@router.api_route("/{path:path}", methods=["DELETE"])
async def dav_delete(
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await VirtualFSService.delete_path(full_path)
return Response(status_code=204, headers=_dav_headers())
@router.api_route("/{path:path}", methods=["MKCOL"])
async def dav_mkcol(
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await VirtualFSService.make_dir(full_path)
return Response(status_code=201, headers=_dav_headers())
def _parse_destination(dest: str) -> str:
if not dest:
raise HTTPException(400, detail="Missing Destination header")
p = urlparse(dest)
path = p.path if p.scheme else dest
if path.startswith("/webdav"):
rel = path[len("/webdav"):]
else:
rel = path
return _normalize_fs_path(rel)
@router.api_route("/{path:path}", methods=["MOVE"])
async def dav_move(path: str, request: Request, user: User = Depends(_get_basic_user)):
full_src = _normalize_fs_path(path)
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await VirtualFSService.move_path(full_src, dst, overwrite=overwrite)
return Response(status_code=204, headers=_dav_headers())
@router.api_route("/{path:path}", methods=["COPY"])
async def dav_copy(path: str, request: Request, user: User = Depends(_get_basic_user)):
full_src = _normalize_fs_path(path)
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await VirtualFSService.copy_path(full_src, dst, overwrite=overwrite)
return Response(status_code=201 if not overwrite else 204, headers=_dav_headers())