Compare commits

...

2 Commits

24 changed files with 564 additions and 216 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/krau/SaveAny-Bot/pkg/aria2" "github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/parser" "github.com/krau/SaveAny-Bot/pkg/parser"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"github.com/krau/SaveAny-Bot/pkg/telegraph" "github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid" "github.com/rs/xid"
@@ -68,9 +69,14 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error { func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
taskID := task.TaskID() taskID := task.TaskID()
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook) info := RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
err := core.AddTask(f.ctx, NewExecutableWrapper(task)) // Inject the progress sink into the context so the task's Emit calls update
// the API store (and fire the webhook on terminal states) without the task
// knowing about the API.
taskCtx := taskevent.WithSink(f.ctx, info)
err := core.AddTask(taskCtx, task)
if err != nil { if err != nil {
DeleteTask(taskID) DeleteTask(taskID)
return fmt.Errorf("failed to add task: %w", err) return fmt.Errorf("failed to add task: %w", err)

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/krau/SaveAny-Bot/core" "github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
@@ -117,7 +118,7 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
// 取消任务 // Cancel the task; the terminal status is set via the task event stream.
if err := core.CancelTask(r.Context(), taskID); err != nil { if err := core.CancelTask(r.Context(), taskID); err != nil {
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error()) WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
return return
@@ -184,27 +185,45 @@ func extractTaskIDFromPath(path string) string {
return parts[3] return parts[3]
} }
// convertTaskProgressToResponse 将任务进度转换为响应格式 // convertTaskProgressToResponse renders a task's current state, computing
// percent and speed from the snapshot taken under the task's mutex.
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse { func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
status, total, downloaded, totalFiles, downloadedFiles, startedAt, errMsg, updatedAt := task.snapshot()
resp := TaskInfoResponse{ resp := TaskInfoResponse{
TaskID: task.TaskID, TaskID: task.TaskID,
Type: tasktype.TaskType(task.Type), Type: tasktype.TaskType(task.Type),
Status: task.Status, Status: status,
Title: task.Title, Title: task.Title,
Storage: task.Storage, Storage: task.Storage,
Path: task.Path, Path: task.Path,
Error: task.Error, Error: errMsg,
CreatedAt: task.CreatedAt, CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt, UpdatedAt: updatedAt,
} }
// 计算进度 var percent float64
if task.TotalBytes > 0 { var speedMBPS float64
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes) if total > 0 {
percent = float64(downloaded) * 100 / float64(total)
} else if totalFiles > 0 {
percent = float64(downloadedFiles) * 100 / float64(totalFiles)
}
if !startedAt.IsZero() {
elapsed := time.Since(startedAt).Seconds()
if elapsed > 0 {
speedMBPS = float64(downloaded) / elapsed / (1024 * 1024)
}
}
if total > 0 || totalFiles > 0 {
resp.Progress = &TaskProgress{ resp.Progress = &TaskProgress{
TotalBytes: task.TotalBytes, TotalBytes: total,
DownloadedBytes: task.DownloadedBytes, DownloadedBytes: downloaded,
TotalFiles: totalFiles,
DownloadedFiles: downloadedFiles,
Percent: percent, Percent: percent,
SpeedMBPS: speedMBPS,
} }
} }

View File

@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
) )
// setupTestServer creates a test server with handlers // setupTestServer creates a test server with handlers
@@ -403,32 +404,38 @@ func TestConcurrentProgressStore(t *testing.T) {
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates // TestProgressTrackerConcurrentUpdates tests concurrent progress updates
func TestProgressTrackerConcurrentUpdates(t *testing.T) { func TestProgressTrackerConcurrentUpdates(t *testing.T) {
tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "") info := RegisterTask("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
tracker.OnStart(10000, 10) info.Emit(taskevent.Event{TaskID: "concurrent-progress", Phase: taskevent.PhaseStart, TotalBytes: 10000})
var wg sync.WaitGroup var wg sync.WaitGroup
numGoroutines := 50 numGoroutines := 50
updatesPerGoroutine := 100 updatesPerGoroutine := 100
// Concurrent progress updates // Concurrent progress updates via the Sink interface
for i := range numGoroutines { for i := range numGoroutines {
wg.Add(1) wg.Add(1)
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
for j := range updatesPerGoroutine { for j := range updatesPerGoroutine {
tracker.OnProgress(int64(id*updatesPerGoroutine+j), j) info.Emit(taskevent.Event{
TaskID: "concurrent-progress",
Phase: taskevent.PhaseProgress,
DownloadedBytes: int64(id*updatesPerGoroutine + j),
TotalBytes: 10000,
})
} }
}(i) }(i)
} }
wg.Wait() wg.Wait()
info := tracker.GetInfo() status, _, downloaded, _, _, _, _, _ := info.snapshot()
if info.Status != TaskStatusRunning { if status != TaskStatusRunning {
t.Errorf("expected status Running after concurrent updates, got %s", info.Status) t.Errorf("expected status Running after concurrent updates, got %s", status)
}
if downloaded <= 0 {
t.Errorf("expected downloaded bytes > 0 after concurrent updates, got %d", downloaded)
} }
// Note: Due to race conditions in the simple implementation,
// we can't reliably check exact values without proper synchronization
} }
// TestTaskFactoryValidation tests TaskFactory parameter validation // TestTaskFactoryValidation tests TaskFactory parameter validation
@@ -526,8 +533,7 @@ func TestEdgeCases(t *testing.T) {
{ {
name: "Progress tracker with empty webhook", name: "Progress tracker with empty webhook",
fn: func(t *testing.T) { fn: func(t *testing.T) {
tracker := NewProgressTracker("test", "type", "storage", "path", "title", "") info := RegisterTask("test-empty-webhook", "type", "storage", "path", "title", "")
info := tracker.GetInfo()
if info.Webhook != "" { if info.Webhook != "" {
t.Error("expected empty webhook") t.Error("expected empty webhook")
} }

View File

@@ -2,12 +2,16 @@ package api
import ( import (
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
) )
// TaskProgressInfo 存储任务的进度信息 // TaskProgressInfo stores the progress of an API-submitted task. All fields are
// guarded by mu. It implements taskevent.Sink so the task layer can update it
// without knowing about the API.
type TaskProgressInfo struct { type TaskProgressInfo struct {
mu sync.Mutex
TaskID string TaskID string
Type string Type string
Status TaskStatus Status TaskStatus
@@ -21,20 +25,25 @@ type TaskProgressInfo struct {
Error string Error string
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
StartedAt time.Time
Webhook string Webhook string
webhookNotified bool
} }
// progressStore 存储所有 API 任务的进度信息 // progressStore holds all API tasks. Entries are removed a fixed duration after
// they reach a terminal state to bound memory usage.
type progressStore struct { type progressStore struct {
mu sync.RWMutex mu sync.RWMutex
tasks map[string]*TaskProgressInfo tasks map[string]*TaskProgressInfo
retention time.Duration
} }
var store = &progressStore{ var store = &progressStore{
tasks: make(map[string]*TaskProgressInfo), tasks: make(map[string]*TaskProgressInfo),
retention: 24 * time.Hour,
} }
// RegisterTask 注册一个新的 API 任务 // RegisterTask registers a new API task and returns its progress info.
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo { func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
info := &TaskProgressInfo{ info := &TaskProgressInfo{
TaskID: taskID, TaskID: taskID,
@@ -55,7 +64,7 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP
return info return info
} }
// GetTask 获取任务进度信息 // GetTask returns the progress info for a task.
func GetTask(taskID string) (*TaskProgressInfo, bool) { func GetTask(taskID string) (*TaskProgressInfo, bool) {
store.mu.RLock() store.mu.RLock()
defer store.mu.RUnlock() defer store.mu.RUnlock()
@@ -63,7 +72,7 @@ func GetTask(taskID string) (*TaskProgressInfo, bool) {
return info, ok return info, ok
} }
// GetAllTasks 获取所有任务 // GetAllTasks returns all tracked tasks.
func GetAllTasks() []*TaskProgressInfo { func GetAllTasks() []*TaskProgressInfo {
store.mu.RLock() store.mu.RLock()
defer store.mu.RUnlock() defer store.mu.RUnlock()
@@ -75,76 +84,133 @@ func GetAllTasks() []*TaskProgressInfo {
return tasks return tasks
} }
// DeleteTask 删除任务记录 // DeleteTask removes a task record.
func DeleteTask(taskID string) { func DeleteTask(taskID string) {
store.mu.Lock() store.mu.Lock()
defer store.mu.Unlock() defer store.mu.Unlock()
delete(store.tasks, taskID) delete(store.tasks, taskID)
} }
// UpdateStatus 更新任务状态 // CleanupExpired removes tasks that reached a terminal state more than the
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) { // store's retention duration ago. It is safe to call periodically.
t.Status = status func CleanupExpired() {
t.UpdatedAt = time.Now() now := time.Now()
store.mu.Lock()
defer store.mu.Unlock()
for id, info := range store.tasks {
info.mu.Lock()
terminal := info.Status == TaskStatusCompleted || info.Status == TaskStatusFailed || info.Status == TaskStatusCancelled
stale := terminal && now.Sub(info.UpdatedAt) > store.retention
info.mu.Unlock()
if stale {
delete(store.tasks, id)
}
}
} }
// SetError 设置错误信息 // StartCleanupLoop runs CleanupExpired on a fixed interval until ctx is done.
// It should be started once during API server initialization.
func StartCleanupLoop(ctx interface{ Done() <-chan struct{} }) {
go func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
CleanupExpired()
}
}
}()
}
// UpdateStatus sets the task status.
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
t.mu.Lock()
t.Status = status
t.UpdatedAt = time.Now()
if status == TaskStatusRunning && t.StartedAt.IsZero() {
t.StartedAt = t.UpdatedAt
}
t.mu.Unlock()
}
// SetError marks the task failed with an error message.
func (t *TaskProgressInfo) SetError(err string) { func (t *TaskProgressInfo) SetError(err string) {
t.mu.Lock()
t.Error = err t.Error = err
t.Status = TaskStatusFailed t.Status = TaskStatusFailed
t.UpdatedAt = time.Now() t.UpdatedAt = time.Now()
t.mu.Unlock()
} }
// ProgressTracker 用于 API 任务的进度追踪 // snapshot returns a point-in-time copy of the fields needed to render a
type ProgressTracker struct { // response, so callers never touch the mutex directly.
info *TaskProgressInfo func (t *TaskProgressInfo) snapshot() (status TaskStatus, total, downloaded int64, totalFiles, downloadedFiles int, startedAt time.Time, err string, updatedAt time.Time) {
t.mu.Lock()
defer t.mu.Unlock()
return t.Status, t.TotalBytes, t.DownloadedBytes, t.TotalFiles, t.DownloadedFiles, t.StartedAt, t.Error, t.UpdatedAt
} }
// NewProgressTracker 创建新的进度追踪器 // Emit implements taskevent.Sink. It translates task lifecycle events into
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker { // status/progress updates and fires the webhook on terminal transitions.
info := RegisterTask(taskID, taskType, storage, path, title, webhook) func (t *TaskProgressInfo) Emit(e taskevent.Event) {
return &ProgressTracker{info: info} t.mu.Lock()
} switch e.Phase {
case taskevent.PhaseStart:
// OnStart 任务开始 t.Status = TaskStatusRunning
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) { if t.StartedAt.IsZero() {
p.info.Status = TaskStatusRunning t.StartedAt = time.Now()
p.info.TotalBytes = totalBytes }
p.info.TotalFiles = totalFiles if e.TotalBytes > 0 {
p.info.UpdatedAt = time.Now() t.TotalBytes = e.TotalBytes
} }
case taskevent.PhaseProgress:
// OnProgress 进度更新 t.Status = TaskStatusRunning
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) { if e.TotalBytes > 0 {
atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes) t.TotalBytes = e.TotalBytes
p.info.DownloadedFiles = downloadedFiles }
p.info.UpdatedAt = time.Now() t.DownloadedBytes = e.DownloadedBytes
} if e.TotalFiles > 0 {
t.TotalFiles = e.TotalFiles
// OnDone 任务完成 }
func (p *ProgressTracker) OnDone(err error) { if e.DownloadedFiles > 0 {
if err != nil { t.DownloadedFiles = e.DownloadedFiles
p.info.Status = TaskStatusFailed }
p.info.Error = err.Error() case taskevent.PhaseDone:
} else { if e.Err != nil {
p.info.Status = TaskStatusCompleted t.Status = TaskStatusFailed
t.Error = e.Err.Error()
} else {
t.Status = TaskStatusCompleted
}
}
t.UpdatedAt = time.Now()
notify := t.Webhook != "" && !t.webhookNotified && (t.Status == TaskStatusCompleted || t.Status == TaskStatusFailed)
if notify {
t.webhookNotified = true
}
t.mu.Unlock()
if notify {
payload := CreateWebhookPayload(t.TaskID, t.Type, t.Status, t.Storage, t.Path, e.Err)
SendWebhook(nil, payload)
} }
p.info.UpdatedAt = time.Now()
} }
// GetInfo 获取任务信息 // ProgressTracker is retained for compatibility but is no longer the primary
func (p *ProgressTracker) GetInfo() *TaskProgressInfo { // progress path; taskevent drives updates now. These methods are safe no-ops
return p.info // when called on a nil receiver.
type ProgressTracker struct{}
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
return &ProgressTracker{}
} }
// UpdateProgressBytes 更新下载字节数 func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {}
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) { func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {}
atomic.StoreInt64(&p.info.DownloadedBytes, bytes) func (p *ProgressTracker) OnDone(err error) {}
p.info.UpdatedAt = time.Now() func (p *ProgressTracker) GetInfo() *TaskProgressInfo { return nil }
} func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {}
func (p *ProgressTracker) UpdateProgressFiles(files int) {}
// UpdateProgressFiles 更新下载文件数
func (p *ProgressTracker) UpdateProgressFiles(files int) {
p.info.DownloadedFiles = files
p.info.UpdatedAt = time.Now()
}

