mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-07 00:19:58 +08:00
feat: add configuration flags and enhance config initialization
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,6 +14,10 @@ var rootCmd = &cobra.Command{
|
|||||||
Run: Run,
|
Run: Run,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
config.RegisterFlags(rootCmd)
|
||||||
|
}
|
||||||
|
|
||||||
func Execute(ctx context.Context) {
|
func Execute(ctx context.Context) {
|
||||||
if err := rootCmd.ExecuteContext(ctx); err != nil {
|
if err := rootCmd.ExecuteContext(ctx); err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func Run(cmd *cobra.Command, _ []string) {
|
|||||||
})
|
})
|
||||||
ctx = log.WithContext(ctx, logger)
|
ctx = log.WithContext(ctx, logger)
|
||||||
|
|
||||||
exitChan, err := initAll(ctx)
|
exitChan, err := initAll(ctx, cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Init failed", "error", err)
|
logger.Fatal("Init failed", "error", err)
|
||||||
}
|
}
|
||||||
@@ -51,8 +51,9 @@ func Run(cmd *cobra.Command, _ []string) {
|
|||||||
cleanCache()
|
cleanCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
func initAll(ctx context.Context) (<-chan struct{}, error) {
|
func initAll(ctx context.Context, cmd *cobra.Command) (<-chan struct{}, error) {
|
||||||
if err := config.Init(ctx); err != nil {
|
configFile := config.GetConfigFile(cmd)
|
||||||
|
if err := config.Init(ctx, configFile); err != nil {
|
||||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
return nil, fmt.Errorf("failed to load config: %w", err)
|
||||||
}
|
}
|
||||||
cache.Init()
|
cache.Init()
|
||||||
|
|||||||
83
config/flags.go
Normal file
83
config/flags.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisterFlags(cmd *cobra.Command) {
|
||||||
|
flags := cmd.Flags()
|
||||||
|
|
||||||
|
// 基础配置
|
||||||
|
flags.StringP("config", "c", "", "config file path")
|
||||||
|
flags.StringP("lang", "l", "", "language (e.g., zh-Hans, en)")
|
||||||
|
flags.IntP("workers", "w", 0, "number of workers")
|
||||||
|
flags.Int("retry", 0, "retry times")
|
||||||
|
flags.Int("threads", 0, "number of threads")
|
||||||
|
flags.Bool("stream", false, "enable stream mode")
|
||||||
|
flags.Bool("no-clean-cache", false, "do not clean cache on exit")
|
||||||
|
flags.String("proxy", "", "proxy URL (http, https, socks5, socks5h)")
|
||||||
|
|
||||||
|
// Telegram 配置
|
||||||
|
flags.String("telegram-token", "", "telegram bot token")
|
||||||
|
flags.Int("telegram-app-id", 0, "telegram app id")
|
||||||
|
flags.String("telegram-app-hash", "", "telegram app hash")
|
||||||
|
flags.Int("telegram-rpc-retry", 0, "telegram rpc retry times")
|
||||||
|
flags.Bool("telegram-userbot-enable", false, "enable userbot")
|
||||||
|
flags.String("telegram-userbot-session", "", "userbot session path")
|
||||||
|
flags.Bool("telegram-proxy-enable", false, "enable telegram proxy")
|
||||||
|
flags.String("telegram-proxy-url", "", "telegram proxy URL")
|
||||||
|
|
||||||
|
// 数据库配置
|
||||||
|
flags.String("db-path", "", "database path")
|
||||||
|
flags.String("db-session", "", "session database path")
|
||||||
|
|
||||||
|
// 临时目录配置
|
||||||
|
flags.String("temp-base-path", "", "temp directory base path")
|
||||||
|
|
||||||
|
// Parser 配置
|
||||||
|
flags.Bool("parser-plugin-enable", false, "enable parser plugins")
|
||||||
|
flags.StringSlice("parser-plugin-dirs", nil, "parser plugin directories")
|
||||||
|
flags.String("parser-proxy", "", "parser proxy URL")
|
||||||
|
|
||||||
|
// 绑定到 viper
|
||||||
|
bindFlags(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindFlags(cmd *cobra.Command) {
|
||||||
|
flags := cmd.Flags()
|
||||||
|
|
||||||
|
viper.BindPFlag("lang", flags.Lookup("lang"))
|
||||||
|
viper.BindPFlag("workers", flags.Lookup("workers"))
|
||||||
|
viper.BindPFlag("retry", flags.Lookup("retry"))
|
||||||
|
viper.BindPFlag("threads", flags.Lookup("threads"))
|
||||||
|
viper.BindPFlag("stream", flags.Lookup("stream"))
|
||||||
|
viper.BindPFlag("no_clean_cache", flags.Lookup("no-clean-cache"))
|
||||||
|
viper.BindPFlag("proxy", flags.Lookup("proxy"))
|
||||||
|
|
||||||
|
// Telegram
|
||||||
|
viper.BindPFlag("telegram.token", flags.Lookup("telegram-token"))
|
||||||
|
viper.BindPFlag("telegram.app_id", flags.Lookup("telegram-app-id"))
|
||||||
|
viper.BindPFlag("telegram.app_hash", flags.Lookup("telegram-app-hash"))
|
||||||
|
viper.BindPFlag("telegram.rpc_retry", flags.Lookup("telegram-rpc-retry"))
|
||||||
|
viper.BindPFlag("telegram.userbot.enable", flags.Lookup("telegram-userbot-enable"))
|
||||||
|
viper.BindPFlag("telegram.userbot.session", flags.Lookup("telegram-userbot-session"))
|
||||||
|
viper.BindPFlag("telegram.proxy.enable", flags.Lookup("telegram-proxy-enable"))
|
||||||
|
viper.BindPFlag("telegram.proxy.url", flags.Lookup("telegram-proxy-url"))
|
||||||
|
|
||||||
|
// database
|
||||||
|
viper.BindPFlag("db.path", flags.Lookup("db-path"))
|
||||||
|
viper.BindPFlag("db.session", flags.Lookup("db-session"))
|
||||||
|
// 临时目录
|
||||||
|
viper.BindPFlag("temp.base_path", flags.Lookup("temp-base-path"))
|
||||||
|
|
||||||
|
// Parser
|
||||||
|
viper.BindPFlag("parser.plugin_enable", flags.Lookup("parser-plugin-enable"))
|
||||||
|
viper.BindPFlag("parser.plugin_dirs", flags.Lookup("parser-plugin-dirs"))
|
||||||
|
viper.BindPFlag("parser.proxy", flags.Lookup("parser-proxy"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConfigFile(cmd *cobra.Command) string {
|
||||||
|
configFile, _ := cmd.Flags().GetString("config")
|
||||||
|
return configFile
|
||||||
|
}
|
||||||
@@ -52,16 +52,39 @@ func (c Config) GetStorageByName(name string) storage.StorageConfig {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Init(ctx context.Context) error {
|
func Init(ctx context.Context, configFile ...string) error {
|
||||||
viper.SetConfigName("config")
|
|
||||||
viper.AddConfigPath(".")
|
|
||||||
viper.AddConfigPath("/etc/saveany/")
|
|
||||||
viper.SetConfigType("toml")
|
viper.SetConfigType("toml")
|
||||||
viper.SetEnvPrefix("SAVEANY")
|
viper.SetEnvPrefix("SAVEANY")
|
||||||
viper.AutomaticEnv()
|
viper.AutomaticEnv()
|
||||||
replacer := strings.NewReplacer(".", "_")
|
replacer := strings.NewReplacer(".", "_")
|
||||||
viper.SetEnvKeyReplacer(replacer)
|
viper.SetEnvKeyReplacer(replacer)
|
||||||
|
|
||||||
|
// 如果指定了配置文件路径,则使用指定的配置文件
|
||||||
|
// 配置文件支持传入一个 http(s) URL 地址
|
||||||
|
if len(configFile) > 0 && configFile[0] != "" {
|
||||||
|
cfg := configFile[0]
|
||||||
|
if strings.HasPrefix(cfg, "http://") || strings.HasPrefix(cfg, "https://") {
|
||||||
|
// 使用远程配置文件
|
||||||
|
resp, err := http.Get(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch remote config file: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("failed to fetch remote config file: status code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if err := viper.ReadConfig(resp.Body); err != nil {
|
||||||
|
return fmt.Errorf("failed to read remote config file: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
viper.SetConfigFile(cfg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
viper.SetConfigName("config")
|
||||||
|
viper.AddConfigPath(".")
|
||||||
|
viper.AddConfigPath("/etc/saveany/")
|
||||||
|
}
|
||||||
|
|
||||||
defaultConfigs := map[string]any{
|
defaultConfigs := map[string]any{
|
||||||
// 基础配置
|
// 基础配置
|
||||||
"lang": "zh-Hans",
|
"lang": "zh-Hans",
|
||||||
@@ -125,13 +148,6 @@ func Init(ctx context.Context) error {
|
|||||||
storageNames[storage.GetName()] = struct{}{}
|
storageNames[storage.GetName()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(i18n.TWithoutInit(cfg.Lang, i18nk.ConfigLoadedStorages, map[string]any{
|
|
||||||
"Count": len(cfg.Storages),
|
|
||||||
}))
|
|
||||||
for _, storage := range cfg.Storages {
|
|
||||||
fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType())
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Workers < 1 {
|
if cfg.Workers < 1 {
|
||||||
cfg.Workers = 1
|
cfg.Workers = 1
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user