diff --git a/cmd/root.go b/cmd/root.go index 4f4f334..99957a3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/krau/SaveAny-Bot/config" "github.com/spf13/cobra" ) @@ -13,6 +14,10 @@ var rootCmd = &cobra.Command{ Run: Run, } +func init() { + config.RegisterFlags(rootCmd) +} + func Execute(ctx context.Context) { if err := rootCmd.ExecuteContext(ctx); err != nil { fmt.Println(err) diff --git a/cmd/run.go b/cmd/run.go index cf99073..3852610 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -34,7 +34,7 @@ func Run(cmd *cobra.Command, _ []string) { }) ctx = log.WithContext(ctx, logger) - exitChan, err := initAll(ctx) + exitChan, err := initAll(ctx, cmd) if err != nil { logger.Fatal("Init failed", "error", err) } @@ -51,8 +51,9 @@ func Run(cmd *cobra.Command, _ []string) { cleanCache() } -func initAll(ctx context.Context) (<-chan struct{}, error) { - if err := config.Init(ctx); err != nil { +func initAll(ctx context.Context, cmd *cobra.Command) (<-chan struct{}, error) { + configFile := config.GetConfigFile(cmd) + if err := config.Init(ctx, configFile); err != nil { return nil, fmt.Errorf("failed to load config: %w", err) } cache.Init() diff --git a/config/flags.go b/config/flags.go new file mode 100644 index 0000000..3320649 --- /dev/null +++ b/config/flags.go @@ -0,0 +1,83 @@ +package config + +import ( + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func RegisterFlags(cmd *cobra.Command) { + flags := cmd.Flags() + + // 基础配置 + flags.StringP("config", "c", "", "config file path") + flags.StringP("lang", "l", "", "language (e.g., zh-Hans, en)") + flags.IntP("workers", "w", 0, "number of workers") + flags.Int("retry", 0, "retry times") + flags.Int("threads", 0, "number of threads") + flags.Bool("stream", false, "enable stream mode") + flags.Bool("no-clean-cache", false, "do not clean cache on exit") + flags.String("proxy", "", "proxy URL (http, https, socks5, socks5h)") + + // Telegram 配置 + flags.String("telegram-token", "", "telegram bot token") + flags.Int("telegram-app-id", 0, "telegram app id") + flags.String("telegram-app-hash", "", "telegram app hash") + flags.Int("telegram-rpc-retry", 0, "telegram rpc retry times") + flags.Bool("telegram-userbot-enable", false, "enable userbot") + flags.String("telegram-userbot-session", "", "userbot session path") + flags.Bool("telegram-proxy-enable", false, "enable telegram proxy") + flags.String("telegram-proxy-url", "", "telegram proxy URL") + + // 数据库配置 + flags.String("db-path", "", "database path") + flags.String("db-session", "", "session database path") + + // 临时目录配置 + flags.String("temp-base-path", "", "temp directory base path") + + // Parser 配置 + flags.Bool("parser-plugin-enable", false, "enable parser plugins") + flags.StringSlice("parser-plugin-dirs", nil, "parser plugin directories") + flags.String("parser-proxy", "", "parser proxy URL") + + // 绑定到 viper + bindFlags(cmd) +} + +func bindFlags(cmd *cobra.Command) { + flags := cmd.Flags() + + viper.BindPFlag("lang", flags.Lookup("lang")) + viper.BindPFlag("workers", flags.Lookup("workers")) + viper.BindPFlag("retry", flags.Lookup("retry")) + viper.BindPFlag("threads", flags.Lookup("threads")) + viper.BindPFlag("stream", flags.Lookup("stream")) + viper.BindPFlag("no_clean_cache", flags.Lookup("no-clean-cache")) + viper.BindPFlag("proxy", flags.Lookup("proxy")) + + // Telegram + viper.BindPFlag("telegram.token", flags.Lookup("telegram-token")) + viper.BindPFlag("telegram.app_id", flags.Lookup("telegram-app-id")) + viper.BindPFlag("telegram.app_hash", flags.Lookup("telegram-app-hash")) + viper.BindPFlag("telegram.rpc_retry", flags.Lookup("telegram-rpc-retry")) + viper.BindPFlag("telegram.userbot.enable", flags.Lookup("telegram-userbot-enable")) + viper.BindPFlag("telegram.userbot.session", flags.Lookup("telegram-userbot-session")) + viper.BindPFlag("telegram.proxy.enable", flags.Lookup("telegram-proxy-enable")) + viper.BindPFlag("telegram.proxy.url", flags.Lookup("telegram-proxy-url")) + + // database + viper.BindPFlag("db.path", flags.Lookup("db-path")) + viper.BindPFlag("db.session", flags.Lookup("db-session")) + // 临时目录 + viper.BindPFlag("temp.base_path", flags.Lookup("temp-base-path")) + + // Parser + viper.BindPFlag("parser.plugin_enable", flags.Lookup("parser-plugin-enable")) + viper.BindPFlag("parser.plugin_dirs", flags.Lookup("parser-plugin-dirs")) + viper.BindPFlag("parser.proxy", flags.Lookup("parser-proxy")) +} + +func GetConfigFile(cmd *cobra.Command) string { + configFile, _ := cmd.Flags().GetString("config") + return configFile +} diff --git a/config/viper.go b/config/viper.go index f66cebb..1b3fc89 100644 --- a/config/viper.go +++ b/config/viper.go @@ -52,16 +52,39 @@ func (c Config) GetStorageByName(name string) storage.StorageConfig { return nil } -func Init(ctx context.Context) error { - viper.SetConfigName("config") - viper.AddConfigPath(".") - viper.AddConfigPath("/etc/saveany/") +func Init(ctx context.Context, configFile ...string) error { viper.SetConfigType("toml") viper.SetEnvPrefix("SAVEANY") viper.AutomaticEnv() replacer := strings.NewReplacer(".", "_") viper.SetEnvKeyReplacer(replacer) + // 如果指定了配置文件路径,则使用指定的配置文件 + // 配置文件支持传入一个 http(s) URL 地址 + if len(configFile) > 0 && configFile[0] != "" { + cfg := configFile[0] + if strings.HasPrefix(cfg, "http://") || strings.HasPrefix(cfg, "https://") { + // 使用远程配置文件 + resp, err := http.Get(cfg) + if err != nil { + return fmt.Errorf("failed to fetch remote config file: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to fetch remote config file: status code %d", resp.StatusCode) + } + if err := viper.ReadConfig(resp.Body); err != nil { + return fmt.Errorf("failed to read remote config file: %w", err) + } + } else { + viper.SetConfigFile(cfg) + } + } else { + viper.SetConfigName("config") + viper.AddConfigPath(".") + viper.AddConfigPath("/etc/saveany/") + } + defaultConfigs := map[string]any{ // 基础配置 "lang": "zh-Hans", @@ -125,13 +148,6 @@ func Init(ctx context.Context) error { storageNames[storage.GetName()] = struct{}{} } - fmt.Println(i18n.TWithoutInit(cfg.Lang, i18nk.ConfigLoadedStorages, map[string]any{ - "Count": len(cfg.Storages), - })) - for _, storage := range cfg.Storages { - fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType()) - } - if cfg.Workers < 1 { cfg.Workers = 1 }