View File

@@ -57,22 +57,19 @@ func NewServer(ctx context.Context) *Server {
// 404 处理 // 404 处理
mux.HandleFunc("/", NotFoundHandler) mux.HandleFunc("/", NotFoundHandler)
// 应用中间件 // Apply middleware chain.
var handler http.Handler = mux var handler http.Handler = mux
// 添加认证中间件 // Apply auth middleware when a token is configured.
token := cfg.Token token := cfg.Token
if token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
}
if token != "" { if token != "" {
handler = AuthMiddleware()(handler) handler = AuthMiddleware()(handler)
} }
// 添加日志中间件 // Add logging middleware.
handler = loggingMiddleware(handler) handler = loggingMiddleware(handler)
// 添加恢复中间件 // Add recovery middleware.
handler = recoveryMiddleware(handler) handler = recoveryMiddleware(handler)
return &Server{ return &Server{
@@ -151,7 +148,8 @@ func (rw *responseWriter) WriteHeader(code int) {
rw.ResponseWriter.WriteHeader(code) rw.ResponseWriter.WriteHeader(code)
} }
// Start 初始化并启动 API 服务器 // Start initializes and starts the API server. It refuses to start without a
// token, since an open download proxy is a security risk.
func Start(ctx context.Context) error { func Start(ctx context.Context) error {
cfg := config.C().API cfg := config.C().API
@@ -160,9 +158,13 @@ func Start(ctx context.Context) error {
} }
if cfg.Token == "" { if cfg.Token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!") return fmt.Errorf("API server is enabled but no token is set; refusing to start insecurely")
} }
server := NewServer(ctx) server := NewServer(ctx)
return server.Start(ctx) if err := server.Start(ctx); err != nil {
return err
}
StartCleanupLoop(ctx)
return nil
} }

