feat: add OpenAI protocol support and enhance AI provider configuration

This commit is contained in:
shiyu
2026-06-20 20:16:32 +08:00
parent 64fe02c23a
commit c8b43dbf4d
8 changed files with 583 additions and 14 deletions

View File

@@ -3,6 +3,19 @@ from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel, Field, field_validator
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
OPENAI_PROTOCOL_CHAT_COMPLETIONS = "chat_completions"
OPENAI_PROTOCOL_RESPONSES = "responses"
OPENAI_PROTOCOLS = {OPENAI_PROTOCOL_CHAT_COMPLETIONS, OPENAI_PROTOCOL_RESPONSES}
OPENAI_PROTOCOL_ALIASES = {
"chat": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
"chat_completion": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
"chat_completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
"chat/completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
"/chat/completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
"response": OPENAI_PROTOCOL_RESPONSES,
"responses": OPENAI_PROTOCOL_RESPONSES,
"/responses": OPENAI_PROTOCOL_RESPONSES,
}
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
@@ -16,6 +29,34 @@ def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
return normalized
def normalize_openai_protocol(value: Any) -> str:
if value is None:
return OPENAI_PROTOCOL_CHAT_COMPLETIONS
key = str(value).strip().lower().replace("-", "_").replace(".", "_")
if not key:
return OPENAI_PROTOCOL_CHAT_COMPLETIONS
normalized = OPENAI_PROTOCOL_ALIASES.get(key)
if normalized:
return normalized
normalized = OPENAI_PROTOCOL_ALIASES.get(key.replace("_", "/"))
if normalized:
return normalized
if key in OPENAI_PROTOCOLS:
return key
raise ValueError("openai_protocol must be 'chat_completions' or 'responses'")
def normalize_provider_extra_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
if not isinstance(config, dict):
raise ValueError("extra_config must be an object")
normalized = dict(config)
if "openai_protocol" in normalized:
normalized["openai_protocol"] = normalize_openai_protocol(normalized.get("openai_protocol"))
return normalized
class AIProviderBase(BaseModel):
name: str
identifier: str = Field(..., pattern=r"^[a-z0-9_\-\.]+$")
@@ -34,6 +75,11 @@ class AIProviderBase(BaseModel):
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
return fmt
@field_validator("extra_config")
@classmethod
def normalize_extra_config(cls, value: Optional[dict]) -> Optional[dict]:
return normalize_provider_extra_config(value)
class AIProviderCreate(AIProviderBase):
pass
@@ -58,6 +104,11 @@ class AIProviderUpdate(BaseModel):
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
return fmt
@field_validator("extra_config")
@classmethod
def normalize_extra_config(cls, value: Optional[dict]) -> Optional[dict]:
return normalize_provider_extra_config(value)
class AIModelBase(BaseModel):
name: str