feat: Add OneDrive storage adapter with support for file operations and thumbnail retrieval

This commit is contained in:
shiyu
2025-08-29 12:08:05 +08:00
parent 86e81bf40c
commit 2f92fa353c
2 changed files with 477 additions and 14 deletions

View File

@@ -0,0 +1,431 @@
from __future__ import annotations
from datetime import datetime, timezone, timedelta
from typing import List, Dict, Tuple, AsyncIterator
import httpx
from fastapi.responses import StreamingResponse
from fastapi import HTTPException
from models import StorageAdapter
MS_GRAPH_URL = "https://graph.microsoft.com/v1.0"
MS_OAUTH_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
class OneDriveAdapter:
"""OneDrive 存储适配器"""
def __init__(self, record: StorageAdapter):
self.record = record
cfg = record.config
self.client_id = cfg.get("client_id")
self.client_secret = cfg.get("client_secret")
self.refresh_token = cfg.get("refresh_token")
self.root = cfg.get("root", "/").strip("/")
if not all([self.client_id, self.client_secret, self.refresh_token]):
raise ValueError(
"OneDrive 适配器需要 client_id, client_secret, 和 refresh_token")
self._access_token: str | None = None
self._token_expiry: datetime | None = None
def get_effective_root(self, sub_path: str | None) -> str:
"""
获取有效根路径。
:param sub_path: 子路径。
:return: 完整的有效路径。
"""
if sub_path:
return f"/{self.root.strip('/')}/{sub_path.strip('/')}".strip()
return f"/{self.root.strip('/')}".strip()
def _get_api_path(self, rel_path: str) -> str:
"""
将用户可见的相对路径转换为 Graph API 路径段。
:param rel_path: 相对路径。
:return: Graph API 路径段。
"""
full_path = self.get_effective_root(rel_path).strip('/')
if not full_path:
return ""
return f":/{full_path}"
async def _get_access_token(self) -> str:
"""
获取或刷新 access token。
:return: access token。
"""
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
return self._access_token
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"refresh_token": self.refresh_token,
"grant_type": "refresh_token",
}
async with httpx.AsyncClient() as client:
resp = await client.post(MS_OAUTH_URL, data=data)
resp.raise_for_status()
token_data = resp.json()
self._access_token = token_data["access_token"]
self._token_expiry = datetime.now(
timezone.utc) + timedelta(seconds=token_data["expires_in"] - 300)
return self._access_token
async def _request(self, method: str, api_path_segment: str | None = None, *, full_url: str | None = None, **kwargs):
"""
向 Microsoft Graph API 发送请求。
:param method: HTTP 方法。
:param api_path_segment: API 路径段 (与 full_url 互斥)。
:param full_url: 完整的请求 URL (与 api_path_segment 互斥)。
:param kwargs: 其他请求参数。
:return: 响应对象。
"""
if not ((api_path_segment is not None) ^ (full_url is not None)):
raise ValueError("必须提供 api_path_segment 或 full_url 中的一个,且仅一个")
token = await self._get_access_token()
headers = {"Authorization": f"Bearer {token}"}
if "headers" in kwargs:
headers.update(kwargs.pop("headers"))
url = full_url if full_url else f"{MS_GRAPH_URL}/me/drive/root{api_path_segment}"
async with httpx.AsyncClient() as client:
resp = await client.request(method, url, headers=headers, **kwargs)
# 如果 token 过期 (401),刷新并重试一次
if resp.status_code == 401:
self._access_token = None # 强制刷新
token = await self._get_access_token()
headers["Authorization"] = f"Bearer {token}"
resp = await client.request(method, url, headers=headers, **kwargs)
return resp
def _format_item(self, item: Dict) -> Dict:
"""
将 Graph API 返回的 item 格式化为统一的格式。
:param item: Graph API 返回的 item 字典。
:return: 格式化后的字典。
"""
is_dir = "folder" in item
return {
"name": item["name"],
"is_dir": is_dir,
"size": 0 if is_dir else item.get("size", 0),
"mtime": int(datetime.fromisoformat(item["lastModifiedDateTime"].replace("Z", "+00:00")).timestamp()),
"type": "dir" if is_dir else "file",
}
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
"""
列出目录内容。
:param root: 根路径 (在此适配器中未使用,通过配置的 root 确定)。
:param rel: 相对路径。
:param page_num: 页码。
:param page_size: 每页大小。
:return: 文件/目录列表和总数。
"""
api_path = self._get_api_path(rel)
children_path = f"{api_path}:/children" if api_path else "/children"
# Graph API 的分页是基于 @odata.nextLink token 的。
# 为了支持自定义排序(文件夹在前),我们必须获取所有项目,
# 然后在内存中进行排序和分页。此版本通过处理分页链接来稳健地获取所有项目。
all_items = []
# 初始请求
resp = await self._request("GET", api_path_segment=children_path, params={"$top": 200})
while True:
if resp.status_code == 404 and not all_items:
return [], 0
resp.raise_for_status()
try:
data = resp.json()
except Exception as e:
raise IOError(f"解析 Graph API 响应失败: {e}") from e
all_items.extend(data.get("value", []))
next_link = data.get("@odata.nextLink")
if not next_link:
break
# 后续分页请求
resp = await self._request("GET", full_url=next_link)
formatted_items = [self._format_item(item) for item in all_items]
# 排序:文件夹在前,然后按名称排序
formatted_items.sort(key=lambda x: (
not x["is_dir"], x["name"].lower()))
total_count = len(formatted_items)
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size
return formatted_items[start_idx:end_idx], total_count
async def read_file(self, root: str, rel: str) -> bytes:
"""
读取文件内容。
:param root: 根路径。
:param rel: 相对路径。
:return: 文件内容的字节流。
"""
api_path = self._get_api_path(rel)
if not api_path:
raise IsADirectoryError("不能将根目录作为文件读取")
resp = await self._request("GET", api_path_segment=f"{api_path}:/content")
if resp.status_code == 404:
raise FileNotFoundError(rel)
resp.raise_for_status()
return resp.content
async def write_file(self, root: str, rel: str, data: bytes):
"""
写入文件。
:param root: 根路径。
:param rel: 相对路径。
:param data: 文件内容的字节流。
"""
api_path = self._get_api_path(rel)
if not api_path:
raise ValueError("不能直接写入根路径")
resp = await self._request("PUT", api_path_segment=f"{api_path}:/content", content=data)
resp.raise_for_status()
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
"""
以流式方式写入文件。
:param root: 根路径。
:param rel: 相对路径。
:param data_iter: 文件内容的异步迭代器。
:return: 文件大小。
"""
api_path = self._get_api_path(rel)
if not api_path:
raise ValueError("不能直接写入根路径")
resp = await self._request("PUT", api_path_segment=f"{api_path}:/content", content=data_iter)
resp.raise_for_status()
return resp.json().get("size", 0)
async def mkdir(self, root: str, rel: str):
"""
创建目录。
:param root: 根路径。
:param rel: 相对路径。
"""
parent_path_str, new_dir_name = rel.rstrip(
'/').rsplit('/', 1) if '/' in rel.rstrip('/') else ('', rel)
parent_api_path = self._get_api_path(parent_path_str)
children_path = f"{parent_api_path}:/children" if parent_api_path else "/children"
payload = {
"name": new_dir_name,
"folder": {},
"@microsoft.graph.conflictBehavior": "fail" # 如果已存在则失败
}
resp = await self._request("POST", api_path_segment=children_path, json=payload)
resp.raise_for_status()
async def delete(self, root: str, rel: str):
"""
删除文件或目录。
:param root: 根路径。
:param rel: 相对路径。
"""
api_path = self._get_api_path(rel)
if not api_path:
raise ValueError("不能删除根目录")
resp = await self._request("DELETE", api_path_segment=api_path)
if resp.status_code not in (204, 404):
resp.raise_for_status()
async def move(self, root: str, src_rel: str, dst_rel: str):
"""
移动或重命名文件/目录。
:param root: 根路径。
:param src_rel: 源相对路径。
:param dst_rel: 目标相对路径。
"""
src_api_path = self._get_api_path(src_rel)
if not src_api_path:
raise ValueError("不能移动根目录")
dst_parent_rel, dst_name = dst_rel.rstrip(
'/').rsplit('/', 1) if '/' in dst_rel.rstrip('/') else ('', dst_rel)
dst_parent_api_path = self._get_api_path(dst_parent_rel)
# 获取父项目的 ID
parent_resp = await self._request("GET", api_path_segment=dst_parent_api_path)
parent_resp.raise_for_status()
parent_id = parent_resp.json()["id"]
payload = {
"parentReference": {"id": parent_id},
"name": dst_name
}
resp = await self._request("PATCH", api_path_segment=src_api_path, json=payload)
resp.raise_for_status()
async def rename(self, root: str, src_rel: str, dst_rel: str):
"""
重命名文件或目录。
在 Graph API 中,移动和重命名是同一个 PATCH 操作。
"""
await self.move(root, src_rel, dst_rel)
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
"""
复制文件或目录。
:param root: 根路径。
:param src_rel: 源相对路径。
:param dst_rel: 目标相对路径。
:param overwrite: 是否覆盖 (在此 API 中未直接使用)。
"""
src_api_path = self._get_api_path(src_rel)
if not src_api_path:
raise ValueError("不能复制根目录")
dst_parent_rel, dst_name = dst_rel.rstrip(
'/').rsplit('/', 1) if '/' in dst_rel.rstrip('/') else ('', dst_rel)
dst_parent_api_path = self._get_api_path(dst_parent_rel)
parent_resp = await self._request("GET", api_path_segment=dst_parent_api_path)
parent_resp.raise_for_status()
parent_id = parent_resp.json()["id"]
payload = {"parentReference": {"id": parent_id}, "name": dst_name}
copy_path = f"{src_api_path}:/copy"
resp = await self._request("POST", api_path_segment=copy_path, json=payload)
resp.raise_for_status()
async def stream_file(self, root: str, rel: str, range_header: str | None):
"""
流式传输文件(支持范围请求)。
:param root: 根路径。
:param rel: 相对路径。
:param range_header: HTTP Range 头。
:return: FastAPI StreamingResponse 对象。
"""
api_path = self._get_api_path(rel)
if not api_path:
raise IsADirectoryError("不能对目录进行流式传输")
resp = await self._request("GET", api_path_segment=api_path)
if resp.status_code == 404:
raise FileNotFoundError(rel)
resp.raise_for_status()
item_data = resp.json()
download_url = item_data.get("@microsoft.graph.downloadUrl")
if not download_url:
raise Exception("无法获取下载 URL")
file_size = item_data.get("size", 0)
content_type = item_data.get("file", {}).get(
"mimeType", "application/octet-stream")
start = 0
end = file_size - 1
status = 200
headers = {
"Accept-Ranges": "bytes",
"Content-Type": content_type,
"Content-Disposition": f"inline; filename=\"{item_data.get('name')}\""
}
if range_header and range_header.startswith("bytes="):
try:
part = range_header.removeprefix("bytes=")
s, e = part.split("-", 1)
if s.strip():
start = int(s)
if e.strip():
end = int(e)
if start >= file_size:
raise HTTPException(416, "Requested Range Not Satisfiable")
if end >= file_size:
end = file_size - 1
status = 206
except ValueError:
raise HTTPException(400, "Invalid Range header")
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
headers["Content-Length"] = str(end - start + 1)
else:
headers["Content-Length"] = str(file_size)
async def file_iterator():
nonlocal start, end
async with httpx.AsyncClient() as client:
req_headers = {'Range': f'bytes={start}-{end}'}
async with client.stream("GET", download_url, headers=req_headers) as stream_resp:
stream_resp.raise_for_status()
async for chunk in stream_resp.aiter_bytes():
yield chunk
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
"""
获取文件的缩略图。
:param root: 根路径。
:param rel: 相对路径。
:param size: 缩略图大小 (large, medium, small)。
:return: 缩略图内容的字节流,或在不支持时返回 None。
"""
api_path = self._get_api_path(rel)
if not api_path:
return None
thumb_path = f"{api_path}:/thumbnails/0/{size}"
try:
resp = await self._request("GET", api_path_segment=thumb_path)
if resp.status_code == 200:
thumb_data = resp.json()
async with httpx.AsyncClient() as client:
thumb_resp = await client.get(thumb_data['url'])
thumb_resp.raise_for_status()
return thumb_resp.content
elif resp.status_code == 404:
return None
else:
resp.raise_for_status()
except Exception:
return None
async def stat_file(self, root: str, rel: str):
"""
获取文件或目录的元数据。
:param root: 根路径。
:param rel: 相对路径。
:return: 格式化后的文件/目录信息。
"""
api_path = self._get_api_path(rel)
resp = await self._request("GET", api_path_segment=api_path)
if resp.status_code == 404:
raise FileNotFoundError(rel)
resp.raise_for_status()
return self._format_item(resp.json())
ADAPTER_TYPE = "OneDrive"
CONFIG_SCHEMA = [
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
{"key": "client_secret", "label": "Client Secret",
"type": "password", "required": True},
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
"required": True, "help_text": "可以通过运行 'python -m services.adapters.onedrive' 获取"},
{"key": "root", "label": "根目录 (Root Path)", "type": "string",
"required": False, "placeholder": "默认为根目录 /"},
]
def ADAPTER_FACTORY(rec): return OneDriveAdapter(rec)