View File

@@ -40,6 +40,8 @@ type CreateTaskResponse struct {
type TaskProgress struct { type TaskProgress struct {
TotalBytes int64 `json:"total_bytes,omitempty"` TotalBytes int64 `json:"total_bytes,omitempty"`
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"` DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
TotalFiles int `json:"total_files,omitempty"`
DownloadedFiles int `json:"downloaded_files,omitempty"`
Percent float64 `json:"percent,omitempty"` Percent float64 `json:"percent,omitempty"`
SpeedMBPS float64 `json:"speed_mbps,omitempty"` SpeedMBPS float64 `json:"speed_mbps,omitempty"`
} }

View File

@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time" "time"
@@ -30,9 +29,14 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
webhookURL := info.Webhook webhookURL := info.Webhook
// 异步发送 webhook // Async send with retries.
go func() { go func() {
logger := log.FromContext(ctx).With("task_id", payload.TaskID) var logger *log.Logger
if ctx != nil {
logger = log.FromContext(ctx).With("task_id", payload.TaskID)
} else {
logger = log.Default().With("task_id", payload.TaskID)
}
payloadBytes, err := json.Marshal(payload) payloadBytes, err := json.Marshal(payload)
if err != nil { if err != nil {
@@ -72,7 +76,7 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
}() }()
} }
// CreateWebhookPayload 创建 Webhook 负载 // CreateWebhookPayload creates a Webhook payload.
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload { func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
payload := &WebhookPayload{ payload := &WebhookPayload{
TaskID: taskID, TaskID: taskID,
@@ -93,38 +97,3 @@ func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, sto
return payload return payload
} }
// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调
func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error {
info, ok := GetTask(taskID)
if !ok {
return fmt.Errorf("task not found: %s", taskID)
}
err := fn()
// 确定任务状态
status := TaskStatusCompleted
if err != nil {
if err == context.Canceled {
status = TaskStatusCancelled
} else {
status = TaskStatusFailed
}
}
// 更新任务状态
if err != nil {
info.SetError(err.Error())
} else {
info.UpdateStatus(TaskStatusCompleted)
}
// 发送 webhook
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}

View File

@@ -1,58 +0,0 @@
package api
import (
"context"
"errors"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
type ExecutableWrapper struct {
inner core.Executable
}
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
return &ExecutableWrapper{inner: inner}
}
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
taskID := w.inner.TaskID()
if info, ok := GetTask(taskID); ok {
info.UpdateStatus(TaskStatusRunning)
}
err := w.inner.Execute(ctx)
info, ok := GetTask(taskID)
if !ok {
return err
}
var status TaskStatus
if err != nil {
if errors.Is(err, context.Canceled) {
status = TaskStatusCancelled
info.UpdateStatus(TaskStatusCancelled)
} else {
status = TaskStatusFailed
info.SetError(err.Error())
}
} else {
status = TaskStatusCompleted
info.UpdateStatus(TaskStatusCompleted)
}
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/krau/SaveAny-Bot/common/i18n" "github.com/krau/SaveAny-Bot/common/i18n"
"github.com/krau/SaveAny-Bot/common/i18n/i18nk" "github.com/krau/SaveAny-Bot/common/i18n/i18nk"
"github.com/krau/SaveAny-Bot/common/utils/strutil" "github.com/krau/SaveAny-Bot/common/utils/strutil"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/database" "github.com/krau/SaveAny-Bot/database"
"github.com/krau/SaveAny-Bot/pkg/rule" "github.com/krau/SaveAny-Bot/pkg/rule"
) )
@@ -84,6 +85,46 @@ func handleRuleCmd(ctx *ext.Context, update *ext.Update) error {
return dispatcher.EndGroups return dispatcher.EndGroups
} }
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoCreateRuleSuccess, nil)), nil) ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoCreateRuleSuccess, nil)), nil)
case "preset":
// /rule preset <storage> [base_path]
if len(args) < 3 {
ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildRuleHelpStyling(user.ApplyRule, user.Rules)), nil)
return dispatcher.EndGroups
}
storageName := args[2]
if !config.C().HasStorage(user.ChatID, storageName) {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleErrorStorageNotFound, map[string]any{
"Storage": storageName,
})), nil)
return dispatcher.EndGroups
}
basePath := ""
if len(args) >= 4 {
basePath = args[3]
}
presets := rule.PresetCategories(basePath)
imported := 0
for _, p := range presets {
rd := &database.Rule{
Type: rule.FileNameRegex.String(),
Data: p.Regex,
StorageName: storageName,
DirPath: p.Dir,
UserID: user.ID,
}
if err := database.CreateRule(ctx, rd); err != nil {
logger.Errorf("failed to create preset rule %s: %s", p.Name, err)
continue
}
imported++
}
if imported == 0 {
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleErrorCreateRuleFailed, nil)), nil)
return dispatcher.EndGroups
}
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoPresetImported, map[string]any{
"Count": imported,
})), nil)
case "del": case "del":
// /rule del <id> // /rule del <id>
if len(args) < 3 { if len(args) < 3 {

View File

@@ -24,6 +24,8 @@ func BuildRuleHelpStyling(enabled bool, rules []database.Rule) []styling.StyledT
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpSwitchSuffix, nil)), styling.Plain(i18n.T(i18nk.BotMsgRuleHelpSwitchSuffix, nil)),
styling.Code("add"), styling.Code("add"),
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpAddSuffix, nil)), styling.Plain(i18n.T(i18nk.BotMsgRuleHelpAddSuffix, nil)),
styling.Code("preset"),
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpPresetSuffix, nil)),
styling.Code("del"), styling.Code("del"),
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpDelSuffix, nil)), styling.Plain(i18n.T(i18nk.BotMsgRuleHelpDelSuffix, nil)),
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpExistingRulesPrefix, nil)), styling.Plain(i18n.T(i18nk.BotMsgRuleHelpExistingRulesPrefix, nil)),

