feat: exec command hook , close #79

This commit is contained in:
krau
2025-06-20 21:30:50 +08:00
parent 88128ecac2
commit c82c2462bf
14 changed files with 169 additions and 80 deletions

7
config/cache.go Normal file
View File

@@ -0,0 +1,7 @@
package config
type cacheConfig struct {
TTL int64 `toml:"ttl" mapstructure:"ttl" json:"ttl"`
NumCounters int64 `toml:"num_counters" mapstructure:"num_counters" json:"num_counters"`
MaxCost int64 `toml:"max_cost" mapstructure:"max_cost" json:"max_cost"`
}

6
config/db.go Normal file
View File

@@ -0,0 +1,6 @@
package config
type dbConfig struct {
Path string `toml:"path" mapstructure:"path"`
Session string `toml:"session" mapstructure:"session"`
}

22
config/hook.go Normal file
View File

@@ -0,0 +1,22 @@
package config
type hookConfig struct {
Exec hookExecConfig `toml:"exec" mapstructure:"exec" json:"exec"`
}
type hookExecConfig struct {
// command to execute, for all task types
TaskBeforeStart string `toml:"task_before_start" mapstructure:"task_before_start" json:"task_before_start"`
TaskSuccess string `toml:"task_success" mapstructure:"task_success" json:"task_success"`
TaskFail string `toml:"task_fail" mapstructure:"task_fail" json:"task_fail"`
TaskCancel string `toml:"task_cancel" mapstructure:"task_cancel" json:"task_cancel"`
// TaskTypes map[string]hookExecOnTypeConfig `toml:"task_types" mapstructure:"task_types" json:"task_types"` // [TODO]
}
// type hookExecOnTypeConfig struct {
// TaskBeforeStart string `toml:"task_before_start" mapstructure:"task_before_start" json:"task_before_start"`
// TaskSuccess string `toml:"task_success" mapstructure:"task_success" json:"task_success"`
// TaskFail string `toml:"task_fail" mapstructure:"task_fail" json:"task_fail"`
// TaskCancel string `toml:"task_cancel" mapstructure:"task_cancel" json:"task_cancel"`
// }

5
config/temp.go Normal file
View File

@@ -0,0 +1,5 @@
package config
type tempConfig struct {
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
}

20
config/tg.go Normal file
View File

@@ -0,0 +1,20 @@
package config
type telegramConfig struct {
Token string `toml:"token" mapstructure:"token"`
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
Proxy tgProxyConfig `toml:"proxy" mapstructure:"proxy"`
RpcRetry int `toml:"rpc_retry" mapstructure:"rpc_retry" json:"rpc_retry"`
Userbot userbotConfig `toml:"userbot" mapstructure:"userbot" json:"userbot"` // [TODO]
}
type userbotConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
Session string `toml:"session" mapstructure:"session"`
}
type tgProxyConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
URL string `toml:"url" mapstructure:"url"`
}

View File

