From c82c2462bf6b82e49adf26c7490da736fe698b91 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Fri, 20 Jun 2025 21:30:50 +0800 Subject: [PATCH] feat: exec command hook , close #79 --- config/cache.go | 7 +++ config/db.go | 6 +++ config/hook.go | 22 +++++++++ config/temp.go | 5 ++ config/tg.go | 20 ++++++++ config/viper.go | 101 ++++++++++++++------------------------- core/batchtftask/task.go | 5 ++ core/core.go | 28 +++++++++-- core/hookutil.go | 23 +++++++++ core/tftask/execute.go | 2 +- core/tftask/stream.go | 2 +- core/tftask/taskinfo.go | 10 ++-- core/tftask/tftask.go | 13 +++-- core/tphtask/task.go | 5 ++ 14 files changed, 169 insertions(+), 80 deletions(-) create mode 100644 config/cache.go create mode 100644 config/db.go create mode 100644 config/hook.go create mode 100644 config/temp.go create mode 100644 config/tg.go create mode 100644 core/hookutil.go diff --git a/config/cache.go b/config/cache.go new file mode 100644 index 0000000..9e5dab0 --- /dev/null +++ b/config/cache.go @@ -0,0 +1,7 @@ +package config + +type cacheConfig struct { + TTL int64 `toml:"ttl" mapstructure:"ttl" json:"ttl"` + NumCounters int64 `toml:"num_counters" mapstructure:"num_counters" json:"num_counters"` + MaxCost int64 `toml:"max_cost" mapstructure:"max_cost" json:"max_cost"` +} diff --git a/config/db.go b/config/db.go new file mode 100644 index 0000000..ba7b2a6 --- /dev/null +++ b/config/db.go @@ -0,0 +1,6 @@ +package config + +type dbConfig struct { + Path string `toml:"path" mapstructure:"path"` + Session string `toml:"session" mapstructure:"session"` +} diff --git a/config/hook.go b/config/hook.go new file mode 100644 index 0000000..ab01131 --- /dev/null +++ b/config/hook.go @@ -0,0 +1,22 @@ +package config + +type hookConfig struct { + Exec hookExecConfig `toml:"exec" mapstructure:"exec" json:"exec"` +} + +type hookExecConfig struct { + // command to execute, for all task types + TaskBeforeStart string `toml:"task_before_start" mapstructure:"task_before_start" json:"task_before_start"` + TaskSuccess string `toml:"task_success" mapstructure:"task_success" json:"task_success"` + TaskFail string `toml:"task_fail" mapstructure:"task_fail" json:"task_fail"` + TaskCancel string `toml:"task_cancel" mapstructure:"task_cancel" json:"task_cancel"` + + // TaskTypes map[string]hookExecOnTypeConfig `toml:"task_types" mapstructure:"task_types" json:"task_types"` // [TODO] +} + +// type hookExecOnTypeConfig struct { +// TaskBeforeStart string `toml:"task_before_start" mapstructure:"task_before_start" json:"task_before_start"` +// TaskSuccess string `toml:"task_success" mapstructure:"task_success" json:"task_success"` +// TaskFail string `toml:"task_fail" mapstructure:"task_fail" json:"task_fail"` +// TaskCancel string `toml:"task_cancel" mapstructure:"task_cancel" json:"task_cancel"` +// } diff --git a/config/temp.go b/config/temp.go new file mode 100644 index 0000000..059867b --- /dev/null +++ b/config/temp.go @@ -0,0 +1,5 @@ +package config + +type tempConfig struct { + BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"` +} diff --git a/config/tg.go b/config/tg.go new file mode 100644 index 0000000..ba62174 --- /dev/null +++ b/config/tg.go @@ -0,0 +1,20 @@ +package config + +type telegramConfig struct { + Token string `toml:"token" mapstructure:"token"` + AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"` + AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"` + Proxy tgProxyConfig `toml:"proxy" mapstructure:"proxy"` + RpcRetry int `toml:"rpc_retry" mapstructure:"rpc_retry" json:"rpc_retry"` + Userbot userbotConfig `toml:"userbot" mapstructure:"userbot" json:"userbot"` // [TODO] +} + +type userbotConfig struct { + Enable bool `toml:"enable" mapstructure:"enable"` + Session string `toml:"session" mapstructure:"session"` +} + +type tgProxyConfig struct { + Enable bool `toml:"enable" mapstructure:"enable"` + URL string `toml:"url" mapstructure:"url"` +} diff --git a/config/viper.go b/config/viper.go index bf74f2a..2049905 100644 --- a/config/viper.go +++ b/config/viper.go @@ -15,59 +15,23 @@ import ( ) type Config struct { - Lang string `toml:"lang" mapstructure:"lang" json:"lang"` - Workers int `toml:"workers" mapstructure:"workers"` - Retry int `toml:"retry" mapstructure:"retry"` - NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"` - Threads int `toml:"threads" mapstructure:"threads" json:"threads"` - Stream bool `toml:"stream" mapstructure:"stream" json:"stream"` - Cache cacheConfig `toml:"cache" mapstructure:"cache" json:"cache"` - - Users []userConfig `toml:"users" mapstructure:"users" json:"users"` + Lang string `toml:"lang" mapstructure:"lang" json:"lang"` + Workers int `toml:"workers" mapstructure:"workers"` + Retry int `toml:"retry" mapstructure:"retry"` + NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"` + Threads int `toml:"threads" mapstructure:"threads" json:"threads"` + Stream bool `toml:"stream" mapstructure:"stream" json:"stream"` + Cache cacheConfig `toml:"cache" mapstructure:"cache" json:"cache"` + Users []userConfig `toml:"users" mapstructure:"users" json:"users"` Temp tempConfig `toml:"temp" mapstructure:"temp"` DB dbConfig `toml:"db" mapstructure:"db"` Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"` Storages []storage.StorageConfig `toml:"-" mapstructure:"-" json:"storages"` + Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"` } -type cacheConfig struct { - TTL int64 `toml:"ttl" mapstructure:"ttl" json:"ttl"` - NumCounters int64 `toml:"num_counters" mapstructure:"num_counters" json:"num_counters"` - MaxCost int64 `toml:"max_cost" mapstructure:"max_cost" json:"max_cost"` -} - -type tempConfig struct { - BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"` - CacheTTL int64 `toml:"cache_ttl" mapstructure:"cache_ttl" json:"cache_ttl"` -} - -type dbConfig struct { - Path string `toml:"path" mapstructure:"path"` - Session string `toml:"session" mapstructure:"session"` -} - -type telegramConfig struct { - Token string `toml:"token" mapstructure:"token"` - AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"` - AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"` - Timeout int `toml:"timeout" mapstructure:"timeout" json:"timeout"` - Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"` - RpcRetry int `toml:"rpc_retry" mapstructure:"rpc_retry" json:"rpc_retry"` - Userbot userbotConfig `toml:"userbot" mapstructure:"userbot" json:"userbot"` -} - -type userbotConfig struct { - Enable bool `toml:"enable" mapstructure:"enable"` - Session string `toml:"session" mapstructure:"session"` -} - -type proxyConfig struct { - Enable bool `toml:"enable" mapstructure:"enable"` - URL string `toml:"url" mapstructure:"url"` -} - -var Cfg *Config +var Cfg *Config = &Config{} func (c Config) GetStorageByName(name string) storage.StorageConfig { for _, storage := range c.Storages { @@ -88,28 +52,36 @@ func Init(ctx context.Context) error { replacer := strings.NewReplacer(".", "_") viper.SetEnvKeyReplacer(replacer) - viper.SetDefault("lang", "zh-Hans") + defaultConfigs := map[string]any{ + // 基础配置 + "lang": "zh-Hans", + "workers": 3, + "retry": 3, + "threads": 4, - viper.SetDefault("workers", 3) - viper.SetDefault("retry", 3) - viper.SetDefault("threads", 4) + // 缓存配置 + "cache.ttl": 86400, + "cache.num_counters": 1e5, + "cache.max_cost": 1e6, - viper.SetDefault("cache.ttl", 86400) - viper.SetDefault("cache.num_counters", 1e5) - viper.SetDefault("cache.max_cost", 1e6) + // Telegram + "telegram.app_id": 1025907, + "telegram.app_hash": "452b0359b988148995f22ff0f4229750", + "telegram.rpc_retry": 5, + "telegram.userbot.enable": false, + "telegram.userbot.session": "data/usersession.db", - viper.SetDefault("telegram.app_id", 1025907) - viper.SetDefault("telegram.app_hash", "452b0359b988148995f22ff0f4229750") - viper.SetDefault("telegram.timeout", 60) - viper.SetDefault("telegram.flood_retry", 5) - viper.SetDefault("telegram.rpc_retry", 5) - viper.SetDefault("telegram.userbot.enable", false) - viper.SetDefault("telegram.userbot.session", "data/usersession.db") + // 临时目录 + "temp.base_path": "cache/", - viper.SetDefault("temp.base_path", "cache/") + // 数据库 + "db.path": "data/saveany.db", + "db.session": "data/session.db", + } - viper.SetDefault("db.path", "data/saveany.db") - viper.SetDefault("db.session", "data/session.db") + for key, value := range defaultConfigs { + viper.SetDefault(key, value) + } if err := viper.SafeWriteConfigAs("config.toml"); err != nil { if _, ok := err.(viper.ConfigFileAlreadyExistsError); !ok { @@ -122,8 +94,6 @@ func Init(ctx context.Context) error { os.Exit(1) } - Cfg = &Config{} - if err := viper.Unmarshal(Cfg); err != nil { fmt.Println("Error unmarshalling config file, ", err) os.Exit(1) @@ -170,7 +140,6 @@ func Init(ctx context.Context) error { userStorages[user.ID] = user.Storages } } - return nil } diff --git a/core/batchtftask/task.go b/core/batchtftask/task.go index 6edfce3..10320ad 100644 --- a/core/batchtftask/task.go +++ b/core/batchtftask/task.go @@ -8,6 +8,7 @@ import ( "github.com/krau/SaveAny-Bot/common/tdler" "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/tfile" "github.com/krau/SaveAny-Bot/storage" "github.com/rs/xid" @@ -35,6 +36,10 @@ type Task struct { failed map[string]error // errors for each element } +func (t *Task) Type() tasktype.TaskType { + return tasktype.TaskTypeTgfiles +} + func NewTaskElement( stor storage.Storage, path string, diff --git a/core/core.go b/core/core.go index 588b13e..1a3f861 100644 --- a/core/core.go +++ b/core/core.go @@ -2,32 +2,54 @@ package core import ( "context" + "errors" "github.com/charmbracelet/log" "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/queue" ) var queueInstance *queue.TaskQueue[Exectable] type Exectable interface { + Type() tasktype.TaskType TaskID() string Execute(ctx context.Context) error } func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan struct{}) { + logger := log.FromContext(ctx) + execHooks := config.Cfg.Hook.Exec for { semaphore <- struct{}{} qtask, err := qe.Get() if err != nil { + logger.Error("Failed to get task from queue:", err) break // queue closed and empty } - log.FromContext(ctx).Infof("Processing task: %s", qtask.ID) task := qtask.Data + logger.Infof("Processing task: %s", task.TaskID()) + if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil { + logger.Errorf("Failed to execute before start hook for task %s: %v", task.TaskID(), err) + } if err := task.Execute(qtask.Context()); err != nil { - log.FromContext(ctx).Errorf("Failed to execute task %s: %v", qtask.ID, err) + if errors.Is(err, context.Canceled) { + logger.Infof("Task %s was canceled", task.TaskID()) + if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil { + logger.Errorf("Failed to execute cancel hook for task %s: %v", task.TaskID(), err) + } + } else { + logger.Errorf("Failed to execute task %s: %v", task.TaskID(), err) + if err := ExecCommandString(ctx, execHooks.TaskFail); err != nil { + logger.Errorf("Failed to execute fail hook for task %s: %v", task.TaskID(), err) + } + } } else { - log.FromContext(ctx).Infof("Task %s completed successfully", qtask.ID) + logger.Infof("Task %s completed successfully", task.TaskID()) + if err := ExecCommandString(ctx, execHooks.TaskSuccess); err != nil { + logger.Errorf("Failed to execute success hook for task %s: %v", task.TaskID(), err) + } } qe.Done(qtask.ID) <-semaphore diff --git a/core/hookutil.go b/core/hookutil.go new file mode 100644 index 0000000..69f26aa --- /dev/null +++ b/core/hookutil.go @@ -0,0 +1,23 @@ +package core + +import ( + "context" + "os" + "os/exec" + "runtime" +) + +func ExecCommandString(ctx context.Context, cmd string) error { + if cmd == "" { + return nil + } + var execCmd *exec.Cmd + if runtime.GOOS == "windows" { + execCmd = exec.CommandContext(ctx, "cmd.exe", "/C", cmd) + } else { + execCmd = exec.CommandContext(ctx, "sh", "-c", cmd) + } + execCmd.Stdout = os.Stdout + execCmd.Stderr = os.Stderr + return execCmd.Run() +} diff --git a/core/tftask/execute.go b/core/tftask/execute.go index 4002e61..a51cb52 100644 --- a/core/tftask/execute.go +++ b/core/tftask/execute.go @@ -14,7 +14,7 @@ import ( "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" ) -func (t *TGFileTask) Execute(ctx context.Context) error { +func (t *Task) Execute(ctx context.Context) error { logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name())) t.Progress.OnStart(ctx, t) if t.stream { diff --git a/core/tftask/stream.go b/core/tftask/stream.go index 854db27..1d56b25 100644 --- a/core/tftask/stream.go +++ b/core/tftask/stream.go @@ -10,7 +10,7 @@ import ( "golang.org/x/sync/errgroup" ) -func executeStream(ctx context.Context, task *TGFileTask) error { +func executeStream(ctx context.Context, task *Task) error { logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name())) pr, pw := io.Pipe() diff --git a/core/tftask/taskinfo.go b/core/tftask/taskinfo.go index abcae29..891580f 100644 --- a/core/tftask/taskinfo.go +++ b/core/tftask/taskinfo.go @@ -8,22 +8,22 @@ type TaskInfo interface { StorageName() string } -func (t *TGFileTask) TaskID() string { +func (t *Task) TaskID() string { return t.ID } -func (t *TGFileTask) FileName() string { +func (t *Task) FileName() string { return t.File.Name() } -func (t *TGFileTask) FileSize() int64 { +func (t *Task) FileSize() int64 { return t.File.Size() } -func (t *TGFileTask) StoragePath() string { +func (t *Task) StoragePath() string { return t.Path } -func (t *TGFileTask) StorageName() string { +func (t *Task) StorageName() string { return t.Storage.Name() } diff --git a/core/tftask/tftask.go b/core/tftask/tftask.go index f83aad6..943ab1f 100644 --- a/core/tftask/tftask.go +++ b/core/tftask/tftask.go @@ -7,11 +7,12 @@ import ( "github.com/krau/SaveAny-Bot/common/tdler" "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/tfile" "github.com/krau/SaveAny-Bot/storage" ) -type TGFileTask struct { +type Task struct { ID string Ctx context.Context File tfile.TGFile @@ -23,6 +24,10 @@ type TGFileTask struct { localPath string } +func (t *Task) Type() tasktype.TaskType { + return tasktype.TaskTypeTgfiles +} + func NewTGFileTask( id string, ctx context.Context, @@ -31,14 +36,14 @@ func NewTGFileTask( stor storage.Storage, path string, progress ProgressTracker, -) (*TGFileTask, error) { +) (*Task, error) { _, ok := stor.(storage.StorageCannotStream) if !config.Cfg.Stream || ok { cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) if err != nil { return nil, fmt.Errorf("failed to get absolute path for cache: %w", err) } - tftask := &TGFileTask{ + tftask := &Task{ ID: id, Ctx: ctx, client: client, @@ -50,7 +55,7 @@ func NewTGFileTask( } return tftask, nil } - tfileTask := &TGFileTask{ + tfileTask := &Task{ ID: id, Ctx: ctx, client: client, diff --git a/core/tphtask/task.go b/core/tphtask/task.go index cff20d9..d232440 100644 --- a/core/tphtask/task.go +++ b/core/tphtask/task.go @@ -4,6 +4,7 @@ import ( "context" "sync/atomic" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/telegraph" "github.com/krau/SaveAny-Bot/storage" ) @@ -23,6 +24,10 @@ type Task struct { downloaded atomic.Int64 } +func (t *Task) Type() tasktype.TaskType { + return tasktype.TaskTypeTphpics +} + func NewTask( id string, ctx context.Context,