View File

@@ -84,8 +84,8 @@ const (
BotMsgCommonPromptSelectDefaultDir Key = "bot.msg.common.prompt_select_default_dir" BotMsgCommonPromptSelectDefaultDir Key = "bot.msg.common.prompt_select_default_dir"
BotMsgCommonPromptSelectDefaultStorage Key = "bot.msg.common.prompt_select_default_storage" BotMsgCommonPromptSelectDefaultStorage Key = "bot.msg.common.prompt_select_default_storage"
BotMsgCommonPromptSelectDir Key = "bot.msg.common.prompt_select_dir" BotMsgCommonPromptSelectDir Key = "bot.msg.common.prompt_select_dir"
BotMsgConfigButtonFilenameStrategy Key = "bot.msg.config.button_filename_strategy"
BotMsgConfigButtonConflictStrategy Key = "bot.msg.config.button_conflict_strategy" BotMsgConfigButtonConflictStrategy Key = "bot.msg.config.button_conflict_strategy"
BotMsgConfigButtonFilenameStrategy Key = "bot.msg.config.button_filename_strategy"
BotMsgConfigConflictStrategyAsk Key = "bot.msg.config.conflict_strategy_ask" BotMsgConfigConflictStrategyAsk Key = "bot.msg.config.conflict_strategy_ask"
BotMsgConfigConflictStrategyOverwrite Key = "bot.msg.config.conflict_strategy_overwrite" BotMsgConfigConflictStrategyOverwrite Key = "bot.msg.config.conflict_strategy_overwrite"
BotMsgConfigConflictStrategyRename Key = "bot.msg.config.conflict_strategy_rename" BotMsgConfigConflictStrategyRename Key = "bot.msg.config.conflict_strategy_rename"
@@ -93,8 +93,8 @@ const (
BotMsgConfigErrorInvalidCallbackData Key = "bot.msg.config.error_invalid_callback_data" BotMsgConfigErrorInvalidCallbackData Key = "bot.msg.config.error_invalid_callback_data"
BotMsgConfigErrorInvalidTemplate Key = "bot.msg.config.error_invalid_template" BotMsgConfigErrorInvalidTemplate Key = "bot.msg.config.error_invalid_template"
BotMsgConfigFnametmplHelp Key = "bot.msg.config.fnametmpl_help" BotMsgConfigFnametmplHelp Key = "bot.msg.config.fnametmpl_help"
BotMsgConfigInfoCurrentTemplatePrefix Key = "bot.msg.config.info_current_template_prefix"
BotMsgConfigInfoConflictStrategySet Key = "bot.msg.config.info_conflict_strategy_set" BotMsgConfigInfoConflictStrategySet Key = "bot.msg.config.info_conflict_strategy_set"
BotMsgConfigInfoCurrentTemplatePrefix Key = "bot.msg.config.info_current_template_prefix"
BotMsgConfigInfoFilenameStrategySet Key = "bot.msg.config.info_filename_strategy_set" BotMsgConfigInfoFilenameStrategySet Key = "bot.msg.config.info_filename_strategy_set"
BotMsgConfigInfoTemplateUpdated Key = "bot.msg.config.info_template_updated" BotMsgConfigInfoTemplateUpdated Key = "bot.msg.config.info_template_updated"
BotMsgConfigPromptSelectConflictStrategy Key = "bot.msg.config.prompt_select_conflict_strategy" BotMsgConfigPromptSelectConflictStrategy Key = "bot.msg.config.prompt_select_conflict_strategy"
@@ -200,6 +200,7 @@ const (
BotMsgRuleErrorGetUserRulesFailed Key = "bot.msg.rule.error_get_user_rules_failed" BotMsgRuleErrorGetUserRulesFailed Key = "bot.msg.rule.error_get_user_rules_failed"
BotMsgRuleErrorInvalidRuleId Key = "bot.msg.rule.error_invalid_rule_id" BotMsgRuleErrorInvalidRuleId Key = "bot.msg.rule.error_invalid_rule_id"
BotMsgRuleErrorInvalidRuleType Key = "bot.msg.rule.error_invalid_rule_type" BotMsgRuleErrorInvalidRuleType Key = "bot.msg.rule.error_invalid_rule_type"
BotMsgRuleErrorStorageNotFound Key = "bot.msg.rule.error_storage_not_found"
BotMsgRuleErrorUpdateUserFailed Key = "bot.msg.rule.error_update_user_failed" BotMsgRuleErrorUpdateUserFailed Key = "bot.msg.rule.error_update_user_failed"
BotMsgRuleHelpAddSuffix Key = "bot.msg.rule.help_add_suffix" BotMsgRuleHelpAddSuffix Key = "bot.msg.rule.help_add_suffix"
BotMsgRuleHelpAvailableOps Key = "bot.msg.rule.help_available_ops" BotMsgRuleHelpAvailableOps Key = "bot.msg.rule.help_available_ops"
@@ -207,13 +208,16 @@ const (
BotMsgRuleHelpCurrentModeEnabled Key = "bot.msg.rule.help_current_mode_enabled" BotMsgRuleHelpCurrentModeEnabled Key = "bot.msg.rule.help_current_mode_enabled"
BotMsgRuleHelpDelSuffix Key = "bot.msg.rule.help_del_suffix" BotMsgRuleHelpDelSuffix Key = "bot.msg.rule.help_del_suffix"
BotMsgRuleHelpExistingRulesPrefix Key = "bot.msg.rule.help_existing_rules_prefix" BotMsgRuleHelpExistingRulesPrefix Key = "bot.msg.rule.help_existing_rules_prefix"
BotMsgRuleHelpPresetSuffix Key = "bot.msg.rule.help_preset_suffix"
BotMsgRuleHelpSwitchSuffix Key = "bot.msg.rule.help_switch_suffix" BotMsgRuleHelpSwitchSuffix Key = "bot.msg.rule.help_switch_suffix"
BotMsgRuleHelpUsage Key = "bot.msg.rule.help_usage" BotMsgRuleHelpUsage Key = "bot.msg.rule.help_usage"
BotMsgRuleInfoCreateRuleSuccess Key = "bot.msg.rule.info_create_rule_success" BotMsgRuleInfoCreateRuleSuccess Key = "bot.msg.rule.info_create_rule_success"
BotMsgRuleInfoDeleteRuleSuccess Key = "bot.msg.rule.info_delete_rule_success" BotMsgRuleInfoDeleteRuleSuccess Key = "bot.msg.rule.info_delete_rule_success"
BotMsgRuleInfoPresetImported Key = "bot.msg.rule.info_preset_imported"
BotMsgRuleInfoRuleModeDisabled Key = "bot.msg.rule.info_rule_mode_disabled" BotMsgRuleInfoRuleModeDisabled Key = "bot.msg.rule.info_rule_mode_disabled"
BotMsgRuleInfoRuleModeEnabled Key = "bot.msg.rule.info_rule_mode_enabled" BotMsgRuleInfoRuleModeEnabled Key = "bot.msg.rule.info_rule_mode_enabled"
BotMsgRulePromptProvideRuleId Key = "bot.msg.rule.prompt_provide_rule_id" BotMsgRulePromptProvideRuleId Key = "bot.msg.rule.prompt_provide_rule_id"
BotMsgRulePromptProvideStorageName Key = "bot.msg.rule.prompt_provide_storage_name"
BotMsgSaveErrorInvalidIdOrUsername Key = "bot.msg.save.error_invalid_id_or_username" BotMsgSaveErrorInvalidIdOrUsername Key = "bot.msg.save.error_invalid_id_or_username"
BotMsgSaveHelpText Key = "bot.msg.save_help_text" BotMsgSaveHelpText Key = "bot.msg.save_help_text"
BotMsgStorageInfoFilenamePrefix Key = "bot.msg.storage.info_filename_prefix" BotMsgStorageInfoFilenamePrefix Key = "bot.msg.storage.info_filename_prefix"

View File

@@ -196,7 +196,11 @@ bot:
help_switch_suffix: " - Toggle rule mode\n" help_switch_suffix: " - Toggle rule mode\n"
help_add_suffix: " <type> <data> <storage_name> <path> - Add rule\n" help_add_suffix: " <type> <data> <storage_name> <path> - Add rule\n"
help_del_suffix: " <rule_id> - Delete rule\n" help_del_suffix: " <rule_id> - Delete rule\n"
help_preset_suffix: " <storage_name> [base_path] - Import built-in filetype rules (video/image/audio/document/archive)\n"
help_existing_rules_prefix: "\nCurrent rules:\n" help_existing_rules_prefix: "\nCurrent rules:\n"
prompt_provide_storage_name: "Please provide a storage name"
error_storage_not_found: "Storage not found: {{.Storage}}"
info_preset_imported: "Imported {{.Count}} built-in classification rules into storage {{.Storage}}"
dir: dir:
error_get_user_dirs_failed: "Failed to get user directories" error_get_user_dirs_failed: "Failed to get user directories"
error_get_user_failed: "Failed to get user" error_get_user_failed: "Failed to get user"

View File

@@ -197,7 +197,11 @@ bot:
help_switch_suffix: " - 开关规则模式\n" help_switch_suffix: " - 开关规则模式\n"
help_add_suffix: " <类型> <数据> <存储名> <路径> - 添加规则\n" help_add_suffix: " <类型> <数据> <存储名> <路径> - 添加规则\n"
help_del_suffix: " <规则ID> - 删除规则\n" help_del_suffix: " <规则ID> - 删除规则\n"
help_preset_suffix: " <存储名> [基础路径] - 导入内置文件类型分类规则(视频/图片/音频/文档/压缩包)\n"
help_existing_rules_prefix: "\n当前已添加的规则:\n" help_existing_rules_prefix: "\n当前已添加的规则:\n"
prompt_provide_storage_name: "请提供存储名称"
error_storage_not_found: "未找到存储: {{.Storage}}"
info_preset_imported: "已导入 {{.Count}} 条内置分类规则到存储 {{.Storage}}"
dir: dir:
error_get_user_dirs_failed: "获取用户文件夹失败" error_get_user_dirs_failed: "获取用户文件夹失败"
error_get_user_failed: "获取用户失败" error_get_user_failed: "获取用户失败"

View File

@@ -8,6 +8,7 @@ import (
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/queue" "github.com/krau/SaveAny-Bot/pkg/queue"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
) )
var queueInstance *queue.TaskQueue[Executable] var queueInstance *queue.TaskQueue[Executable]
@@ -30,11 +31,14 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
break // queue closed and empty break // queue closed and empty
} }
exe := qtask.Data exe := qtask.Data
taskCtx := qtask.Context()
logger.Infof("Processing task: %s", exe.TaskID()) logger.Infof("Processing task: %s", exe.TaskID())
if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil { taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseStart})
if err := ExecCommandString(taskCtx, execHooks.TaskBeforeStart); err != nil {
logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err) logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err)
} }
if err := exe.Execute(qtask.Context()); err != nil { err = exe.Execute(taskCtx)
if err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
logger.Infof("Task %s was canceled", exe.TaskID()) logger.Infof("Task %s was canceled", exe.TaskID())
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil { if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
@@ -52,6 +56,7 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err) logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err)
} }
} }
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseDone, Err: err})
qe.Done(qtask.ID) qe.Done(qtask.ID)
<-semaphore <-semaphore
} }

View File

@@ -6,12 +6,14 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"time" "time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/aria2" "github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
) )
// Execute implements core.Executable. // Execute implements core.Executable.
@@ -77,6 +79,12 @@ func (t *Task) waitForDownload(ctx context.Context) error {
if t.Progress != nil { if t.Progress != nil {
t.Progress.OnProgress(ctx, t, status) t.Progress.OnProgress(ctx, t, status)
} }
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: parseInt64(status.TotalLength),
DownloadedBytes: parseInt64(status.CompletedLength),
})
// Check if download is complete // Check if download is complete
if status.IsDownloadComplete() { if status.IsDownloadComplete() {
@@ -248,3 +256,16 @@ func (t *Task) cancelAria2Download() {
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err) logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
} }
} }
// parseInt64 parses an aria2 status string (decimal bytes) into int64,
// returning 0 on failure so it can be used directly in progress events.
func parseInt64(s string) int64 {
if s == "" {
return 0
}
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0
}
return n
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/krau/SaveAny-Bot/common/utils/ioutil" "github.com/krau/SaveAny-Bot/common/utils/ioutil"
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@@ -62,8 +63,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
return elem.Storage.Save(uploadCtx, pr, elem.Path) return elem.Storage.Save(uploadCtx, pr, elem.Path)
}) })
wr := ioutil.NewProgressWriter(pw, func(n int) { wr := ioutil.NewProgressWriter(pw, func(n int) {
t.downloaded.Add(int64(n)) downloaded := t.downloaded.Add(int64(n))
t.Progress.OnProgress(ctx, t) t.Progress.OnProgress(ctx, t)
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalSize,
DownloadedBytes: downloaded,
})
}) })
errg.Go(func() error { errg.Go(func() error {
defer pw.Close() defer pw.Close()
@@ -92,8 +99,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
} }
}() }()
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) { wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
t.downloaded.Add(int64(n)) downloaded := t.downloaded.Add(int64(n))
t.Progress.OnProgress(ctx, t) t.Progress.OnProgress(ctx, t)
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalSize,
DownloadedBytes: downloaded,
})
}) })
_, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt) _, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt)
if err != nil { if err != nil {

View File

@@ -15,6 +15,7 @@ import (
"github.com/krau/SaveAny-Bot/common/utils/ioutil" "github.com/krau/SaveAny-Bot/common/utils/ioutil"
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@@ -143,10 +144,16 @@ func (t *Task) processLink(ctx context.Context, file *File) error {
} }
}() }()
wr := ioutil.NewProgressWriter(cacheFile, func(n int) { wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
t.downloadedBytes.Add(int64(n)) downloaded := t.downloadedBytes.Add(int64(n))
if t.Progress != nil { if t.Progress != nil {
t.Progress.OnProgress(ctx, t) t.Progress.OnProgress(ctx, t)
} }
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalBytes,
DownloadedBytes: downloaded,
})
}) })
copyResultCh := make(chan error, 1) copyResultCh := make(chan error, 1)

