feat(s3): implement multipart upload functionality and related endpoints

This commit is contained in:
shiyu
2025-12-30 12:16:18 +08:00
parent 53130383c1
commit 28ede26801

View File

@@ -2,9 +2,15 @@ import base64
import datetime as dt
import hashlib
import hmac
import json
import os
import re
import shutil
import uuid
from typing import Dict, Iterable, List, Optional, Tuple
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple
import aiofiles
from fastapi import APIRouter, Request, Response
from fastapi import HTTPException
@@ -19,6 +25,12 @@ router = APIRouter(prefix="/s3", tags=["s3"])
FALSEY = {"0", "false", "off", "no"}
_XML_NS = "http://s3.amazonaws.com/doc/2006-03-01/"
_MPU_ROOT = "data/s3_multipart"
_MPU_META_NAME = "meta.json"
_MPU_PART_DATA_TMPL = "part-{part_number:06d}.bin"
_MPU_PART_META_TMPL = "part-{part_number:06d}.json"
_MPU_PART_META_RE = re.compile(r"^part-(\d{6})\.json$")
class S3Settings(Dict[str, str]):
bucket: str
@@ -413,6 +425,380 @@ def _resource_path(bucket: str, key: Optional[str] = None) -> str:
return f"/s3/{bucket}"
def _safe_upload_id(upload_id: Optional[str]) -> Optional[str]:
if not upload_id:
return None
value = upload_id.strip()
if not value:
return None
if "/" in value or "\\" in value:
return None
return value
def _mpu_dir(upload_id: str) -> str:
return os.path.join(_MPU_ROOT, upload_id)
def _mpu_meta_path(upload_id: str) -> str:
return os.path.join(_mpu_dir(upload_id), _MPU_META_NAME)
def _mpu_part_data_path(upload_id: str, part_number: int) -> str:
return os.path.join(_mpu_dir(upload_id), _MPU_PART_DATA_TMPL.format(part_number=part_number))
def _mpu_part_meta_path(upload_id: str, part_number: int) -> str:
return os.path.join(_mpu_dir(upload_id), _MPU_PART_META_TMPL.format(part_number=part_number))
async def _read_json(path: str) -> Optional[Dict[str, Any]]:
try:
async with aiofiles.open(path, "r", encoding="utf-8") as f:
raw = await f.read()
data = json.loads(raw or "{}")
return data if isinstance(data, dict) else None
except FileNotFoundError:
return None
except Exception:
return None
async def _write_json(path: str, data: Dict[str, Any]) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
async with aiofiles.open(path, "w", encoding="utf-8") as f:
await f.write(json.dumps(data, ensure_ascii=False))
async def _load_mpu_meta(bucket: str, key: str, upload_id: Optional[str]) -> Tuple[Optional[Dict[str, Any]], Optional[Response]]:
safe_id = _safe_upload_id(upload_id)
if not safe_id:
return None, _s3_error(
"NoSuchUpload",
"The specified upload does not exist.",
_resource_path(bucket, key),
status=404,
)
meta = await _read_json(_mpu_meta_path(safe_id))
if not meta or meta.get("bucket") != bucket or meta.get("key") != key:
return None, _s3_error(
"NoSuchUpload",
"The specified upload does not exist.",
_resource_path(bucket, key),
status=404,
)
return meta, None
def _parse_int(value: Optional[str], default: int) -> int:
if value is None:
return default
try:
return int(value)
except ValueError:
return default
async def _create_multipart_upload(request: Request, settings: S3Settings, bucket: str, key: str) -> Response:
os.makedirs(_MPU_ROOT, exist_ok=True)
upload_id = uuid.uuid4().hex
dir_path = _mpu_dir(upload_id)
while True:
try:
os.makedirs(dir_path, exist_ok=False)
break
except FileExistsError:
upload_id = uuid.uuid4().hex
dir_path = _mpu_dir(upload_id)
meta = {
"bucket": bucket,
"key": key,
"virtual_path": _virtual_path(settings, key),
"initiated": _now_iso(),
}
await _write_json(_mpu_meta_path(upload_id), meta)
_, headers = _meta_headers()
xml = (
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
f"<CreateMultipartUploadResult xmlns=\"{_XML_NS}\">"
f"<Bucket>{bucket}</Bucket>"
f"<Key>{key}</Key>"
f"<UploadId>{upload_id}</UploadId>"
f"</CreateMultipartUploadResult>"
)
headers.update({"Content-Type": "application/xml"})
return Response(content=xml, media_type="application/xml", headers=headers)
async def _upload_part(request: Request, bucket: str, key: str, upload_id: Optional[str], part_number_raw: Optional[str]) -> Response:
part_number = _parse_int(part_number_raw, 0)
if part_number <= 0:
return _s3_error("InvalidArgument", "partNumber is invalid", _resource_path(bucket, key), status=400)
meta, err = await _load_mpu_meta(bucket, key, upload_id)
if err:
return err
assert meta
safe_id = _safe_upload_id(upload_id)
assert safe_id
part_path = _mpu_part_data_path(safe_id, part_number)
tmp_path = part_path + ".tmp"
md5 = hashlib.md5()
size = 0
async with aiofiles.open(tmp_path, "wb") as f:
async for chunk in request.stream():
if not chunk:
continue
await f.write(chunk)
md5.update(chunk)
size += len(chunk)
etag = '"' + md5.hexdigest() + '"'
os.replace(tmp_path, part_path)
await _write_json(
_mpu_part_meta_path(safe_id, part_number),
{"PartNumber": part_number, "ETag": etag, "Size": size, "LastModified": _now_iso()},
)
_, headers = _meta_headers()
headers.update({"ETag": etag, "Content-Length": "0"})
return Response(status_code=200, headers=headers)
async def _list_parts(request: Request, settings: S3Settings, bucket: str, key: str, upload_id: Optional[str]) -> Response:
meta, err = await _load_mpu_meta(bucket, key, upload_id)
if err:
return err
assert meta
safe_id = _safe_upload_id(upload_id)
assert safe_id
dir_path = _mpu_dir(safe_id)
part_metas: List[Dict[str, Any]] = []
try:
filenames = os.listdir(dir_path)
except FileNotFoundError:
filenames = []
for name in filenames:
m = _MPU_PART_META_RE.match(name)
if not m:
continue
pn = int(m.group(1))
info = await _read_json(os.path.join(dir_path, name))
if not info:
continue
info.setdefault("PartNumber", pn)
part_metas.append(info)
part_metas.sort(key=lambda item: int(item.get("PartNumber") or 0))
max_parts = max(1, min(1000, _parse_int(request.query_params.get("max-parts"), 1000)))
marker = max(0, _parse_int(request.query_params.get("part-number-marker"), 0))
filtered = [p for p in part_metas if int(p.get("PartNumber") or 0) > marker]
is_truncated = len(filtered) > max_parts
shown = filtered[:max_parts]
next_marker = int(shown[-1]["PartNumber"]) if is_truncated and shown else 0
_, headers = _meta_headers()
body = [f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", f"<ListPartsResult xmlns=\"{_XML_NS}\">"]
body.append(f"<Bucket>{bucket}</Bucket>")
body.append(f"<Key>{key}</Key>")
body.append(f"<UploadId>{safe_id}</UploadId>")
body.append(
f"<Initiator><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Initiator>"
)
body.append(
f"<Owner><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Owner>"
)
body.append("<StorageClass>STANDARD</StorageClass>")
body.append(f"<PartNumberMarker>{marker}</PartNumberMarker>")
body.append(f"<NextPartNumberMarker>{next_marker}</NextPartNumberMarker>")
body.append(f"<MaxParts>{max_parts}</MaxParts>")
body.append(f"<IsTruncated>{str(is_truncated).lower()}</IsTruncated>")
for part in shown:
pn = int(part.get("PartNumber") or 0)
etag = part.get("ETag") or ""
size = int(part.get("Size") or 0)
last_modified = part.get("LastModified") or _now_iso()
body.append(
f"<Part><PartNumber>{pn}</PartNumber><LastModified>{last_modified}</LastModified><ETag>{etag}</ETag><Size>{size}</Size></Part>"
)
body.append("</ListPartsResult>")
xml = "".join(body)
headers.update({"Content-Type": "application/xml"})
return Response(content=xml, media_type="application/xml", headers=headers)
async def _abort_multipart_upload(bucket: str, key: str, upload_id: Optional[str]) -> Response:
_, err = await _load_mpu_meta(bucket, key, upload_id)
if err:
return err
safe_id = _safe_upload_id(upload_id)
assert safe_id
shutil.rmtree(_mpu_dir(safe_id), ignore_errors=True)
_, headers = _meta_headers()
return Response(status_code=204, headers=headers)
def _parse_complete_parts(body_bytes: bytes) -> List[Tuple[int, str]]:
if not body_bytes:
return []
root = ET.fromstring(body_bytes)
parts: List[Tuple[int, str]] = []
for part_el in root.findall(".//{*}Part"):
pn_el = part_el.find("{*}PartNumber")
etag_el = part_el.find("{*}ETag")
if pn_el is None or pn_el.text is None:
continue
pn = _parse_int(pn_el.text.strip(), 0)
if pn <= 0:
continue
etag = (etag_el.text or "").strip() if etag_el is not None else ""
parts.append((pn, etag))
parts.sort(key=lambda item: item[0])
return parts
async def _complete_multipart_upload(request: Request, settings: S3Settings, bucket: str, key: str, upload_id: Optional[str]) -> Response:
meta, err = await _load_mpu_meta(bucket, key, upload_id)
if err:
return err
assert meta
safe_id = _safe_upload_id(upload_id)
assert safe_id
try:
body_bytes = await request.body()
except Exception:
body_bytes = b""
try:
parts_req = _parse_complete_parts(body_bytes)
except Exception:
return _s3_error("MalformedXML", "The XML you provided was not well-formed.", _resource_path(bucket, key), status=400)
if not parts_req:
return _s3_error("MalformedXML", "CompleteMultipartUpload parts missing.", _resource_path(bucket, key), status=400)
part_metas: List[Dict[str, Any]] = []
for pn, _etag in parts_req:
info = await _read_json(_mpu_part_meta_path(safe_id, pn))
if not info:
return _s3_error("InvalidPart", "One or more of the specified parts could not be found.", _resource_path(bucket, key), status=400)
info.setdefault("PartNumber", pn)
part_metas.append(info)
async def merged_iter() -> AsyncIterator[bytes]:
for info in part_metas:
pn = int(info.get("PartNumber") or 0)
part_path = _mpu_part_data_path(safe_id, pn)
async with aiofiles.open(part_path, "rb") as f:
while True:
chunk = await f.read(1024 * 1024)
if not chunk:
break
yield chunk
await VirtualFSService.write_file_stream(meta.get("virtual_path") or _virtual_path(settings, key), merged_iter(), overwrite=True)
etag = ""
if len(part_metas) == 1:
etag = str(part_metas[0].get("ETag") or "")
else:
md5_bytes = bytearray()
for info in part_metas:
raw = str(info.get("ETag") or "").strip().strip('"')
try:
md5_bytes.extend(bytes.fromhex(raw))
except ValueError:
pass
digest = hashlib.md5(bytes(md5_bytes)).hexdigest() if md5_bytes else hashlib.md5(b"").hexdigest()
etag = '"' + f"{digest}-{len(part_metas)}" + '"'
shutil.rmtree(_mpu_dir(safe_id), ignore_errors=True)
_, headers = _meta_headers()
headers.update({"Content-Type": "application/xml", "ETag": etag})
location = str(request.url.replace(query=""))
xml = (
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
f"<CompleteMultipartUploadResult xmlns=\"{_XML_NS}\">"
f"<Location>{location}</Location>"
f"<Bucket>{bucket}</Bucket>"
f"<Key>{key}</Key>"
f"<ETag>{etag}</ETag>"
f"</CompleteMultipartUploadResult>"
)
return Response(content=xml, media_type="application/xml", headers=headers)
async def _list_multipart_uploads(request: Request, settings: S3Settings, bucket: str) -> Response:
os.makedirs(_MPU_ROOT, exist_ok=True)
prefix = request.query_params.get("prefix") or ""
max_uploads = max(1, min(1000, _parse_int(request.query_params.get("max-uploads"), 1000)))
key_marker = request.query_params.get("key-marker") or ""
upload_id_marker = request.query_params.get("upload-id-marker") or ""
uploads: List[Tuple[str, str, str]] = []
try:
ids = os.listdir(_MPU_ROOT)
except FileNotFoundError:
ids = []
for uid in ids:
safe_id = _safe_upload_id(uid)
if not safe_id:
continue
meta = await _read_json(_mpu_meta_path(safe_id))
if not meta:
continue
if meta.get("bucket") != bucket:
continue
key = str(meta.get("key") or "")
if prefix and not key.startswith(prefix):
continue
initiated = str(meta.get("initiated") or _now_iso())
uploads.append((key, safe_id, initiated))
uploads.sort(key=lambda item: (item[0], item[1]))
if key_marker:
uploads = [
it
for it in uploads
if (it[0] > key_marker) or (it[0] == key_marker and it[1] > upload_id_marker)
]
is_truncated = len(uploads) > max_uploads
shown = uploads[:max_uploads]
next_key_marker = shown[-1][0] if is_truncated and shown else ""
next_upload_id_marker = shown[-1][1] if is_truncated and shown else ""
_, headers = _meta_headers()
body = [f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", f"<ListMultipartUploadsResult xmlns=\"{_XML_NS}\">"]
body.append(f"<Bucket>{bucket}</Bucket>")
body.append(f"<Prefix>{prefix}</Prefix>")
body.append(f"<KeyMarker>{key_marker}</KeyMarker>")
body.append(f"<UploadIdMarker>{upload_id_marker}</UploadIdMarker>")
body.append(f"<NextKeyMarker>{next_key_marker}</NextKeyMarker>")
body.append(f"<NextUploadIdMarker>{next_upload_id_marker}</NextUploadIdMarker>")
body.append(f"<MaxUploads>{max_uploads}</MaxUploads>")
body.append(f"<IsTruncated>{str(is_truncated).lower()}</IsTruncated>")
for key, uid, initiated in shown:
body.append(
f"<Upload><Key>{key}</Key><UploadId>{uid}</UploadId>"
f"<Initiator><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Initiator>"
f"<Owner><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Owner>"
f"<StorageClass>STANDARD</StorageClass><Initiated>{initiated}</Initiated></Upload>"
)
body.append("</ListMultipartUploadsResult>")
xml = "".join(body)
headers.update({"Content-Type": "application/xml"})
return Response(content=xml, media_type="application/xml", headers=headers)
@router.get("")
@audit(action=AuditAction.READ, description="S3: 列出桶")
async def list_buckets(request: Request):
@@ -451,6 +837,8 @@ async def list_objects(request: Request, bucket: str):
return auth
params = request.query_params
if "uploads" in params:
return await _list_multipart_uploads(request, settings, bucket)
if params.get("list-type", "2") != "2":
return _s3_error("InvalidArgument", "Only ListObjectsV2 (list-type=2) is supported.", _resource_path(bucket), status=400)
@@ -585,6 +973,11 @@ async def object_get_head(request: Request, bucket: str, object_path: str):
return error
assert settings
key = object_path.lstrip("/")
upload_id = request.query_params.get("uploadId") or request.query_params.get("uploadid")
if upload_id and request.method == "GET":
return await _list_parts(request, settings, bucket, key, upload_id)
if upload_id and request.method == "HEAD":
return _s3_error("MethodNotAllowed", "Method Not Allowed", _resource_path(bucket, key), status=405)
meta, err = await _stat_object(settings, key)
if err:
return err
@@ -610,6 +1003,10 @@ async def put_object(request: Request, bucket: str, object_path: str):
return error
assert settings
key = object_path.lstrip("/")
upload_id = request.query_params.get("uploadId") or request.query_params.get("uploadid")
part_number = request.query_params.get("partNumber") or request.query_params.get("partnumber")
if upload_id and part_number:
return await _upload_part(request, bucket, key, upload_id, part_number)
await VirtualFSService.write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
meta, err = await _stat_object(settings, key)
if err:
@@ -623,6 +1020,24 @@ async def put_object(request: Request, bucket: str, object_path: str):
return Response(status_code=200, headers=headers)
@router.post("/{bucket}/{object_path:path}")
@audit(action=AuditAction.UPLOAD, description="S3: Multipart 上传")
async def post_object(request: Request, bucket: str, object_path: str):
settings, error = await _ensure_bucket_and_auth(request, bucket)
if error:
return error
assert settings
key = object_path.lstrip("/")
params = request.query_params
upload_id = params.get("uploadId") or params.get("uploadid")
if "uploads" in params:
return await _create_multipart_upload(request, settings, bucket, key)
if upload_id:
return await _complete_multipart_upload(request, settings, bucket, key, upload_id)
return _s3_error("InvalidRequest", "Unsupported POST operation.", _resource_path(bucket, key), status=400)
@router.delete("/{bucket}/{object_path:path}")
@audit(action=AuditAction.DELETE, description="S3: 删除对象")
async def delete_object(request: Request, bucket: str, object_path: str):
@@ -631,6 +1046,9 @@ async def delete_object(request: Request, bucket: str, object_path: str):
return error
assert settings
key = object_path.lstrip("/")
upload_id = request.query_params.get("uploadId") or request.query_params.get("uploadid")
if upload_id:
return await _abort_multipart_upload(bucket, key, upload_id)
try:
await VirtualFSService.delete_path(_virtual_path(settings, key))
except HTTPException as exc: