feat: exec command hook , close #79
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/rs/xid"
|
||||
@@ -35,6 +36,10 @@ type Task struct {
|
||||
failed map[string]error // errors for each element
|
||||
}
|
||||
|
||||
func (t *Task) Type() tasktype.TaskType {
|
||||
return tasktype.TaskTypeTgfiles
|
||||
}
|
||||
|
||||
func NewTaskElement(
|
||||
stor storage.Storage,
|
||||
path string,
|
||||
|
||||
28
core/core.go
28
core/core.go
@@ -2,32 +2,54 @@ package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/queue"
|
||||
)
|
||||
|
||||
var queueInstance *queue.TaskQueue[Exectable]
|
||||
|
||||
type Exectable interface {
|
||||
Type() tasktype.TaskType
|
||||
TaskID() string
|
||||
Execute(ctx context.Context) error
|
||||
}
|
||||
|
||||
func worker(ctx context.Context, qe *queue.TaskQueue[Exectable], semaphore chan struct{}) {
|
||||
logger := log.FromContext(ctx)
|
||||
execHooks := config.Cfg.Hook.Exec
|
||||
for {
|
||||
semaphore <- struct{}{}
|
||||
qtask, err := qe.Get()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get task from queue:", err)
|
||||
break // queue closed and empty
|
||||
}
|
||||
log.FromContext(ctx).Infof("Processing task: %s", qtask.ID)
|
||||
task := qtask.Data
|
||||
logger.Infof("Processing task: %s", task.TaskID())
|
||||
if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil {
|
||||
logger.Errorf("Failed to execute before start hook for task %s: %v", task.TaskID(), err)
|
||||
}
|
||||
if err := task.Execute(qtask.Context()); err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to execute task %s: %v", qtask.ID, err)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Infof("Task %s was canceled", task.TaskID())
|
||||
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
|
||||
logger.Errorf("Failed to execute cancel hook for task %s: %v", task.TaskID(), err)
|
||||
}
|
||||
} else {
|
||||
logger.Errorf("Failed to execute task %s: %v", task.TaskID(), err)
|
||||
if err := ExecCommandString(ctx, execHooks.TaskFail); err != nil {
|
||||
logger.Errorf("Failed to execute fail hook for task %s: %v", task.TaskID(), err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.FromContext(ctx).Infof("Task %s completed successfully", qtask.ID)
|
||||
logger.Infof("Task %s completed successfully", task.TaskID())
|
||||
if err := ExecCommandString(ctx, execHooks.TaskSuccess); err != nil {
|
||||
logger.Errorf("Failed to execute success hook for task %s: %v", task.TaskID(), err)
|
||||
}
|
||||
}
|
||||
qe.Done(qtask.ID)
|
||||
<-semaphore
|
||||
|
||||
23
core/hookutil.go
Normal file
23
core/hookutil.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func ExecCommandString(ctx context.Context, cmd string) error {
|
||||
if cmd == "" {
|
||||
return nil
|
||||
}
|
||||
var execCmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
execCmd = exec.CommandContext(ctx, "cmd.exe", "/C", cmd)
|
||||
} else {
|
||||
execCmd = exec.CommandContext(ctx, "sh", "-c", cmd)
|
||||
}
|
||||
execCmd.Stdout = os.Stdout
|
||||
execCmd.Stderr = os.Stderr
|
||||
return execCmd.Run()
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
)
|
||||
|
||||
func (t *TGFileTask) Execute(ctx context.Context) error {
|
||||
func (t *Task) Execute(ctx context.Context) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", t.File.Name()))
|
||||
t.Progress.OnStart(ctx, t)
|
||||
if t.stream {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func executeStream(ctx context.Context, task *TGFileTask) error {
|
||||
func executeStream(ctx context.Context, task *Task) error {
|
||||
logger := log.FromContext(ctx).WithPrefix(fmt.Sprintf("file[%s]", task.File.Name()))
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
@@ -8,22 +8,22 @@ type TaskInfo interface {
|
||||
StorageName() string
|
||||
}
|
||||
|
||||
func (t *TGFileTask) TaskID() string {
|
||||
func (t *Task) TaskID() string {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func (t *TGFileTask) FileName() string {
|
||||
func (t *Task) FileName() string {
|
||||
return t.File.Name()
|
||||
}
|
||||
|
||||
func (t *TGFileTask) FileSize() int64 {
|
||||
func (t *Task) FileSize() int64 {
|
||||
return t.File.Size()
|
||||
}
|
||||
|
||||
func (t *TGFileTask) StoragePath() string {
|
||||
func (t *Task) StoragePath() string {
|
||||
return t.Path
|
||||
}
|
||||
|
||||
func (t *TGFileTask) StorageName() string {
|
||||
func (t *Task) StorageName() string {
|
||||
return t.Storage.Name()
|
||||
}
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
|
||||
"github.com/krau/SaveAny-Bot/common/tdler"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/tfile"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
|
||||
type TGFileTask struct {
|
||||
type Task struct {
|
||||
ID string
|
||||
Ctx context.Context
|
||||
File tfile.TGFile
|
||||
@@ -23,6 +24,10 @@ type TGFileTask struct {
|
||||
localPath string
|
||||
}
|
||||
|
||||
func (t *Task) Type() tasktype.TaskType {
|
||||
return tasktype.TaskTypeTgfiles
|
||||
}
|
||||
|
||||
func NewTGFileTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
@@ -31,14 +36,14 @@ func NewTGFileTask(
|
||||
stor storage.Storage,
|
||||
path string,
|
||||
progress ProgressTracker,
|
||||
) (*TGFileTask, error) {
|
||||
) (*Task, error) {
|
||||
_, ok := stor.(storage.StorageCannotStream)
|
||||
if !config.Cfg.Stream || ok {
|
||||
cachePath, err := filepath.Abs(filepath.Join(config.Cfg.Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path for cache: %w", err)
|
||||
}
|
||||
tftask := &TGFileTask{
|
||||
tftask := &Task{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
client: client,
|
||||
@@ -50,7 +55,7 @@ func NewTGFileTask(
|
||||
}
|
||||
return tftask, nil
|
||||
}
|
||||
tfileTask := &TGFileTask{
|
||||
tfileTask := &Task{
|
||||
ID: id,
|
||||
Ctx: ctx,
|
||||
client: client,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
)
|
||||
@@ -23,6 +24,10 @@ type Task struct {
|
||||
downloaded atomic.Int64
|
||||
}
|
||||
|
||||
func (t *Task) Type() tasktype.TaskType {
|
||||
return tasktype.TaskTypeTphpics
|
||||
}
|
||||
|
||||
func NewTask(
|
||||
id string,
|
||||
ctx context.Context,
|
||||
|
||||
Reference in New Issue
Block a user