View File

@@ -16,6 +16,7 @@ import (
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/parser" "github.com/krau/SaveAny-Bot/pkg/parser"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@@ -107,10 +108,16 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er
} }
}() }()
wr := ioutil.NewProgressWriter(cacheFile, func(n int) { wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
t.downloadedBytes.Add(int64(n)) downloaded := t.downloadedBytes.Add(int64(n))
if t.progress != nil { if t.progress != nil {
t.progress.OnProgress(ctx, t) t.progress.OnProgress(ctx, t)
} }
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalBytes,
DownloadedBytes: downloaded,
})
}) })
copyResultCh := make(chan error, 1) copyResultCh := make(chan error, 1)

View File

@@ -11,6 +11,7 @@ import (
"github.com/duke-git/lancet/v2/retry" "github.com/duke-git/lancet/v2/retry"
"github.com/krau/SaveAny-Bot/common/utils/fsutil" "github.com/krau/SaveAny-Bot/common/utils/fsutil"
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@@ -27,8 +28,14 @@ func (t *Task) Execute(ctx context.Context) error {
logger.Errorf("Error processing picture %s: %v", pic, err) logger.Errorf("Error processing picture %s: %v", pic, err)
return fmt.Errorf("failed to process picture %s: %w", pic, err) return fmt.Errorf("failed to process picture %s: %w", pic, err)
} }
t.downloaded.Add(1) downloaded := t.downloaded.Add(1)
t.progress.OnProgress(gctx, t) t.progress.OnProgress(gctx, t)
taskevent.Emit(gctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalFiles: t.totalpics,
DownloadedFiles: int(downloaded),
})
return nil return nil
}) })
} }

