diff --git a/domain/ai/__init__.py b/domain/ai/__init__.py index dffdb77..5f3961e 100644 --- a/domain/ai/__init__.py +++ b/domain/ai/__init__.py @@ -22,8 +22,11 @@ from .types import ( AIModelUpdate, AIProviderCreate, AIProviderUpdate, + OPENAI_PROTOCOL_CHAT_COMPLETIONS, + OPENAI_PROTOCOL_RESPONSES, VectorDBConfigPayload, normalize_capabilities, + normalize_openai_protocol, ) from .vector_providers import ( BaseVectorProvider, @@ -58,6 +61,9 @@ __all__ = [ "get_provider_class", "ABILITIES", "normalize_capabilities", + "normalize_openai_protocol", + "OPENAI_PROTOCOL_CHAT_COMPLETIONS", + "OPENAI_PROTOCOL_RESPONSES", "AIDefaultsUpdate", "AIModelCreate", "AIModelUpdate", diff --git a/domain/ai/api.py b/domain/ai/api.py index bbd8af7..4d0674d 100644 --- a/domain/ai/api.py +++ b/domain/ai/api.py @@ -34,7 +34,7 @@ async def list_providers_endpoint( @audit( action=AuditAction.CREATE, description="创建 AI 提供商", - body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"], + body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url", "extra_config"], redact_fields=["api_key"], ) @router_ai.post("/providers") @@ -61,7 +61,7 @@ async def get_provider( @audit( action=AuditAction.UPDATE, description="更新 AI 提供商", - body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"], + body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key", "extra_config"], redact_fields=["api_key"], ) @router_ai.put("/providers/{provider_id}") diff --git a/domain/ai/inference.py b/domain/ai/inference.py index 2f11a50..2d30d68 100644 --- a/domain/ai/inference.py +++ b/domain/ai/inference.py @@ -6,6 +6,7 @@ from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from models.database import AIModel, AIProvider from .service import AIProviderService +from .types import OPENAI_PROTOCOL_RESPONSES, normalize_openai_protocol provider_service = AIProviderService @@ -227,6 +228,248 @@ def _is_azure_openai(provider: AIProvider) -> bool: return ".openai.azure.com" in base_url +def _openai_protocol(provider: AIProvider) -> str: + extra = provider.extra_config if isinstance(provider.extra_config, dict) else {} + return normalize_openai_protocol(extra.get("openai_protocol")) + + +def _content_to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if not isinstance(part, dict): + continue + text = part.get("text") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + try: + return json.dumps(content, ensure_ascii=False) + except TypeError: + return str(content) + + +def _openai_content_to_responses_input(content: Any) -> Any: + if content is None: + return "" + if isinstance(content, str): + return content + if not isinstance(content, list): + return _content_to_text(content) + + blocks: List[Dict[str, Any]] = [] + for part in content: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type in {"text", "input_text"} and isinstance(part.get("text"), str): + blocks.append({"type": "input_text", "text": part["text"]}) + continue + if part_type == "image_url": + image_url = part.get("image_url") + url = image_url.get("url") if isinstance(image_url, dict) else image_url + if not isinstance(url, str) or not url.strip(): + continue + block: Dict[str, Any] = {"type": "input_image", "image_url": url} + detail = image_url.get("detail") if isinstance(image_url, dict) else None + if isinstance(detail, str) and detail.strip(): + block["detail"] = detail + blocks.append(block) + continue + if part_type == "input_image" and isinstance(part.get("image_url"), str): + block = {"type": "input_image", "image_url": part["image_url"]} + detail = part.get("detail") + if isinstance(detail, str) and detail.strip(): + block["detail"] = detail + blocks.append(block) + continue + if part_type == "input_file" and isinstance(part.get("file_id") or part.get("filename"), str): + blocks.append(dict(part)) + + return blocks or "" + + +def _openai_tool_call_to_responses_item(call: Dict[str, Any], idx: int) -> Dict[str, Any] | None: + fn = call.get("function") + fn = fn if isinstance(fn, dict) else {} + name = fn.get("name") + if not isinstance(name, str) or not name.strip(): + return None + + raw_args = fn.get("arguments") + if isinstance(raw_args, dict): + args_text = json.dumps(raw_args, ensure_ascii=False) + elif isinstance(raw_args, str): + args_text = raw_args + else: + args_text = "" + + return { + "type": "function_call", + "call_id": str(call.get("id") or f"call_{idx}"), + "name": name, + "arguments": args_text, + } + + +def _openai_messages_to_responses_input(messages: List[Dict[str, Any]]) -> Tuple[str, List[Dict[str, Any]]]: + instructions: List[str] = [] + input_items: List[Dict[str, Any]] = [] + + for msg in messages: + if not isinstance(msg, dict): + continue + role = msg.get("role") + if role in {"system", "developer"}: + text = _content_to_text(msg.get("content")).strip() + if text: + instructions.append(text) + continue + + if role == "tool": + content = msg.get("content") + output = content if isinstance(content, str) else _content_to_text(content) + call_id = str(msg.get("tool_call_id") or "").strip() + if call_id: + input_items.append({ + "type": "function_call_output", + "call_id": call_id, + "output": output, + }) + elif output: + input_items.append({"role": "user", "content": output}) + continue + + if role == "user": + input_items.append({ + "role": "user", + "content": _openai_content_to_responses_input(msg.get("content")), + }) + continue + + if role == "assistant": + text = _content_to_text(msg.get("content")) + if text: + input_items.append({"role": "assistant", "content": text}) + tool_calls = msg.get("tool_calls") + if isinstance(tool_calls, list): + for idx, call in enumerate(tool_calls): + if not isinstance(call, dict): + continue + item = _openai_tool_call_to_responses_item(call, idx) + if item: + input_items.append(item) + + if not input_items: + input_items.append({"role": "user", "content": ""}) + + return "\n\n".join(instructions).strip(), input_items + + +def _openai_tools_to_responses(tools: List[Dict[str, Any]] | None) -> List[Dict[str, Any]] | None: + if not tools: + return None + out: List[Dict[str, Any]] = [] + for tool in tools: + if not isinstance(tool, dict) or tool.get("type") != "function": + continue + fn = tool.get("function") + fn = fn if isinstance(fn, dict) else {} + name = fn.get("name") + if not isinstance(name, str) or not name.strip(): + continue + entry: Dict[str, Any] = { + "type": "function", + "name": name, + "description": fn.get("description") if isinstance(fn.get("description"), str) else "", + "parameters": fn.get("parameters") if isinstance(fn.get("parameters"), dict) else {}, + } + if isinstance(fn.get("strict"), bool): + entry["strict"] = fn["strict"] + out.append(entry) + return out or None + + +def _openai_tool_choice_to_responses(tool_choice: Any) -> Any | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + return tool_choice + if not isinstance(tool_choice, dict): + return None + if tool_choice.get("type") == "function": + fn = tool_choice.get("function") + name = fn.get("name") if isinstance(fn, dict) else tool_choice.get("name") + if isinstance(name, str) and name.strip(): + return {"type": "function", "name": name} + return None + + +def _responses_function_call_to_openai(item: Dict[str, Any], idx: int) -> Dict[str, Any] | None: + name = item.get("name") + if not isinstance(name, str) or not name.strip(): + return None + raw_args = item.get("arguments") + if isinstance(raw_args, dict): + args_text = json.dumps(raw_args, ensure_ascii=False) + elif isinstance(raw_args, str): + args_text = raw_args + else: + args_text = "" + return { + "id": str(item.get("call_id") or item.get("id") or f"call_{idx}"), + "type": "function", + "function": {"name": name, "arguments": args_text}, + } + + +def _responses_content_to_text(content: Any) -> str: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + parts: List[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") in {"output_text", "text"} and isinstance(block.get("text"), str): + parts.append(block["text"]) + return "".join(parts) + + +def _responses_body_to_openai_message(body: Dict[str, Any]) -> Dict[str, Any]: + text_parts: List[str] = [] + tool_calls: List[Dict[str, Any]] = [] + output = body.get("output") + + if isinstance(output, list): + for idx, item in enumerate(output): + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type == "message": + text = _responses_content_to_text(item.get("content")) + if text: + text_parts.append(text) + continue + if item_type == "function_call": + tool_call = _responses_function_call_to_openai(item, idx) + if tool_call: + tool_calls.append(tool_call) + + if not text_parts and isinstance(body.get("output_text"), str): + text_parts.append(body["output_text"]) + + message: Dict[str, Any] = {"role": "assistant", "content": "".join(text_parts)} + if tool_calls: + message["tool_calls"] = tool_calls + return message + + def _gemini_endpoint(provider: AIProvider, path: str) -> str: base = (provider.base_url or "").rstrip("/") if not base: @@ -808,6 +1051,12 @@ async def _chat_stream_with_ollama( async def _describe_with_openai(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str: + if _openai_protocol(provider) == OPENAI_PROTOCOL_RESPONSES: + return await _describe_with_openai_responses(provider, model, base64_image, detail) + return await _describe_with_openai_chat_completions(provider, model, base64_image, detail) + + +async def _describe_with_openai_chat_completions(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str: url = _openai_endpoint(provider, "/chat/completions") payload = { "model": model.name, @@ -834,6 +1083,32 @@ async def _describe_with_openai(provider: AIProvider, model: AIModel, base64_ima return body["choices"][0]["message"]["content"] +async def _describe_with_openai_responses(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str: + url = _openai_endpoint(provider, "/responses") + payload = { + "model": model.name, + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_image}", + "detail": detail, + }, + {"type": "input_text", "text": "描述这个图片"}, + ], + } + ], + } + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post(url, headers=_openai_headers(provider), json=payload) + response.raise_for_status() + body = response.json() + message = _responses_body_to_openai_message(body if isinstance(body, dict) else {}) + return str(message.get("content") or "") + + async def _describe_with_anthropic(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str: url = _anthropic_endpoint(provider, "/messages") detail_text = f"描述这个图片,细节等级:{detail}" @@ -1080,6 +1355,37 @@ async def _chat_with_openai( tool_choice: Any | None, temperature: float | None, timeout: float, +) -> Dict[str, Any]: + if _openai_protocol(provider) == OPENAI_PROTOCOL_RESPONSES: + return await _chat_with_openai_responses( + provider, + model, + messages, + tools=tools, + tool_choice=tool_choice, + temperature=temperature, + timeout=timeout, + ) + return await _chat_with_openai_chat_completions( + provider, + model, + messages, + tools=tools, + tool_choice=tool_choice, + temperature=temperature, + timeout=timeout, + ) + + +async def _chat_with_openai_chat_completions( + provider: AIProvider, + model: AIModel, + messages: List[Dict[str, Any]], + *, + tools: List[Dict[str, Any]] | None, + tool_choice: Any | None, + temperature: float | None, + timeout: float, ) -> Dict[str, Any]: url = _openai_endpoint(provider, "/chat/completions") payload: Dict[str, Any] = { @@ -1106,6 +1412,41 @@ async def _chat_with_openai( return message +async def _chat_with_openai_responses( + provider: AIProvider, + model: AIModel, + messages: List[Dict[str, Any]], + *, + tools: List[Dict[str, Any]] | None, + tool_choice: Any | None, + temperature: float | None, + timeout: float, +) -> Dict[str, Any]: + url = _openai_endpoint(provider, "/responses") + instructions, input_items = _openai_messages_to_responses_input(messages) + payload: Dict[str, Any] = { + "model": model.name, + "input": input_items, + } + if instructions: + payload["instructions"] = instructions + response_tools = _openai_tools_to_responses(tools) + if response_tools: + payload["tools"] = response_tools + payload["tool_choice"] = _openai_tool_choice_to_responses(tool_choice) or "auto" + if temperature is not None: + payload["temperature"] = float(temperature) + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post(url, headers=_openai_headers(provider), json=payload) + response.raise_for_status() + body = response.json() + + if not isinstance(body, dict): + raise RuntimeError("Responses 接口返回格式异常") + return _responses_body_to_openai_message(body) + + async def chat_completion_stream( messages: List[Dict[str, Any]], *, @@ -1174,6 +1515,40 @@ async def _chat_stream_with_openai( tool_choice: Any | None, temperature: float | None, timeout: float, +) -> AsyncIterator[Dict[str, Any]]: + if _openai_protocol(provider) == OPENAI_PROTOCOL_RESPONSES: + async for event in _chat_stream_with_openai_responses( + provider, + model, + messages, + tools=tools, + tool_choice=tool_choice, + temperature=temperature, + timeout=timeout, + ): + yield event + return + async for event in _chat_stream_with_openai_chat_completions( + provider, + model, + messages, + tools=tools, + tool_choice=tool_choice, + temperature=temperature, + timeout=timeout, + ): + yield event + + +async def _chat_stream_with_openai_chat_completions( + provider: AIProvider, + model: AIModel, + messages: List[Dict[str, Any]], + *, + tools: List[Dict[str, Any]] | None, + tool_choice: Any | None, + temperature: float | None, + timeout: float, ) -> AsyncIterator[Dict[str, Any]]: url = _openai_endpoint(provider, "/chat/completions") payload: Dict[str, Any] = { @@ -1273,3 +1648,100 @@ async def _chat_stream_with_openai( message["tool_calls"] = tool_calls yield {"type": "message", "message": message, "finish_reason": finish_reason} + + +async def _chat_stream_with_openai_responses( + provider: AIProvider, + model: AIModel, + messages: List[Dict[str, Any]], + *, + tools: List[Dict[str, Any]] | None, + tool_choice: Any | None, + temperature: float | None, + timeout: float, +) -> AsyncIterator[Dict[str, Any]]: + url = _openai_endpoint(provider, "/responses") + instructions, input_items = _openai_messages_to_responses_input(messages) + payload: Dict[str, Any] = { + "model": model.name, + "input": input_items, + "stream": True, + } + if instructions: + payload["instructions"] = instructions + response_tools = _openai_tools_to_responses(tools) + if response_tools: + payload["tools"] = response_tools + payload["tool_choice"] = _openai_tool_choice_to_responses(tool_choice) or "auto" + if temperature is not None: + payload["temperature"] = float(temperature) + + content_parts: List[str] = [] + output_items: Dict[int, Dict[str, Any]] = {} + completed_body: Dict[str, Any] | None = None + current_event: str | None = None + + async with httpx.AsyncClient(timeout=timeout) as client: + async with client.stream("POST", url, headers=_openai_headers(provider), json=payload) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line: + continue + if line.startswith("event:"): + current_event = line[6:].strip() + continue + if not line.startswith("data:"): + continue + data = line[5:].strip() + if not data or data == "[DONE]": + continue + try: + chunk = json.loads(data) + except json.JSONDecodeError: + continue + if not isinstance(chunk, dict): + continue + + event_type = chunk.get("type") or current_event + if event_type == "response.output_text.delta": + delta = chunk.get("delta") + if isinstance(delta, str) and delta: + content_parts.append(delta) + yield {"type": "delta", "delta": delta} + continue + + if event_type in {"response.output_item.added", "response.output_item.done"}: + item = chunk.get("item") + idx = chunk.get("output_index") + if isinstance(item, dict) and isinstance(idx, int): + output_items[idx] = item + continue + + if event_type == "response.completed": + response_body = chunk.get("response") + if isinstance(response_body, dict): + completed_body = response_body + elif isinstance(chunk.get("output"), list): + completed_body = chunk + break + + if event_type == "error": + error = chunk.get("error") if isinstance(chunk.get("error"), dict) else chunk + raise RuntimeError(str(error.get("message") if isinstance(error, dict) else error)) + + if completed_body is not None: + message = _responses_body_to_openai_message(completed_body) + if not message.get("content") and content_parts: + message["content"] = "".join(content_parts) + yield {"type": "message", "message": message, "finish_reason": None} + return + + if output_items: + body = {"output": [output_items[idx] for idx in sorted(output_items.keys())]} + message = _responses_body_to_openai_message(body) + if not message.get("content") and content_parts: + message["content"] = "".join(content_parts) + yield {"type": "message", "message": message, "finish_reason": None} + return + + yield {"type": "message", "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": None} diff --git a/domain/ai/types.py b/domain/ai/types.py index 71be85a..ec8f53a 100644 --- a/domain/ai/types.py +++ b/domain/ai/types.py @@ -3,6 +3,19 @@ from typing import Any, Dict, Iterable, List, Optional from pydantic import BaseModel, Field, field_validator ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"] +OPENAI_PROTOCOL_CHAT_COMPLETIONS = "chat_completions" +OPENAI_PROTOCOL_RESPONSES = "responses" +OPENAI_PROTOCOLS = {OPENAI_PROTOCOL_CHAT_COMPLETIONS, OPENAI_PROTOCOL_RESPONSES} +OPENAI_PROTOCOL_ALIASES = { + "chat": OPENAI_PROTOCOL_CHAT_COMPLETIONS, + "chat_completion": OPENAI_PROTOCOL_CHAT_COMPLETIONS, + "chat_completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS, + "chat/completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS, + "/chat/completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS, + "response": OPENAI_PROTOCOL_RESPONSES, + "responses": OPENAI_PROTOCOL_RESPONSES, + "/responses": OPENAI_PROTOCOL_RESPONSES, +} def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]: @@ -16,6 +29,34 @@ def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]: return normalized +def normalize_openai_protocol(value: Any) -> str: + if value is None: + return OPENAI_PROTOCOL_CHAT_COMPLETIONS + key = str(value).strip().lower().replace("-", "_").replace(".", "_") + if not key: + return OPENAI_PROTOCOL_CHAT_COMPLETIONS + normalized = OPENAI_PROTOCOL_ALIASES.get(key) + if normalized: + return normalized + normalized = OPENAI_PROTOCOL_ALIASES.get(key.replace("_", "/")) + if normalized: + return normalized + if key in OPENAI_PROTOCOLS: + return key + raise ValueError("openai_protocol must be 'chat_completions' or 'responses'") + + +def normalize_provider_extra_config(config: Optional[dict]) -> Optional[dict]: + if config is None: + return None + if not isinstance(config, dict): + raise ValueError("extra_config must be an object") + normalized = dict(config) + if "openai_protocol" in normalized: + normalized["openai_protocol"] = normalize_openai_protocol(normalized.get("openai_protocol")) + return normalized + + class AIProviderBase(BaseModel): name: str identifier: str = Field(..., pattern=r"^[a-z0-9_\-\.]+$") @@ -34,6 +75,11 @@ class AIProviderBase(BaseModel): raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'") return fmt + @field_validator("extra_config") + @classmethod + def normalize_extra_config(cls, value: Optional[dict]) -> Optional[dict]: + return normalize_provider_extra_config(value) + class AIProviderCreate(AIProviderBase): pass @@ -58,6 +104,11 @@ class AIProviderUpdate(BaseModel): raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'") return fmt + @field_validator("extra_config") + @classmethod + def normalize_extra_config(cls, value: Optional[dict]) -> Optional[dict]: + return normalize_provider_extra_config(value) + class AIModelBase(BaseModel): name: str diff --git a/web/src/api/aiProviders.ts b/web/src/api/aiProviders.ts index d16e9ac..506cf36 100644 --- a/web/src/api/aiProviders.ts +++ b/web/src/api/aiProviders.ts @@ -1,6 +1,7 @@ import request from './client'; export type AIAbility = 'chat' | 'vision' | 'embedding' | 'rerank' | 'voice' | 'tools'; +export type OpenAIProtocol = 'chat_completions' | 'responses'; export interface AIProviderPayload { name: string; diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index a12a83c..f3f597d 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -516,6 +516,7 @@ "Enter identifier": "Enter identifier", "Only lowercase letters, numbers, dash, dot and underscore are allowed": "Only lowercase letters, numbers, dash, dot and underscore are allowed", "API Format": "API Format", + "OpenAI Protocol": "OpenAI Protocol", "Base URL": "Base URL", "Enter base url": "Enter base URL", "Optional, can also be provided per request": "Optional, can also be provided per request", diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index 42618c5..30bc00e 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -515,6 +515,7 @@ "Enter identifier": "请输入标识符", "Only lowercase letters, numbers, dash, dot and underscore are allowed": "仅允许小写字母、数字、连字符、点和下划线", "API Format": "API 格式", + "OpenAI Protocol": "OpenAI 协议", "Base URL": "基础 URL", "Enter base url": "请输入基础 URL", "Optional, can also be provided per request": "可选,也可在请求时提供", diff --git a/web/src/pages/SystemSettingsPage/components/AiSettingsTab.tsx b/web/src/pages/SystemSettingsPage/components/AiSettingsTab.tsx index 355b318..866dfc6 100644 --- a/web/src/pages/SystemSettingsPage/components/AiSettingsTab.tsx +++ b/web/src/pages/SystemSettingsPage/components/AiSettingsTab.tsx @@ -48,6 +48,7 @@ import type { AIModelPayload, AIProvider, AIProviderPayload, + OpenAIProtocol, } from '../../../api/aiProviders'; import { createModel, @@ -90,6 +91,11 @@ interface ProviderTemplate { } const abilityOrder: AIAbility[] = ['chat', 'vision', 'embedding', 'rerank', 'voice', 'tools']; +const defaultOpenAIProtocol: OpenAIProtocol = 'chat_completions'; + +function normalizeOpenAIProtocol(value: unknown): OpenAIProtocol { + return value === 'responses' ? 'responses' : defaultOpenAIProtocol; +} const abilityInfo: Record = { chat: { @@ -242,6 +248,7 @@ type AIProviderFormValues = { name?: string; identifier?: string; api_format: AIProviderPayload['api_format']; + openai_protocol?: OpenAIProtocol; base_url?: string; api_key?: string; logo_url?: string; @@ -275,12 +282,19 @@ export default function AiSettingsTab() { const [addingRemoteModels, setAddingRemoteModels] = useState(false); const [modelModalTab, setModelModalTab] = useState<'remote' | 'manual'>('remote'); const [remoteSearchKeyword, setRemoteSearchKeyword] = useState(''); + const providerApiFormat = Form.useWatch('api_format', providerForm); const capabilitiesValue = Form.useWatch('capabilities', modelForm); const showEmbeddingDimensions = useMemo(() => { const capabilities = Array.isArray(capabilitiesValue) ? capabilitiesValue : []; return capabilities.includes('embedding') || capabilities.includes('rerank'); }, [capabilitiesValue]); + useEffect(() => { + if (providerApiFormat === 'openai' && !providerForm.getFieldValue('openai_protocol')) { + providerForm.setFieldValue('openai_protocol', defaultOpenAIProtocol); + } + }, [providerApiFormat, providerForm]); + useEffect(() => { if (!showEmbeddingDimensions) { modelForm.setFieldsValue({ embedding_dimensions: null }); @@ -338,6 +352,7 @@ export default function AiSettingsTab() { name: existing.name, identifier: existing.identifier, api_format: existing.api_format, + openai_protocol: normalizeOpenAIProtocol(existing.extra_config?.openai_protocol), base_url: existing.base_url ?? undefined, api_key: '', logo_url: existing.logo_url ?? undefined, @@ -345,7 +360,7 @@ export default function AiSettingsTab() { }); } else { providerForm.resetFields(); - providerForm.setFieldsValue({ api_format: 'openai' }); + providerForm.setFieldsValue({ api_format: 'openai', openai_protocol: defaultOpenAIProtocol }); setSelectedTemplate(null); setProviderModal({ open: true, step: 1 }); } @@ -364,6 +379,7 @@ export default function AiSettingsTab() { name: t(template.nameKey), identifier: template.identifier, api_format: template.api_format, + openai_protocol: template.api_format === 'openai' ? defaultOpenAIProtocol : undefined, base_url: template.base_url ?? '', api_key: '', logo_url: template.logo_url ?? '', @@ -375,7 +391,7 @@ export default function AiSettingsTab() { setProviderModal((prev) => ({ ...prev, step: 1, editing: undefined })); setSelectedTemplate(null); providerForm.resetFields(); - providerForm.setFieldsValue({ api_format: 'openai' }); + providerForm.setFieldsValue({ api_format: 'openai', openai_protocol: defaultOpenAIProtocol }); }; const handleSubmitProvider = async () => { @@ -384,6 +400,12 @@ export default function AiSettingsTab() { const trimmedApiKey = values.api_key?.trim(); const trimmedLogoUrl = values.logo_url?.trim(); const trimmedProviderType = values.provider_type?.trim(); + const extraConfig = { ...(providerModal.editing?.extra_config ?? {}) }; + if (values.api_format === 'openai') { + extraConfig.openai_protocol = normalizeOpenAIProtocol(values.openai_protocol); + } else { + delete extraConfig.openai_protocol; + } const payload: AIProviderPayload = { name: (values.name || '').trim(), identifier: (values.identifier || '').trim(), @@ -391,6 +413,7 @@ export default function AiSettingsTab() { base_url: trimmedBaseUrl ? trimmedBaseUrl : null, logo_url: trimmedLogoUrl ? trimmedLogoUrl : null, provider_type: trimmedProviderType ? trimmedProviderType : null, + extra_config: Object.keys(extraConfig).length ? extraConfig : null, }; if (trimmedApiKey) { payload.api_key = trimmedApiKey; @@ -1117,16 +1140,30 @@ export default function AiSettingsTab() { label={t('API Format')} rules={[{ required: true }]} > - + + {providerApiFormat === 'openai' ? ( + +