View File

@@ -5,7 +5,8 @@ from pathlib import Path
from typing import Tuple
from fastapi import HTTPException
ALLOWED_EXT = {"jpg", "jpeg", "png", "webp", "gif", "bmp", "tiff", "arw", "cr2", "cr3", "nef", "rw2", "orf", "pef", "dng"}
ALLOWED_EXT = {"jpg", "jpeg", "png", "webp", "gif", "bmp",
"tiff", "arw", "cr2", "cr3", "nef", "rw2", "orf", "pef", "dng"}
RAW_EXT = {"arw", "cr2", "cr3", "nef", "rw2", "orf", "pef", "dng"}
MAX_SOURCE_SIZE = 200 * 1024 * 1024
CACHE_ROOT = Path('data/.thumb_cache')
@@ -49,11 +50,12 @@ def generate_thumb(data: bytes, w: int, h: int, fit: str, is_raw: bool = False)
thumb = raw.extract_thumb()
except rawpy.LibRawNoThumbnailError:
thumb = None
if thumb is not None and thumb.format in [rawpy.ThumbFormat.JPEG, rawpy.ThumbFormat.BITMAP]:
im = Image.open(io.BytesIO(thumb.data))
else:
rgb = raw.postprocess(use_camera_wb=False, use_auto_wb=True, output_bps=8)
rgb = raw.postprocess(
use_camera_wb=False, use_auto_wb=True, output_bps=8)
im = Image.fromarray(rgb)
except Exception as e:
print(f"rawpy processing failed: {e}")
@@ -87,18 +89,48 @@ def generate_thumb(data: bytes, w: int, h: int, fit: str, is_raw: bool = False)
async def get_or_create_thumb(adapter, adapter_id: int, root: str, rel: str, w: int, h: int, fit: str = 'cover'):
stat = await adapter.stat_file(root, rel)
if stat['size'] > MAX_SOURCE_SIZE:
raise HTTPException(400, detail="Image too large for thumbnail")
key = _cache_key(adapter_id, rel, stat['size'], int(stat['mtime']), w, h, fit)
raise HTTPException(400, detail="Image too large for thumbnail")
key = _cache_key(adapter_id, rel, stat['size'], int(
stat['mtime']), w, h, fit)
path = _cache_path(key)
if path.exists():
return path.read_bytes(), 'image/webp', key
_ensure_cache_dir(path)
read_data = await adapter.read_file(root, rel)
try:
thumb_bytes, mime = generate_thumb(read_data, w, h, fit, is_raw=is_raw_filename(rel))
except Exception as e:
print(e)
raise HTTPException(500, detail=f"Thumbnail generation failed: {e}")
path.write_bytes(thumb_bytes)
return thumb_bytes, mime, key
thumb_bytes, mime = None, None
get_thumb_impl = getattr(adapter, "get_thumbnail", None)
if callable(get_thumb_impl):
size_str = "large" if w > 400 else "medium" if w > 100 else "small"
native_thumb_bytes = await get_thumb_impl(root, rel, size_str)
if native_thumb_bytes:
try:
from PIL import Image
im = Image.open(io.BytesIO(native_thumb_bytes))
buf = io.BytesIO()
im.save(buf, 'WEBP', quality=85)
thumb_bytes = buf.getvalue()
mime = 'image/webp'
except Exception as e:
print(
f"Failed to convert native thumbnail to WebP: {e}, falling back.")
thumb_bytes, mime = None, None
if not thumb_bytes:
read_data = await adapter.read_file(root, rel)
try:
thumb_bytes, mime = generate_thumb(
read_data, w, h, fit, is_raw=is_raw_filename(rel))
except Exception as e:
print(e)
raise HTTPException(
500, detail=f"Thumbnail generation failed: {e}")
if thumb_bytes:
path.write_bytes(thumb_bytes)
return thumb_bytes, mime, key
raise HTTPException(
500, detail="Failed to generate thumbnail by any means")