View File

@@ -4,6 +4,8 @@ import (
"context" "context"
"io" "io"
"sync/atomic" "sync/atomic"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
) )
type ProgressWriterAt struct { type ProgressWriterAt struct {
@@ -20,9 +22,16 @@ func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
downloaded := w.downloaded.Add(int64(at))
if w.progress != nil { if w.progress != nil {
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total) w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
} }
taskevent.Emit(w.ctx, taskevent.Event{
TaskID: w.info.TaskID(),
Phase: taskevent.PhaseProgress,
TotalBytes: w.total,
DownloadedBytes: downloaded,
})
return at, nil return at, nil
} }
@@ -56,9 +65,16 @@ func (w *ProgressWriter) Write(p []byte) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
downloaded := w.downloaded.Add(int64(at))
if w.progress != nil { if w.progress != nil {
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total) w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
} }
taskevent.Emit(w.ctx, taskevent.Event{
TaskID: w.info.TaskID(),
Phase: taskevent.PhaseProgress,
TotalBytes: w.total,
DownloadedBytes: downloaded,
})
return at, nil return at, nil
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
"github.com/krau/SaveAny-Bot/pkg/taskevent"
"github.com/krau/SaveAny-Bot/storage" "github.com/krau/SaveAny-Bot/storage"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@@ -116,6 +117,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
t.uploaded.Add(size) t.uploaded.Add(size)
t.Progress.OnProgress(ctx, t) t.Progress.OnProgress(ctx, t)
taskevent.Emit(ctx, taskevent.Event{
TaskID: t.ID,
Phase: taskevent.PhaseProgress,
TotalBytes: t.totalSize,
DownloadedBytes: t.uploaded.Load(),
})
logger.Info("File uploaded successfully") logger.Info("File uploaded successfully")
return nil return nil

55
pkg/rule/preset.go Normal file
View File

