Files
MyGoNavi/internal/ai/provider/openai.go
Syngnat a5fdfefa2d 🐛 fix(ai/volcengine): 修复火山引擎兼容路径并拆分双预设
- OpenAI 兼容 URL 归一化改为保留已有 v3 和 v4 版本段,避免火山与智谱地址被错误补 /v1
- 对误填 /chat/completions 和 /models 的地址先回退到 base URL,再拼接目标端点
- 模型列表与连通性检测复用统一端点解析逻辑,修复火山 Coding Plan 等兼容服务请求
- AI 设置页拆分火山方舟与火山 Coding Plan 两个预设,并按完整路径精确匹配回显
- 修正模型下拉默认值行为,未选模型时保持占位态,避免误用动态列表首项
- 补充 provider 与 service 回归测试,并新增需求追踪文档
2026-03-27 12:04:55 +08:00

425 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"GoNavi-Wails/internal/ai"
)
const (
defaultOpenAIBaseURL = "https://api.openai.com/v1"
defaultOpenAIMaxTokens = 4096
defaultOpenAITemperature = 0.7
openAIHTTPTimeout = 120 * time.Second
)
// OpenAIProvider 实现 OpenAI / OpenAI 兼容 API 的 Provider
type OpenAIProvider struct {
config ai.ProviderConfig
baseURL string
client *http.Client
}
// NewOpenAIProvider 创建 OpenAI Provider 实例
func NewOpenAIProvider(config ai.ProviderConfig) (Provider, error) {
baseURL := NormalizeOpenAICompatibleBaseURL(config.BaseURL)
model := strings.TrimSpace(config.Model)
if model == "" {
return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型")
}
maxTokens := config.MaxTokens
if maxTokens <= 0 {
maxTokens = defaultOpenAIMaxTokens
}
temperature := config.Temperature
if temperature <= 0 {
temperature = defaultOpenAITemperature
}
normalized := config
normalized.BaseURL = baseURL
normalized.Model = model
normalized.MaxTokens = maxTokens
normalized.Temperature = temperature
return &OpenAIProvider{
config: normalized,
baseURL: baseURL,
client: &http.Client{
Timeout: openAIHTTPTimeout,
},
}, nil
}
func (p *OpenAIProvider) Name() string {
if strings.TrimSpace(p.config.Name) != "" {
return p.config.Name
}
return "OpenAI"
}
func (p *OpenAIProvider) Validate() error {
if strings.TrimSpace(p.config.APIKey) == "" {
return fmt.Errorf("API Key 不能为空")
}
return nil
}
// openAIChatRequest OpenAI API 请求体
type openAIChatRequest struct {
Model string `json:"model"`
Messages []openAIChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []ai.Tool `json:"tools,omitempty"`
}
type openAIChatMessage struct {
Role string `json:"role"`
Content interface{} `json:"content,omitempty"`
ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
func buildOpenAIMessages(reqMessages []ai.Message, modelName string, baseURL string) []openAIChatMessage {
messages := make([]openAIChatMessage, len(reqMessages))
for i, m := range reqMessages {
if m.Role == "tool" {
messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content, ToolCallID: m.ToolCallID}
continue
}
if len(m.ToolCalls) > 0 {
messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content, ToolCalls: m.ToolCalls}
continue
}
if len(m.Images) > 0 {
var contentParts []map[string]interface{}
text := m.Content
if text == "" {
text = "请描述和分析这张图片。" // 兼容部分模型(如 ZhipuAI/GLM-4V强制要求图片必须伴随有效文本块同时防止强 System Prompt 下模型当成空消息处理
}
contentParts = append(contentParts, map[string]interface{}{
"type": "text",
"text": text,
})
for _, img := range m.Images {
imgURL := img
// 仅当直接请求智谱官方 API 域名时(它原生不接受 data 协议前缀),才截取裸 Base64
if strings.Contains(strings.ToLower(baseURL), "bigmodel") {
if _, raw, err := ParseDataURI(img); err == nil {
imgURL = raw
}
}
contentParts = append(contentParts, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": imgURL,
},
})
}
messages[i] = openAIChatMessage{Role: m.Role, Content: contentParts}
} else {
messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
}
}
return messages
}
// openAIChatResponse OpenAI API 响应体
type openAIChatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
// openAIStreamChunk SSE 流式响应片段
type openAIToolCallDelta struct {
Index int `json:"index"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function *struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
} `json:"function,omitempty"`
}
type openAIStreamChunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
ToolCalls []openAIToolCallDelta `json:"tool_calls,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func (p *OpenAIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
if err := p.Validate(); err != nil {
return nil, err
}
messages := buildOpenAIMessages(req.Messages, p.config.Model, p.baseURL)
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
body := openAIChatRequest{
Model: p.config.Model,
Messages: messages,
Temperature: temperature,
Stream: false,
Tools: req.Tools,
}
respBody, err := p.doRequest(ctx, body)
if err != nil {
// 当带 tools 的请求返回 400 时,自动降级为不带 tools 的纯文本请求
if len(req.Tools) > 0 && isHTTP400Error(err) {
fmt.Println("[OpenAI] 模型不支持 Function Calling自动降级为纯文本模式")
body.Tools = nil
respBody, err = p.doRequest(ctx, body)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
defer respBody.Close()
var result openAIChatResponse
if err := json.NewDecoder(respBody).Decode(&result); err != nil {
return nil, fmt.Errorf("解析 OpenAI 响应失败: %w", err)
}
if result.Error != nil && result.Error.Message != "" {
return nil, fmt.Errorf("OpenAI API 错误: %s", result.Error.Message)
}
if len(result.Choices) == 0 {
return nil, fmt.Errorf("OpenAI 返回空响应")
}
return &ai.ChatResponse{
Content: result.Choices[0].Message.Content,
TokensUsed: ai.TokenUsage{
PromptTokens: result.Usage.PromptTokens,
CompletionTokens: result.Usage.CompletionTokens,
TotalTokens: result.Usage.TotalTokens,
},
ToolCalls: result.Choices[0].Message.ToolCalls,
}, nil
}
func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
if err := p.Validate(); err != nil {
return err
}
messages := buildOpenAIMessages(req.Messages, p.config.Model, p.baseURL)
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
body := openAIChatRequest{
Model: p.config.Model,
Messages: messages,
Temperature: temperature,
Stream: true,
Tools: req.Tools,
}
respBody, err := p.doRequest(ctx, body)
if err != nil {
// 当带 tools 的请求返回 400 时,自动降级为不带 tools 的纯文本请求
if len(req.Tools) > 0 && isHTTP400Error(err) {
fmt.Println("[OpenAI] 模型不支持 Function Calling自动降级为纯文本模式")
body.Tools = nil
respBody, err = p.doRequest(ctx, body)
if err != nil {
return err
}
} else {
return err
}
}
defer respBody.Close()
receivedContent := false
var activeToolCalls []ai.ToolCall
scanner := bufio.NewScanner(respBody)
// 增大 scanner buffer防止长行被截断
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
if !strings.HasPrefix(line, "data: ") {
// 非 SSE 数据行,可能是错误信息,记录日志
if strings.Contains(line, "error") || strings.Contains(line, "Error") {
callback(ai.StreamChunk{Error: fmt.Sprintf("服务端返回异常: %s", line), Done: true})
return nil
}
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
callback(ai.StreamChunk{Done: true})
return nil
}
var chunk openAIStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue // 跳过格式异常的行
}
if chunk.Error != nil && chunk.Error.Message != "" {
callback(ai.StreamChunk{Error: fmt.Sprintf("API 错误: %s", chunk.Error.Message), Done: true})
return nil
}
if len(chunk.Choices) > 0 {
choice := chunk.Choices[0]
// Handle ToolCalls delta
if len(choice.Delta.ToolCalls) > 0 {
receivedContent = true
for _, tcDelta := range choice.Delta.ToolCalls {
// Expand activeToolCalls slice if index is larger
for len(activeToolCalls) <= tcDelta.Index {
activeToolCalls = append(activeToolCalls, ai.ToolCall{Type: "function"})
}
if tcDelta.ID != "" {
activeToolCalls[tcDelta.Index].ID = tcDelta.ID
}
if tcDelta.Function != nil {
if tcDelta.Function.Name != "" {
activeToolCalls[tcDelta.Index].Function.Name += tcDelta.Function.Name
}
if tcDelta.Function.Arguments != "" {
activeToolCalls[tcDelta.Index].Function.Arguments += tcDelta.Function.Arguments
}
}
}
// 实时推送目前已解析的 ToolCalls 状态
callback(ai.StreamChunk{ToolCalls: activeToolCalls})
}
content := choice.Delta.Content
if content != "" {
receivedContent = true
callback(ai.StreamChunk{Content: content})
}
// 支持 DeepSeek/千问等模型的 reasoning_content 字段
if choice.Delta.ReasoningContent != "" {
receivedContent = true
callback(ai.StreamChunk{Thinking: choice.Delta.ReasoningContent})
}
if choice.FinishReason != nil {
if *choice.FinishReason == "tool_calls" {
callback(ai.StreamChunk{ToolCalls: activeToolCalls, Done: true})
return nil
}
callback(ai.StreamChunk{Done: true})
return nil
}
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取 OpenAI 流式响应失败: %w", err)
}
// 如果流正常结束但没有收到任何内容,可能是 API 响应格式不兼容
if !receivedContent {
callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 API 端点和模型是否正确", Done: true})
return nil
}
callback(ai.StreamChunk{Done: true})
return nil
}
func (p *OpenAIProvider) doRequest(ctx context.Context, body interface{}) (io.ReadCloser, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := ResolveOpenAICompatibleEndpoint(p.baseURL, "chat/completions")
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
// 仅在流式请求时明确声明 SSE防止代理缓冲
if strings.Contains(string(jsonBody), `"stream":true`) || strings.Contains(string(jsonBody), `"stream": true`) {
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
httpReq.Header.Set("Connection", "keep-alive")
}
// 自定义 headers用于兼容各类 OpenAI 兼容服务)
for k, v := range p.config.Headers {
httpReq.Header.Set(k, v)
}
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求到 %s 失败: %w", url, err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("OpenAI API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
return resp.Body, nil
}
// isHTTP400Error 检查错误是否为 HTTP 4xx 客户端错误400/422 等),
// 通常表示模型不支持请求中的某些参数(如 tools/functions
func isHTTP400Error(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "(HTTP 400)") ||
strings.Contains(msg, "(HTTP 422)") ||
strings.Contains(msg, "(HTTP 404)")
}