From 20bc1cfbb71539e0c6f4825e8c0ca6d876551c11 Mon Sep 17 00:00:00 2001 From: shiyu Date: Fri, 29 Aug 2025 12:50:51 +0800 Subject: [PATCH] feat: Implement S3Adapter for S3 compatible object storage with file operations --- services/adapters/s3.py | 366 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 366 insertions(+) create mode 100644 services/adapters/s3.py diff --git a/services/adapters/s3.py b/services/adapters/s3.py new file mode 100644 index 0000000..f904be4 --- /dev/null +++ b/services/adapters/s3.py @@ -0,0 +1,366 @@ +from __future__ import annotations +import asyncio +import mimetypes +from datetime import datetime +from typing import List, Dict, Tuple, AsyncIterator +from urllib.parse import quote + +import aioboto3 +from botocore.exceptions import ClientError +from fastapi import HTTPException +from fastapi.responses import StreamingResponse +from models import StorageAdapter +from services.logging import LogService + + +class S3Adapter: + """S3 兼容对象存储适配器""" + + def __init__(self, record: StorageAdapter): + self.record = record + cfg = record.config + self.bucket_name = cfg.get("bucket_name") + self.aws_access_key_id = cfg.get("access_key_id") + self.aws_secret_access_key = cfg.get("secret_access_key") + self.region_name = cfg.get("region_name") + self.endpoint_url = cfg.get("endpoint_url") + self.root = cfg.get("root", "").strip("/") + + if not all([self.bucket_name, self.aws_access_key_id, self.aws_secret_access_key]): + raise ValueError( + "S3 适配器需要 bucket_name, access_key_id, 和 secret_access_key") + + self.session = aioboto3.Session( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + region_name=self.region_name, + ) + + def get_effective_root(self, sub_path: str | None) -> str: + """获取 S3 中的有效根路径 (key prefix)""" + if sub_path: + return f"{self.root}/{sub_path.strip('/')}".strip("/") + return self.root + + def _get_s3_key(self, rel_path: str) -> str: + """将相对路径转换为 S3 key""" + rel_path = rel_path.strip("/") + if self.root: + return f"{self.root}/{rel_path}" + return rel_path + + def _get_client(self): + return self.session.client("s3", endpoint_url=self.endpoint_url) + + async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]: + prefix = self._get_s3_key(rel) + if prefix and not prefix.endswith("/"): + prefix += "/" + + all_items = [] + + async with self._get_client() as s3: + paginator = s3.get_paginator("list_objects_v2") + async for result in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter="/"): + # 添加子目录 + for common_prefix in result.get("CommonPrefixes", []): + dir_name = common_prefix.get( + "Prefix").removeprefix(prefix).strip("/") + if dir_name: + all_items.append({ + "name": dir_name, + "is_dir": True, + "size": 0, + "mtime": 0, + "type": "dir", + }) + + # 添加文件 + for content in result.get("Contents", []): + file_key = content.get("Key") + if file_key == prefix: # 忽略目录本身 + continue + file_name = file_key.removeprefix(prefix) + if file_name: + all_items.append({ + "name": file_name, + "is_dir": False, + "size": content.get("Size", 0), + "mtime": int(content.get("LastModified", datetime.now()).timestamp()), + "type": "file", + }) + + # 在内存中排序和分页 + all_items.sort(key=lambda x: (not x["is_dir"], x["name"].lower())) + total_count = len(all_items) + start_idx = (page_num - 1) * page_size + end_idx = start_idx + page_size + + return all_items[start_idx:end_idx], total_count + + async def read_file(self, root: str, rel: str) -> bytes: + key = self._get_s3_key(rel) + async with self._get_client() as s3: + try: + resp = await s3.get_object(Bucket=self.bucket_name, Key=key) + return await resp["Body"].read() + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError(rel) + raise + + async def write_file(self, root: str, rel: str, data: bytes): + key = self._get_s3_key(rel) + async with self._get_client() as s3: + await s3.put_object(Bucket=self.bucket_name, Key=key, Body=data) + await LogService.info( + "adapter:s3", f"Wrote file to {rel}", + details={"adapter_id": self.record.id, + "bucket": self.bucket_name, "key": key, "size": len(data)} + ) + + async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]): + key = self._get_s3_key(rel) + MIN_PART_SIZE = 5 * 1024 * 1024 + + async with self._get_client() as s3: + mpu = await s3.create_multipart_upload(Bucket=self.bucket_name, Key=key) + upload_id = mpu['UploadId'] + + parts = [] + part_number = 1 + total_size = 0 + buffer = bytearray() + + try: + async for chunk in data_iter: + if not chunk: + continue + buffer.extend(chunk) + + while len(buffer) >= MIN_PART_SIZE: + part_data = buffer[:MIN_PART_SIZE] + del buffer[:MIN_PART_SIZE] + + part = await s3.upload_part( + Bucket=self.bucket_name, + Key=key, + PartNumber=part_number, + UploadId=upload_id, + Body=part_data + ) + + parts.append({'PartNumber': part_number, 'ETag': part['ETag']}) + total_size += len(part_data) + part_number += 1 + + if buffer: + part = await s3.upload_part( + Bucket=self.bucket_name, + Key=key, + PartNumber=part_number, + UploadId=upload_id, + Body=bytes(buffer) + ) + parts.append({'PartNumber': part_number, 'ETag': part['ETag']}) + total_size += len(buffer) + + await s3.complete_multipart_upload( + Bucket=self.bucket_name, + Key=key, + UploadId=upload_id, + MultipartUpload={'Parts': parts} + ) + except Exception as e: + await s3.abort_multipart_upload( + Bucket=self.bucket_name, + Key=key, + UploadId=upload_id + ) + raise IOError(f"S3 stream upload failed: {e}") from e + + await LogService.info( + "adapter:s3", f"Wrote file stream to {rel}", + details={"adapter_id": self.record.id, "bucket": self.bucket_name, "key": key, "size": total_size} + ) + return total_size + + async def mkdir(self, root: str, rel: str): + key = self._get_s3_key(rel) + if not key.endswith("/"): + key += "/" + async with self._get_client() as s3: + await s3.put_object(Bucket=self.bucket_name, Key=key, Body=b"") + await LogService.info( + "adapter:s3", f"Created directory {rel}", + details={"adapter_id": self.record.id, + "bucket": self.bucket_name, "key": key} + ) + + async def delete(self, root: str, rel: str): + key = self._get_s3_key(rel) + async with self._get_client() as s3: + is_dir_like = False + try: + head = await s3.head_object(Bucket=self.bucket_name, Key=key) + if head['ContentLength'] == 0 and key.endswith('/'): + is_dir_like = True + except ClientError as e: + if e.response['Error']['Code'] != '404': + raise + + # 如果是目录,删除目录下的所有对象 + if is_dir_like or not await self.stat_file(root, rel): + dir_key = key if key.endswith('/') else key + '/' + paginator = s3.get_paginator("list_objects_v2") + objects_to_delete = [] + async for result in paginator.paginate(Bucket=self.bucket_name, Prefix=dir_key): + for content in result.get("Contents", []): + objects_to_delete.append({"Key": content["Key"]}) + if objects_to_delete: + await s3.delete_objects(Bucket=self.bucket_name, Delete={"Objects": objects_to_delete}) + # 如果是文件,直接删除 + else: + await s3.delete_object(Bucket=self.bucket_name, Key=key) + + await LogService.info( + "adapter:s3", f"Deleted {rel}", + details={"adapter_id": self.record.id, + "bucket": self.bucket_name, "key": key} + ) + + async def move(self, root: str, src_rel: str, dst_rel: str): + await self.copy(root, src_rel, dst_rel, overwrite=True) + await self.delete(root, src_rel) + await LogService.info( + "adapter:s3", f"Moved {src_rel} to {dst_rel}", + details={"adapter_id": self.record.id, "bucket": self.bucket_name, + "src_key": self._get_s3_key(src_rel), "dst_key": self._get_s3_key(dst_rel)} + ) + + async def rename(self, root: str, src_rel: str, dst_rel: str): + await self.move(root, src_rel, dst_rel) + + async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False): + src_key = self._get_s3_key(src_rel) + dst_key = self._get_s3_key(dst_rel) + + async with self._get_client() as s3: + if not overwrite: + try: + await s3.head_object(Bucket=self.bucket_name, Key=dst_key) + raise FileExistsError(dst_rel) + except ClientError as e: + if e.response["Error"]["Code"] != "404": + raise + + copy_source = {"Bucket": self.bucket_name, "Key": src_key} + await s3.copy_object(CopySource=copy_source, Bucket=self.bucket_name, Key=dst_key) + await LogService.info( + "adapter:s3", f"Copied {src_rel} to {dst_rel}", + details={"adapter_id": self.record.id, "bucket": self.bucket_name, + "src_key": src_key, "dst_key": dst_key} + ) + + async def stat_file(self, root: str, rel: str): + key = self._get_s3_key(rel) + async with self._get_client() as s3: + try: + head = await s3.head_object(Bucket=self.bucket_name, Key=key) + return { + "name": rel.split("/")[-1], + "is_dir": False, + "size": head["ContentLength"], + "mtime": int(head["LastModified"].timestamp()), + "type": "file", + } + except ClientError as e: + if e.response["Error"]["Code"] == "404": + # 检查是否为一个 "目录" + dir_key = key if key.endswith('/') else key + '/' + resp = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=dir_key, MaxKeys=1) + if resp.get('KeyCount', 0) > 0: + return { + "name": rel.split("/")[-1], + "is_dir": True, + "size": 0, + "mtime": 0, + "type": "dir", + } + raise FileNotFoundError(rel) + raise + + async def stream_file(self, root: str, rel: str, range_header: str | None): + key = self._get_s3_key(rel) + async with self._get_client() as s3: + try: + head = await s3.head_object(Bucket=self.bucket_name, Key=key) + file_size = head["ContentLength"] + content_type = head.get("ContentType", mimetypes.guess_type(key)[ + 0] or "application/octet-stream") + except ClientError as e: + if e.response["Error"]["Code"] == "404": + raise HTTPException( + status_code=404, detail="File not found") + raise + + start = 0 + end = file_size - 1 + status = 200 + headers = { + "Accept-Ranges": "bytes", + "Content-Type": content_type, + "Content-Length": str(file_size), + "Content-Disposition": f"inline; filename=\"{quote(rel.split('/')[-1])}\"" + } + + if range_header: + range_val = range_header.strip().partition("=")[2] + s, _, e = range_val.partition("-") + try: + start = int(s) if s else 0 + end = int(e) if e else file_size - 1 + if start >= file_size or end >= file_size or start > end: + raise HTTPException( + status_code=416, detail="Requested Range Not Satisfiable") + status = 206 + headers["Content-Length"] = str(end - start + 1) + headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" + except ValueError: + raise HTTPException( + status_code=400, detail="Invalid Range header") + + range_arg = f"bytes={start}-{end}" + + async def iterator(): + try: + resp = await s3.get_object(Bucket=self.bucket_name, Key=key, Range=range_arg) + body = resp["Body"] + while chunk := await body.read(65536): + yield chunk + except Exception as e: + LogService.error( + "adapter:s3", f"Error streaming file {key}: {e}") + + return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type) + + +ADAPTER_TYPE = "S3" + +CONFIG_SCHEMA = [ + {"key": "bucket_name", "label": "Bucket 名称", + "type": "string", "required": True}, + {"key": "access_key_id", "label": "Access Key ID", + "type": "string", "required": True}, + {"key": "secret_access_key", "label": "Secret Access Key", + "type": "password", "required": True}, + {"key": "region_name", "label": "区域 (Region)", "type": "string", + "required": False, "placeholder": "例如 us-east-1"}, + {"key": "endpoint_url", "label": "Endpoint URL", "type": "string", + "required": False, "placeholder": "对于 S3 兼容存储, 例如 https://minio.example.com"}, + {"key": "root", "label": "根路径 (Root Path)", "type": "string", + "required": False, "placeholder": "在 bucket 内的路径前缀"}, +] + + +def ADAPTER_FACTORY(rec): return S3Adapter(rec)