@@ -0,0 +1,55 @@
package rule
import "path"
// PresetCategory describes a built-in filetype classification: files whose name
// matches Regex are routed into the Dir subdirectory (joined with a user base path).
type PresetCategory struct {
// Name is a stable identifier for the category (used in logs/messages).
Name string
// Regex is a FILENAME-REGEX rule data string matching this category's extensions.
Regex string
// Dir is the default subdirectory name for this category.
Dir string
}
// presetCategories holds the default filetype classification rules.
// Regexes are case-insensitive and match common file extensions.
var presetCategories = []PresetCategory{
{
Name: "video",
Regex: `(?i)\.(mp4|mkv|ts|avi|flv|mov|webm|wmv|rmvb|m2ts)$`,
Dir: "视频",
},
{
Name: "image",
Regex: `(?i)\.(jpg|jpeg|png|gif|webp|bmp)$`,
Dir: "图片",
},
{
Name: "audio",
Regex: `(?i)\.(mp3|flac|wav|aac|m4a|ogg)$`,
Dir: "音频",
},
{
Name: "document",
Regex: `(?i)\.(pdf|doc|docx|xls|xlsx|ppt|pptx|txt|md|csv|epub|mobi|azw3|chm)$`,
Dir: "文档",
},
{
Name: "archive",
Regex: `(?i)\.(zip|rar|7z|tar|gz|bz2|xz|r\d{1,3}|z\d{1,3}|\d{3}|part\d+\.rar|7z\.\d{3})$`,
Dir: "压缩包",
},
}
// PresetCategories returns the built-in filetype classification rules with each
// category's directory joined under basePath. basePath may be empty.
func PresetCategories(basePath string) []PresetCategory {
out := make([]PresetCategory, len(presetCategories))
for i, c := range presetCategories {
c.Dir = path.Join(basePath, c.Dir)
out[i] = c
}
return out
}

55
pkg/rule/preset_test.go Normal file
View File

@@ -0,0 +1,55 @@
package rule
import (
"regexp"
"testing"
)
func TestPresetCategoriesCompile(t *testing.T) {
for _, c := range PresetCategories("") {
if _, err := regexp.Compile(c.Regex); err != nil {
t.Errorf("preset %q has invalid regex %q: %v", c.Name, c.Regex, err)
}
}
}
func TestPresetCategoriesMatch(t *testing.T) {
cases := map[string]string{
"video": "movie.MP4",
"image": "photo.jpg",
"audio": "song.flac",
"document": "report.pdf",
"archive": "backup.zip",
}
byName := make(map[string]*regexp.Regexp)
for _, c := range PresetCategories("") {
byName[c.Name] = regexp.MustCompile(c.Regex)
}
for name, filename := range cases {
re, ok := byName[name]
if !ok {
t.Errorf("missing preset category %q", name)
continue
}
if !re.MatchString(filename) {
t.Errorf("preset %q did not match %q", name, filename)
}
}
}
func TestPresetCategoriesBasePath(t *testing.T) {
presets := PresetCategories("/media")
for _, c := range presets {
if c.Dir == "" || c.Dir[0] != '/' {
t.Errorf("preset %q dir %q not joined under base path", c.Name, c.Dir)
}
}
// Empty base path must not prefix a separator.
for _, c := range PresetCategories("") {
if c.Dir == "" || c.Dir[0] == '/' {
t.Errorf("preset %q dir %q should be relative when base path empty", c.Name, c.Dir)
}
}
}

View File

@@ -0,0 +1,88 @@
// Package taskevent provides a decoupled, context-scoped event bus for task
// lifecycle progress. Producers (task implementations) emit events via Emit;
// consumers (e.g. the API progress store, the Telegram message editor) register
// as Sinks and are injected through context. This keeps the task layer free of
// any concrete progress-display dependency, so new task types gain progress
// reporting for free and new observers can be added without touching tasks.
package taskevent
import "context"
// Phase marks a stage in a task's lifecycle.
type Phase int
const (
PhaseStart Phase = iota
PhaseProgress
PhaseDone
)
func (p Phase) String() string {
switch p {
case PhaseStart:
return "start"
case PhaseProgress:
return "progress"
case PhaseDone:
return "done"
default:
return "unknown"
}
}
// Event describes a single progress observation for a task. Byte fields are
// populated by byte-stream tasks; file-count fields by count-based tasks. A
// task may fill whichever subset it has; observers ignore zero values.
type Event struct {
TaskID string
Phase Phase
TotalBytes int64
DownloadedBytes int64
TotalFiles int
DownloadedFiles int
Err error
}
// Sink receives task events. Implementations must be safe for concurrent use.
type Sink interface {
Emit(Event)
}
// SinkFunc is a function adapter for Sink.
type SinkFunc func(Event)
func (f SinkFunc) Emit(e Event) { f(e) }
type sinkKey struct{}
// WithSink returns a ctx carrying the given sinks. Multiple sinks can be passed
// and all will receive every emitted event. Sinks already present in ctx are
// preserved.
func WithSink(ctx context.Context, sinks ...Sink) context.Context {
if len(sinks) == 0 {
return ctx
}
var existing []Sink
if v, ok := ctx.Value(sinkKey{}).([]Sink); ok {
existing = v
}
merged := make([]Sink, 0, len(existing)+len(sinks))
merged = append(merged, existing...)
merged = append(merged, sinks...)
return context.WithValue(ctx, sinkKey{}, merged)
}
// Emit broadcasts an event to all sinks carried by ctx. It is a no-op when no
// sink is attached, so producers can call it unconditionally.
func Emit(ctx context.Context, e Event) {
if ctx == nil {
return
}
sinks, ok := ctx.Value(sinkKey{}).([]Sink)
if !ok {
return
}
for _, s := range sinks {
s.Emit(e)
}
}