feat: add ai-compatible API endpoints

This commit is contained in:
Sebastian
2026-04-22 16:11:38 +08:00
committed by jxxghp
parent 460b386004
commit 097dff13a3
8 changed files with 1048 additions and 3 deletions

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
@@ -30,3 +30,5 @@ api_router.include_router(recommend.router, prefix="/recommend", tags=["recommen
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
api_router.include_router(openai.router, prefix="/openai/v1", tags=["openai"])
api_router.include_router(anthropic.router, prefix="/anthropic/v1", tags=["anthropic"])

View File

@@ -0,0 +1,158 @@
import asyncio
import json
import time
import uuid
from typing import AsyncIterator, List, Optional
from fastapi import APIRouter, Header, Security
from fastapi.responses import JSONResponse, StreamingResponse
from app import schemas
from app.api.endpoints.openai import (
MODEL_ID,
_CollectingMoviePilotAgent,
_error_response as _openai_error_response,
)
from app.api.openai_utils import build_anthropic_messages, build_prompt, build_session_id
from app.core.config import settings
from app.core.security import anthropic_api_key_header
from app.schemas.types import MessageChannel
router = APIRouter()
SESSION_PREFIX = "anthropic:"
def _anthropic_error_response(
message: str,
status_code: int,
error_type: str = "invalid_request_error",
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content=schemas.AnthropicErrorResponse(
error=schemas.AnthropicErrorDetail(type=error_type, message=message)
).model_dump(),
)
def _check_auth(api_key: Optional[str]) -> Optional[JSONResponse]:
if not api_key or api_key != settings.API_TOKEN:
return _anthropic_error_response(
"invalid x-api-key",
401,
error_type="authentication_error",
)
return None
async def _stream_anthropic_response(
agent: _CollectingMoviePilotAgent,
prompt: str,
images: List[str],
) -> AsyncIterator[str]:
event_queue: asyncio.Queue = asyncio.Queue()
if hasattr(agent.stream_handler, "bind_queue"):
agent.stream_handler.bind_queue(event_queue)
message_id = f"msg_{uuid.uuid4().hex}"
async def _run_agent():
try:
await agent.process(prompt, images=images, files=None)
except Exception as exc:
await event_queue.put({"error": str(exc)})
finally:
await event_queue.put(None)
task = asyncio.create_task(_run_agent())
try:
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': MODEL_ID, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}}, ensure_ascii=False)}\n\n"
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}, ensure_ascii=False)}\n\n"
while True:
item = await event_queue.get()
if item is None:
break
if isinstance(item, dict) and item.get("error"):
raise RuntimeError(str(item["error"]))
text = str(item or "")
if not text:
continue
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': text}}, ensure_ascii=False)}\n\n"
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0}, ensure_ascii=False)}\n\n"
yield f"event: message_delta\ndata: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 0}}, ensure_ascii=False)}\n\n"
yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'}, ensure_ascii=False)}\n\n"
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@router.post("/messages", summary="Anthropic compatible messages", response_model=schemas.AnthropicMessagesResponse)
async def messages(
payload: schemas.AnthropicMessagesRequest,
x_api_key: Optional[str] = Security(anthropic_api_key_header),
anthropic_version: Optional[str] = Header(default=None, alias="anthropic-version"),
):
auth_error = _check_auth(x_api_key)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _anthropic_error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="api_error",
)
normalized_messages = build_anthropic_messages(payload.system, payload.messages)
try:
prompt, images = build_prompt(normalized_messages, use_server_session=False)
except ValueError as exc:
return _anthropic_error_response(str(exc), 400)
session_seed = anthropic_version or "anthropic"
session_id = build_session_id(f"{session_seed}:{uuid.uuid4().hex}", SESSION_PREFIX)
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_id,
channel=MessageChannel.Web.value,
source="anthropic",
username="anthropic-client",
stream_mode=payload.stream,
)
if payload.stream:
return StreamingResponse(
_stream_anthropic_response(agent=agent, prompt=prompt, images=images),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _anthropic_error_response(str(exc), 500, error_type="api_error")
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
return schemas.AnthropicMessagesResponse(
id=f"msg_{uuid.uuid4().hex}",
content=[schemas.AnthropicTextBlock(text=content)],
model=MODEL_ID,
)

426
app/api/endpoints/openai.py Normal file
View File

@@ -0,0 +1,426 @@
import asyncio
import json
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple
from fastapi import APIRouter, Request, Security
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials
from app import schemas
from app.api.openai_utils import (
build_completion_payload,
build_prompt,
build_responses_input,
build_session_id,
)
from app.agent import MoviePilotAgent, StreamingHandler
from app.core.config import settings
from app.core.security import openai_bearer_scheme
from app.schemas.types import MessageChannel
router = APIRouter()
MODEL_ID = "moviepilot-agent"
SESSION_PREFIX = "openai:"
class _CollectingMoviePilotAgent(MoviePilotAgent):
"""
捕获 Agent 最终输出,避免再通过消息渠道二次发送。
"""
def __init__(self, *args, stream_mode: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.collected_messages: List[str] = []
self.stream_mode = stream_mode
if stream_mode:
self.stream_handler = _OpenAIStreamingHandler()
def _should_stream(self) -> bool:
return self.stream_mode
async def send_agent_message(self, message: str, title: str = ""):
text = (message or "").strip()
if title and text:
text = f"{title}\n{text}"
elif title:
text = title.strip()
if text:
self.collected_messages.append(text)
if self.stream_mode:
self.stream_handler.emit(text)
async def _save_agent_message_to_db(self, message: str, title: str = ""):
return None
class _OpenAIStreamingHandler(StreamingHandler):
"""
将 Agent 流式输出转发到 OpenAI SSE 队列,不向站内消息系统落消息。
"""
def __init__(self):
super().__init__()
self._event_queue: Optional[asyncio.Queue] = None
def bind_queue(self, queue: asyncio.Queue):
self._event_queue = queue
def emit(self, token: str):
super().emit(token)
if token and self._event_queue is not None:
self._event_queue.put_nowait(token)
async def start_streaming(
self,
channel: Optional[str] = None,
source: Optional[str] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
title: str = "",
):
self._channel = channel
self._source = source
self._user_id = user_id
self._username = username
self._title = title
self._streaming_enabled = True
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._max_message_length = 0
async def stop_streaming(self) -> Tuple[bool, str]:
if not self._streaming_enabled:
return False, ""
self._streaming_enabled = False
with self._lock:
final_text = self._buffer
self._buffer = ""
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
return True, final_text
def _sse_payload(data: dict) -> str:
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _stream_response(
agent: _CollectingMoviePilotAgent,
prompt: str,
images: List[str],
) -> AsyncIterator[str]:
event_queue: asyncio.Queue = asyncio.Queue()
if isinstance(agent.stream_handler, _OpenAIStreamingHandler):
agent.stream_handler.bind_queue(event_queue)
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
finished = False
async def _run_agent():
try:
await agent.process(prompt, images=images, files=None)
except Exception as exc:
await event_queue.put({"error": str(exc)})
finally:
await event_queue.put(None)
task = asyncio.create_task(_run_agent())
try:
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
while True:
item = await event_queue.get()
if item is None:
break
if isinstance(item, dict) and item.get("error"):
raise RuntimeError(str(item["error"]))
text = str(item or "")
if not text:
continue
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {"content": text},
"finish_reason": None,
}
],
}
)
finished = True
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop",
}
],
}
)
yield "data: [DONE]\n\n"
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
elif finished:
await task
def _error_response(
message: str,
status_code: int,
error_type: str = "invalid_request_error",
code: Optional[str] = None,
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content=schemas.OpenAIErrorResponse(
error=schemas.OpenAIErrorDetail(
message=message,
type=error_type,
code=code,
)
).model_dump(),
headers={"WWW-Authenticate": "Bearer"},
)
def _check_auth(
credentials: Optional[HTTPAuthorizationCredentials],
) -> Optional[JSONResponse]:
if not credentials or credentials.scheme.lower() != "bearer":
return _error_response(
"Invalid bearer token.",
401,
error_type="authentication_error",
code="invalid_api_key",
)
if credentials.credentials != settings.API_TOKEN:
return _error_response(
"Invalid bearer token.",
401,
error_type="authentication_error",
code="invalid_api_key",
)
return None
@router.get("/models", summary="OpenAI compatible models", response_model=schemas.OpenAIModelListResponse)
async def list_models(
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
now = int(time.time())
return schemas.OpenAIModelListResponse(
data=[schemas.OpenAIModelInfo(id=MODEL_ID, created=now)]
)
@router.post(
"/chat/completions",
summary="OpenAI compatible chat completions",
response_model=schemas.OpenAIChatCompletionResponse,
)
async def chat_completions(
payload: schemas.OpenAIChatCompletionsRequest,
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="server_error",
code="ai_agent_disabled",
)
if not payload.messages:
return _error_response(
"`messages` must be a non-empty array.",
400,
code="invalid_messages",
)
session_key = (
str(payload.user or "").strip()
or str(request.headers.get("x-session-id") or "").strip()
or str(uuid.uuid4())
)
use_server_session = bool(
str(payload.user or "").strip()
or str(request.headers.get("x-session-id") or "").strip()
)
try:
prompt, images = build_prompt(payload.messages, use_server_session=use_server_session)
except ValueError as exc:
return _error_response(str(exc), 400, code="invalid_messages")
session_id = build_session_id(session_key, SESSION_PREFIX)
username = str(payload.user or "openai-client")
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_key,
channel=MessageChannel.Web.value,
source="openai",
username=username,
stream_mode=payload.stream,
)
if payload.stream:
return StreamingResponse(
_stream_response(agent=agent, prompt=prompt, images=images),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _error_response(
str(exc),
500,
error_type="server_error",
code="agent_execution_failed",
)
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
return JSONResponse(content=build_completion_payload(content, MODEL_ID))
@router.post("/responses", summary="OpenAI compatible responses", response_model=schemas.OpenAIResponsesResponse)
async def responses(
payload: schemas.OpenAIResponsesRequest,
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="server_error",
code="ai_agent_disabled",
)
if payload.stream:
return _error_response(
"Streaming is not supported for /responses yet.",
400,
code="unsupported_stream",
)
normalized_messages = build_responses_input(payload.input, instructions=payload.instructions)
if not normalized_messages:
return _error_response(
"`input` must include at least one usable message.",
400,
code="invalid_input",
)
try:
prompt, images = build_prompt(normalized_messages, use_server_session=bool(payload.user))
except ValueError as exc:
return _error_response(str(exc), 400, code="invalid_input")
session_key = str(payload.user or uuid.uuid4())
session_id = build_session_id(session_key, SESSION_PREFIX)
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_key,
channel=MessageChannel.Web.value,
source="openai.responses",
username=str(payload.user or "openai-client"),
stream_mode=False,
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _error_response(
str(exc),
500,
error_type="server_error",
code="agent_execution_failed",
)
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
created_at = int(time.time())
response_id = f"resp_{uuid.uuid4().hex}"
output_message = schemas.OpenAIResponsesOutputMessage(
id=f"msg_{uuid.uuid4().hex}",
content=[schemas.OpenAIResponsesOutputText(text=content)],
)
return schemas.OpenAIResponsesResponse(
id=response_id,
created_at=created_at,
model=MODEL_ID,
output=[output_message],
usage=schemas.OpenAIUsage(),
)

177
app/api/openai_utils.py Normal file
View File

@@ -0,0 +1,177 @@
import hashlib
import time
import uuid
from typing import Any, Dict, List, Tuple
def _get_message_field(message: Any, field: str, default: Any = None) -> Any:
if isinstance(message, dict):
return message.get(field, default)
return getattr(message, field, default)
def extract_text_and_images(content: Any) -> Tuple[str, List[str]]:
if content is None:
return "", []
if isinstance(content, str):
return content.strip(), []
text_parts: List[str] = []
image_urls: List[str] = []
if isinstance(content, list):
for item in content:
if isinstance(item, str):
normalized = item.strip()
if normalized:
text_parts.append(normalized)
continue
if not isinstance(item, dict):
continue
item_type = (item.get("type") or "").lower()
if item_type == "text":
text = item.get("text")
if text and str(text).strip():
text_parts.append(str(text).strip())
elif item_type == "input_text":
text = item.get("text")
if text and str(text).strip():
text_parts.append(str(text).strip())
elif item_type == "image_url":
image_url = item.get("image_url")
url = image_url.get("url") if isinstance(image_url, dict) else image_url
if url and str(url).strip():
image_urls.append(str(url).strip())
elif item_type == "input_image":
url = item.get("image_url")
if url and str(url).strip():
image_urls.append(str(url).strip())
elif item_type == "image":
source = item.get("source") or {}
if isinstance(source, dict) and source.get("type") == "base64":
data = source.get("data")
media_type = source.get("media_type") or "image/png"
if data and str(data).strip():
image_urls.append(f"data:{media_type};base64,{str(data).strip()}")
return "\n".join(text_parts).strip(), image_urls
def build_prompt(messages: List[Any], use_server_session: bool) -> Tuple[str, List[str]]:
system_texts: List[str] = []
transcript: List[str] = []
latest_user_text = ""
latest_user_images: List[str] = []
for message in messages:
role = str(_get_message_field(message, "role", "user") or "user").lower()
if role == "developer":
role = "system"
text, images = extract_text_and_images(_get_message_field(message, "content"))
if role == "system":
if text:
system_texts.append(text)
continue
if role == "user":
if text or images:
latest_user_text = text
latest_user_images = images
if text:
transcript.append(f"user: {text}")
continue
if text:
transcript.append(f"{role}: {text}")
if not latest_user_text and not latest_user_images:
raise ValueError("No usable user message found in messages.")
prompt_parts: List[str] = []
if system_texts:
prompt_parts.append("系统要求:\n" + "\n\n".join(system_texts))
if not use_server_session and transcript:
history = transcript[:-1] if transcript[-1].startswith("user: ") else transcript
if history:
prompt_parts.append("对话上下文:\n" + "\n".join(history[-10:]))
if latest_user_text:
prompt_parts.append("当前用户消息:\n" + latest_user_text)
else:
prompt_parts.append("当前用户消息:\n请结合图片内容回复。")
return "\n\n".join(part for part in prompt_parts if part).strip(), latest_user_images
def build_session_id(session_key: str, prefix: str) -> str:
digest = hashlib.sha256(session_key.encode("utf-8")).hexdigest()
return f"{prefix}{digest[:32]}"
def build_completion_payload(content: str, model_id: str) -> Dict[str, Any]:
created = int(time.time())
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": created,
"model": model_id,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
def build_responses_input(
input_data: Any, instructions: str | None = None
) -> List[Dict[str, Any]]:
messages: List[Dict[str, Any]] = []
if instructions and str(instructions).strip():
messages.append({"role": "system", "content": str(instructions).strip()})
if isinstance(input_data, str):
normalized = input_data.strip()
if normalized:
messages.append({"role": "user", "content": normalized})
return messages
if isinstance(input_data, list):
for item in input_data:
if not isinstance(item, dict):
continue
item_type = (item.get("type") or "").lower()
if item_type == "message":
role = item.get("role") or "user"
content = item.get("content")
messages.append({"role": role, "content": content})
elif item.get("role") and "content" in item:
messages.append({"role": item.get("role"), "content": item.get("content")})
return messages
if isinstance(input_data, dict) and input_data.get("role") and "content" in input_data:
messages.append({"role": input_data.get("role"), "content": input_data.get("content")})
return messages
def build_anthropic_messages(
system: Any, messages: List[Any]
) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
system_text, _ = extract_text_and_images(system)
if system_text:
normalized.append({"role": "system", "content": system_text})
for message in messages:
role = _get_message_field(message, "role", "user")
content = _get_message_field(message, "content")
normalized.append({"role": role, "content": content})
return normalized

View File

@@ -13,7 +13,7 @@ from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from cryptography.fernet import Fernet
from fastapi import HTTPException, status, Security, Request, Response
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie, HTTPBearer
from passlib.context import CryptContext
from app import schemas
@@ -42,6 +42,12 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
# API KEY 通过 QUERY 认证
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
# OpenAI compatible Bearer Token 认证
openai_bearer_scheme = HTTPBearer(auto_error=False)
# Anthropic compatible API Key 认证
anthropic_api_key_header = APIKeyHeader(name="x-api-key", auto_error=False, scheme_name="anthropic_api_key_header")
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None

View File

@@ -11,6 +11,7 @@ from .monitoring import *
from .plugin import *
from .response import *
from .rule import *
from .openai import *
from .servarr import *
from .servcookie import *
from .site import *
@@ -23,4 +24,3 @@ from .transfer import *
from .user import *
from .workflow import *
from .mcp import *

156
app/schemas/openai.py Normal file
View File

@@ -0,0 +1,156 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
class OpenAIModelInfo(BaseModel):
id: str
object: str = "model"
created: int
owned_by: str = "moviepilot"
class OpenAIModelListResponse(BaseModel):
object: str = "list"
data: List[OpenAIModelInfo] = Field(default_factory=list)
class OpenAIChatMessage(BaseModel):
role: str
content: Any
name: Optional[str] = None
model_config = ConfigDict(extra="allow")
class OpenAIChatCompletionsRequest(BaseModel):
model: Optional[str] = None
messages: List[OpenAIChatMessage]
user: Optional[str] = None
stream: bool = False
model_config = ConfigDict(extra="allow")
class OpenAIResponsesRequest(BaseModel):
model: Optional[str] = None
input: Any
instructions: Optional[str] = None
user: Optional[str] = None
stream: bool = False
model_config = ConfigDict(extra="allow")
class OpenAIChatChoiceMessage(BaseModel):
role: str = "assistant"
content: str
class OpenAIChatChoice(BaseModel):
index: int = 0
message: OpenAIChatChoiceMessage
finish_reason: str = "stop"
class OpenAIUsage(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class OpenAIChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[OpenAIChatChoice]
usage: OpenAIUsage
class OpenAIResponsesOutputText(BaseModel):
type: str = "output_text"
text: str
annotations: List[Dict[str, Any]] = Field(default_factory=list)
class OpenAIResponsesOutputMessage(BaseModel):
id: str
type: str = "message"
status: str = "completed"
role: str = "assistant"
content: List[OpenAIResponsesOutputText] = Field(default_factory=list)
class OpenAIResponsesResponse(BaseModel):
id: str
object: str = "response"
created_at: int
status: str = "completed"
model: str
output: List[OpenAIResponsesOutputMessage] = Field(default_factory=list)
error: Optional[Any] = None
incomplete_details: Optional[Any] = None
usage: OpenAIUsage
class OpenAIErrorDetail(BaseModel):
message: str
type: str = "invalid_request_error"
param: Optional[str] = None
code: Optional[str] = None
class OpenAIErrorResponse(BaseModel):
error: OpenAIErrorDetail
OpenAIChatContentPart = Dict[str, Any]
class AnthropicMessage(BaseModel):
role: str
content: Any
model_config = ConfigDict(extra="allow")
class AnthropicMessagesRequest(BaseModel):
model: Optional[str] = None
messages: List[AnthropicMessage]
system: Optional[Any] = None
max_tokens: Optional[int] = 1024
stream: bool = False
model_config = ConfigDict(extra="allow")
class AnthropicTextBlock(BaseModel):
type: str = "text"
text: str
class AnthropicUsage(BaseModel):
input_tokens: int = 0
output_tokens: int = 0
class AnthropicMessagesResponse(BaseModel):
id: str
type: str = "message"
role: str = "assistant"
content: List[AnthropicTextBlock] = Field(default_factory=list)
model: str
stop_reason: str = "end_turn"
stop_sequence: Optional[str] = None
usage: AnthropicUsage = Field(default_factory=AnthropicUsage)
class AnthropicErrorDetail(BaseModel):
type: str = "invalid_request_error"
message: str
class AnthropicErrorResponse(BaseModel):
type: str = "error"
error: AnthropicErrorDetail

120
tests/test_openai_utils.py Normal file
View File

@@ -0,0 +1,120 @@
from unittest import TestCase
from app.api.openai_utils import (
build_anthropic_messages,
build_completion_payload,
build_prompt,
build_responses_input,
build_session_id,
extract_text_and_images,
)
class OpenAIUtilsTest(TestCase):
def test_extract_text_and_images(self):
text, images = extract_text_and_images(
[
{"type": "text", "text": "你好"},
{"type": "image_url", "image_url": {"url": "https://example.com/a.png"}},
{"type": "text", "text": "世界"},
]
)
self.assertEqual(text, "你好\n世界")
self.assertEqual(images, ["https://example.com/a.png"])
def test_extract_text_and_images_with_input_image_and_base64_image(self):
text, images = extract_text_and_images(
[
{"type": "input_text", "text": "看图"},
{"type": "input_image", "image_url": "https://example.com/b.png"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "YWJj",
},
},
]
)
self.assertEqual(text, "看图")
self.assertEqual(
images,
["https://example.com/b.png", "data:image/png;base64,YWJj"],
)
def test_build_prompt_without_server_session_keeps_recent_history(self):
prompt, images = build_prompt(
[
{"role": "system", "content": "回答简短"},
{"role": "user", "content": "第一句"},
{"role": "assistant", "content": "第一答"},
{"role": "user", "content": "第二句"},
],
use_server_session=False,
)
self.assertIn("系统要求:\n回答简短", prompt)
self.assertIn("对话上下文:\nuser: 第一句\nassistant: 第一答", prompt)
self.assertIn("当前用户消息:\n第二句", prompt)
self.assertEqual(images, [])
def test_build_prompt_with_server_session_ignores_history_block(self):
prompt, _ = build_prompt(
[
{"role": "user", "content": "历史问题"},
{"role": "assistant", "content": "历史回答"},
{"role": "user", "content": "当前问题"},
],
use_server_session=True,
)
self.assertNotIn("对话上下文:", prompt)
self.assertIn("当前用户消息:\n当前问题", prompt)
def test_build_prompt_accepts_image_only_user_message(self):
prompt, images = build_prompt(
[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/a.png"}}
],
}
],
use_server_session=True,
)
self.assertIn("请结合图片内容回复", prompt)
self.assertEqual(images, ["https://example.com/a.png"])
def test_build_session_id_is_stable(self):
session_id = build_session_id("user-1", "openai:")
self.assertTrue(session_id.startswith("openai:"))
self.assertEqual(session_id, build_session_id("user-1", "openai:"))
self.assertNotEqual(session_id, build_session_id("user-2", "openai:"))
def test_build_completion_payload(self):
payload = build_completion_payload("你好", "moviepilot-agent")
self.assertEqual(payload["model"], "moviepilot-agent")
self.assertEqual(payload["choices"][0]["message"]["content"], "你好")
self.assertEqual(payload["choices"][0]["finish_reason"], "stop")
def test_build_responses_input(self):
messages = build_responses_input(
[
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "你好"}],
}
],
instructions="你要简短回答",
)
self.assertEqual(messages[0]["role"], "system")
self.assertEqual(messages[1]["role"], "user")
def test_build_anthropic_messages(self):
messages = build_anthropic_messages(
system=[{"type": "text", "text": "你是助手"}],
messages=[{"role": "user", "content": "你好"}],
)
self.assertEqual(messages[0]["role"], "system")
self.assertEqual(messages[1]["role"], "user")