@@ -15,59 +15,23 @@ import (
)
type Config struct {
Lang string `toml:"lang" mapstructure:"lang" json:"lang"`
Workers int `toml:"workers" mapstructure:"workers"`
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
Cache cacheConfig `toml:"cache" mapstructure:"cache" json:"cache"`
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
Lang string `toml:"lang" mapstructure:"lang" json:"lang"`
Workers int `toml:"workers" mapstructure:"workers"`
Retry int `toml:"retry" mapstructure:"retry"`
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
Cache cacheConfig `toml:"cache" mapstructure:"cache" json:"cache"`
Users []userConfig `toml:"users" mapstructure:"users" json:"users"`
Temp tempConfig `toml:"temp" mapstructure:"temp"`
DB dbConfig `toml:"db" mapstructure:"db"`
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram"`
Storages []storage.StorageConfig `toml:"-" mapstructure:"-" json:"storages"`
Hook hookConfig `toml:"hook" mapstructure:"hook" json:"hook"`
}
type cacheConfig struct {
TTL int64 `toml:"ttl" mapstructure:"ttl" json:"ttl"`
NumCounters int64 `toml:"num_counters" mapstructure:"num_counters" json:"num_counters"`
MaxCost int64 `toml:"max_cost" mapstructure:"max_cost" json:"max_cost"`
}
type tempConfig struct {
BasePath string `toml:"base_path" mapstructure:"base_path" json:"base_path"`
CacheTTL int64 `toml:"cache_ttl" mapstructure:"cache_ttl" json:"cache_ttl"`
}
type dbConfig struct {
Path string `toml:"path" mapstructure:"path"`
Session string `toml:"session" mapstructure:"session"`
}
type telegramConfig struct {
Token string `toml:"token" mapstructure:"token"`
AppID int `toml:"app_id" mapstructure:"app_id" json:"app_id"`
AppHash string `toml:"app_hash" mapstructure:"app_hash" json:"app_hash"`
Timeout int `toml:"timeout" mapstructure:"timeout" json:"timeout"`
Proxy proxyConfig `toml:"proxy" mapstructure:"proxy"`
RpcRetry int `toml:"rpc_retry" mapstructure:"rpc_retry" json:"rpc_retry"`
Userbot userbotConfig `toml:"userbot" mapstructure:"userbot" json:"userbot"`
}
type userbotConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
Session string `toml:"session" mapstructure:"session"`
}
type proxyConfig struct {
Enable bool `toml:"enable" mapstructure:"enable"`
URL string `toml:"url" mapstructure:"url"`
}
var Cfg *Config
var Cfg *Config = &Config{}
func (c Config) GetStorageByName(name string) storage.StorageConfig {
for _, storage := range c.Storages {
@@ -88,28 +52,36 @@ func Init(ctx context.Context) error {
replacer := strings.NewReplacer(".", "_")
viper.SetEnvKeyReplacer(replacer)
viper.SetDefault("lang", "zh-Hans")
defaultConfigs := map[string]any{
// 基础配置
"lang": "zh-Hans",
"workers": 3,
"retry": 3,
"threads": 4,
viper.SetDefault("workers", 3)
viper.SetDefault("retry", 3)
viper.SetDefault("threads", 4)
// 缓存配置
"cache.ttl": 86400,
"cache.num_counters": 1e5,
"cache.max_cost": 1e6,
viper.SetDefault("cache.ttl", 86400)
viper.SetDefault("cache.num_counters", 1e5)
viper.SetDefault("cache.max_cost", 1e6)
// Telegram
"telegram.app_id": 1025907,
"telegram.app_hash": "452b0359b988148995f22ff0f4229750",
"telegram.rpc_retry": 5,
"telegram.userbot.enable": false,
"telegram.userbot.session": "data/usersession.db",
viper.SetDefault("telegram.app_id", 1025907)
viper.SetDefault("telegram.app_hash", "452b0359b988148995f22ff0f4229750")
viper.SetDefault("telegram.timeout", 60)
viper.SetDefault("telegram.flood_retry", 5)
viper.SetDefault("telegram.rpc_retry", 5)
viper.SetDefault("telegram.userbot.enable", false)
viper.SetDefault("telegram.userbot.session", "data/usersession.db")
// 临时目录
"temp.base_path": "cache/",
viper.SetDefault("temp.base_path", "cache/")
// 数据库
"db.path": "data/saveany.db",
"db.session": "data/session.db",
}
viper.SetDefault("db.path", "data/saveany.db")
viper.SetDefault("db.session", "data/session.db")
for key, value := range defaultConfigs {
viper.SetDefault(key, value)
}
if err := viper.SafeWriteConfigAs("config.toml"); err != nil {
if _, ok := err.(viper.ConfigFileAlreadyExistsError); !ok {
@@ -122,8 +94,6 @@ func Init(ctx context.Context) error {
os.Exit(1)
}
Cfg = &Config{}
if err := viper.Unmarshal(Cfg); err != nil {
fmt.Println("Error unmarshalling config file, ", err)
os.Exit(1)
@@ -170,7 +140,6 @@ func Init(ctx context.Context) error {
userStorages[user.ID] = user.Storages
}
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
@@ -35,6 +36,10 @@ type Task struct {
failed map[string]error // errors for each element
}
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeTgfiles
}
func NewTaskElement(
stor storage.Storage,
path string,

View File

@@ -2,32 +2,54 @@ package core
import (
"context"
"errors"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/queue"
)
var queueInstance *queue.TaskQueue[Exectable]
type Exectable interface {
Type() tasktype.TaskType
TaskID() string
Execute(ctx context.Context) error
}
func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan struct{}) {
logger := log.FromContext(ctx)
execHooks := config.Cfg.Hook.Exec
for {
semaphore <- struct{}{}
qtask, err := qe.Get()
if err != nil {
logger.Error("Failed to get task from queue:", err)
break // queue closed and empty
}
log.FromContext(ctx).Infof("Processing task: %s", qtask.ID)
task := qtask.Data
logger.Infof("Processing task: %s", task.TaskID())
if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil {
logger.Errorf("Failed to execute before start hook for task %s: %v", task.TaskID(), err)
}
if err := task.Execute(qtask.Context()); err != nil {
log.FromContext(ctx).Errorf("Failed to execute task %s: %v", qtask.ID, err)
if errors.Is(err, context.Canceled) {
logger.Infof("Task %s was canceled", task.TaskID())
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
logger.Errorf("Failed to execute cancel hook for task %s: %v", task.TaskID(), err)
}
} else {
logger.Errorf("Failed to execute task %s: %v", task.TaskID(), err)
if err := ExecCommandString(ctx, execHooks.TaskFail); err != nil {
logger.Errorf("Failed to execute fail hook for task %s: %v", task.TaskID(), err)
}
}
} else {
log.FromContext(ctx).Infof("Task %s completed successfully", qtask.ID)
logger.Infof("Task %s completed successfully", task.TaskID())
if err := ExecCommandString(ctx, execHooks.TaskSuccess); err != nil {
logger.Errorf("Failed to execute success hook for task %s: %v", task.TaskID(), err)
}
}
qe.Done(qtask.ID)
<-semaphore

23
core/hookutil.go Normal file
View File

@@ -0,0 +1,23 @@
package core
import (
"context"
"os"
"os/exec"
"runtime"
)
func ExecCommandString(ctx context.Context, cmd string) error {
if cmd == "" {
return nil
}
var execCmd *exec.Cmd
if runtime.GOOS == "windows" {
execCmd = exec.CommandContext(ctx, "cmd.exe", "/C", cmd)
} else {
execCmd = exec.CommandContext(ctx, "sh", "-c", cmd)
}
execCmd.Stdout = os.Stdout
execCmd.Stderr = os.Stderr
return execCmd.Run()
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
)
func (t *TGFileTask) Execute(ctx context.Context) error {
func (t *Task) Execute(ctx context.Context) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name()))
t.Progress.OnStart(ctx, t)
if t.stream {

View File

@@ -10,7 +10,7 @@ import (
"golang.org/x/sync/errgroup"
)
func executeStream(ctx context.Context, task *TGFileTask) error {
func executeStream(ctx context.Context, task *Task) error {
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name()))
pr, pw := io.Pipe()

View File

@@ -8,22 +8,22 @@ type TaskInfo interface {
StorageName() string
}
func (t *TGFileTask) TaskID() string {
func (t *Task) TaskID() string {
return t.ID
}
func (t *TGFileTask) FileName() string {
func (t *Task) FileName() string {
return t.File.Name()
}
func (t *TGFileTask) FileSize() int64 {
func (t *Task) FileSize() int64 {
return t.File.Size()
}
func (t *TGFileTask) StoragePath() string {
func (t *Task) StoragePath() string {
return t.Path
}
func (t *TGFileTask) StorageName() string {
func (t *Task) StorageName() string {
return t.Storage.Name()
}

View File

@@ -7,11 +7,12 @@ import (
"github.com/krau/SaveAny-Bot/common/tdler"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/tfile"
"github.com/krau/SaveAny-Bot/storage"
)
type TGFileTask struct {
type Task struct {
ID string
Ctx context.Context
File tfile.TGFile
@@ -23,6 +24,10 @@ type TGFileTask struct {
localPath string
}
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeTgfiles
}
func NewTGFileTask(
id string,
ctx context.Context,
@@ -31,14 +36,14 @@ func NewTGFileTask(
stor storage.Storage,
path string,
progress ProgressTracker,
) (*TGFileTask, error) {
) (*Task, error) {
_, ok := stor.(storage.StorageCannotStream)
if !config.Cfg.Stream || ok {
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
if err != nil {
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
}
tftask := &TGFileTask{
tftask := &Task{
ID: id,
Ctx: ctx,
client: client,
@@ -50,7 +55,7 @@ func NewTGFileTask(
}
return tftask, nil
}
tfileTask := &TGFileTask{
tfileTask := &Task{
ID: id,
Ctx: ctx,
client: client,

View File

@@ -4,6 +4,7 @@ import (
"context"
"sync/atomic"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/storage"
)
@@ -23,6 +24,10 @@ type Task struct {
downloaded atomic.Int64
}
func (t *Task) Type() tasktype.TaskType {
return tasktype.TaskTypeTphpics
}
func NewTask(
id string,
ctx context.Context,