From e5d1e143e0c7f3620c612d87bc31f5180d00d17e Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Sat, 23 Aug 2025 14:29:32 +0800 Subject: [PATCH] feat: configurable parser and refactor config --- client/bot/bot.go | 16 +++--- client/bot/handlers/middleware.go | 2 +- client/bot/handlers/register.go | 2 +- client/bot/handlers/utils/shortcut/message.go | 2 +- client/middleware/default.go | 2 +- client/user/userclient.go | 12 ++--- cmd/run.go | 18 +++---- common/cache/ristretto.go | 6 +-- common/utils/netutil/proxy.go | 39 +++++++++++++++ common/utils/tphutil/tph.go | 4 +- config/parser.go | 9 ++++ config/user.go | 6 +-- config/viper.go | 49 +++++++------------ core/core.go | 6 +-- core/tasks/batchtfile/execute.go | 4 +- core/tasks/batchtfile/task.go | 4 +- core/tasks/parsed/execute.go | 6 +-- core/tasks/parsed/task.go | 2 +- core/tasks/telegraph/execute.go | 6 +-- core/tasks/tfile/execute.go | 4 +- core/tasks/tfile/tftask.go | 4 +- database/db.go | 6 +-- parsers/parser.go | 25 +++++++--- parsers/twitter/parser.go | 30 ++++++++++-- pkg/parser/parser.go | 8 ++- pkg/tfile/dler.go | 2 +- storage/load.go | 10 ++-- storage/telegram/telegram.go | 2 +- 28 files changed, 181 insertions(+), 105 deletions(-) diff --git a/client/bot/bot.go b/client/bot/bot.go index 7dffb68..b712c1a 100644 --- a/client/bot/bot.go +++ b/client/bot/bot.go @@ -27,8 +27,8 @@ func Init(ctx context.Context) { }) go func() { var resolver dcs.Resolver - if config.Cfg.Telegram.Proxy.Enable && config.Cfg.Telegram.Proxy.URL != "" { - dialer, err := netutil.NewProxyDialer(config.Cfg.Telegram.Proxy.URL) + if config.C().Telegram.Proxy.Enable && config.C().Telegram.Proxy.URL != "" { + dialer, err := netutil.NewProxyDialer(config.C().Telegram.Proxy.URL) if err != nil { resultChan <- struct { client *gotgproto.Client @@ -43,16 +43,16 @@ func Init(ctx context.Context) { resolver = dcs.DefaultResolver() } client, err := gotgproto.NewClient( - config.Cfg.Telegram.AppID, - config.Cfg.Telegram.AppHash, - gotgproto.ClientTypeBot(config.Cfg.Telegram.Token), + config.C().Telegram.AppID, + config.C().Telegram.AppHash, + gotgproto.ClientTypeBot(config.C().Telegram.Token), &gotgproto.ClientOpts{ - Session: sessionMaker.SqlSession(gormlite.Open(config.Cfg.DB.Session)), + Session: sessionMaker.SqlSession(gormlite.Open(config.C().DB.Session)), DisableCopyright: true, Middlewares: middleware.NewDefaultMiddlewares(ctx, 5*time.Minute), Resolver: resolver, Context: ctx, - MaxRetries: config.Cfg.Telegram.RpcRetry, + MaxRetries: config.C().Telegram.RpcRetry, AutoFetchReply: true, ErrorHandler: func(ctx *ext.Context, u *ext.Update, s string) error { log.FromContext(ctx).Errorf("Unhandled error: %s", s) @@ -79,7 +79,7 @@ func Init(ctx context.Context) { {Command: "dir", Description: "管理存储文件夹"}, {Command: "rule", Description: "管理规则"}, } - if config.Cfg.Telegram.Userbot.Enable { + if config.C().Telegram.Userbot.Enable { commands = append(commands, tg.BotCommand{Command: "watch", Description: "监听聊天"}) commands = append(commands, tg.BotCommand{Command: "unwatch", Description: "取消监听聊天"}) } diff --git a/client/bot/handlers/middleware.go b/client/bot/handlers/middleware.go index 98945d6..c00e73f 100644 --- a/client/bot/handlers/middleware.go +++ b/client/bot/handlers/middleware.go @@ -11,7 +11,7 @@ import ( func checkPermission(ctx *ext.Context, update *ext.Update) error { userID := update.GetUserChat().GetID() - if !slice.Contain(config.Cfg.GetUsersID(), userID) { + if !slice.Contain(config.C().GetUsersID(), userID) { const noPermissionText string = ` 您不在白名单中, 无法使用此 Bot. 您可以部署自己的实例: https://github.com/krau/SaveAny-Bot diff --git a/client/bot/handlers/register.go b/client/bot/handlers/register.go index c4243f7..2c043fb 100644 --- a/client/bot/handlers/register.go +++ b/client/bot/handlers/register.go @@ -56,7 +56,7 @@ func Register(disp dispatcher.Dispatcher) { disp.AddHandler(handlers.NewMessage(filters.Message.Media, handleSilentMode(handleMediaMessage, handleSilentSaveMedia))) disp.AddHandler(handlers.NewMessage(filters.Message.Text, handleSilentMode(handleTextMessage, handleSilentSaveText))) - if config.Cfg.Telegram.Userbot.Enable { + if config.C().Telegram.Userbot.Enable { go listenMediaMessageEvent(userclient.GetMediaMessageCh()) } } diff --git a/client/bot/handlers/utils/shortcut/message.go b/client/bot/handlers/utils/shortcut/message.go index 1563707..363c00a 100644 --- a/client/bot/handlers/utils/shortcut/message.go +++ b/client/bot/handlers/utils/shortcut/message.go @@ -102,7 +102,7 @@ func GetFilesFromUpdateLinkMessageWithReplyEdit(ctx *ext.Context, update *ext.Up } tctx := ctx - if config.Cfg.Telegram.Userbot.Enable { + if config.C().Telegram.Userbot.Enable { tctx = uc.GetCtx() } diff --git a/client/middleware/default.go b/client/middleware/default.go index 90dfa14..8e6dcaf 100644 --- a/client/middleware/default.go +++ b/client/middleware/default.go @@ -16,7 +16,7 @@ import ( func NewDefaultMiddlewares(ctx context.Context, timeout time.Duration) []telegram.Middleware { return []telegram.Middleware{ recovery.New(ctx, newBackoff(timeout)), - retry.New(config.Cfg.Telegram.RpcRetry), + retry.New(config.C().Telegram.RpcRetry), floodwait.NewSimpleWaiter(), } } diff --git a/client/user/userclient.go b/client/user/userclient.go index 1920e74..cfc5349 100644 --- a/client/user/userclient.go +++ b/client/user/userclient.go @@ -54,8 +54,8 @@ func Login(ctx context.Context) (*gotgproto.Client, error) { }) go func() { var resolver dcs.Resolver - if config.Cfg.Telegram.Proxy.Enable && config.Cfg.Telegram.Proxy.URL != "" { - dialer, err := netutil.NewProxyDialer(config.Cfg.Telegram.Proxy.URL) + if config.C().Telegram.Proxy.Enable && config.C().Telegram.Proxy.URL != "" { + dialer, err := netutil.NewProxyDialer(config.C().Telegram.Proxy.URL) if err != nil { res <- struct { client *gotgproto.Client @@ -70,16 +70,16 @@ func Login(ctx context.Context) (*gotgproto.Client, error) { resolver = dcs.DefaultResolver() } tclient, err := gotgproto.NewClient( - config.Cfg.Telegram.AppID, - config.Cfg.Telegram.AppHash, + config.C().Telegram.AppID, + config.C().Telegram.AppHash, gotgproto.ClientTypePhone(""), &gotgproto.ClientOpts{ - Session: sessionMaker.SqlSession(gormlite.Open(config.Cfg.Telegram.Userbot.Session)), + Session: sessionMaker.SqlSession(gormlite.Open(config.C().Telegram.Userbot.Session)), AuthConversator: &terminalAuthConversator{}, Context: ctx, DisableCopyright: true, Resolver: resolver, - MaxRetries: config.Cfg.Telegram.RpcRetry, + MaxRetries: config.C().Telegram.RpcRetry, AutoFetchReply: true, Middlewares: middleware.NewDefaultMiddlewares(ctx, 5*time.Minute), ErrorHandler: func(ctx *ext.Context, u *ext.Update, s string) error { diff --git a/cmd/run.go b/cmd/run.go index 4059f23..024a94d 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -50,12 +50,12 @@ func initAll(ctx context.Context) { } cache.Init() logger := log.FromContext(ctx) - i18n.Init(config.Cfg.Lang) + i18n.Init(config.C().Lang) logger.Info(i18n.T(i18nk.Initing)) database.Init(ctx) storage.LoadStorages(ctx) - if config.Cfg.Parser.PluginEnable { - for _, dir := range config.Cfg.Parser.PluginDirs { + if config.C().Parser.PluginEnable { + for _, dir := range config.C().Parser.PluginDirs { if err := parsers.LoadPlugins(ctx, dir); err != nil { logger.Error("Failed to load parser plugins", "dir", dir, "error", err) } else { @@ -63,7 +63,7 @@ func initAll(ctx context.Context) { } } } - if config.Cfg.Telegram.Userbot.Enable { + if config.C().Telegram.Userbot.Enable { _, err := userclient.Login(ctx) if err != nil { logger.Fatalf("User client login failed: %s", err) @@ -73,13 +73,13 @@ func initAll(ctx context.Context) { } func cleanCache() { - if config.Cfg.NoCleanCache { + if config.C().NoCleanCache { return } - if config.Cfg.Temp.BasePath != "" && !config.Cfg.Stream { - if slices.Contains([]string{"/", ".", "\\", ".."}, filepath.Clean(config.Cfg.Temp.BasePath)) { + if config.C().Temp.BasePath != "" && !config.C().Stream { + if slices.Contains([]string{"/", ".", "\\", ".."}, filepath.Clean(config.C().Temp.BasePath)) { log.Error(i18n.T(i18nk.InvalidCacheDir, map[string]any{ - "Path": config.Cfg.Temp.BasePath, + "Path": config.C().Temp.BasePath, })) return } @@ -90,7 +90,7 @@ func cleanCache() { })) return } - cachePath := filepath.Join(currentDir, config.Cfg.Temp.BasePath) + cachePath := filepath.Join(currentDir, config.C().Temp.BasePath) cachePath, err = filepath.Abs(cachePath) if err != nil { log.Error(i18n.T(i18nk.GetCacheAbsPathFailed, map[string]any{ diff --git a/common/cache/ristretto.go b/common/cache/ristretto.go index 507e7be..5bacbf4 100644 --- a/common/cache/ristretto.go +++ b/common/cache/ristretto.go @@ -16,8 +16,8 @@ func Init() { panic("cache already initialized") } c, err := ristretto.NewCache(&ristretto.Config[string, any]{ - NumCounters: config.Cfg.Cache.NumCounters, - MaxCost: config.Cfg.Cache.MaxCost, + NumCounters: config.C().Cache.NumCounters, + MaxCost: config.C().Cache.MaxCost, BufferItems: 64, OnReject: func(item *ristretto.Item[any]) { log.Warnf("Cache item rejected: key=%d, value=%v", item.Key, item.Value) @@ -30,7 +30,7 @@ func Init() { } func Set(key string, value any) error { - ok := cache.SetWithTTL(key, value, 0, time.Duration(config.Cfg.Cache.TTL)*time.Second) + ok := cache.SetWithTTL(key, value, 0, time.Duration(config.C().Cache.TTL)*time.Second) if !ok { return fmt.Errorf("failed to set value in cache") } diff --git a/common/utils/netutil/proxy.go b/common/utils/netutil/proxy.go index 6833fe4..038702f 100644 --- a/common/utils/netutil/proxy.go +++ b/common/utils/netutil/proxy.go @@ -1,6 +1,10 @@ package netutil import ( + "context" + "fmt" + "net" + "net/http" "net/url" "golang.org/x/net/proxy" @@ -13,3 +17,38 @@ func NewProxyDialer(proxyUrl string) (proxy.Dialer, error) { } return proxy.FromURL(url, proxy.Direct) } + +func NewProxyHTTPClient(proxyUrl string) (*http.Client, error) { + if proxyUrl == "" { + return http.DefaultClient, nil + } + + u, err := url.Parse(proxyUrl) + if err != nil { + return nil, err + } + + switch u.Scheme { + case "http", "https": + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(u), + }, + }, nil + case "socks5": + dialer, err := proxy.SOCKS5("tcp", u.Host, nil, proxy.Direct) + if err != nil { + return nil, err + } + + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported proxy scheme: %s", u.Scheme) + } +} diff --git a/common/utils/tphutil/tph.go b/common/utils/tphutil/tph.go index 86df916..cb907c7 100644 --- a/common/utils/tphutil/tph.go +++ b/common/utils/tphutil/tph.go @@ -13,8 +13,8 @@ func DefaultClient() *telegraph.Client { if tphClient != nil { return tphClient } - if config.Cfg.Telegram.Proxy.Enable && config.Cfg.Telegram.Proxy.URL != "" { - proxyUrl := config.Cfg.Telegram.Proxy.URL + if config.C().Telegram.Proxy.Enable && config.C().Telegram.Proxy.URL != "" { + proxyUrl := config.C().Telegram.Proxy.URL var err error tphClient, err = telegraph.NewClientWithProxy(proxyUrl) if err != nil { diff --git a/config/parser.go b/config/parser.go index 4d5cfc7..cce3ed1 100644 --- a/config/parser.go +++ b/config/parser.go @@ -3,4 +3,13 @@ package config type parserConfig struct { PluginEnable bool `toml:"plugin_enable" mapstructure:"plugin_enable" json:"plugin_enable"` PluginDirs []string `toml:"plugin_dirs" mapstructure:"plugin_dirs" json:"plugin_dirs"` + + ParserCfgs map[string]map[string]any `mapstructure:",remain"` +} + +func (c Config) GetParserConfigByName(name string) map[string]any { + if c.Parser.ParserCfgs == nil { + return nil + } + return c.Parser.ParserCfgs[name] } diff --git a/config/user.go b/config/user.go index 3153ed9..ac97e0e 100644 --- a/config/user.go +++ b/config/user.go @@ -14,7 +14,7 @@ var userIDs []int64 var storages []string var userStorages = make(map[int64][]string) -func (c *Config) GetStorageNamesByUserID(userID int64) []string { +func (c Config) GetStorageNamesByUserID(userID int64) []string { us, ok := userStorages[userID] if ok { return us @@ -22,11 +22,11 @@ func (c *Config) GetStorageNamesByUserID(userID int64) []string { return nil } -func (c *Config) GetUsersID() []int64 { +func (c Config) GetUsersID() []int64 { return userIDs } -func (c *Config) HasStorage(userID int64, storageName string) bool { +func (c Config) HasStorage(userID int64, storageName string) bool { us, ok := userStorages[userID] if !ok { return false diff --git a/config/viper.go b/config/viper.go index df8dde1..eecf6a1 100644 --- a/config/viper.go +++ b/config/viper.go @@ -32,7 +32,11 @@ type Config struct { Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"` } -var Cfg *Config = &Config{} +var cfg = &Config{} + +func C() Config { + return *cfg +} func (c Config) GetStorageByName(name string) storage.StorageConfig { for _, storage := range c.Storages { @@ -95,7 +99,7 @@ func Init(ctx context.Context) error { os.Exit(1) } - if err := viper.Unmarshal(Cfg); err != nil { + if err := viper.Unmarshal(cfg); err != nil { fmt.Println("Error unmarshalling config file, ", err) os.Exit(1) } @@ -104,36 +108,36 @@ func Init(ctx context.Context) error { if err != nil { return fmt.Errorf("error loading storage configs: %w", err) } - Cfg.Storages = storagesConfig + cfg.Storages = storagesConfig storageNames := make(map[string]struct{}) - for _, storage := range Cfg.Storages { + for _, storage := range cfg.Storages { if _, ok := storageNames[storage.GetName()]; ok { - return errors.New(i18n.TWithoutInit(Cfg.Lang, i18nk.ConfigInvalidDuplicateStorageName, map[string]any{ + return errors.New(i18n.TWithoutInit(cfg.Lang, i18nk.ConfigInvalidDuplicateStorageName, map[string]any{ "Name": storage.GetName(), })) } storageNames[storage.GetName()] = struct{}{} } - fmt.Println(i18n.TWithoutInit(Cfg.Lang, i18nk.LoadedStorages, map[string]any{ - "Count": len(Cfg.Storages), + fmt.Println(i18n.TWithoutInit(cfg.Lang, i18nk.LoadedStorages, map[string]any{ + "Count": len(cfg.Storages), })) - for _, storage := range Cfg.Storages { + for _, storage := range cfg.Storages { fmt.Printf(" - %s (%s)\n", storage.GetName(), storage.GetType()) } - if Cfg.Workers < 1 || Cfg.Retry < 1 { - return errors.New(i18n.TWithoutInit(Cfg.Lang, i18nk.ConfigInvalidWorkersOrRetry, map[string]any{ - "Workers": Cfg.Workers, - "Retry": Cfg.Retry, + if cfg.Workers < 1 || cfg.Retry < 1 { + return errors.New(i18n.TWithoutInit(cfg.Lang, i18nk.ConfigInvalidWorkersOrRetry, map[string]any{ + "Workers": cfg.Workers, + "Retry": cfg.Retry, })) } - for _, storage := range Cfg.Storages { + for _, storage := range cfg.Storages { storages = append(storages, storage.GetName()) } - for _, user := range Cfg.Users { + for _, user := range cfg.Users { userIDs = append(userIDs, user.ID) if user.Blacklist { userStorages[user.ID] = slice.Compact(slice.Difference(storages, user.Storages)) @@ -143,20 +147,3 @@ func Init(ctx context.Context) error { } return nil } - -func Set(key string, value any) { - viper.Set(key, value) -} - -func ReloadConfig() error { - if err := viper.WriteConfig(); err != nil { - return err - } - if err := viper.ReadInConfig(); err != nil { - return err - } - if error := viper.Unmarshal(Cfg); error != nil { - return error - } - return nil -} diff --git a/core/core.go b/core/core.go index 1a3f861..2a312c9 100644 --- a/core/core.go +++ b/core/core.go @@ -20,7 +20,7 @@ type Exectable interface { func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan struct{}) { logger := log.FromContext(ctx) - execHooks := config.Cfg.Hook.Exec + execHooks := config.C().Hook.Exec for { semaphore <- struct{}{} qtask, err := qe.Get() @@ -58,11 +58,11 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan func Run(ctx context.Context) { log.FromContext(ctx).Info("Start processing tasks...") - semaphore := make(chan struct{}, config.Cfg.Workers) + semaphore := make(chan struct{}, config.C().Workers) if queueInstance == nil { queueInstance = queue.NewTaskQueue[Exectable]() } - for range config.Cfg.Workers { + for range config.C().Workers { go worker(ctx, queueInstance, semaphore) } diff --git a/core/tasks/batchtfile/execute.go b/core/tasks/batchtfile/execute.go index 893b9ac..c6c9be7 100644 --- a/core/tasks/batchtfile/execute.go +++ b/core/tasks/batchtfile/execute.go @@ -21,7 +21,7 @@ func (t *Task) Execute(ctx context.Context) error { logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("batch_file[%s]", t.ID)) logger.Info("Starting batch file task") t.Progress.OnStart(ctx, t) - workers := config.Cfg.Workers + workers := config.C().Workers eg, gctx := errgroup.WithContext(ctx) eg.SetLimit(workers) for _, elem := range t.Elems { @@ -124,6 +124,6 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { return err } return nil - }, retry.Context(vctx), retry.RetryTimes(uint(config.Cfg.Retry))) + }, retry.Context(vctx), retry.RetryTimes(uint(config.C().Retry))) return err } diff --git a/core/tasks/batchtfile/task.go b/core/tasks/batchtfile/task.go index cb7de1a..e4c1326 100644 --- a/core/tasks/batchtfile/task.go +++ b/core/tasks/batchtfile/task.go @@ -47,8 +47,8 @@ func NewTaskElement( ) (*TaskElement, error) { id := xid.New().String() _, ok := stor.(storage.StorageCannotStream) - if !config.Cfg.Stream || ok { - cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) + if !config.C().Stream || ok { + cachePath, err := filepath.Abs(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) if err != nil { return nil, fmt.Errorf("failed to get absolute path for cache: %w", err) } diff --git a/core/tasks/parsed/execute.go b/core/tasks/parsed/execute.go index 899921e..f97e6d9 100644 --- a/core/tasks/parsed/execute.go +++ b/core/tasks/parsed/execute.go @@ -26,7 +26,7 @@ func (t *Task) Execute(ctx context.Context) error { t.progress.OnStart(ctx, t) } eg, gctx := errgroup.WithContext(ctx) - eg.SetLimit(config.Cfg.Workers) + eg.SetLimit(config.C().Workers) for _, resource := range t.item.Resources { eg.Go(func() error { t.processingMu.RLock() @@ -96,7 +96,7 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er if t.stream { return t.Stor.Save(ctx, resp.Body, path.Join(t.StorPath, resource.Filename)) } - cacheFile, err := fsutil.CreateFile(filepath.Join(config.Cfg.Temp.BasePath, + cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("resource_%s_%s", t.ID, resource.Filename))) if err != nil { return fmt.Errorf("failed to create cache file for resource %s: %w", resource.URL, err) @@ -131,7 +131,7 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er return fmt.Errorf("failed to seek cache file for resource %s: %w", resource.URL, err) } return t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, resource.Filename)) - }, retry.Context(ctx), retry.RetryTimes(uint(config.Cfg.Retry))) + }, retry.Context(ctx), retry.RetryTimes(uint(config.C().Retry))) if ctx.Err() != nil { return ctx.Err() } diff --git a/core/tasks/parsed/task.go b/core/tasks/parsed/task.go index e6e221d..b7b52f4 100644 --- a/core/tasks/parsed/task.go +++ b/core/tasks/parsed/task.go @@ -54,7 +54,7 @@ func NewTask( }, } _, ok := stor.(storage.StorageCannotStream) - stream := config.Cfg.Stream && !ok + stream := config.C().Stream && !ok return &Task{ ID: id, Ctx: ctx, diff --git a/core/tasks/telegraph/execute.go b/core/tasks/telegraph/execute.go index d9b1375..72d0dfb 100644 --- a/core/tasks/telegraph/execute.go +++ b/core/tasks/telegraph/execute.go @@ -20,7 +20,7 @@ func (t *Task) Execute(ctx context.Context) error { logger.Infof("Starting Telegraph task %s", t.PhPath) t.progress.OnStart(ctx, t) eg, gctx := errgroup.WithContext(ctx) - eg.SetLimit(config.Cfg.Workers) + eg.SetLimit(config.C().Workers) for i, pic := range t.Pics { eg.Go(func() error { err := t.processPic(gctx, pic, i) @@ -46,7 +46,7 @@ func (t *Task) Execute(ctx context.Context) error { func (t *Task) processPic(ctx context.Context, picUrl string, index int) error { retryOpts := []retry.Option{ retry.Context(ctx), - retry.RetryTimes(uint(config.Cfg.Retry)), + retry.RetryTimes(uint(config.C().Retry)), } var lastErr error err := retry.Retry(func() error { @@ -59,7 +59,7 @@ func (t *Task) processPic(ctx context.Context, picUrl string, index int) error { defer body.Close() filename := fmt.Sprintf("%d%s", index+1, path.Ext(picUrl)) if t.cannotStream { - cacheFile, err := fsutil.CreateFile(filepath.Join(config.Cfg.Temp.BasePath, + cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("tph_%s_%s", t.TaskID(), filename), )) if err != nil { diff --git a/core/tasks/tfile/execute.go b/core/tasks/tfile/execute.go index 8418654..305c028 100644 --- a/core/tasks/tfile/execute.go +++ b/core/tasks/tfile/execute.go @@ -57,7 +57,7 @@ func (t *Task) Execute(ctx context.Context) error { return fmt.Errorf("failed to get file stat: %w", err) } vctx := context.WithValue(ctx, ctxkey.ContentLength, fileStat.Size()) - for i := range config.Cfg.Retry + 1 { + for i := range config.C().Retry + 1 { if err = vctx.Err(); err != nil { return fmt.Errorf("context canceled while saving file: %w", err) } @@ -68,7 +68,7 @@ func (t *Task) Execute(ctx context.Context) error { } defer file.Close() if err = t.Storage.Save(vctx, file, t.Path); err != nil { - if i == config.Cfg.Retry { + if i == config.C().Retry { return fmt.Errorf("failed to save file: %w", err) } logger.Errorf("Failed to save file: %s, retrying...", err) diff --git a/core/tasks/tfile/tftask.go b/core/tasks/tfile/tftask.go index 82718d7..e12d7e4 100644 --- a/core/tasks/tfile/tftask.go +++ b/core/tasks/tfile/tftask.go @@ -35,8 +35,8 @@ func NewTGFileTask( progress ProgressTracker, ) (*Task, error) { _, ok := stor.(storage.StorageCannotStream) - if !config.Cfg.Stream || ok { - cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) + if !config.C().Stream || ok { + cachePath, err := filepath.Abs(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) if err != nil { return nil, fmt.Errorf("failed to get absolute path for cache: %w", err) } diff --git a/database/db.go b/database/db.go index 44da636..5800c5c 100644 --- a/database/db.go +++ b/database/db.go @@ -19,11 +19,11 @@ var db *gorm.DB func Init(ctx context.Context) { logger := log.FromContext(ctx) - if err := os.MkdirAll(filepath.Dir(config.Cfg.DB.Path), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(config.C().DB.Path), 0755); err != nil { logger.Fatal("Failed to create data directory: ", err) } var err error - db, err = gorm.Open(gormlite.Open(config.Cfg.DB.Path), &gorm.Config{ + db, err = gorm.Open(gormlite.Open(config.C().DB.Path), &gorm.Config{ Logger: glogger.New(logger, glogger.Config{ Colorful: true, SlowThreshold: time.Second * 5, @@ -60,7 +60,7 @@ func syncUsers(ctx context.Context) error { } cfgUserMap := make(map[int64]struct{}) - for _, u := range config.Cfg.Users { + for _, u := range config.C().Users { cfgUserMap[u.ID] = struct{}{} } diff --git a/parsers/parser.go b/parsers/parser.go index 052775b..f8961b0 100644 --- a/parsers/parser.go +++ b/parsers/parser.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/parsers/twitter" "github.com/krau/SaveAny-Bot/pkg/parser" ) @@ -12,14 +13,9 @@ import ( var ( parsers []parser.Parser parsersMu sync.Mutex + doConfig sync.Once ) -func GetParsers() []parser.Parser { - parsersMu.Lock() - defer parsersMu.Unlock() - return parsers -} - func AddParser(p ...parser.Parser) { parsersMu.Lock() defer parsersMu.Unlock() @@ -35,6 +31,23 @@ var ( ) func ParseWithContext(ctx context.Context, url string) (*parser.Item, error) { + doConfig.Do(func() { + parsersMu.Lock() + defer parsersMu.Unlock() + if len(parsers) == 0 { + return + } + for _, pser := range parsers { + if configurable, ok := pser.(parser.ConfigurableParser); ok { + cfg := config.C().GetParserConfigByName(configurable.Name()) + if cfg != nil { + if err := configurable.Configure(cfg); err != nil { + fmt.Printf("Error configuring parser %s: %v\n", configurable.Name(), err) + } + } + } + } + }) ch := make(chan *parser.Item, 1) errCh := make(chan error, 1) diff --git a/parsers/twitter/parser.go b/parsers/twitter/parser.go index e3aff3c..bb7d5de 100644 --- a/parsers/twitter/parser.go +++ b/parsers/twitter/parser.go @@ -10,18 +10,20 @@ import ( "regexp" "strings" + "github.com/krau/SaveAny-Bot/common/utils/netutil" "github.com/krau/SaveAny-Bot/pkg/parser" ) type TwitterParser struct { - client http.Client + client http.Client + apiDomain string } const ( - FxTwitterApi = "api.fxtwitter.com" + fxTwitterApi = "api.fxtwitter.com" ) -var _ parser.Parser = (*TwitterParser)(nil) +var _ parser.ConfigurableParser = (*TwitterParser)(nil) var ( twitterSourceURLRegexp *regexp.Regexp = regexp.MustCompile(`(?:twitter|x)\.com/([^/]+)/status/(\d+)`) @@ -40,7 +42,7 @@ func (p *TwitterParser) Parse(ctx context.Context, u string) (*parser.Item, erro if id == "" { return nil, errors.New("invalid Twitter URL") } - apiUrl := fmt.Sprintf("https://%s/_/status/%s", FxTwitterApi, id) + apiUrl := fmt.Sprintf("https://%s/_/status/%s", p.apiDomain, id) req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiUrl, nil) if err != nil { return nil, fmt.Errorf("failed to create request to Twitter API: %w", err) @@ -93,3 +95,23 @@ func (p *TwitterParser) Parse(ctx context.Context, u string) (*parser.Item, erro func (p *TwitterParser) CanHandle(u string) bool { return twitterSourceURLRegexp.MatchString(u) } + +func (p *TwitterParser) Name() string { + return "twitter" +} + +func (p *TwitterParser) Configure(config map[string]any) error { + if domain, ok := config["api_domain"].(string); ok && domain != "" { + p.apiDomain = domain + } else { + p.apiDomain = fxTwitterApi + } + if proxyUrl, ok := config["proxy"].(string); ok && proxyUrl != "" { + proxyClient, err := netutil.NewProxyHTTPClient(proxyUrl) + if err != nil { + return fmt.Errorf("failed to create proxy client: %w", err) + } + p.client = *proxyClient + } + return nil +} diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index ca175fe..5ba8881 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -11,12 +11,18 @@ type Parser interface { Parse(ctx context.Context, url string) (*Item, error) } +type ConfigurableParser interface { + Parser + Configure(config map[string]any) error + Name() string +} + // Resource is a single downloadable resource with metadata. type Resource struct { URL string `json:"url"` Filename string `json:"filename"` // with ext MimeType string `json:"mime_type"` - Extension string `json:"extension"` + Extension string `json:"extension"` // e.g. "mp4" Size int64 `json:"size"` // 0 when unknown Hash map[string]string `json:"hash"` // {"md5": "...", "sha256": "..."} Headers map[string]string `json:"headers"` // HTTP headers when downloading diff --git a/pkg/tfile/dler.go b/pkg/tfile/dler.go index dc03c24..bfe8022 100644 --- a/pkg/tfile/dler.go +++ b/pkg/tfile/dler.go @@ -9,5 +9,5 @@ import ( func NewDownloader(file TGFile) *downloader.Builder { return downloader.NewDownloader().WithPartSize(tglimit.MaxPartSize). - Download(file.Dler(), file.Location()).WithThreads(dlutil.BestThreads(file.Size(), config.Cfg.Threads)) + Download(file.Dler(), file.Location()).WithThreads(dlutil.BestThreads(file.Size(), config.C().Threads)) } diff --git a/storage/load.go b/storage/load.go index d37c8f7..4d67f01 100644 --- a/storage/load.go +++ b/storage/load.go @@ -20,7 +20,7 @@ func getStorageByName(ctx context.Context, name string) (Storage, error) { if ok { return storage, nil } - cfg := config.Cfg.GetStorageByName(name) + cfg := config.C().GetStorageByName(name) if cfg == nil { return nil, fmt.Errorf("未找到存储 %s", name) } @@ -39,7 +39,7 @@ func GetStorageByUserIDAndName(ctx context.Context, chatID int64, name string) ( return nil, ErrStorageNameEmpty } - if !config.Cfg.HasStorage(chatID, name) { + if !config.C().HasStorage(chatID, name) { return nil, fmt.Errorf("没有找到用户 %d 的存储 %s", chatID, name) } @@ -54,7 +54,7 @@ func GetUserStorages(ctx context.Context, chatID int64) []Storage { return storages } var storages []Storage - for _, name := range config.Cfg.GetStorageNamesByUserID(chatID) { + for _, name := range config.C().GetStorageNamesByUserID(chatID) { storage, err := getStorageByName(ctx, name) if err != nil { continue @@ -67,14 +67,14 @@ func GetUserStorages(ctx context.Context, chatID int64) []Storage { func LoadStorages(ctx context.Context) { logger := log.FromContext(ctx) logger.Info("加载存储...") - for _, storage := range config.Cfg.Storages { + for _, storage := range config.C().Storages { _, err := getStorageByName(ctx, storage.GetName()) if err != nil { logger.Errorf("加载存储 %s 失败: %v", storage.GetName(), err) } } logger.Infof("成功加载 %d 个存储", len(Storages)) - for user := range config.Cfg.GetUsersID() { + for user := range config.C().GetUsersID() { UserStorages[int64(user)] = GetUserStorages(ctx, int64(user)) } } diff --git a/storage/telegram/telegram.go b/storage/telegram/telegram.go index 55a225c..82943bd 100644 --- a/storage/telegram/telegram.go +++ b/storage/telegram/telegram.go @@ -100,7 +100,7 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er } upler := uploader.NewUploader(tctx.Raw). WithPartSize(tglimit.MaxUploadPartSize). - WithThreads(config.Cfg.Threads) + WithThreads(config.C().Threads) var file tg.InputFileClass size := func() int64 {