mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
feat: add ai-compatible API endpoints
This commit is contained in:
@@ -2,7 +2,7 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
@@ -30,3 +30,5 @@ api_router.include_router(recommend.router, prefix="/recommend", tags=["recommen
|
||||
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
|
||||
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
|
||||
api_router.include_router(openai.router, prefix="/openai/v1", tags=["openai"])
|
||||
api_router.include_router(anthropic.router, prefix="/anthropic/v1", tags=["anthropic"])
|
||||
|
||||
158
app/api/endpoints/anthropic.py
Normal file
158
app/api/endpoints/anthropic.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Header, Security
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.api.endpoints.openai import (
|
||||
MODEL_ID,
|
||||
_CollectingMoviePilotAgent,
|
||||
_error_response as _openai_error_response,
|
||||
)
|
||||
from app.api.openai_utils import build_anthropic_messages, build_prompt, build_session_id
|
||||
from app.core.config import settings
|
||||
from app.core.security import anthropic_api_key_header
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
SESSION_PREFIX = "anthropic:"
|
||||
|
||||
|
||||
def _anthropic_error_response(
|
||||
message: str,
|
||||
status_code: int,
|
||||
error_type: str = "invalid_request_error",
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=schemas.AnthropicErrorResponse(
|
||||
error=schemas.AnthropicErrorDetail(type=error_type, message=message)
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
def _check_auth(api_key: Optional[str]) -> Optional[JSONResponse]:
|
||||
if not api_key or api_key != settings.API_TOKEN:
|
||||
return _anthropic_error_response(
|
||||
"invalid x-api-key",
|
||||
401,
|
||||
error_type="authentication_error",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _stream_anthropic_response(
|
||||
agent: _CollectingMoviePilotAgent,
|
||||
prompt: str,
|
||||
images: List[str],
|
||||
) -> AsyncIterator[str]:
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
if hasattr(agent.stream_handler, "bind_queue"):
|
||||
agent.stream_handler.bind_queue(event_queue)
|
||||
|
||||
message_id = f"msg_{uuid.uuid4().hex}"
|
||||
|
||||
async def _run_agent():
|
||||
try:
|
||||
await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
await event_queue.put({"error": str(exc)})
|
||||
finally:
|
||||
await event_queue.put(None)
|
||||
|
||||
task = asyncio.create_task(_run_agent())
|
||||
try:
|
||||
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': MODEL_ID, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}, ensure_ascii=False)}\n\n"
|
||||
while True:
|
||||
item = await event_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("error"):
|
||||
raise RuntimeError(str(item["error"]))
|
||||
text = str(item or "")
|
||||
if not text:
|
||||
continue
|
||||
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': text}}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: message_delta\ndata: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 0}}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'}, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/messages", summary="Anthropic compatible messages", response_model=schemas.AnthropicMessagesResponse)
|
||||
async def messages(
|
||||
payload: schemas.AnthropicMessagesRequest,
|
||||
x_api_key: Optional[str] = Security(anthropic_api_key_header),
|
||||
anthropic_version: Optional[str] = Header(default=None, alias="anthropic-version"),
|
||||
):
|
||||
auth_error = _check_auth(x_api_key)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return _anthropic_error_response(
|
||||
"MoviePilot AI agent is disabled.",
|
||||
503,
|
||||
error_type="api_error",
|
||||
)
|
||||
|
||||
normalized_messages = build_anthropic_messages(payload.system, payload.messages)
|
||||
try:
|
||||
prompt, images = build_prompt(normalized_messages, use_server_session=False)
|
||||
except ValueError as exc:
|
||||
return _anthropic_error_response(str(exc), 400)
|
||||
|
||||
session_seed = anthropic_version or "anthropic"
|
||||
session_id = build_session_id(f"{session_seed}:{uuid.uuid4().hex}", SESSION_PREFIX)
|
||||
agent = _CollectingMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=session_id,
|
||||
channel=MessageChannel.Web.value,
|
||||
source="anthropic",
|
||||
username="anthropic-client",
|
||||
stream_mode=payload.stream,
|
||||
)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
_stream_anthropic_response(agent=agent, prompt=prompt, images=images),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
return _anthropic_error_response(str(exc), 500, error_type="api_error")
|
||||
|
||||
content = "\n\n".join(
|
||||
message.strip()
|
||||
for message in agent.collected_messages
|
||||
if message and message.strip()
|
||||
).strip()
|
||||
if not content and result:
|
||||
content = str(result).strip()
|
||||
if not content:
|
||||
content = "未获得有效回复。"
|
||||
|
||||
return schemas.AnthropicMessagesResponse(
|
||||
id=f"msg_{uuid.uuid4().hex}",
|
||||
content=[schemas.AnthropicTextBlock(text=content)],
|
||||
model=MODEL_ID,
|
||||
)
|
||||
426
app/api/endpoints/openai.py
Normal file
426
app/api/endpoints/openai.py
Normal file
@@ -0,0 +1,426 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Request, Security
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
from app import schemas
|
||||
from app.api.openai_utils import (
|
||||
build_completion_payload,
|
||||
build_prompt,
|
||||
build_responses_input,
|
||||
build_session_id,
|
||||
)
|
||||
from app.agent import MoviePilotAgent, StreamingHandler
|
||||
from app.core.config import settings
|
||||
from app.core.security import openai_bearer_scheme
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
MODEL_ID = "moviepilot-agent"
|
||||
SESSION_PREFIX = "openai:"
|
||||
|
||||
|
||||
class _CollectingMoviePilotAgent(MoviePilotAgent):
|
||||
"""
|
||||
捕获 Agent 最终输出,避免再通过消息渠道二次发送。
|
||||
"""
|
||||
|
||||
def __init__(self, *args, stream_mode: bool = False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.collected_messages: List[str] = []
|
||||
self.stream_mode = stream_mode
|
||||
if stream_mode:
|
||||
self.stream_handler = _OpenAIStreamingHandler()
|
||||
|
||||
def _should_stream(self) -> bool:
|
||||
return self.stream_mode
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = ""):
|
||||
text = (message or "").strip()
|
||||
if title and text:
|
||||
text = f"{title}\n{text}"
|
||||
elif title:
|
||||
text = title.strip()
|
||||
if text:
|
||||
self.collected_messages.append(text)
|
||||
if self.stream_mode:
|
||||
self.stream_handler.emit(text)
|
||||
|
||||
async def _save_agent_message_to_db(self, message: str, title: str = ""):
|
||||
return None
|
||||
|
||||
|
||||
class _OpenAIStreamingHandler(StreamingHandler):
|
||||
"""
|
||||
将 Agent 流式输出转发到 OpenAI SSE 队列,不向站内消息系统落消息。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._event_queue: Optional[asyncio.Queue] = None
|
||||
|
||||
def bind_queue(self, queue: asyncio.Queue):
|
||||
self._event_queue = queue
|
||||
|
||||
def emit(self, token: str):
|
||||
super().emit(token)
|
||||
if token and self._event_queue is not None:
|
||||
self._event_queue.put_nowait(token)
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
channel: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
title: str = "",
|
||||
):
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._user_id = user_id
|
||||
self._username = username
|
||||
self._title = title
|
||||
self._streaming_enabled = True
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
self._max_message_length = 0
|
||||
|
||||
async def stop_streaming(self) -> Tuple[bool, str]:
|
||||
if not self._streaming_enabled:
|
||||
return False, ""
|
||||
self._streaming_enabled = False
|
||||
with self._lock:
|
||||
final_text = self._buffer
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
return True, final_text
|
||||
|
||||
|
||||
def _sse_payload(data: dict) -> str:
|
||||
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
async def _stream_response(
|
||||
agent: _CollectingMoviePilotAgent,
|
||||
prompt: str,
|
||||
images: List[str],
|
||||
) -> AsyncIterator[str]:
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
if isinstance(agent.stream_handler, _OpenAIStreamingHandler):
|
||||
agent.stream_handler.bind_queue(event_queue)
|
||||
|
||||
created = int(time.time())
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
finished = False
|
||||
|
||||
async def _run_agent():
|
||||
try:
|
||||
await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
await event_queue.put({"error": str(exc)})
|
||||
finally:
|
||||
await event_queue.put(None)
|
||||
|
||||
task = asyncio.create_task(_run_agent())
|
||||
|
||||
try:
|
||||
yield _sse_payload(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
while True:
|
||||
item = await event_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, dict) and item.get("error"):
|
||||
raise RuntimeError(str(item["error"]))
|
||||
text = str(item or "")
|
||||
if not text:
|
||||
continue
|
||||
yield _sse_payload(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": text},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
finished = True
|
||||
yield _sse_payload(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": MODEL_ID,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
elif finished:
|
||||
await task
|
||||
|
||||
|
||||
def _error_response(
|
||||
message: str,
|
||||
status_code: int,
|
||||
error_type: str = "invalid_request_error",
|
||||
code: Optional[str] = None,
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=schemas.OpenAIErrorResponse(
|
||||
error=schemas.OpenAIErrorDetail(
|
||||
message=message,
|
||||
type=error_type,
|
||||
code=code,
|
||||
)
|
||||
).model_dump(),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def _check_auth(
|
||||
credentials: Optional[HTTPAuthorizationCredentials],
|
||||
) -> Optional[JSONResponse]:
|
||||
if not credentials or credentials.scheme.lower() != "bearer":
|
||||
return _error_response(
|
||||
"Invalid bearer token.",
|
||||
401,
|
||||
error_type="authentication_error",
|
||||
code="invalid_api_key",
|
||||
)
|
||||
if credentials.credentials != settings.API_TOKEN:
|
||||
return _error_response(
|
||||
"Invalid bearer token.",
|
||||
401,
|
||||
error_type="authentication_error",
|
||||
code="invalid_api_key",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/models", summary="OpenAI compatible models", response_model=schemas.OpenAIModelListResponse)
|
||||
async def list_models(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
||||
):
|
||||
auth_error = _check_auth(credentials)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
now = int(time.time())
|
||||
return schemas.OpenAIModelListResponse(
|
||||
data=[schemas.OpenAIModelInfo(id=MODEL_ID, created=now)]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completions",
|
||||
summary="OpenAI compatible chat completions",
|
||||
response_model=schemas.OpenAIChatCompletionResponse,
|
||||
)
|
||||
async def chat_completions(
|
||||
payload: schemas.OpenAIChatCompletionsRequest,
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
||||
):
|
||||
auth_error = _check_auth(credentials)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return _error_response(
|
||||
"MoviePilot AI agent is disabled.",
|
||||
503,
|
||||
error_type="server_error",
|
||||
code="ai_agent_disabled",
|
||||
)
|
||||
|
||||
if not payload.messages:
|
||||
return _error_response(
|
||||
"`messages` must be a non-empty array.",
|
||||
400,
|
||||
code="invalid_messages",
|
||||
)
|
||||
|
||||
session_key = (
|
||||
str(payload.user or "").strip()
|
||||
or str(request.headers.get("x-session-id") or "").strip()
|
||||
or str(uuid.uuid4())
|
||||
)
|
||||
use_server_session = bool(
|
||||
str(payload.user or "").strip()
|
||||
or str(request.headers.get("x-session-id") or "").strip()
|
||||
)
|
||||
|
||||
try:
|
||||
prompt, images = build_prompt(payload.messages, use_server_session=use_server_session)
|
||||
except ValueError as exc:
|
||||
return _error_response(str(exc), 400, code="invalid_messages")
|
||||
|
||||
session_id = build_session_id(session_key, SESSION_PREFIX)
|
||||
username = str(payload.user or "openai-client")
|
||||
agent = _CollectingMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=session_key,
|
||||
channel=MessageChannel.Web.value,
|
||||
source="openai",
|
||||
username=username,
|
||||
stream_mode=payload.stream,
|
||||
)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
_stream_response(agent=agent, prompt=prompt, images=images),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
return _error_response(
|
||||
str(exc),
|
||||
500,
|
||||
error_type="server_error",
|
||||
code="agent_execution_failed",
|
||||
)
|
||||
|
||||
content = "\n\n".join(
|
||||
message.strip()
|
||||
for message in agent.collected_messages
|
||||
if message and message.strip()
|
||||
).strip()
|
||||
if not content and result:
|
||||
content = str(result).strip()
|
||||
if not content:
|
||||
content = "未获得有效回复。"
|
||||
|
||||
return JSONResponse(content=build_completion_payload(content, MODEL_ID))
|
||||
|
||||
|
||||
@router.post("/responses", summary="OpenAI compatible responses", response_model=schemas.OpenAIResponsesResponse)
|
||||
async def responses(
|
||||
payload: schemas.OpenAIResponsesRequest,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
|
||||
):
|
||||
auth_error = _check_auth(credentials)
|
||||
if auth_error:
|
||||
return auth_error
|
||||
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return _error_response(
|
||||
"MoviePilot AI agent is disabled.",
|
||||
503,
|
||||
error_type="server_error",
|
||||
code="ai_agent_disabled",
|
||||
)
|
||||
|
||||
if payload.stream:
|
||||
return _error_response(
|
||||
"Streaming is not supported for /responses yet.",
|
||||
400,
|
||||
code="unsupported_stream",
|
||||
)
|
||||
|
||||
normalized_messages = build_responses_input(payload.input, instructions=payload.instructions)
|
||||
if not normalized_messages:
|
||||
return _error_response(
|
||||
"`input` must include at least one usable message.",
|
||||
400,
|
||||
code="invalid_input",
|
||||
)
|
||||
|
||||
try:
|
||||
prompt, images = build_prompt(normalized_messages, use_server_session=bool(payload.user))
|
||||
except ValueError as exc:
|
||||
return _error_response(str(exc), 400, code="invalid_input")
|
||||
|
||||
session_key = str(payload.user or uuid.uuid4())
|
||||
session_id = build_session_id(session_key, SESSION_PREFIX)
|
||||
agent = _CollectingMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=session_key,
|
||||
channel=MessageChannel.Web.value,
|
||||
source="openai.responses",
|
||||
username=str(payload.user or "openai-client"),
|
||||
stream_mode=False,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.process(prompt, images=images, files=None)
|
||||
except Exception as exc:
|
||||
return _error_response(
|
||||
str(exc),
|
||||
500,
|
||||
error_type="server_error",
|
||||
code="agent_execution_failed",
|
||||
)
|
||||
|
||||
content = "\n\n".join(
|
||||
message.strip()
|
||||
for message in agent.collected_messages
|
||||
if message and message.strip()
|
||||
).strip()
|
||||
if not content and result:
|
||||
content = str(result).strip()
|
||||
if not content:
|
||||
content = "未获得有效回复。"
|
||||
|
||||
created_at = int(time.time())
|
||||
response_id = f"resp_{uuid.uuid4().hex}"
|
||||
output_message = schemas.OpenAIResponsesOutputMessage(
|
||||
id=f"msg_{uuid.uuid4().hex}",
|
||||
content=[schemas.OpenAIResponsesOutputText(text=content)],
|
||||
)
|
||||
return schemas.OpenAIResponsesResponse(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=MODEL_ID,
|
||||
output=[output_message],
|
||||
usage=schemas.OpenAIUsage(),
|
||||
)
|
||||
177
app/api/openai_utils.py
Normal file
177
app/api/openai_utils.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
||||
def _get_message_field(message: Any, field: str, default: Any = None) -> Any:
|
||||
if isinstance(message, dict):
|
||||
return message.get(field, default)
|
||||
return getattr(message, field, default)
|
||||
|
||||
|
||||
def extract_text_and_images(content: Any) -> Tuple[str, List[str]]:
|
||||
if content is None:
|
||||
return "", []
|
||||
if isinstance(content, str):
|
||||
return content.strip(), []
|
||||
|
||||
text_parts: List[str] = []
|
||||
image_urls: List[str] = []
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
normalized = item.strip()
|
||||
if normalized:
|
||||
text_parts.append(normalized)
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = (item.get("type") or "").lower()
|
||||
if item_type == "text":
|
||||
text = item.get("text")
|
||||
if text and str(text).strip():
|
||||
text_parts.append(str(text).strip())
|
||||
elif item_type == "input_text":
|
||||
text = item.get("text")
|
||||
if text and str(text).strip():
|
||||
text_parts.append(str(text).strip())
|
||||
elif item_type == "image_url":
|
||||
image_url = item.get("image_url")
|
||||
url = image_url.get("url") if isinstance(image_url, dict) else image_url
|
||||
if url and str(url).strip():
|
||||
image_urls.append(str(url).strip())
|
||||
elif item_type == "input_image":
|
||||
url = item.get("image_url")
|
||||
if url and str(url).strip():
|
||||
image_urls.append(str(url).strip())
|
||||
elif item_type == "image":
|
||||
source = item.get("source") or {}
|
||||
if isinstance(source, dict) and source.get("type") == "base64":
|
||||
data = source.get("data")
|
||||
media_type = source.get("media_type") or "image/png"
|
||||
if data and str(data).strip():
|
||||
image_urls.append(f"data:{media_type};base64,{str(data).strip()}")
|
||||
return "\n".join(text_parts).strip(), image_urls
|
||||
|
||||
|
||||
def build_prompt(messages: List[Any], use_server_session: bool) -> Tuple[str, List[str]]:
|
||||
system_texts: List[str] = []
|
||||
transcript: List[str] = []
|
||||
latest_user_text = ""
|
||||
latest_user_images: List[str] = []
|
||||
|
||||
for message in messages:
|
||||
role = str(_get_message_field(message, "role", "user") or "user").lower()
|
||||
if role == "developer":
|
||||
role = "system"
|
||||
text, images = extract_text_and_images(_get_message_field(message, "content"))
|
||||
if role == "system":
|
||||
if text:
|
||||
system_texts.append(text)
|
||||
continue
|
||||
if role == "user":
|
||||
if text or images:
|
||||
latest_user_text = text
|
||||
latest_user_images = images
|
||||
if text:
|
||||
transcript.append(f"user: {text}")
|
||||
continue
|
||||
if text:
|
||||
transcript.append(f"{role}: {text}")
|
||||
|
||||
if not latest_user_text and not latest_user_images:
|
||||
raise ValueError("No usable user message found in messages.")
|
||||
|
||||
prompt_parts: List[str] = []
|
||||
if system_texts:
|
||||
prompt_parts.append("系统要求:\n" + "\n\n".join(system_texts))
|
||||
|
||||
if not use_server_session and transcript:
|
||||
history = transcript[:-1] if transcript[-1].startswith("user: ") else transcript
|
||||
if history:
|
||||
prompt_parts.append("对话上下文:\n" + "\n".join(history[-10:]))
|
||||
|
||||
if latest_user_text:
|
||||
prompt_parts.append("当前用户消息:\n" + latest_user_text)
|
||||
else:
|
||||
prompt_parts.append("当前用户消息:\n请结合图片内容回复。")
|
||||
|
||||
return "\n\n".join(part for part in prompt_parts if part).strip(), latest_user_images
|
||||
|
||||
|
||||
def build_session_id(session_key: str, prefix: str) -> str:
|
||||
digest = hashlib.sha256(session_key.encode("utf-8")).hexdigest()
|
||||
return f"{prefix}{digest[:32]}"
|
||||
|
||||
|
||||
def build_completion_payload(content: str, model_id: str) -> Dict[str, Any]:
|
||||
created = int(time.time())
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model_id,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_responses_input(
|
||||
input_data: Any, instructions: str | None = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if instructions and str(instructions).strip():
|
||||
messages.append({"role": "system", "content": str(instructions).strip()})
|
||||
|
||||
if isinstance(input_data, str):
|
||||
normalized = input_data.strip()
|
||||
if normalized:
|
||||
messages.append({"role": "user", "content": normalized})
|
||||
return messages
|
||||
|
||||
if isinstance(input_data, list):
|
||||
for item in input_data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = (item.get("type") or "").lower()
|
||||
if item_type == "message":
|
||||
role = item.get("role") or "user"
|
||||
content = item.get("content")
|
||||
messages.append({"role": role, "content": content})
|
||||
elif item.get("role") and "content" in item:
|
||||
messages.append({"role": item.get("role"), "content": item.get("content")})
|
||||
return messages
|
||||
|
||||
if isinstance(input_data, dict) and input_data.get("role") and "content" in input_data:
|
||||
messages.append({"role": input_data.get("role"), "content": input_data.get("content")})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_anthropic_messages(
|
||||
system: Any, messages: List[Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
system_text, _ = extract_text_and_images(system)
|
||||
if system_text:
|
||||
normalized.append({"role": "system", "content": system_text})
|
||||
|
||||
for message in messages:
|
||||
role = _get_message_field(message, "role", "user")
|
||||
content = _get_message_field(message, "content")
|
||||
normalized.append({"role": role, "content": content})
|
||||
return normalized
|
||||
@@ -13,7 +13,7 @@ from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import HTTPException, status, Security, Request, Response
|
||||
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie
|
||||
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie, HTTPBearer
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app import schemas
|
||||
@@ -42,6 +42,12 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
|
||||
# API KEY 通过 QUERY 认证
|
||||
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
|
||||
|
||||
# OpenAI compatible Bearer Token 认证
|
||||
openai_bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Anthropic compatible API Key 认证
|
||||
anthropic_api_key_header = APIKeyHeader(name="x-api-key", auto_error=False, scheme_name="anthropic_api_key_header")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
|
||||
@@ -11,6 +11,7 @@ from .monitoring import *
|
||||
from .plugin import *
|
||||
from .response import *
|
||||
from .rule import *
|
||||
from .openai import *
|
||||
from .servarr import *
|
||||
from .servcookie import *
|
||||
from .site import *
|
||||
@@ -23,4 +24,3 @@ from .transfer import *
|
||||
from .user import *
|
||||
from .workflow import *
|
||||
from .mcp import *
|
||||
|
||||
|
||||
156
app/schemas/openai.py
Normal file
156
app/schemas/openai.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class OpenAIModelInfo(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int
|
||||
owned_by: str = "moviepilot"
|
||||
|
||||
|
||||
class OpenAIModelListResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[OpenAIModelInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Any
|
||||
name: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIChatCompletionsRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[OpenAIChatMessage]
|
||||
user: Optional[str] = None
|
||||
stream: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIResponsesRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
input: Any
|
||||
instructions: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
stream: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIChatChoiceMessage(BaseModel):
|
||||
role: str = "assistant"
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAIChatChoice(BaseModel):
|
||||
index: int = 0
|
||||
message: OpenAIChatChoiceMessage
|
||||
finish_reason: str = "stop"
|
||||
|
||||
|
||||
class OpenAIUsage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class OpenAIChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[OpenAIChatChoice]
|
||||
usage: OpenAIUsage
|
||||
|
||||
|
||||
class OpenAIResponsesOutputText(BaseModel):
|
||||
type: str = "output_text"
|
||||
text: str
|
||||
annotations: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIResponsesOutputMessage(BaseModel):
|
||||
id: str
|
||||
type: str = "message"
|
||||
status: str = "completed"
|
||||
role: str = "assistant"
|
||||
content: List[OpenAIResponsesOutputText] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OpenAIResponsesResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "response"
|
||||
created_at: int
|
||||
status: str = "completed"
|
||||
model: str
|
||||
output: List[OpenAIResponsesOutputMessage] = Field(default_factory=list)
|
||||
error: Optional[Any] = None
|
||||
incomplete_details: Optional[Any] = None
|
||||
usage: OpenAIUsage
|
||||
|
||||
|
||||
class OpenAIErrorDetail(BaseModel):
|
||||
message: str
|
||||
type: str = "invalid_request_error"
|
||||
param: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIErrorResponse(BaseModel):
|
||||
error: OpenAIErrorDetail
|
||||
|
||||
|
||||
OpenAIChatContentPart = Dict[str, Any]
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
role: str
|
||||
content: Any
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AnthropicMessagesRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[AnthropicMessage]
|
||||
system: Optional[Any] = None
|
||||
max_tokens: Optional[int] = 1024
|
||||
stream: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AnthropicTextBlock(BaseModel):
|
||||
type: str = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class AnthropicUsage(BaseModel):
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
|
||||
class AnthropicMessagesResponse(BaseModel):
|
||||
id: str
|
||||
type: str = "message"
|
||||
role: str = "assistant"
|
||||
content: List[AnthropicTextBlock] = Field(default_factory=list)
|
||||
model: str
|
||||
stop_reason: str = "end_turn"
|
||||
stop_sequence: Optional[str] = None
|
||||
usage: AnthropicUsage = Field(default_factory=AnthropicUsage)
|
||||
|
||||
|
||||
class AnthropicErrorDetail(BaseModel):
|
||||
type: str = "invalid_request_error"
|
||||
message: str
|
||||
|
||||
|
||||
class AnthropicErrorResponse(BaseModel):
|
||||
type: str = "error"
|
||||
error: AnthropicErrorDetail
|
||||
120
tests/test_openai_utils.py
Normal file
120
tests/test_openai_utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from app.api.openai_utils import (
|
||||
build_anthropic_messages,
|
||||
build_completion_payload,
|
||||
build_prompt,
|
||||
build_responses_input,
|
||||
build_session_id,
|
||||
extract_text_and_images,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIUtilsTest(TestCase):
|
||||
def test_extract_text_and_images(self):
|
||||
text, images = extract_text_and_images(
|
||||
[
|
||||
{"type": "text", "text": "你好"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/a.png"}},
|
||||
{"type": "text", "text": "世界"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(text, "你好\n世界")
|
||||
self.assertEqual(images, ["https://example.com/a.png"])
|
||||
|
||||
def test_extract_text_and_images_with_input_image_and_base64_image(self):
|
||||
text, images = extract_text_and_images(
|
||||
[
|
||||
{"type": "input_text", "text": "看图"},
|
||||
{"type": "input_image", "image_url": "https://example.com/b.png"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "YWJj",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
self.assertEqual(text, "看图")
|
||||
self.assertEqual(
|
||||
images,
|
||||
["https://example.com/b.png", "data:image/png;base64,YWJj"],
|
||||
)
|
||||
|
||||
def test_build_prompt_without_server_session_keeps_recent_history(self):
|
||||
prompt, images = build_prompt(
|
||||
[
|
||||
{"role": "system", "content": "回答简短"},
|
||||
{"role": "user", "content": "第一句"},
|
||||
{"role": "assistant", "content": "第一答"},
|
||||
{"role": "user", "content": "第二句"},
|
||||
],
|
||||
use_server_session=False,
|
||||
)
|
||||
self.assertIn("系统要求:\n回答简短", prompt)
|
||||
self.assertIn("对话上下文:\nuser: 第一句\nassistant: 第一答", prompt)
|
||||
self.assertIn("当前用户消息:\n第二句", prompt)
|
||||
self.assertEqual(images, [])
|
||||
|
||||
def test_build_prompt_with_server_session_ignores_history_block(self):
|
||||
prompt, _ = build_prompt(
|
||||
[
|
||||
{"role": "user", "content": "历史问题"},
|
||||
{"role": "assistant", "content": "历史回答"},
|
||||
{"role": "user", "content": "当前问题"},
|
||||
],
|
||||
use_server_session=True,
|
||||
)
|
||||
self.assertNotIn("对话上下文:", prompt)
|
||||
self.assertIn("当前用户消息:\n当前问题", prompt)
|
||||
|
||||
def test_build_prompt_accepts_image_only_user_message(self):
|
||||
prompt, images = build_prompt(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/a.png"}}
|
||||
],
|
||||
}
|
||||
],
|
||||
use_server_session=True,
|
||||
)
|
||||
self.assertIn("请结合图片内容回复", prompt)
|
||||
self.assertEqual(images, ["https://example.com/a.png"])
|
||||
|
||||
def test_build_session_id_is_stable(self):
|
||||
session_id = build_session_id("user-1", "openai:")
|
||||
self.assertTrue(session_id.startswith("openai:"))
|
||||
self.assertEqual(session_id, build_session_id("user-1", "openai:"))
|
||||
self.assertNotEqual(session_id, build_session_id("user-2", "openai:"))
|
||||
|
||||
def test_build_completion_payload(self):
|
||||
payload = build_completion_payload("你好", "moviepilot-agent")
|
||||
self.assertEqual(payload["model"], "moviepilot-agent")
|
||||
self.assertEqual(payload["choices"][0]["message"]["content"], "你好")
|
||||
self.assertEqual(payload["choices"][0]["finish_reason"], "stop")
|
||||
|
||||
def test_build_responses_input(self):
|
||||
messages = build_responses_input(
|
||||
[
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "你好"}],
|
||||
}
|
||||
],
|
||||
instructions="你要简短回答",
|
||||
)
|
||||
self.assertEqual(messages[0]["role"], "system")
|
||||
self.assertEqual(messages[1]["role"], "user")
|
||||
|
||||
def test_build_anthropic_messages(self):
|
||||
messages = build_anthropic_messages(
|
||||
system=[{"type": "text", "text": "你是助手"}],
|
||||
messages=[{"role": "user", "content": "你好"}],
|
||||
)
|
||||
self.assertEqual(messages[0]["role"], "system")
|
||||
self.assertEqual(messages[1]["role"], "user")
|
||||
Reference in New Issue
Block a user