From 097dff13a3b301ebf64269cc4c7c07b35774da34 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Wed, 22 Apr 2026 16:11:38 +0800 Subject: [PATCH] feat: add ai-compatible API endpoints --- app/api/apiv1.py | 4 +- app/api/endpoints/anthropic.py | 158 ++++++++++++ app/api/endpoints/openai.py | 426 +++++++++++++++++++++++++++++++++ app/api/openai_utils.py | 177 ++++++++++++++ app/core/security.py | 8 +- app/schemas/__init__.py | 2 +- app/schemas/openai.py | 156 ++++++++++++ tests/test_openai_utils.py | 120 ++++++++++ 8 files changed, 1048 insertions(+), 3 deletions(-) create mode 100644 app/api/endpoints/anthropic.py create mode 100644 app/api/endpoints/openai.py create mode 100644 app/api/openai_utils.py create mode 100644 app/schemas/openai.py create mode 100644 tests/test_openai_utils.py diff --git a/app/api/apiv1.py b/app/api/apiv1.py index b519acb9..3744d2f1 100644 --- a/app/api/apiv1.py +++ b/app/api/apiv1.py @@ -2,7 +2,7 @@ from fastapi import APIRouter from app.api.endpoints import login, user, webhook, message, site, subscribe, \ media, douban, search, plugin, tmdb, history, system, download, dashboard, \ - transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa + transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic api_router = APIRouter() api_router.include_router(login.router, prefix="/login", tags=["login"]) @@ -30,3 +30,5 @@ api_router.include_router(recommend.router, prefix="/recommend", tags=["recommen api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"]) api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"]) api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"]) +api_router.include_router(openai.router, prefix="/openai/v1", tags=["openai"]) +api_router.include_router(anthropic.router, prefix="/anthropic/v1", tags=["anthropic"]) diff --git a/app/api/endpoints/anthropic.py b/app/api/endpoints/anthropic.py new file mode 100644 index 00000000..dde2d192 --- /dev/null +++ b/app/api/endpoints/anthropic.py @@ -0,0 +1,158 @@ +import asyncio +import json +import time +import uuid +from typing import AsyncIterator, List, Optional + +from fastapi import APIRouter, Header, Security +from fastapi.responses import JSONResponse, StreamingResponse + +from app import schemas +from app.api.endpoints.openai import ( + MODEL_ID, + _CollectingMoviePilotAgent, + _error_response as _openai_error_response, +) +from app.api.openai_utils import build_anthropic_messages, build_prompt, build_session_id +from app.core.config import settings +from app.core.security import anthropic_api_key_header +from app.schemas.types import MessageChannel + +router = APIRouter() + +SESSION_PREFIX = "anthropic:" + + +def _anthropic_error_response( + message: str, + status_code: int, + error_type: str = "invalid_request_error", +) -> JSONResponse: + return JSONResponse( + status_code=status_code, + content=schemas.AnthropicErrorResponse( + error=schemas.AnthropicErrorDetail(type=error_type, message=message) + ).model_dump(), + ) + + +def _check_auth(api_key: Optional[str]) -> Optional[JSONResponse]: + if not api_key or api_key != settings.API_TOKEN: + return _anthropic_error_response( + "invalid x-api-key", + 401, + error_type="authentication_error", + ) + return None + + +async def _stream_anthropic_response( + agent: _CollectingMoviePilotAgent, + prompt: str, + images: List[str], +) -> AsyncIterator[str]: + event_queue: asyncio.Queue = asyncio.Queue() + if hasattr(agent.stream_handler, "bind_queue"): + agent.stream_handler.bind_queue(event_queue) + + message_id = f"msg_{uuid.uuid4().hex}" + + async def _run_agent(): + try: + await agent.process(prompt, images=images, files=None) + except Exception as exc: + await event_queue.put({"error": str(exc)}) + finally: + await event_queue.put(None) + + task = asyncio.create_task(_run_agent()) + try: + yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': MODEL_ID, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}}, ensure_ascii=False)}\n\n" + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}, ensure_ascii=False)}\n\n" + while True: + item = await event_queue.get() + if item is None: + break + if isinstance(item, dict) and item.get("error"): + raise RuntimeError(str(item["error"])) + text = str(item or "") + if not text: + continue + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': text}}, ensure_ascii=False)}\n\n" + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0}, ensure_ascii=False)}\n\n" + yield f"event: message_delta\ndata: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 0}}, ensure_ascii=False)}\n\n" + yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'}, ensure_ascii=False)}\n\n" + finally: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@router.post("/messages", summary="Anthropic compatible messages", response_model=schemas.AnthropicMessagesResponse) +async def messages( + payload: schemas.AnthropicMessagesRequest, + x_api_key: Optional[str] = Security(anthropic_api_key_header), + anthropic_version: Optional[str] = Header(default=None, alias="anthropic-version"), +): + auth_error = _check_auth(x_api_key) + if auth_error: + return auth_error + + if not settings.AI_AGENT_ENABLE: + return _anthropic_error_response( + "MoviePilot AI agent is disabled.", + 503, + error_type="api_error", + ) + + normalized_messages = build_anthropic_messages(payload.system, payload.messages) + try: + prompt, images = build_prompt(normalized_messages, use_server_session=False) + except ValueError as exc: + return _anthropic_error_response(str(exc), 400) + + session_seed = anthropic_version or "anthropic" + session_id = build_session_id(f"{session_seed}:{uuid.uuid4().hex}", SESSION_PREFIX) + agent = _CollectingMoviePilotAgent( + session_id=session_id, + user_id=session_id, + channel=MessageChannel.Web.value, + source="anthropic", + username="anthropic-client", + stream_mode=payload.stream, + ) + + if payload.stream: + return StreamingResponse( + _stream_anthropic_response(agent=agent, prompt=prompt, images=images), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + try: + result = await agent.process(prompt, images=images, files=None) + except Exception as exc: + return _anthropic_error_response(str(exc), 500, error_type="api_error") + + content = "\n\n".join( + message.strip() + for message in agent.collected_messages + if message and message.strip() + ).strip() + if not content and result: + content = str(result).strip() + if not content: + content = "未获得有效回复。" + + return schemas.AnthropicMessagesResponse( + id=f"msg_{uuid.uuid4().hex}", + content=[schemas.AnthropicTextBlock(text=content)], + model=MODEL_ID, + ) diff --git a/app/api/endpoints/openai.py b/app/api/endpoints/openai.py new file mode 100644 index 00000000..01a30185 --- /dev/null +++ b/app/api/endpoints/openai.py @@ -0,0 +1,426 @@ +import asyncio +import json +import time +import uuid +from typing import AsyncIterator, List, Optional, Tuple + +from fastapi import APIRouter, Request, Security +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import HTTPAuthorizationCredentials + +from app import schemas +from app.api.openai_utils import ( + build_completion_payload, + build_prompt, + build_responses_input, + build_session_id, +) +from app.agent import MoviePilotAgent, StreamingHandler +from app.core.config import settings +from app.core.security import openai_bearer_scheme +from app.schemas.types import MessageChannel + +router = APIRouter() + +MODEL_ID = "moviepilot-agent" +SESSION_PREFIX = "openai:" + + +class _CollectingMoviePilotAgent(MoviePilotAgent): + """ + 捕获 Agent 最终输出,避免再通过消息渠道二次发送。 + """ + + def __init__(self, *args, stream_mode: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.collected_messages: List[str] = [] + self.stream_mode = stream_mode + if stream_mode: + self.stream_handler = _OpenAIStreamingHandler() + + def _should_stream(self) -> bool: + return self.stream_mode + + async def send_agent_message(self, message: str, title: str = ""): + text = (message or "").strip() + if title and text: + text = f"{title}\n{text}" + elif title: + text = title.strip() + if text: + self.collected_messages.append(text) + if self.stream_mode: + self.stream_handler.emit(text) + + async def _save_agent_message_to_db(self, message: str, title: str = ""): + return None + + +class _OpenAIStreamingHandler(StreamingHandler): + """ + 将 Agent 流式输出转发到 OpenAI SSE 队列,不向站内消息系统落消息。 + """ + + def __init__(self): + super().__init__() + self._event_queue: Optional[asyncio.Queue] = None + + def bind_queue(self, queue: asyncio.Queue): + self._event_queue = queue + + def emit(self, token: str): + super().emit(token) + if token and self._event_queue is not None: + self._event_queue.put_nowait(token) + + async def start_streaming( + self, + channel: Optional[str] = None, + source: Optional[str] = None, + user_id: Optional[str] = None, + username: Optional[str] = None, + title: str = "", + ): + self._channel = channel + self._source = source + self._user_id = user_id + self._username = username + self._title = title + self._streaming_enabled = True + self._sent_text = "" + self._message_response = None + self._msg_start_offset = 0 + self._max_message_length = 0 + + async def stop_streaming(self) -> Tuple[bool, str]: + if not self._streaming_enabled: + return False, "" + self._streaming_enabled = False + with self._lock: + final_text = self._buffer + self._buffer = "" + self._sent_text = "" + self._message_response = None + self._msg_start_offset = 0 + return True, final_text + + +def _sse_payload(data: dict) -> str: + return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + + +async def _stream_response( + agent: _CollectingMoviePilotAgent, + prompt: str, + images: List[str], +) -> AsyncIterator[str]: + event_queue: asyncio.Queue = asyncio.Queue() + if isinstance(agent.stream_handler, _OpenAIStreamingHandler): + agent.stream_handler.bind_queue(event_queue) + + created = int(time.time()) + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + finished = False + + async def _run_agent(): + try: + await agent.process(prompt, images=images, files=None) + except Exception as exc: + await event_queue.put({"error": str(exc)}) + finally: + await event_queue.put(None) + + task = asyncio.create_task(_run_agent()) + + try: + yield _sse_payload( + { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": MODEL_ID, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + } + ], + } + ) + + while True: + item = await event_queue.get() + if item is None: + break + if isinstance(item, dict) and item.get("error"): + raise RuntimeError(str(item["error"])) + text = str(item or "") + if not text: + continue + yield _sse_payload( + { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": MODEL_ID, + "choices": [ + { + "index": 0, + "delta": {"content": text}, + "finish_reason": None, + } + ], + } + ) + + finished = True + yield _sse_payload( + { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": MODEL_ID, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + ) + yield "data: [DONE]\n\n" + finally: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + elif finished: + await task + + +def _error_response( + message: str, + status_code: int, + error_type: str = "invalid_request_error", + code: Optional[str] = None, +) -> JSONResponse: + return JSONResponse( + status_code=status_code, + content=schemas.OpenAIErrorResponse( + error=schemas.OpenAIErrorDetail( + message=message, + type=error_type, + code=code, + ) + ).model_dump(), + headers={"WWW-Authenticate": "Bearer"}, + ) + + +def _check_auth( + credentials: Optional[HTTPAuthorizationCredentials], +) -> Optional[JSONResponse]: + if not credentials or credentials.scheme.lower() != "bearer": + return _error_response( + "Invalid bearer token.", + 401, + error_type="authentication_error", + code="invalid_api_key", + ) + if credentials.credentials != settings.API_TOKEN: + return _error_response( + "Invalid bearer token.", + 401, + error_type="authentication_error", + code="invalid_api_key", + ) + return None + + +@router.get("/models", summary="OpenAI compatible models", response_model=schemas.OpenAIModelListResponse) +async def list_models( + credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme), +): + auth_error = _check_auth(credentials) + if auth_error: + return auth_error + now = int(time.time()) + return schemas.OpenAIModelListResponse( + data=[schemas.OpenAIModelInfo(id=MODEL_ID, created=now)] + ) + + +@router.post( + "/chat/completions", + summary="OpenAI compatible chat completions", + response_model=schemas.OpenAIChatCompletionResponse, +) +async def chat_completions( + payload: schemas.OpenAIChatCompletionsRequest, + request: Request, + credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme), +): + auth_error = _check_auth(credentials) + if auth_error: + return auth_error + + if not settings.AI_AGENT_ENABLE: + return _error_response( + "MoviePilot AI agent is disabled.", + 503, + error_type="server_error", + code="ai_agent_disabled", + ) + + if not payload.messages: + return _error_response( + "`messages` must be a non-empty array.", + 400, + code="invalid_messages", + ) + + session_key = ( + str(payload.user or "").strip() + or str(request.headers.get("x-session-id") or "").strip() + or str(uuid.uuid4()) + ) + use_server_session = bool( + str(payload.user or "").strip() + or str(request.headers.get("x-session-id") or "").strip() + ) + + try: + prompt, images = build_prompt(payload.messages, use_server_session=use_server_session) + except ValueError as exc: + return _error_response(str(exc), 400, code="invalid_messages") + + session_id = build_session_id(session_key, SESSION_PREFIX) + username = str(payload.user or "openai-client") + agent = _CollectingMoviePilotAgent( + session_id=session_id, + user_id=session_key, + channel=MessageChannel.Web.value, + source="openai", + username=username, + stream_mode=payload.stream, + ) + + if payload.stream: + return StreamingResponse( + _stream_response(agent=agent, prompt=prompt, images=images), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + try: + result = await agent.process(prompt, images=images, files=None) + except Exception as exc: + return _error_response( + str(exc), + 500, + error_type="server_error", + code="agent_execution_failed", + ) + + content = "\n\n".join( + message.strip() + for message in agent.collected_messages + if message and message.strip() + ).strip() + if not content and result: + content = str(result).strip() + if not content: + content = "未获得有效回复。" + + return JSONResponse(content=build_completion_payload(content, MODEL_ID)) + + +@router.post("/responses", summary="OpenAI compatible responses", response_model=schemas.OpenAIResponsesResponse) +async def responses( + payload: schemas.OpenAIResponsesRequest, + credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme), +): + auth_error = _check_auth(credentials) + if auth_error: + return auth_error + + if not settings.AI_AGENT_ENABLE: + return _error_response( + "MoviePilot AI agent is disabled.", + 503, + error_type="server_error", + code="ai_agent_disabled", + ) + + if payload.stream: + return _error_response( + "Streaming is not supported for /responses yet.", + 400, + code="unsupported_stream", + ) + + normalized_messages = build_responses_input(payload.input, instructions=payload.instructions) + if not normalized_messages: + return _error_response( + "`input` must include at least one usable message.", + 400, + code="invalid_input", + ) + + try: + prompt, images = build_prompt(normalized_messages, use_server_session=bool(payload.user)) + except ValueError as exc: + return _error_response(str(exc), 400, code="invalid_input") + + session_key = str(payload.user or uuid.uuid4()) + session_id = build_session_id(session_key, SESSION_PREFIX) + agent = _CollectingMoviePilotAgent( + session_id=session_id, + user_id=session_key, + channel=MessageChannel.Web.value, + source="openai.responses", + username=str(payload.user or "openai-client"), + stream_mode=False, + ) + + try: + result = await agent.process(prompt, images=images, files=None) + except Exception as exc: + return _error_response( + str(exc), + 500, + error_type="server_error", + code="agent_execution_failed", + ) + + content = "\n\n".join( + message.strip() + for message in agent.collected_messages + if message and message.strip() + ).strip() + if not content and result: + content = str(result).strip() + if not content: + content = "未获得有效回复。" + + created_at = int(time.time()) + response_id = f"resp_{uuid.uuid4().hex}" + output_message = schemas.OpenAIResponsesOutputMessage( + id=f"msg_{uuid.uuid4().hex}", + content=[schemas.OpenAIResponsesOutputText(text=content)], + ) + return schemas.OpenAIResponsesResponse( + id=response_id, + created_at=created_at, + model=MODEL_ID, + output=[output_message], + usage=schemas.OpenAIUsage(), + ) diff --git a/app/api/openai_utils.py b/app/api/openai_utils.py new file mode 100644 index 00000000..4fcbbbe8 --- /dev/null +++ b/app/api/openai_utils.py @@ -0,0 +1,177 @@ +import hashlib +import time +import uuid +from typing import Any, Dict, List, Tuple + + +def _get_message_field(message: Any, field: str, default: Any = None) -> Any: + if isinstance(message, dict): + return message.get(field, default) + return getattr(message, field, default) + + +def extract_text_and_images(content: Any) -> Tuple[str, List[str]]: + if content is None: + return "", [] + if isinstance(content, str): + return content.strip(), [] + + text_parts: List[str] = [] + image_urls: List[str] = [] + if isinstance(content, list): + for item in content: + if isinstance(item, str): + normalized = item.strip() + if normalized: + text_parts.append(normalized) + continue + if not isinstance(item, dict): + continue + item_type = (item.get("type") or "").lower() + if item_type == "text": + text = item.get("text") + if text and str(text).strip(): + text_parts.append(str(text).strip()) + elif item_type == "input_text": + text = item.get("text") + if text and str(text).strip(): + text_parts.append(str(text).strip()) + elif item_type == "image_url": + image_url = item.get("image_url") + url = image_url.get("url") if isinstance(image_url, dict) else image_url + if url and str(url).strip(): + image_urls.append(str(url).strip()) + elif item_type == "input_image": + url = item.get("image_url") + if url and str(url).strip(): + image_urls.append(str(url).strip()) + elif item_type == "image": + source = item.get("source") or {} + if isinstance(source, dict) and source.get("type") == "base64": + data = source.get("data") + media_type = source.get("media_type") or "image/png" + if data and str(data).strip(): + image_urls.append(f"data:{media_type};base64,{str(data).strip()}") + return "\n".join(text_parts).strip(), image_urls + + +def build_prompt(messages: List[Any], use_server_session: bool) -> Tuple[str, List[str]]: + system_texts: List[str] = [] + transcript: List[str] = [] + latest_user_text = "" + latest_user_images: List[str] = [] + + for message in messages: + role = str(_get_message_field(message, "role", "user") or "user").lower() + if role == "developer": + role = "system" + text, images = extract_text_and_images(_get_message_field(message, "content")) + if role == "system": + if text: + system_texts.append(text) + continue + if role == "user": + if text or images: + latest_user_text = text + latest_user_images = images + if text: + transcript.append(f"user: {text}") + continue + if text: + transcript.append(f"{role}: {text}") + + if not latest_user_text and not latest_user_images: + raise ValueError("No usable user message found in messages.") + + prompt_parts: List[str] = [] + if system_texts: + prompt_parts.append("系统要求:\n" + "\n\n".join(system_texts)) + + if not use_server_session and transcript: + history = transcript[:-1] if transcript[-1].startswith("user: ") else transcript + if history: + prompt_parts.append("对话上下文:\n" + "\n".join(history[-10:])) + + if latest_user_text: + prompt_parts.append("当前用户消息:\n" + latest_user_text) + else: + prompt_parts.append("当前用户消息:\n请结合图片内容回复。") + + return "\n\n".join(part for part in prompt_parts if part).strip(), latest_user_images + + +def build_session_id(session_key: str, prefix: str) -> str: + digest = hashlib.sha256(session_key.encode("utf-8")).hexdigest() + return f"{prefix}{digest[:32]}" + + +def build_completion_payload(content: str, model_id: str) -> Dict[str, Any]: + created = int(time.time()) + return { + "id": f"chatcmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": created, + "model": model_id, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + + +def build_responses_input( + input_data: Any, instructions: str | None = None +) -> List[Dict[str, Any]]: + messages: List[Dict[str, Any]] = [] + if instructions and str(instructions).strip(): + messages.append({"role": "system", "content": str(instructions).strip()}) + + if isinstance(input_data, str): + normalized = input_data.strip() + if normalized: + messages.append({"role": "user", "content": normalized}) + return messages + + if isinstance(input_data, list): + for item in input_data: + if not isinstance(item, dict): + continue + item_type = (item.get("type") or "").lower() + if item_type == "message": + role = item.get("role") or "user" + content = item.get("content") + messages.append({"role": role, "content": content}) + elif item.get("role") and "content" in item: + messages.append({"role": item.get("role"), "content": item.get("content")}) + return messages + + if isinstance(input_data, dict) and input_data.get("role") and "content" in input_data: + messages.append({"role": input_data.get("role"), "content": input_data.get("content")}) + + return messages + + +def build_anthropic_messages( + system: Any, messages: List[Any] +) -> List[Dict[str, Any]]: + normalized: List[Dict[str, Any]] = [] + system_text, _ = extract_text_and_images(system) + if system_text: + normalized.append({"role": "system", "content": system_text}) + + for message in messages: + role = _get_message_field(message, "role", "user") + content = _get_message_field(message, "content") + normalized.append({"role": role, "content": content}) + return normalized diff --git a/app/core/security.py b/app/core/security.py index b9596c43..ca64f094 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -13,7 +13,7 @@ from Crypto.Cipher import AES from Crypto.Util.Padding import pad from cryptography.fernet import Fernet from fastapi import HTTPException, status, Security, Request, Response -from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie +from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie, HTTPBearer from passlib.context import CryptContext from app import schemas @@ -42,6 +42,12 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a # API KEY 通过 QUERY 认证 api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query") +# OpenAI compatible Bearer Token 认证 +openai_bearer_scheme = HTTPBearer(auto_error=False) + +# Anthropic compatible API Key 认证 +anthropic_api_key_header = APIKeyHeader(name="x-api-key", auto_error=False, scheme_name="anthropic_api_key_header") + def __get_api_token( token_query: Annotated[str | None, Security(api_token_query)] = None diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index d04cdbb8..1fcbf9c6 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -11,6 +11,7 @@ from .monitoring import * from .plugin import * from .response import * from .rule import * +from .openai import * from .servarr import * from .servcookie import * from .site import * @@ -23,4 +24,3 @@ from .transfer import * from .user import * from .workflow import * from .mcp import * - diff --git a/app/schemas/openai.py b/app/schemas/openai.py new file mode 100644 index 00000000..b46d23cb --- /dev/null +++ b/app/schemas/openai.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class OpenAIModelInfo(BaseModel): + id: str + object: str = "model" + created: int + owned_by: str = "moviepilot" + + +class OpenAIModelListResponse(BaseModel): + object: str = "list" + data: List[OpenAIModelInfo] = Field(default_factory=list) + + +class OpenAIChatMessage(BaseModel): + role: str + content: Any + name: Optional[str] = None + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionsRequest(BaseModel): + model: Optional[str] = None + messages: List[OpenAIChatMessage] + user: Optional[str] = None + stream: bool = False + + model_config = ConfigDict(extra="allow") + + +class OpenAIResponsesRequest(BaseModel): + model: Optional[str] = None + input: Any + instructions: Optional[str] = None + user: Optional[str] = None + stream: bool = False + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatChoiceMessage(BaseModel): + role: str = "assistant" + content: str + + +class OpenAIChatChoice(BaseModel): + index: int = 0 + message: OpenAIChatChoiceMessage + finish_reason: str = "stop" + + +class OpenAIUsage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class OpenAIChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[OpenAIChatChoice] + usage: OpenAIUsage + + +class OpenAIResponsesOutputText(BaseModel): + type: str = "output_text" + text: str + annotations: List[Dict[str, Any]] = Field(default_factory=list) + + +class OpenAIResponsesOutputMessage(BaseModel): + id: str + type: str = "message" + status: str = "completed" + role: str = "assistant" + content: List[OpenAIResponsesOutputText] = Field(default_factory=list) + + +class OpenAIResponsesResponse(BaseModel): + id: str + object: str = "response" + created_at: int + status: str = "completed" + model: str + output: List[OpenAIResponsesOutputMessage] = Field(default_factory=list) + error: Optional[Any] = None + incomplete_details: Optional[Any] = None + usage: OpenAIUsage + + +class OpenAIErrorDetail(BaseModel): + message: str + type: str = "invalid_request_error" + param: Optional[str] = None + code: Optional[str] = None + + +class OpenAIErrorResponse(BaseModel): + error: OpenAIErrorDetail + + +OpenAIChatContentPart = Dict[str, Any] + + +class AnthropicMessage(BaseModel): + role: str + content: Any + + model_config = ConfigDict(extra="allow") + + +class AnthropicMessagesRequest(BaseModel): + model: Optional[str] = None + messages: List[AnthropicMessage] + system: Optional[Any] = None + max_tokens: Optional[int] = 1024 + stream: bool = False + + model_config = ConfigDict(extra="allow") + + +class AnthropicTextBlock(BaseModel): + type: str = "text" + text: str + + +class AnthropicUsage(BaseModel): + input_tokens: int = 0 + output_tokens: int = 0 + + +class AnthropicMessagesResponse(BaseModel): + id: str + type: str = "message" + role: str = "assistant" + content: List[AnthropicTextBlock] = Field(default_factory=list) + model: str + stop_reason: str = "end_turn" + stop_sequence: Optional[str] = None + usage: AnthropicUsage = Field(default_factory=AnthropicUsage) + + +class AnthropicErrorDetail(BaseModel): + type: str = "invalid_request_error" + message: str + + +class AnthropicErrorResponse(BaseModel): + type: str = "error" + error: AnthropicErrorDetail diff --git a/tests/test_openai_utils.py b/tests/test_openai_utils.py new file mode 100644 index 00000000..b9730c78 --- /dev/null +++ b/tests/test_openai_utils.py @@ -0,0 +1,120 @@ +from unittest import TestCase + +from app.api.openai_utils import ( + build_anthropic_messages, + build_completion_payload, + build_prompt, + build_responses_input, + build_session_id, + extract_text_and_images, +) + + +class OpenAIUtilsTest(TestCase): + def test_extract_text_and_images(self): + text, images = extract_text_and_images( + [ + {"type": "text", "text": "你好"}, + {"type": "image_url", "image_url": {"url": "https://example.com/a.png"}}, + {"type": "text", "text": "世界"}, + ] + ) + self.assertEqual(text, "你好\n世界") + self.assertEqual(images, ["https://example.com/a.png"]) + + def test_extract_text_and_images_with_input_image_and_base64_image(self): + text, images = extract_text_and_images( + [ + {"type": "input_text", "text": "看图"}, + {"type": "input_image", "image_url": "https://example.com/b.png"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "YWJj", + }, + }, + ] + ) + self.assertEqual(text, "看图") + self.assertEqual( + images, + ["https://example.com/b.png", "data:image/png;base64,YWJj"], + ) + + def test_build_prompt_without_server_session_keeps_recent_history(self): + prompt, images = build_prompt( + [ + {"role": "system", "content": "回答简短"}, + {"role": "user", "content": "第一句"}, + {"role": "assistant", "content": "第一答"}, + {"role": "user", "content": "第二句"}, + ], + use_server_session=False, + ) + self.assertIn("系统要求:\n回答简短", prompt) + self.assertIn("对话上下文:\nuser: 第一句\nassistant: 第一答", prompt) + self.assertIn("当前用户消息:\n第二句", prompt) + self.assertEqual(images, []) + + def test_build_prompt_with_server_session_ignores_history_block(self): + prompt, _ = build_prompt( + [ + {"role": "user", "content": "历史问题"}, + {"role": "assistant", "content": "历史回答"}, + {"role": "user", "content": "当前问题"}, + ], + use_server_session=True, + ) + self.assertNotIn("对话上下文:", prompt) + self.assertIn("当前用户消息:\n当前问题", prompt) + + def test_build_prompt_accepts_image_only_user_message(self): + prompt, images = build_prompt( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "https://example.com/a.png"}} + ], + } + ], + use_server_session=True, + ) + self.assertIn("请结合图片内容回复", prompt) + self.assertEqual(images, ["https://example.com/a.png"]) + + def test_build_session_id_is_stable(self): + session_id = build_session_id("user-1", "openai:") + self.assertTrue(session_id.startswith("openai:")) + self.assertEqual(session_id, build_session_id("user-1", "openai:")) + self.assertNotEqual(session_id, build_session_id("user-2", "openai:")) + + def test_build_completion_payload(self): + payload = build_completion_payload("你好", "moviepilot-agent") + self.assertEqual(payload["model"], "moviepilot-agent") + self.assertEqual(payload["choices"][0]["message"]["content"], "你好") + self.assertEqual(payload["choices"][0]["finish_reason"], "stop") + + def test_build_responses_input(self): + messages = build_responses_input( + [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "你好"}], + } + ], + instructions="你要简短回答", + ) + self.assertEqual(messages[0]["role"], "system") + self.assertEqual(messages[1]["role"], "user") + + def test_build_anthropic_messages(self): + messages = build_anthropic_messages( + system=[{"type": "text", "text": "你是助手"}], + messages=[{"role": "user", "content": "你好"}], + ) + self.assertEqual(messages[0]["role"], "system") + self.assertEqual(messages[1]["role"], "user")