mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-06-21 07:13:55 +08:00
feat: add OpenAI protocol support and enhance AI provider configuration
This commit is contained in:
@@ -22,8 +22,11 @@ from .types import (
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
OPENAI_PROTOCOL_CHAT_COMPLETIONS,
|
||||
OPENAI_PROTOCOL_RESPONSES,
|
||||
VectorDBConfigPayload,
|
||||
normalize_capabilities,
|
||||
normalize_openai_protocol,
|
||||
)
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
@@ -58,6 +61,9 @@ __all__ = [
|
||||
"get_provider_class",
|
||||
"ABILITIES",
|
||||
"normalize_capabilities",
|
||||
"normalize_openai_protocol",
|
||||
"OPENAI_PROTOCOL_CHAT_COMPLETIONS",
|
||||
"OPENAI_PROTOCOL_RESPONSES",
|
||||
"AIDefaultsUpdate",
|
||||
"AIModelCreate",
|
||||
"AIModelUpdate",
|
||||
|
||||
@@ -34,7 +34,7 @@ async def list_providers_endpoint(
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建 AI 提供商",
|
||||
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
|
||||
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url", "extra_config"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.post("/providers")
|
||||
@@ -61,7 +61,7 @@ async def get_provider(
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新 AI 提供商",
|
||||
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
|
||||
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key", "extra_config"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.put("/providers/{provider_id}")
|
||||
|
||||
@@ -6,6 +6,7 @@ from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
|
||||
from models.database import AIModel, AIProvider
|
||||
from .service import AIProviderService
|
||||
from .types import OPENAI_PROTOCOL_RESPONSES, normalize_openai_protocol
|
||||
|
||||
|
||||
provider_service = AIProviderService
|
||||
@@ -227,6 +228,248 @@ def _is_azure_openai(provider: AIProvider) -> bool:
|
||||
return ".openai.azure.com" in base_url
|
||||
|
||||
|
||||
def _openai_protocol(provider: AIProvider) -> str:
|
||||
extra = provider.extra_config if isinstance(provider.extra_config, dict) else {}
|
||||
return normalize_openai_protocol(extra.get("openai_protocol"))
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _openai_content_to_responses_input(content: Any) -> Any:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not isinstance(content, list):
|
||||
return _content_to_text(content)
|
||||
|
||||
blocks: List[Dict[str, Any]] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
part_type = part.get("type")
|
||||
if part_type in {"text", "input_text"} and isinstance(part.get("text"), str):
|
||||
blocks.append({"type": "input_text", "text": part["text"]})
|
||||
continue
|
||||
if part_type == "image_url":
|
||||
image_url = part.get("image_url")
|
||||
url = image_url.get("url") if isinstance(image_url, dict) else image_url
|
||||
if not isinstance(url, str) or not url.strip():
|
||||
continue
|
||||
block: Dict[str, Any] = {"type": "input_image", "image_url": url}
|
||||
detail = image_url.get("detail") if isinstance(image_url, dict) else None
|
||||
if isinstance(detail, str) and detail.strip():
|
||||
block["detail"] = detail
|
||||
blocks.append(block)
|
||||
continue
|
||||
if part_type == "input_image" and isinstance(part.get("image_url"), str):
|
||||
block = {"type": "input_image", "image_url": part["image_url"]}
|
||||
detail = part.get("detail")
|
||||
if isinstance(detail, str) and detail.strip():
|
||||
block["detail"] = detail
|
||||
blocks.append(block)
|
||||
continue
|
||||
if part_type == "input_file" and isinstance(part.get("file_id") or part.get("filename"), str):
|
||||
blocks.append(dict(part))
|
||||
|
||||
return blocks or ""
|
||||
|
||||
|
||||
def _openai_tool_call_to_responses_item(call: Dict[str, Any], idx: int) -> Dict[str, Any] | None:
|
||||
fn = call.get("function")
|
||||
fn = fn if isinstance(fn, dict) else {}
|
||||
name = fn.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
return None
|
||||
|
||||
raw_args = fn.get("arguments")
|
||||
if isinstance(raw_args, dict):
|
||||
args_text = json.dumps(raw_args, ensure_ascii=False)
|
||||
elif isinstance(raw_args, str):
|
||||
args_text = raw_args
|
||||
else:
|
||||
args_text = ""
|
||||
|
||||
return {
|
||||
"type": "function_call",
|
||||
"call_id": str(call.get("id") or f"call_{idx}"),
|
||||
"name": name,
|
||||
"arguments": args_text,
|
||||
}
|
||||
|
||||
|
||||
def _openai_messages_to_responses_input(messages: List[Dict[str, Any]]) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
instructions: List[str] = []
|
||||
input_items: List[Dict[str, Any]] = []
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
role = msg.get("role")
|
||||
if role in {"system", "developer"}:
|
||||
text = _content_to_text(msg.get("content")).strip()
|
||||
if text:
|
||||
instructions.append(text)
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
content = msg.get("content")
|
||||
output = content if isinstance(content, str) else _content_to_text(content)
|
||||
call_id = str(msg.get("tool_call_id") or "").strip()
|
||||
if call_id:
|
||||
input_items.append({
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output,
|
||||
})
|
||||
elif output:
|
||||
input_items.append({"role": "user", "content": output})
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append({
|
||||
"role": "user",
|
||||
"content": _openai_content_to_responses_input(msg.get("content")),
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
text = _content_to_text(msg.get("content"))
|
||||
if text:
|
||||
input_items.append({"role": "assistant", "content": text})
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for idx, call in enumerate(tool_calls):
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
item = _openai_tool_call_to_responses_item(call, idx)
|
||||
if item:
|
||||
input_items.append(item)
|
||||
|
||||
if not input_items:
|
||||
input_items.append({"role": "user", "content": ""})
|
||||
|
||||
return "\n\n".join(instructions).strip(), input_items
|
||||
|
||||
|
||||
def _openai_tools_to_responses(tools: List[Dict[str, Any]] | None) -> List[Dict[str, Any]] | None:
|
||||
if not tools:
|
||||
return None
|
||||
out: List[Dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict) or tool.get("type") != "function":
|
||||
continue
|
||||
fn = tool.get("function")
|
||||
fn = fn if isinstance(fn, dict) else {}
|
||||
name = fn.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
continue
|
||||
entry: Dict[str, Any] = {
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") if isinstance(fn.get("description"), str) else "",
|
||||
"parameters": fn.get("parameters") if isinstance(fn.get("parameters"), dict) else {},
|
||||
}
|
||||
if isinstance(fn.get("strict"), bool):
|
||||
entry["strict"] = fn["strict"]
|
||||
out.append(entry)
|
||||
return out or None
|
||||
|
||||
|
||||
def _openai_tool_choice_to_responses(tool_choice: Any) -> Any | None:
|
||||
if tool_choice is None:
|
||||
return None
|
||||
if isinstance(tool_choice, str):
|
||||
return tool_choice
|
||||
if not isinstance(tool_choice, dict):
|
||||
return None
|
||||
if tool_choice.get("type") == "function":
|
||||
fn = tool_choice.get("function")
|
||||
name = fn.get("name") if isinstance(fn, dict) else tool_choice.get("name")
|
||||
if isinstance(name, str) and name.strip():
|
||||
return {"type": "function", "name": name}
|
||||
return None
|
||||
|
||||
|
||||
def _responses_function_call_to_openai(item: Dict[str, Any], idx: int) -> Dict[str, Any] | None:
|
||||
name = item.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
return None
|
||||
raw_args = item.get("arguments")
|
||||
if isinstance(raw_args, dict):
|
||||
args_text = json.dumps(raw_args, ensure_ascii=False)
|
||||
elif isinstance(raw_args, str):
|
||||
args_text = raw_args
|
||||
else:
|
||||
args_text = ""
|
||||
return {
|
||||
"id": str(item.get("call_id") or item.get("id") or f"call_{idx}"),
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": args_text},
|
||||
}
|
||||
|
||||
|
||||
def _responses_content_to_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
parts: List[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") in {"output_text", "text"} and isinstance(block.get("text"), str):
|
||||
parts.append(block["text"])
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _responses_body_to_openai_message(body: Dict[str, Any]) -> Dict[str, Any]:
|
||||
text_parts: List[str] = []
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
output = body.get("output")
|
||||
|
||||
if isinstance(output, list):
|
||||
for idx, item in enumerate(output):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "message":
|
||||
text = _responses_content_to_text(item.get("content"))
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
continue
|
||||
if item_type == "function_call":
|
||||
tool_call = _responses_function_call_to_openai(item, idx)
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if not text_parts and isinstance(body.get("output_text"), str):
|
||||
text_parts.append(body["output_text"])
|
||||
|
||||
message: Dict[str, Any] = {"role": "assistant", "content": "".join(text_parts)}
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
return message
|
||||
|
||||
|
||||
def _gemini_endpoint(provider: AIProvider, path: str) -> str:
|
||||
base = (provider.base_url or "").rstrip("/")
|
||||
if not base:
|
||||
@@ -808,6 +1051,12 @@ async def _chat_stream_with_ollama(
|
||||
|
||||
|
||||
async def _describe_with_openai(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
if _openai_protocol(provider) == OPENAI_PROTOCOL_RESPONSES:
|
||||
return await _describe_with_openai_responses(provider, model, base64_image, detail)
|
||||
return await _describe_with_openai_chat_completions(provider, model, base64_image, detail)
|
||||
|
||||
|
||||
async def _describe_with_openai_chat_completions(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
url = _openai_endpoint(provider, "/chat/completions")
|
||||
payload = {
|
||||
"model": model.name,
|
||||
@@ -834,6 +1083,32 @@ async def _describe_with_openai(provider: AIProvider, model: AIModel, base64_ima
|
||||
return body["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
async def _describe_with_openai_responses(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
url = _openai_endpoint(provider, "/responses")
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": detail,
|
||||
},
|
||||
{"type": "input_text", "text": "描述这个图片"},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
message = _responses_body_to_openai_message(body if isinstance(body, dict) else {})
|
||||
return str(message.get("content") or "")
|
||||
|
||||
|
||||
async def _describe_with_anthropic(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
url = _anthropic_endpoint(provider, "/messages")
|
||||
detail_text = f"描述这个图片,细节等级:{detail}"
|
||||
@@ -1080,6 +1355,37 @@ async def _chat_with_openai(
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
if _openai_protocol(provider) == OPENAI_PROTOCOL_RESPONSES:
|
||||
return await _chat_with_openai_responses(
|
||||
provider,
|
||||
model,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
)
|
||||
return await _chat_with_openai_chat_completions(
|
||||
provider,
|
||||
model,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
async def _chat_with_openai_chat_completions(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
tools: List[Dict[str, Any]] | None,
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
url = _openai_endpoint(provider, "/chat/completions")
|
||||
payload: Dict[str, Any] = {
|
||||
@@ -1106,6 +1412,41 @@ async def _chat_with_openai(
|
||||
return message
|
||||
|
||||
|
||||
async def _chat_with_openai_responses(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
tools: List[Dict[str, Any]] | None,
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
url = _openai_endpoint(provider, "/responses")
|
||||
instructions, input_items = _openai_messages_to_responses_input(messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"input": input_items,
|
||||
}
|
||||
if instructions:
|
||||
payload["instructions"] = instructions
|
||||
response_tools = _openai_tools_to_responses(tools)
|
||||
if response_tools:
|
||||
payload["tools"] = response_tools
|
||||
payload["tool_choice"] = _openai_tool_choice_to_responses(tool_choice) or "auto"
|
||||
if temperature is not None:
|
||||
payload["temperature"] = float(temperature)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise RuntimeError("Responses 接口返回格式异常")
|
||||
return _responses_body_to_openai_message(body)
|
||||
|
||||
|
||||
async def chat_completion_stream(
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
@@ -1174,6 +1515,40 @@ async def _chat_stream_with_openai(
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
if _openai_protocol(provider) == OPENAI_PROTOCOL_RESPONSES:
|
||||
async for event in _chat_stream_with_openai_responses(
|
||||
provider,
|
||||
model,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
async for event in _chat_stream_with_openai_chat_completions(
|
||||
provider,
|
||||
model,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def _chat_stream_with_openai_chat_completions(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
tools: List[Dict[str, Any]] | None,
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
url = _openai_endpoint(provider, "/chat/completions")
|
||||
payload: Dict[str, Any] = {
|
||||
@@ -1273,3 +1648,100 @@ async def _chat_stream_with_openai(
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
yield {"type": "message", "message": message, "finish_reason": finish_reason}
|
||||
|
||||
|
||||
async def _chat_stream_with_openai_responses(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
tools: List[Dict[str, Any]] | None,
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
url = _openai_endpoint(provider, "/responses")
|
||||
instructions, input_items = _openai_messages_to_responses_input(messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"input": input_items,
|
||||
"stream": True,
|
||||
}
|
||||
if instructions:
|
||||
payload["instructions"] = instructions
|
||||
response_tools = _openai_tools_to_responses(tools)
|
||||
if response_tools:
|
||||
payload["tools"] = response_tools
|
||||
payload["tool_choice"] = _openai_tool_choice_to_responses(tool_choice) or "auto"
|
||||
if temperature is not None:
|
||||
payload["temperature"] = float(temperature)
|
||||
|
||||
content_parts: List[str] = []
|
||||
output_items: Dict[int, Dict[str, Any]] = {}
|
||||
completed_body: Dict[str, Any] | None = None
|
||||
current_event: str | None = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("POST", url, headers=_openai_headers(provider), json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("event:"):
|
||||
current_event = line[6:].strip()
|
||||
continue
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
|
||||
event_type = chunk.get("type") or current_event
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = chunk.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
content_parts.append(delta)
|
||||
yield {"type": "delta", "delta": delta}
|
||||
continue
|
||||
|
||||
if event_type in {"response.output_item.added", "response.output_item.done"}:
|
||||
item = chunk.get("item")
|
||||
idx = chunk.get("output_index")
|
||||
if isinstance(item, dict) and isinstance(idx, int):
|
||||
output_items[idx] = item
|
||||
continue
|
||||
|
||||
if event_type == "response.completed":
|
||||
response_body = chunk.get("response")
|
||||
if isinstance(response_body, dict):
|
||||
completed_body = response_body
|
||||
elif isinstance(chunk.get("output"), list):
|
||||
completed_body = chunk
|
||||
break
|
||||
|
||||
if event_type == "error":
|
||||
error = chunk.get("error") if isinstance(chunk.get("error"), dict) else chunk
|
||||
raise RuntimeError(str(error.get("message") if isinstance(error, dict) else error))
|
||||
|
||||
if completed_body is not None:
|
||||
message = _responses_body_to_openai_message(completed_body)
|
||||
if not message.get("content") and content_parts:
|
||||
message["content"] = "".join(content_parts)
|
||||
yield {"type": "message", "message": message, "finish_reason": None}
|
||||
return
|
||||
|
||||
if output_items:
|
||||
body = {"output": [output_items[idx] for idx in sorted(output_items.keys())]}
|
||||
message = _responses_body_to_openai_message(body)
|
||||
if not message.get("content") and content_parts:
|
||||
message["content"] = "".join(content_parts)
|
||||
yield {"type": "message", "message": message, "finish_reason": None}
|
||||
return
|
||||
|
||||
yield {"type": "message", "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": None}
|
||||
|
||||
@@ -3,6 +3,19 @@ from typing import Any, Dict, Iterable, List, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
|
||||
OPENAI_PROTOCOL_CHAT_COMPLETIONS = "chat_completions"
|
||||
OPENAI_PROTOCOL_RESPONSES = "responses"
|
||||
OPENAI_PROTOCOLS = {OPENAI_PROTOCOL_CHAT_COMPLETIONS, OPENAI_PROTOCOL_RESPONSES}
|
||||
OPENAI_PROTOCOL_ALIASES = {
|
||||
"chat": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
|
||||
"chat_completion": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
|
||||
"chat_completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
|
||||
"chat/completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
|
||||
"/chat/completions": OPENAI_PROTOCOL_CHAT_COMPLETIONS,
|
||||
"response": OPENAI_PROTOCOL_RESPONSES,
|
||||
"responses": OPENAI_PROTOCOL_RESPONSES,
|
||||
"/responses": OPENAI_PROTOCOL_RESPONSES,
|
||||
}
|
||||
|
||||
|
||||
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
@@ -16,6 +29,34 @@ def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_openai_protocol(value: Any) -> str:
|
||||
if value is None:
|
||||
return OPENAI_PROTOCOL_CHAT_COMPLETIONS
|
||||
key = str(value).strip().lower().replace("-", "_").replace(".", "_")
|
||||
if not key:
|
||||
return OPENAI_PROTOCOL_CHAT_COMPLETIONS
|
||||
normalized = OPENAI_PROTOCOL_ALIASES.get(key)
|
||||
if normalized:
|
||||
return normalized
|
||||
normalized = OPENAI_PROTOCOL_ALIASES.get(key.replace("_", "/"))
|
||||
if normalized:
|
||||
return normalized
|
||||
if key in OPENAI_PROTOCOLS:
|
||||
return key
|
||||
raise ValueError("openai_protocol must be 'chat_completions' or 'responses'")
|
||||
|
||||
|
||||
def normalize_provider_extra_config(config: Optional[dict]) -> Optional[dict]:
|
||||
if config is None:
|
||||
return None
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("extra_config must be an object")
|
||||
normalized = dict(config)
|
||||
if "openai_protocol" in normalized:
|
||||
normalized["openai_protocol"] = normalize_openai_protocol(normalized.get("openai_protocol"))
|
||||
return normalized
|
||||
|
||||
|
||||
class AIProviderBase(BaseModel):
|
||||
name: str
|
||||
identifier: str = Field(..., pattern=r"^[a-z0-9_\-\.]+$")
|
||||
@@ -34,6 +75,11 @@ class AIProviderBase(BaseModel):
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
@field_validator("extra_config")
|
||||
@classmethod
|
||||
def normalize_extra_config(cls, value: Optional[dict]) -> Optional[dict]:
|
||||
return normalize_provider_extra_config(value)
|
||||
|
||||
|
||||
class AIProviderCreate(AIProviderBase):
|
||||
pass
|
||||
@@ -58,6 +104,11 @@ class AIProviderUpdate(BaseModel):
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
@field_validator("extra_config")
|
||||
@classmethod
|
||||
def normalize_extra_config(cls, value: Optional[dict]) -> Optional[dict]:
|
||||
return normalize_provider_extra_config(value)
|
||||
|
||||
|
||||
class AIModelBase(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import request from './client';
|
||||
|
||||
export type AIAbility = 'chat' | 'vision' | 'embedding' | 'rerank' | 'voice' | 'tools';
|
||||
export type OpenAIProtocol = 'chat_completions' | 'responses';
|
||||
|
||||
export interface AIProviderPayload {
|
||||
name: string;
|
||||
|
||||
@@ -516,6 +516,7 @@
|
||||
"Enter identifier": "Enter identifier",
|
||||
"Only lowercase letters, numbers, dash, dot and underscore are allowed": "Only lowercase letters, numbers, dash, dot and underscore are allowed",
|
||||
"API Format": "API Format",
|
||||
"OpenAI Protocol": "OpenAI Protocol",
|
||||
"Base URL": "Base URL",
|
||||
"Enter base url": "Enter base URL",
|
||||
"Optional, can also be provided per request": "Optional, can also be provided per request",
|
||||
|
||||
@@ -515,6 +515,7 @@
|
||||
"Enter identifier": "请输入标识符",
|
||||
"Only lowercase letters, numbers, dash, dot and underscore are allowed": "仅允许小写字母、数字、连字符、点和下划线",
|
||||
"API Format": "API 格式",
|
||||
"OpenAI Protocol": "OpenAI 协议",
|
||||
"Base URL": "基础 URL",
|
||||
"Enter base url": "请输入基础 URL",
|
||||
"Optional, can also be provided per request": "可选,也可在请求时提供",
|
||||
|
||||
@@ -48,6 +48,7 @@ import type {
|
||||
AIModelPayload,
|
||||
AIProvider,
|
||||
AIProviderPayload,
|
||||
OpenAIProtocol,
|
||||
} from '../../../api/aiProviders';
|
||||
import {
|
||||
createModel,
|
||||
@@ -90,6 +91,11 @@ interface ProviderTemplate {
|
||||
}
|
||||
|
||||
const abilityOrder: AIAbility[] = ['chat', 'vision', 'embedding', 'rerank', 'voice', 'tools'];
|
||||
const defaultOpenAIProtocol: OpenAIProtocol = 'chat_completions';
|
||||
|
||||
function normalizeOpenAIProtocol(value: unknown): OpenAIProtocol {
|
||||
return value === 'responses' ? 'responses' : defaultOpenAIProtocol;
|
||||
}
|
||||
|
||||
const abilityInfo: Record<AIAbility, { icon: ReactNode; label: string; color: string; description: string }> = {
|
||||
chat: {
|
||||
@@ -242,6 +248,7 @@ type AIProviderFormValues = {
|
||||
name?: string;
|
||||
identifier?: string;
|
||||
api_format: AIProviderPayload['api_format'];
|
||||
openai_protocol?: OpenAIProtocol;
|
||||
base_url?: string;
|
||||
api_key?: string;
|
||||
logo_url?: string;
|
||||
@@ -275,12 +282,19 @@ export default function AiSettingsTab() {
|
||||
const [addingRemoteModels, setAddingRemoteModels] = useState<boolean>(false);
|
||||
const [modelModalTab, setModelModalTab] = useState<'remote' | 'manual'>('remote');
|
||||
const [remoteSearchKeyword, setRemoteSearchKeyword] = useState<string>('');
|
||||
const providerApiFormat = Form.useWatch('api_format', providerForm);
|
||||
const capabilitiesValue = Form.useWatch('capabilities', modelForm);
|
||||
const showEmbeddingDimensions = useMemo(() => {
|
||||
const capabilities = Array.isArray(capabilitiesValue) ? capabilitiesValue : [];
|
||||
return capabilities.includes('embedding') || capabilities.includes('rerank');
|
||||
}, [capabilitiesValue]);
|
||||
|
||||
useEffect(() => {
|
||||
if (providerApiFormat === 'openai' && !providerForm.getFieldValue('openai_protocol')) {
|
||||
providerForm.setFieldValue('openai_protocol', defaultOpenAIProtocol);
|
||||
}
|
||||
}, [providerApiFormat, providerForm]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!showEmbeddingDimensions) {
|
||||
modelForm.setFieldsValue({ embedding_dimensions: null });
|
||||
@@ -338,6 +352,7 @@ export default function AiSettingsTab() {
|
||||
name: existing.name,
|
||||
identifier: existing.identifier,
|
||||
api_format: existing.api_format,
|
||||
openai_protocol: normalizeOpenAIProtocol(existing.extra_config?.openai_protocol),
|
||||
base_url: existing.base_url ?? undefined,
|
||||
api_key: '',
|
||||
logo_url: existing.logo_url ?? undefined,
|
||||
@@ -345,7 +360,7 @@ export default function AiSettingsTab() {
|
||||
});
|
||||
} else {
|
||||
providerForm.resetFields();
|
||||
providerForm.setFieldsValue({ api_format: 'openai' });
|
||||
providerForm.setFieldsValue({ api_format: 'openai', openai_protocol: defaultOpenAIProtocol });
|
||||
setSelectedTemplate(null);
|
||||
setProviderModal({ open: true, step: 1 });
|
||||
}
|
||||
@@ -364,6 +379,7 @@ export default function AiSettingsTab() {
|
||||
name: t(template.nameKey),
|
||||
identifier: template.identifier,
|
||||
api_format: template.api_format,
|
||||
openai_protocol: template.api_format === 'openai' ? defaultOpenAIProtocol : undefined,
|
||||
base_url: template.base_url ?? '',
|
||||
api_key: '',
|
||||
logo_url: template.logo_url ?? '',
|
||||
@@ -375,7 +391,7 @@ export default function AiSettingsTab() {
|
||||
setProviderModal((prev) => ({ ...prev, step: 1, editing: undefined }));
|
||||
setSelectedTemplate(null);
|
||||
providerForm.resetFields();
|
||||
providerForm.setFieldsValue({ api_format: 'openai' });
|
||||
providerForm.setFieldsValue({ api_format: 'openai', openai_protocol: defaultOpenAIProtocol });
|
||||
};
|
||||
|
||||
const handleSubmitProvider = async () => {
|
||||
@@ -384,6 +400,12 @@ export default function AiSettingsTab() {
|
||||
const trimmedApiKey = values.api_key?.trim();
|
||||
const trimmedLogoUrl = values.logo_url?.trim();
|
||||
const trimmedProviderType = values.provider_type?.trim();
|
||||
const extraConfig = { ...(providerModal.editing?.extra_config ?? {}) };
|
||||
if (values.api_format === 'openai') {
|
||||
extraConfig.openai_protocol = normalizeOpenAIProtocol(values.openai_protocol);
|
||||
} else {
|
||||
delete extraConfig.openai_protocol;
|
||||
}
|
||||
const payload: AIProviderPayload = {
|
||||
name: (values.name || '').trim(),
|
||||
identifier: (values.identifier || '').trim(),
|
||||
@@ -391,6 +413,7 @@ export default function AiSettingsTab() {
|
||||
base_url: trimmedBaseUrl ? trimmedBaseUrl : null,
|
||||
logo_url: trimmedLogoUrl ? trimmedLogoUrl : null,
|
||||
provider_type: trimmedProviderType ? trimmedProviderType : null,
|
||||
extra_config: Object.keys(extraConfig).length ? extraConfig : null,
|
||||
};
|
||||
if (trimmedApiKey) {
|
||||
payload.api_key = trimmedApiKey;
|
||||
@@ -1117,16 +1140,30 @@ export default function AiSettingsTab() {
|
||||
label={t('API Format')}
|
||||
rules={[{ required: true }]}
|
||||
>
|
||||
<Select
|
||||
disabled={!allowFormatChange}
|
||||
options={[
|
||||
{ value: 'openai', label: 'OpenAI Compatible' },
|
||||
{ value: 'gemini', label: 'Gemini Compatible' },
|
||||
{ value: 'anthropic', label: 'Anthropic Native' },
|
||||
{ value: 'ollama', label: 'Ollama Native' },
|
||||
]}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Select
|
||||
disabled={!allowFormatChange}
|
||||
options={[
|
||||
{ value: 'openai', label: 'OpenAI Compatible' },
|
||||
{ value: 'gemini', label: 'Gemini Compatible' },
|
||||
{ value: 'anthropic', label: 'Anthropic Native' },
|
||||
{ value: 'ollama', label: 'Ollama Native' },
|
||||
]}
|
||||
/>
|
||||
</Form.Item>
|
||||
{providerApiFormat === 'openai' ? (
|
||||
<Form.Item
|
||||
name="openai_protocol"
|
||||
label={t('OpenAI Protocol')}
|
||||
rules={[{ required: true }]}
|
||||
>
|
||||
<Select
|
||||
options={[
|
||||
{ value: 'chat_completions', label: 'Chat Completions' },
|
||||
{ value: 'responses', label: 'Responses' },
|
||||
]}
|
||||
/>
|
||||
</Form.Item>
|
||||
) : null}
|
||||
<Form.Item name="base_url" label={t('Base URL')} rules={[{ required: true, message: t('Enter base url') }]}>
|
||||
<Input placeholder="https://" />
|
||||
</Form.Item>
|
||||
|
||||
Reference in New Issue
Block a user