From 219f3e81b803aefe56951c6116d3bbe06640d73d Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 20 Nov 2025 11:49:32 +0800 Subject: [PATCH] feat: add Google Drive storage adapter (#50) * Initial plan * Add Google Drive storage adapter implementation Co-authored-by: DrizzleTime <169802108+DrizzleTime@users.noreply.github.com> * Add optional methods for direct download and thumbnail support Co-authored-by: DrizzleTime <169802108+DrizzleTime@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: DrizzleTime <169802108+DrizzleTime@users.noreply.github.com> --- api/routes/adapters.py | 8 +- services/adapters/googledrive.py | 560 +++++++++++++++++++++++++++++++ 2 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 services/adapters/googledrive.py diff --git a/api/routes/adapters.py b/api/routes/adapters.py index 0ff7099..0f72acb 100644 --- a/api/routes/adapters.py +++ b/api/routes/adapters.py @@ -77,11 +77,17 @@ async def list_adapters( async def available_adapter_types( current_user: Annotated[User, Depends(get_current_active_user)] ): + name_map = { + "local": "本地文件系统", + "webdav": "WebDAV", + "GoogleDrive": "Google Drive", + "OneDrive": "OneDrive", + } data = [] for t, fields in get_config_schemas().items(): data.append({ "type": t, - "name": "本地文件系统" if t == "local" else ("WebDAV" if t == "webdav" else t), + "name": name_map.get(t, t), "config_schema": fields, }) return success(data) diff --git a/services/adapters/googledrive.py b/services/adapters/googledrive.py new file mode 100644 index 0000000..913c7be --- /dev/null +++ b/services/adapters/googledrive.py @@ -0,0 +1,560 @@ +from __future__ import annotations +from datetime import datetime, timezone, timedelta +from typing import List, Dict, Tuple, AsyncIterator +import httpx +from fastapi.responses import StreamingResponse, Response +from fastapi import HTTPException +from models import StorageAdapter + +GOOGLE_OAUTH_URL = "https://oauth2.googleapis.com/token" +GOOGLE_DRIVE_API_URL = "https://www.googleapis.com/drive/v3" + + +class GoogleDriveAdapter: + """Google Drive 存储适配器""" + + def __init__(self, record: StorageAdapter): + self.record = record + cfg = record.config + self.client_id = cfg.get("client_id") + self.client_secret = cfg.get("client_secret") + self.refresh_token = cfg.get("refresh_token") + self.root_folder_id = cfg.get("root_folder_id", "root") + self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307")) + + if not all([self.client_id, self.client_secret, self.refresh_token]): + raise ValueError( + "Google Drive 适配器需要 client_id, client_secret, 和 refresh_token") + + self._access_token: str | None = None + self._token_expiry: datetime | None = None + + def get_effective_root(self, sub_path: str | None) -> str: + """ + 获取有效根路径。 + :param sub_path: 子路径。 + :return: 完整的有效路径。 + """ + if sub_path: + return f"{sub_path.strip('/')}".strip() + return "" + + async def _get_access_token(self) -> str: + """ + 获取或刷新 access token。 + :return: access token。 + """ + if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry: + return self._access_token + + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + } + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post(GOOGLE_OAUTH_URL, data=data) + resp.raise_for_status() + token_data = resp.json() + self._access_token = token_data["access_token"] + self._token_expiry = datetime.now( + timezone.utc) + timedelta(seconds=token_data["expires_in"] - 300) + return self._access_token + + async def _request(self, method: str, endpoint: str, **kwargs): + """ + 向 Google Drive API 发送请求。 + :param method: HTTP 方法。 + :param endpoint: API 端点。 + :param kwargs: 其他请求参数。 + :return: 响应对象。 + """ + token = await self._get_access_token() + headers = {"Authorization": f"Bearer {token}"} + if "headers" in kwargs: + headers.update(kwargs.pop("headers")) + + url = f"{GOOGLE_DRIVE_API_URL}{endpoint}" + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.request(method, url, headers=headers, **kwargs) + if resp.status_code == 401: + self._access_token = None + token = await self._get_access_token() + headers["Authorization"] = f"Bearer {token}" + resp = await client.request(method, url, headers=headers, **kwargs) + return resp + + async def _get_folder_id_by_path(self, path: str) -> str: + """ + 通过路径获取文件夹 ID。 + :param path: 路径。 + :return: 文件夹 ID。 + """ + if not path or path == "/": + return self.root_folder_id + + parts = [p for p in path.strip("/").split("/") if p] + current_id = self.root_folder_id + + for part in parts: + query = f"name='{part}' and '{current_id}' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false" + params = {"q": query, "fields": "files(id, name)"} + resp = await self._request("GET", "/files", params=params) + resp.raise_for_status() + data = resp.json() + files = data.get("files", []) + if not files: + raise FileNotFoundError(f"文件夹不存在: {part}") + current_id = files[0]["id"] + + return current_id + + async def _get_file_id_by_path(self, path: str) -> str | None: + """ + 通过路径获取文件 ID。 + :param path: 文件路径。 + :return: 文件 ID 或 None。 + """ + if not path or path == "/": + return self.root_folder_id + + parts = [p for p in path.strip("/").split("/") if p] + parent_id = self.root_folder_id + + for i, part in enumerate(parts): + is_last = i == len(parts) - 1 + mime_filter = "" if is_last else "and mimeType='application/vnd.google-apps.folder'" + query = f"name='{part}' and '{parent_id}' in parents {mime_filter} and trashed=false" + params = {"q": query, "fields": "files(id, name)"} + resp = await self._request("GET", "/files", params=params) + resp.raise_for_status() + data = resp.json() + files = data.get("files", []) + if not files: + return None + parent_id = files[0]["id"] + + return parent_id + + def _format_item(self, item: Dict) -> Dict: + """ + 将 Google Drive API 返回的 item 格式化为统一的格式。 + :param item: Google Drive API 返回的 item 字典。 + :return: 格式化后的字典。 + """ + is_dir = item["mimeType"] == "application/vnd.google-apps.folder" + mtime_str = item.get("modifiedTime", item.get("createdTime", "")) + try: + mtime = int(datetime.fromisoformat(mtime_str.replace("Z", "+00:00")).timestamp()) + except: + mtime = 0 + + return { + "name": item["name"], + "is_dir": is_dir, + "size": 0 if is_dir else int(item.get("size", 0)), + "mtime": mtime, + "type": "dir" if is_dir else "file", + } + + async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]: + """ + 列出目录内容。 + :param root: 根路径。 + :param rel: 相对路径。 + :param page_num: 页码。 + :param page_size: 每页大小。 + :param sort_by: 排序字段 + :param sort_order: 排序顺序 + :return: 文件/目录列表和总数。 + """ + try: + folder_id = await self._get_folder_id_by_path(rel) + except FileNotFoundError: + return [], 0 + + query = f"'{folder_id}' in parents and trashed=false" + params = { + "q": query, + "fields": "files(id, name, mimeType, size, modifiedTime, createdTime)", + "pageSize": 1000, + } + + all_items = [] + page_token = None + + while True: + if page_token: + params["pageToken"] = page_token + + resp = await self._request("GET", "/files", params=params) + if resp.status_code == 404: + return [], 0 + resp.raise_for_status() + + data = resp.json() + all_items.extend(data.get("files", [])) + page_token = data.get("nextPageToken") + + if not page_token: + break + + formatted_items = [self._format_item(item) for item in all_items] + + # 排序 + reverse = sort_order.lower() == "desc" + def get_sort_key(item): + key = (not item["is_dir"],) + sort_field = sort_by.lower() + if sort_field == "name": + key += (item["name"].lower(),) + elif sort_field == "size": + key += (item["size"],) + elif sort_field == "mtime": + key += (item["mtime"],) + else: + key += (item["name"].lower(),) + return key + formatted_items.sort(key=get_sort_key, reverse=reverse) + + total_count = len(formatted_items) + start_idx = (page_num - 1) * page_size + end_idx = start_idx + page_size + + return formatted_items[start_idx:end_idx], total_count + + async def read_file(self, root: str, rel: str) -> bytes: + """ + 读取文件内容。 + :param root: 根路径。 + :param rel: 相对路径。 + :return: 文件内容的字节流。 + """ + file_id = await self._get_file_id_by_path(rel) + if not file_id: + raise FileNotFoundError(rel) + + resp = await self._request("GET", f"/files/{file_id}", params={"alt": "media"}) + if resp.status_code == 404: + raise FileNotFoundError(rel) + resp.raise_for_status() + return resp.content + + async def write_file(self, root: str, rel: str, data: bytes): + """ + 写入文件。 + :param root: 根路径。 + :param rel: 相对路径。 + :param data: 文件内容的字节流。 + """ + parent_path = "/".join(rel.strip("/").split("/")[:-1]) + file_name = rel.strip("/").split("/")[-1] + parent_id = await self._get_folder_id_by_path(parent_path) + + # 检查文件是否已存在 + existing_id = await self._get_file_id_by_path(rel) + + if existing_id: + # 更新现有文件 + async with httpx.AsyncClient(timeout=60.0) as client: + token = await self._get_access_token() + headers = {"Authorization": f"Bearer {token}"} + url = f"https://www.googleapis.com/upload/drive/v3/files/{existing_id}?uploadType=media" + resp = await client.patch(url, headers=headers, content=data) + resp.raise_for_status() + else: + # 创建新文件 + metadata = { + "name": file_name, + "parents": [parent_id] + } + + async with httpx.AsyncClient(timeout=60.0) as client: + token = await self._get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + # 使用 multipart 上传 + import json + boundary = "===============boundary===============" + headers["Content-Type"] = f"multipart/related; boundary={boundary}" + + body = ( + f"--{boundary}\r\n" + f"Content-Type: application/json; charset=UTF-8\r\n\r\n" + f"{json.dumps(metadata)}\r\n" + f"--{boundary}\r\n" + f"Content-Type: application/octet-stream\r\n\r\n" + ).encode() + data + f"\r\n--{boundary}--".encode() + + url = "https://www.googleapis.com/upload/drive/v3/files?uploadType=multipart" + resp = await client.post(url, headers=headers, content=body) + resp.raise_for_status() + + async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]): + """ + 以流式方式写入文件。 + :param root: 根路径。 + :param rel: 相对路径。 + :param data_iter: 文件内容的异步迭代器。 + :return: 文件大小。 + """ + # 先收集所有数据 + chunks = [] + total_size = 0 + async for chunk in data_iter: + chunks.append(chunk) + total_size += len(chunk) + + data = b"".join(chunks) + await self.write_file(root, rel, data) + return total_size + + async def mkdir(self, root: str, rel: str): + """ + 创建目录。 + :param root: 根路径。 + :param rel: 相对路径。 + """ + parent_path = "/".join(rel.strip("/").split("/")[:-1]) + folder_name = rel.strip("/").split("/")[-1] + parent_id = await self._get_folder_id_by_path(parent_path) + + metadata = { + "name": folder_name, + "mimeType": "application/vnd.google-apps.folder", + "parents": [parent_id] + } + + resp = await self._request("POST", "/files", json=metadata) + resp.raise_for_status() + + async def delete(self, root: str, rel: str): + """ + 删除文件或目录。 + :param root: 根路径。 + :param rel: 相对路径。 + """ + file_id = await self._get_file_id_by_path(rel) + if not file_id: + return + + resp = await self._request("DELETE", f"/files/{file_id}") + if resp.status_code not in (204, 404): + resp.raise_for_status() + + async def move(self, root: str, src_rel: str, dst_rel: str): + """ + 移动或重命名文件/目录。 + :param root: 根路径。 + :param src_rel: 源相对路径。 + :param dst_rel: 目标相对路径。 + """ + file_id = await self._get_file_id_by_path(src_rel) + if not file_id: + raise FileNotFoundError(src_rel) + + # 获取当前父文件夹 + resp = await self._request("GET", f"/files/{file_id}", params={"fields": "parents"}) + resp.raise_for_status() + current_parents = resp.json().get("parents", []) + + # 获取目标父文件夹和新名称 + dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1]) + dst_name = dst_rel.strip("/").split("/")[-1] + dst_parent_id = await self._get_folder_id_by_path(dst_parent_path) + + # 更新文件 + params = { + "addParents": dst_parent_id, + "removeParents": ",".join(current_parents) if current_parents else None, + } + metadata = {"name": dst_name} + + resp = await self._request("PATCH", f"/files/{file_id}", params=params, json=metadata) + resp.raise_for_status() + + async def rename(self, root: str, src_rel: str, dst_rel: str): + """ + 重命名文件或目录。 + """ + await self.move(root, src_rel, dst_rel) + + async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False): + """ + 复制文件或目录。 + :param root: 根路径。 + :param src_rel: 源相对路径。 + :param dst_rel: 目标相对路径。 + :param overwrite: 是否覆盖。 + """ + file_id = await self._get_file_id_by_path(src_rel) + if not file_id: + raise FileNotFoundError(src_rel) + + dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1]) + dst_name = dst_rel.strip("/").split("/")[-1] + dst_parent_id = await self._get_folder_id_by_path(dst_parent_path) + + metadata = { + "name": dst_name, + "parents": [dst_parent_id] + } + + resp = await self._request("POST", f"/files/{file_id}/copy", json=metadata) + resp.raise_for_status() + + async def stream_file(self, root: str, rel: str, range_header: str | None): + """ + 流式传输文件(支持范围请求)。 + :param root: 根路径。 + :param rel: 相对路径。 + :param range_header: HTTP Range 头。 + :return: FastAPI StreamingResponse 对象。 + """ + file_id = await self._get_file_id_by_path(rel) + if not file_id: + raise FileNotFoundError(rel) + + # 获取文件元数据 + resp = await self._request("GET", f"/files/{file_id}", params={"fields": "name, size, mimeType"}) + if resp.status_code == 404: + raise FileNotFoundError(rel) + resp.raise_for_status() + item_data = resp.json() + + file_size = int(item_data.get("size", 0)) + content_type = item_data.get("mimeType", "application/octet-stream") + + start = 0 + end = file_size - 1 + status = 200 + headers = { + "Accept-Ranges": "bytes", + "Content-Type": content_type, + "Content-Disposition": f"inline; filename=\"{item_data.get('name')}\"" + } + + if range_header and range_header.startswith("bytes="): + try: + part = range_header.removeprefix("bytes=") + s, e = part.split("-", 1) + if s.strip(): + start = int(s) + if e.strip(): + end = int(e) + if start >= file_size: + raise HTTPException(416, "Requested Range Not Satisfiable") + if end >= file_size: + end = file_size - 1 + status = 206 + except ValueError: + raise HTTPException(400, "Invalid Range header") + + headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" + headers["Content-Length"] = str(end - start + 1) + else: + headers["Content-Length"] = str(file_size) + + async def file_iterator(): + nonlocal start, end + token = await self._get_access_token() + async with httpx.AsyncClient(timeout=60.0) as client: + req_headers = { + 'Authorization': f'Bearer {token}', + 'Range': f'bytes={start}-{end}' + } + url = f"{GOOGLE_DRIVE_API_URL}/files/{file_id}?alt=media" + async with client.stream("GET", url, headers=req_headers) as stream_resp: + stream_resp.raise_for_status() + async for chunk in stream_resp.aiter_bytes(): + yield chunk + + return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type) + + async def stat_file(self, root: str, rel: str): + """ + 获取文件或目录的元数据。 + :param root: 根路径。 + :param rel: 相对路径。 + :return: 格式化后的文件/目录信息。 + """ + file_id = await self._get_file_id_by_path(rel) + if not file_id: + raise FileNotFoundError(rel) + + resp = await self._request("GET", f"/files/{file_id}", params={"fields": "id, name, mimeType, size, modifiedTime, createdTime"}) + if resp.status_code == 404: + raise FileNotFoundError(rel) + resp.raise_for_status() + return self._format_item(resp.json()) + + async def get_direct_download_response(self, root: str, rel: str): + """ + 获取直接下载响应 (307 重定向)。 + :param root: 根路径。 + :param rel: 相对路径。 + :return: 307 重定向响应或 None。 + """ + if not self.enable_redirect_307: + return None + + file_id = await self._get_file_id_by_path(rel) + if not file_id: + raise FileNotFoundError(rel) + + # 获取文件的下载链接 + resp = await self._request("GET", f"/files/{file_id}", params={"fields": "webContentLink"}) + if resp.status_code == 404: + raise FileNotFoundError(rel) + resp.raise_for_status() + + item_data = resp.json() + download_url = item_data.get("webContentLink") + if not download_url: + return None + + return Response(status_code=307, headers={"Location": download_url}) + + async def get_thumbnail(self, root: str, rel: str, size: str = "medium"): + """ + 获取文件的缩略图。 + :param root: 根路径。 + :param rel: 相对路径。 + :param size: 缩略图大小 (暂未使用,Google Drive 自动决定)。 + :return: 缩略图内容的字节流,或在不支持时返回 None。 + """ + file_id = await self._get_file_id_by_path(rel) + if not file_id: + return None + + try: + resp = await self._request("GET", f"/files/{file_id}", params={"fields": "thumbnailLink"}) + if resp.status_code == 200: + item_data = resp.json() + thumbnail_link = item_data.get("thumbnailLink") + if thumbnail_link: + async with httpx.AsyncClient(timeout=30.0) as client: + thumb_resp = await client.get(thumbnail_link) + thumb_resp.raise_for_status() + return thumb_resp.content + return None + except Exception: + return None + + +ADAPTER_TYPE = "GoogleDrive" + +CONFIG_SCHEMA = [ + {"key": "client_id", "label": "Client ID", "type": "string", "required": True}, + {"key": "client_secret", "label": "Client Secret", + "type": "password", "required": True}, + {"key": "refresh_token", "label": "Refresh Token", "type": "password", + "required": True, "help_text": "可以通过 Google OAuth 2.0 Playground 获取"}, + {"key": "root_folder_id", "label": "根文件夹 ID (Root Folder ID)", "type": "string", + "required": False, "placeholder": "默认为根目录 (root)", "default": "root"}, + {"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False}, +] + + +def ADAPTER_FACTORY(rec): return GoogleDriveAdapter(rec)