package ai import ( "encoding/json" "fmt" "net/http" "os" "path/filepath" "strconv" "sync" "time" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/httprunner/httprunner/v5/code" "github.com/joho/godotenv" "github.com/pkg/errors" "github.com/rs/zerolog/log" ) const ( defaultTimeout = 60 * time.Second ) type OpenAIInitConfig struct { ReportURL string `json:"REPORT_SERVER_URL"` Headers map[string]string `json:"defaultHeaders"` } const ( EnvOpenAIBaseURL = "OPENAI_BASE_URL" EnvOpenAIAPIKey = "OPENAI_API_KEY" EnvModelName = "LLM_MODEL_NAME" EnvOpenAIInitConfigJSON = "OPENAI_INIT_CONFIG_JSON" ) var once sync.Once // loadEnv loads environment variables from .env file // it will search for .env file from current working directory upward recursively func loadEnv() { once.Do(func() { // get current working directory cwd, err := os.Getwd() if err != nil { panic(err) } // locate .env file from current working directory upward recursively envPath := cwd for { envFile := filepath.Join(envPath, ".env") if _, err := os.Stat(envFile); err == nil { // found .env file // override existing env variables err = godotenv.Overload(envFile) if err != nil { log.Fatal().Err(err). Str("path", envFile).Msg("overload env file failed") } log.Info().Str("path", envFile).Msg("overload env success") return } // reached root directory parent := filepath.Dir(envPath) if parent == envPath { log.Info().Msg("no .env file found from current directory to root") return } envPath = parent } }) } func checkEnvLLM() error { loadEnv() openaiBaseURL := os.Getenv("OPENAI_BASE_URL") if openaiBaseURL == "" { return errors.Wrap(code.LLMEnvMissedError, "OPENAI_BASE_URL missed") } log.Info().Str("OPENAI_BASE_URL", openaiBaseURL).Msg("get env") openaiAPIKey := os.Getenv("OPENAI_API_KEY") if openaiAPIKey == "" { return errors.Wrap(code.LLMEnvMissedError, "OPENAI_API_KEY missed") } log.Info().Str("OPENAI_API_KEY", maskAPIKey(openaiAPIKey)).Msg("get env") modelName := os.Getenv("LLM_MODEL_NAME") if modelName == "" { return errors.Wrap(code.LLMEnvMissedError, "LLM_MODEL_NAME missed") } log.Info().Str("LLM_MODEL_NAME", modelName).Msg("get env") return nil } func GetEnvConfig(key string) string { return os.Getenv(key) } func GetEnvConfigInJSON(key string) (map[string]interface{}, error) { value := GetEnvConfig(key) if value == "" { return nil, nil } var result map[string]interface{} if err := json.Unmarshal([]byte(value), &result); err != nil { return nil, err } return result, nil } func GetEnvConfigInBool(key string) bool { value := GetEnvConfig(key) if value == "" { return false } boolValue, _ := strconv.ParseBool(value) return boolValue } // GetEnvConfigOrDefault get env config or default value func GetEnvConfigOrDefault(key, defaultValue string) string { value := GetEnvConfig(key) if value == "" { return defaultValue } return value } func GetEnvConfigInInt(key string, defaultValue int) int { value := GetEnvConfig(key) if value == "" { return defaultValue } intValue, err := strconv.Atoi(value) if err != nil { return defaultValue } return intValue } // CustomTransport is a custom RoundTripper that adds headers to every request type CustomTransport struct { Transport http.RoundTripper Headers map[string]string } // RoundTrip executes a single HTTP transaction and adds custom headers func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { for key, value := range c.Headers { req.Header.Set(key, value) } return c.Transport.RoundTrip(req) } type OutputFormat struct { Thought string `json:"thought"` Action string `json:"action"` Error string `json:"error,omitempty"` } // GetModelConfig get OpenAI config func GetModelConfig() (*openai.ChatModelConfig, error) { loadEnv() envConfig := &OpenAIInitConfig{ Headers: make(map[string]string), } // read from JSON config first jsonStr := GetEnvConfig(EnvOpenAIInitConfigJSON) if jsonStr != "" { if err := json.Unmarshal([]byte(jsonStr), envConfig); err != nil { return nil, err } } // outputFormatSchema, err := openapi3gen.NewSchemaRefForValue(&OutputFormat{}, nil) // if err != nil { // log.Fatal().Err(err).Msg("NewSchemaRefForValue failed") // } config := &openai.ChatModelConfig{ HTTPClient: &http.Client{ Timeout: defaultTimeout, Transport: &CustomTransport{ Transport: http.DefaultTransport, Headers: envConfig.Headers, }, }, // TODO: set structured response format // https://github.com/cloudwego/eino-ext/blob/main/components/model/openai/examples/structured/structured.go // ResponseFormat: &openai2.ChatCompletionResponseFormat{ // Type: openai2.ChatCompletionResponseFormatTypeJSONSchema, // JSONSchema: &openai2.ChatCompletionResponseFormatJSONSchema{ // Name: "thought_and_action", // Description: "data that describes planning thought and action", // Schema: outputFormatSchema.Value, // Strict: false, // }, // }, } if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" { config.BaseURL = baseURL } else { return nil, fmt.Errorf("miss env %s", EnvOpenAIBaseURL) } if apiKey := GetEnvConfig(EnvOpenAIAPIKey); apiKey != "" { config.APIKey = apiKey } else { return nil, fmt.Errorf("miss env %s", EnvOpenAIAPIKey) } if modelName := GetEnvConfig(EnvModelName); modelName != "" { config.Model = modelName } else { return nil, fmt.Errorf("miss env %s", EnvModelName) } // log config info log.Info().Str("model", config.Model). Str("baseURL", config.BaseURL). Str("apiKey", maskAPIKey(config.APIKey)). Str("timeout", defaultTimeout.String()). Msg("get model config") return config, nil } // maskAPIKey masks the API key func maskAPIKey(key string) string { if len(key) <= 8 { return "******" } return key[:4] + "******" + key[len(key)-4:] }