feat: add configuration flags and enhance config initialization

This commit is contained in:
krau
2025-12-19 13:49:13 +08:00
parent d3cc56c8e6
commit df64ec3069
4 changed files with 119 additions and 14 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/krau/SaveAny-Bot/config"
"github.com/spf13/cobra"
)
@@ -13,6 +14,10 @@ var rootCmd = &cobra.Command{
Run: Run,
}
func init() {
config.RegisterFlags(rootCmd)
}
func Execute(ctx context.Context) {
if err := rootCmd.ExecuteContext(ctx); err != nil {
fmt.Println(err)

View File

@@ -34,7 +34,7 @@ func Run(cmd *cobra.Command, _ []string) {
})
ctx = log.WithContext(ctx, logger)
exitChan, err := initAll(ctx)
exitChan, err := initAll(ctx, cmd)
if err != nil {
logger.Fatal("Init failed", "error", err)
}
@@ -51,8 +51,9 @@ func Run(cmd *cobra.Command, _ []string) {
cleanCache()
}
func initAll(ctx context.Context) (<-chan struct{}, error) {
if err := config.Init(ctx); err != nil {
func initAll(ctx context.Context, cmd *cobra.Command) (<-chan struct{}, error) {
configFile := config.GetConfigFile(cmd)
if err := config.Init(ctx, configFile); err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
cache.Init()

83
config/flags.go Normal file
View File

@@ -0,0 +1,83 @@
package config
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
func RegisterFlags(cmd *cobra.Command) {
flags := cmd.Flags()
// 基础配置
flags.StringP("config", "c", "", "config file path")
flags.StringP("lang", "l", "", "language (e.g., zh-Hans, en)")
flags.IntP("workers", "w", 0, "number of workers")
flags.Int("retry", 0, "retry times")
flags.Int("threads", 0, "number of threads")
flags.Bool("stream", false, "enable stream mode")
flags.Bool("no-clean-cache", false, "do not clean cache on exit")
flags.String("proxy", "", "proxy URL (http, https, socks5, socks5h)")
// Telegram 配置
flags.String("telegram-token", "", "telegram bot token")
flags.Int("telegram-app-id", 0, "telegram app id")
flags.String("telegram-app-hash", "", "telegram app hash")
flags.Int("telegram-rpc-retry", 0, "telegram rpc retry times")
flags.Bool("telegram-userbot-enable", false, "enable userbot")
flags.String("telegram-userbot-session", "", "userbot session path")
flags.Bool("telegram-proxy-enable", false, "enable telegram proxy")
flags.String("telegram-proxy-url", "", "telegram proxy URL")
// 数据库配置
flags.String("db-path", "", "database path")
flags.String("db-session", "", "session database path")
// 临时目录配置
flags.String("temp-base-path", "", "temp directory base path")
// Parser 配置
flags.Bool("parser-plugin-enable", false, "enable parser plugins")
flags.StringSlice("parser-plugin-dirs", nil, "parser plugin directories")
flags.String("parser-proxy", "", "parser proxy URL")
// 绑定到 viper
bindFlags(cmd)
}
func bindFlags(cmd *cobra.Command) {
flags := cmd.Flags()
viper.BindPFlag("lang", flags.Lookup("lang"))
viper.BindPFlag("workers", flags.Lookup("workers"))
viper.BindPFlag("retry", flags.Lookup("retry"))
viper.BindPFlag("threads", flags.Lookup("threads"))
viper.BindPFlag("stream", flags.Lookup("stream"))
viper.BindPFlag("no_clean_cache", flags.Lookup("no-clean-cache"))
viper.BindPFlag("proxy", flags.Lookup("proxy"))
// Telegram
viper.BindPFlag("telegram.token", flags.Lookup("telegram-token"))
viper.BindPFlag("telegram.app_id", flags.Lookup("telegram-app-id"))
viper.BindPFlag("telegram.app_hash", flags.Lookup("telegram-app-hash"))
viper.BindPFlag("telegram.rpc_retry", flags.Lookup("telegram-rpc-retry"))
viper.BindPFlag("telegram.userbot.enable", flags.Lookup("telegram-userbot-enable"))
viper.BindPFlag("telegram.userbot.session", flags.Lookup("telegram-userbot-session"))
viper.BindPFlag("telegram.proxy.enable", flags.Lookup("telegram-proxy-enable"))
viper.BindPFlag("telegram.proxy.url", flags.Lookup("telegram-proxy-url"))
// database
viper.BindPFlag("db.path", flags.Lookup("db-path"))
viper.BindPFlag("db.session", flags.Lookup("db-session"))
// 临时目录
viper.BindPFlag("temp.base_path", flags.Lookup("temp-base-path"))
// Parser
viper.BindPFlag("parser.plugin_enable", flags.Lookup("parser-plugin-enable"))
viper.BindPFlag("parser.plugin_dirs", flags.Lookup("parser-plugin-dirs"))
viper.BindPFlag("parser.proxy", flags.Lookup("parser-proxy"))
}
func GetConfigFile(cmd *cobra.Command) string {
configFile, _ := cmd.Flags().GetString("config")
return configFile
}

View File

@@ -52,16 +52,39 @@ func (c Config) GetStorageByName(name string) storage.StorageConfig {
return nil
}
func Init(ctx context.Context) error {
viper.SetConfigName("config")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/saveany/")
func Init(ctx context.Context, configFile ...string) error {
viper.SetConfigType("toml")
viper.SetEnvPrefix("SAVEANY")
viper.AutomaticEnv()
replacer := strings.NewReplacer(".", "_")
viper.SetEnvKeyReplacer(replacer)
// 如果指定了配置文件路径,则使用指定的配置文件
// 配置文件支持传入一个 http(s) URL 地址
if len(configFile) > 0 && configFile[0] != "" {
cfg := configFile[0]
if strings.HasPrefix(cfg, "http://") || strings.HasPrefix(cfg, "https://") {
// 使用远程配置文件
resp, err := http.Get(cfg)
if err != nil {
return fmt.Errorf("failed to fetch remote config file: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to fetch remote config file: status code %d", resp.StatusCode)
}
if err := viper.ReadConfig(resp.Body); err != nil {
return fmt.Errorf("failed to read remote config file: %w", err)
}
} else {
viper.SetConfigFile(cfg)
}
} else {
viper.SetConfigName("config")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/saveany/")
}
defaultConfigs := map[string]any{
// 基础配置
"lang": "zh-Hans",
@@ -125,13 +148,6 @@ func Init(ctx context.Context) error {
storageNames[storage.GetName()] = struct{}{}
}
fmt.Println(i18n.TWithoutInit(cfg.Lang, i18nk.ConfigLoadedStorages, map[string]any{
"Count": len(cfg.Storages),
}))
for _, storage := range cfg.Storages {
fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType())
}
if cfg.Workers < 1 {
cfg.Workers = 1
}