mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-27 18:21:35 +08:00
Compare commits
2 Commits
refactor/p
...
v0.58.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bc460c609 | ||
|
|
f02860ff3f |
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/parser"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"github.com/rs/xid"
|
||||
@@ -68,9 +69,14 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
|
||||
|
||||
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
|
||||
taskID := task.TaskID()
|
||||
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
|
||||
info := RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
|
||||
|
||||
err := core.AddTask(f.ctx, NewExecutableWrapper(task))
|
||||
// Inject the progress sink into the context so the task's Emit calls update
|
||||
// the API store (and fire the webhook on terminal states) without the task
|
||||
// knowing about the API.
|
||||
taskCtx := taskevent.WithSink(f.ctx, info)
|
||||
|
||||
err := core.AddTask(taskCtx, task)
|
||||
if err != nil {
|
||||
DeleteTask(taskID)
|
||||
return fmt.Errorf("failed to add task: %w", err)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
@@ -117,7 +118,7 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 取消任务
|
||||
// Cancel the task; the terminal status is set via the task event stream.
|
||||
if err := core.CancelTask(r.Context(), taskID); err != nil {
|
||||
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
|
||||
return
|
||||
@@ -184,27 +185,45 @@ func extractTaskIDFromPath(path string) string {
|
||||
return parts[3]
|
||||
}
|
||||
|
||||
// convertTaskProgressToResponse 将任务进度转换为响应格式
|
||||
// convertTaskProgressToResponse renders a task's current state, computing
|
||||
// percent and speed from the snapshot taken under the task's mutex.
|
||||
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
|
||||
status, total, downloaded, totalFiles, downloadedFiles, startedAt, errMsg, updatedAt := task.snapshot()
|
||||
|
||||
resp := TaskInfoResponse{
|
||||
TaskID: task.TaskID,
|
||||
Type: tasktype.TaskType(task.Type),
|
||||
Status: task.Status,
|
||||
Status: status,
|
||||
Title: task.Title,
|
||||
Storage: task.Storage,
|
||||
Path: task.Path,
|
||||
Error: task.Error,
|
||||
Error: errMsg,
|
||||
CreatedAt: task.CreatedAt,
|
||||
UpdatedAt: task.UpdatedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
// 计算进度
|
||||
if task.TotalBytes > 0 {
|
||||
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
|
||||
var percent float64
|
||||
var speedMBPS float64
|
||||
if total > 0 {
|
||||
percent = float64(downloaded) * 100 / float64(total)
|
||||
} else if totalFiles > 0 {
|
||||
percent = float64(downloadedFiles) * 100 / float64(totalFiles)
|
||||
}
|
||||
if !startedAt.IsZero() {
|
||||
elapsed := time.Since(startedAt).Seconds()
|
||||
if elapsed > 0 {
|
||||
speedMBPS = float64(downloaded) / elapsed / (1024 * 1024)
|
||||
}
|
||||
}
|
||||
|
||||
if total > 0 || totalFiles > 0 {
|
||||
resp.Progress = &TaskProgress{
|
||||
TotalBytes: task.TotalBytes,
|
||||
DownloadedBytes: task.DownloadedBytes,
|
||||
TotalBytes: total,
|
||||
DownloadedBytes: downloaded,
|
||||
TotalFiles: totalFiles,
|
||||
DownloadedFiles: downloadedFiles,
|
||||
Percent: percent,
|
||||
SpeedMBPS: speedMBPS,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
// setupTestServer creates a test server with handlers
|
||||
@@ -403,32 +404,38 @@ func TestConcurrentProgressStore(t *testing.T) {
|
||||
|
||||
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates
|
||||
func TestProgressTrackerConcurrentUpdates(t *testing.T) {
|
||||
tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
|
||||
tracker.OnStart(10000, 10)
|
||||
info := RegisterTask("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
|
||||
info.Emit(taskevent.Event{TaskID: "concurrent-progress", Phase: taskevent.PhaseStart, TotalBytes: 10000})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
updatesPerGoroutine := 100
|
||||
|
||||
// Concurrent progress updates
|
||||
// Concurrent progress updates via the Sink interface
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := range updatesPerGoroutine {
|
||||
tracker.OnProgress(int64(id*updatesPerGoroutine+j), j)
|
||||
info.Emit(taskevent.Event{
|
||||
TaskID: "concurrent-progress",
|
||||
Phase: taskevent.PhaseProgress,
|
||||
DownloadedBytes: int64(id*updatesPerGoroutine + j),
|
||||
TotalBytes: 10000,
|
||||
})
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
info := tracker.GetInfo()
|
||||
if info.Status != TaskStatusRunning {
|
||||
t.Errorf("expected status Running after concurrent updates, got %s", info.Status)
|
||||
status, _, downloaded, _, _, _, _, _ := info.snapshot()
|
||||
if status != TaskStatusRunning {
|
||||
t.Errorf("expected status Running after concurrent updates, got %s", status)
|
||||
}
|
||||
if downloaded <= 0 {
|
||||
t.Errorf("expected downloaded bytes > 0 after concurrent updates, got %d", downloaded)
|
||||
}
|
||||
// Note: Due to race conditions in the simple implementation,
|
||||
// we can't reliably check exact values without proper synchronization
|
||||
}
|
||||
|
||||
// TestTaskFactoryValidation tests TaskFactory parameter validation
|
||||
@@ -526,8 +533,7 @@ func TestEdgeCases(t *testing.T) {
|
||||
{
|
||||
name: "Progress tracker with empty webhook",
|
||||
fn: func(t *testing.T) {
|
||||
tracker := NewProgressTracker("test", "type", "storage", "path", "title", "")
|
||||
info := tracker.GetInfo()
|
||||
info := RegisterTask("test-empty-webhook", "type", "storage", "path", "title", "")
|
||||
if info.Webhook != "" {
|
||||
t.Error("expected empty webhook")
|
||||
}
|
||||
|
||||
216
api/progress.go
216
api/progress.go
@@ -2,39 +2,48 @@ package api
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
// TaskProgressInfo 存储任务的进度信息
|
||||
// TaskProgressInfo stores the progress of an API-submitted task. All fields are
|
||||
// guarded by mu. It implements taskevent.Sink so the task layer can update it
|
||||
// without knowing about the API.
|
||||
type TaskProgressInfo struct {
|
||||
TaskID string
|
||||
Type string
|
||||
Status TaskStatus
|
||||
Title string
|
||||
TotalBytes int64
|
||||
DownloadedBytes int64
|
||||
TotalFiles int
|
||||
DownloadedFiles int
|
||||
Storage string
|
||||
Path string
|
||||
Error string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Webhook string
|
||||
mu sync.Mutex
|
||||
TaskID string
|
||||
Type string
|
||||
Status TaskStatus
|
||||
Title string
|
||||
TotalBytes int64
|
||||
DownloadedBytes int64
|
||||
TotalFiles int
|
||||
DownloadedFiles int
|
||||
Storage string
|
||||
Path string
|
||||
Error string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
StartedAt time.Time
|
||||
Webhook string
|
||||
webhookNotified bool
|
||||
}
|
||||
|
||||
// progressStore 存储所有 API 任务的进度信息
|
||||
// progressStore holds all API tasks. Entries are removed a fixed duration after
|
||||
// they reach a terminal state to bound memory usage.
|
||||
type progressStore struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*TaskProgressInfo
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*TaskProgressInfo
|
||||
retention time.Duration
|
||||
}
|
||||
|
||||
var store = &progressStore{
|
||||
tasks: make(map[string]*TaskProgressInfo),
|
||||
tasks: make(map[string]*TaskProgressInfo),
|
||||
retention: 24 * time.Hour,
|
||||
}
|
||||
|
||||
// RegisterTask 注册一个新的 API 任务
|
||||
// RegisterTask registers a new API task and returns its progress info.
|
||||
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
|
||||
info := &TaskProgressInfo{
|
||||
TaskID: taskID,
|
||||
@@ -55,7 +64,7 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP
|
||||
return info
|
||||
}
|
||||
|
||||
// GetTask 获取任务进度信息
|
||||
// GetTask returns the progress info for a task.
|
||||
func GetTask(taskID string) (*TaskProgressInfo, bool) {
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
@@ -63,7 +72,7 @@ func GetTask(taskID string) (*TaskProgressInfo, bool) {
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// GetAllTasks 获取所有任务
|
||||
// GetAllTasks returns all tracked tasks.
|
||||
func GetAllTasks() []*TaskProgressInfo {
|
||||
store.mu.RLock()
|
||||
defer store.mu.RUnlock()
|
||||
@@ -75,76 +84,133 @@ func GetAllTasks() []*TaskProgressInfo {
|
||||
return tasks
|
||||
}
|
||||
|
||||
// DeleteTask 删除任务记录
|
||||
// DeleteTask removes a task record.
|
||||
func DeleteTask(taskID string) {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
delete(store.tasks, taskID)
|
||||
}
|
||||
|
||||
// UpdateStatus 更新任务状态
|
||||
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
|
||||
t.Status = status
|
||||
t.UpdatedAt = time.Now()
|
||||
// CleanupExpired removes tasks that reached a terminal state more than the
|
||||
// store's retention duration ago. It is safe to call periodically.
|
||||
func CleanupExpired() {
|
||||
now := time.Now()
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
for id, info := range store.tasks {
|
||||
info.mu.Lock()
|
||||
terminal := info.Status == TaskStatusCompleted || info.Status == TaskStatusFailed || info.Status == TaskStatusCancelled
|
||||
stale := terminal && now.Sub(info.UpdatedAt) > store.retention
|
||||
info.mu.Unlock()
|
||||
if stale {
|
||||
delete(store.tasks, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetError 设置错误信息
|
||||
// StartCleanupLoop runs CleanupExpired on a fixed interval until ctx is done.
|
||||
// It should be started once during API server initialization.
|
||||
func StartCleanupLoop(ctx interface{ Done() <-chan struct{} }) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
CleanupExpired()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// UpdateStatus sets the task status.
|
||||
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
|
||||
t.mu.Lock()
|
||||
t.Status = status
|
||||
t.UpdatedAt = time.Now()
|
||||
if status == TaskStatusRunning && t.StartedAt.IsZero() {
|
||||
t.StartedAt = t.UpdatedAt
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetError marks the task failed with an error message.
|
||||
func (t *TaskProgressInfo) SetError(err string) {
|
||||
t.mu.Lock()
|
||||
t.Error = err
|
||||
t.Status = TaskStatusFailed
|
||||
t.UpdatedAt = time.Now()
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// ProgressTracker 用于 API 任务的进度追踪
|
||||
type ProgressTracker struct {
|
||||
info *TaskProgressInfo
|
||||
// snapshot returns a point-in-time copy of the fields needed to render a
|
||||
// response, so callers never touch the mutex directly.
|
||||
func (t *TaskProgressInfo) snapshot() (status TaskStatus, total, downloaded int64, totalFiles, downloadedFiles int, startedAt time.Time, err string, updatedAt time.Time) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.Status, t.TotalBytes, t.DownloadedBytes, t.TotalFiles, t.DownloadedFiles, t.StartedAt, t.Error, t.UpdatedAt
|
||||
}
|
||||
|
||||
// NewProgressTracker 创建新的进度追踪器
|
||||
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
|
||||
info := RegisterTask(taskID, taskType, storage, path, title, webhook)
|
||||
return &ProgressTracker{info: info}
|
||||
}
|
||||
|
||||
// OnStart 任务开始
|
||||
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {
|
||||
p.info.Status = TaskStatusRunning
|
||||
p.info.TotalBytes = totalBytes
|
||||
p.info.TotalFiles = totalFiles
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// OnProgress 进度更新
|
||||
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {
|
||||
atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes)
|
||||
p.info.DownloadedFiles = downloadedFiles
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// OnDone 任务完成
|
||||
func (p *ProgressTracker) OnDone(err error) {
|
||||
if err != nil {
|
||||
p.info.Status = TaskStatusFailed
|
||||
p.info.Error = err.Error()
|
||||
} else {
|
||||
p.info.Status = TaskStatusCompleted
|
||||
// Emit implements taskevent.Sink. It translates task lifecycle events into
|
||||
// status/progress updates and fires the webhook on terminal transitions.
|
||||
func (t *TaskProgressInfo) Emit(e taskevent.Event) {
|
||||
t.mu.Lock()
|
||||
switch e.Phase {
|
||||
case taskevent.PhaseStart:
|
||||
t.Status = TaskStatusRunning
|
||||
if t.StartedAt.IsZero() {
|
||||
t.StartedAt = time.Now()
|
||||
}
|
||||
if e.TotalBytes > 0 {
|
||||
t.TotalBytes = e.TotalBytes
|
||||
}
|
||||
case taskevent.PhaseProgress:
|
||||
t.Status = TaskStatusRunning
|
||||
if e.TotalBytes > 0 {
|
||||
t.TotalBytes = e.TotalBytes
|
||||
}
|
||||
t.DownloadedBytes = e.DownloadedBytes
|
||||
if e.TotalFiles > 0 {
|
||||
t.TotalFiles = e.TotalFiles
|
||||
}
|
||||
if e.DownloadedFiles > 0 {
|
||||
t.DownloadedFiles = e.DownloadedFiles
|
||||
}
|
||||
case taskevent.PhaseDone:
|
||||
if e.Err != nil {
|
||||
t.Status = TaskStatusFailed
|
||||
t.Error = e.Err.Error()
|
||||
} else {
|
||||
t.Status = TaskStatusCompleted
|
||||
}
|
||||
}
|
||||
t.UpdatedAt = time.Now()
|
||||
notify := t.Webhook != "" && !t.webhookNotified && (t.Status == TaskStatusCompleted || t.Status == TaskStatusFailed)
|
||||
if notify {
|
||||
t.webhookNotified = true
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
if notify {
|
||||
payload := CreateWebhookPayload(t.TaskID, t.Type, t.Status, t.Storage, t.Path, e.Err)
|
||||
SendWebhook(nil, payload)
|
||||
}
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// GetInfo 获取任务信息
|
||||
func (p *ProgressTracker) GetInfo() *TaskProgressInfo {
|
||||
return p.info
|
||||
// ProgressTracker is retained for compatibility but is no longer the primary
|
||||
// progress path; taskevent drives updates now. These methods are safe no-ops
|
||||
// when called on a nil receiver.
|
||||
type ProgressTracker struct{}
|
||||
|
||||
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
|
||||
return &ProgressTracker{}
|
||||
}
|
||||
|
||||
// UpdateProgressBytes 更新下载字节数
|
||||
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {
|
||||
atomic.StoreInt64(&p.info.DownloadedBytes, bytes)
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// UpdateProgressFiles 更新下载文件数
|
||||
func (p *ProgressTracker) UpdateProgressFiles(files int) {
|
||||
p.info.DownloadedFiles = files
|
||||
p.info.UpdatedAt = time.Now()
|
||||
}
|
||||
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {}
|
||||
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {}
|
||||
func (p *ProgressTracker) OnDone(err error) {}
|
||||
func (p *ProgressTracker) GetInfo() *TaskProgressInfo { return nil }
|
||||
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {}
|
||||
func (p *ProgressTracker) UpdateProgressFiles(files int) {}
|
||||
|
||||
@@ -57,22 +57,19 @@ func NewServer(ctx context.Context) *Server {
|
||||
// 404 处理
|
||||
mux.HandleFunc("/", NotFoundHandler)
|
||||
|
||||
// 应用中间件
|
||||
// Apply middleware chain.
|
||||
var handler http.Handler = mux
|
||||
|
||||
// 添加认证中间件
|
||||
// Apply auth middleware when a token is configured.
|
||||
token := cfg.Token
|
||||
if token == "" {
|
||||
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
|
||||
}
|
||||
if token != "" {
|
||||
handler = AuthMiddleware()(handler)
|
||||
}
|
||||
|
||||
// 添加日志中间件
|
||||
// Add logging middleware.
|
||||
handler = loggingMiddleware(handler)
|
||||
|
||||
// 添加恢复中间件
|
||||
// Add recovery middleware.
|
||||
handler = recoveryMiddleware(handler)
|
||||
|
||||
return &Server{
|
||||
@@ -151,7 +148,8 @@ func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Start 初始化并启动 API 服务器
|
||||
// Start initializes and starts the API server. It refuses to start without a
|
||||
// token, since an open download proxy is a security risk.
|
||||
func Start(ctx context.Context) error {
|
||||
cfg := config.C().API
|
||||
|
||||
@@ -160,9 +158,13 @@ func Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if cfg.Token == "" {
|
||||
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
|
||||
return fmt.Errorf("API server is enabled but no token is set; refusing to start insecurely")
|
||||
}
|
||||
|
||||
server := NewServer(ctx)
|
||||
return server.Start(ctx)
|
||||
if err := server.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
StartCleanupLoop(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ type CreateTaskResponse struct {
|
||||
type TaskProgress struct {
|
||||
TotalBytes int64 `json:"total_bytes,omitempty"`
|
||||
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
|
||||
TotalFiles int `json:"total_files,omitempty"`
|
||||
DownloadedFiles int `json:"downloaded_files,omitempty"`
|
||||
Percent float64 `json:"percent,omitempty"`
|
||||
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -30,9 +29,14 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
|
||||
|
||||
webhookURL := info.Webhook
|
||||
|
||||
// 异步发送 webhook
|
||||
// Async send with retries.
|
||||
go func() {
|
||||
logger := log.FromContext(ctx).With("task_id", payload.TaskID)
|
||||
var logger *log.Logger
|
||||
if ctx != nil {
|
||||
logger = log.FromContext(ctx).With("task_id", payload.TaskID)
|
||||
} else {
|
||||
logger = log.Default().With("task_id", payload.TaskID)
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -72,7 +76,7 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
|
||||
}()
|
||||
}
|
||||
|
||||
// CreateWebhookPayload 创建 Webhook 负载
|
||||
// CreateWebhookPayload creates a Webhook payload.
|
||||
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
|
||||
payload := &WebhookPayload{
|
||||
TaskID: taskID,
|
||||
@@ -93,38 +97,3 @@ func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, sto
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调
|
||||
func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error {
|
||||
info, ok := GetTask(taskID)
|
||||
if !ok {
|
||||
return fmt.Errorf("task not found: %s", taskID)
|
||||
}
|
||||
|
||||
err := fn()
|
||||
|
||||
// 确定任务状态
|
||||
status := TaskStatusCompleted
|
||||
if err != nil {
|
||||
if err == context.Canceled {
|
||||
status = TaskStatusCancelled
|
||||
} else {
|
||||
status = TaskStatusFailed
|
||||
}
|
||||
}
|
||||
|
||||
// 更新任务状态
|
||||
if err != nil {
|
||||
info.SetError(err.Error())
|
||||
} else {
|
||||
info.UpdateStatus(TaskStatusCompleted)
|
||||
}
|
||||
|
||||
// 发送 webhook
|
||||
if info.Webhook != "" {
|
||||
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
|
||||
SendWebhook(ctx, payload)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/core"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
)
|
||||
|
||||
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
|
||||
type ExecutableWrapper struct {
|
||||
inner core.Executable
|
||||
}
|
||||
|
||||
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
|
||||
return &ExecutableWrapper{inner: inner}
|
||||
}
|
||||
|
||||
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
|
||||
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
|
||||
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
|
||||
|
||||
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
|
||||
taskID := w.inner.TaskID()
|
||||
|
||||
if info, ok := GetTask(taskID); ok {
|
||||
info.UpdateStatus(TaskStatusRunning)
|
||||
}
|
||||
|
||||
err := w.inner.Execute(ctx)
|
||||
|
||||
info, ok := GetTask(taskID)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
var status TaskStatus
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
status = TaskStatusCancelled
|
||||
info.UpdateStatus(TaskStatusCancelled)
|
||||
} else {
|
||||
status = TaskStatusFailed
|
||||
info.SetError(err.Error())
|
||||
}
|
||||
} else {
|
||||
status = TaskStatusCompleted
|
||||
info.UpdateStatus(TaskStatusCompleted)
|
||||
}
|
||||
|
||||
if info.Webhook != "" {
|
||||
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
|
||||
SendWebhook(ctx, payload)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/strutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/database"
|
||||
"github.com/krau/SaveAny-Bot/pkg/rule"
|
||||
)
|
||||
@@ -84,6 +85,46 @@ func handleRuleCmd(ctx *ext.Context, update *ext.Update) error {
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoCreateRuleSuccess, nil)), nil)
|
||||
case "preset":
|
||||
// /rule preset <storage> [base_path]
|
||||
if len(args) < 3 {
|
||||
ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildRuleHelpStyling(user.ApplyRule, user.Rules)), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
storageName := args[2]
|
||||
if !config.C().HasStorage(user.ChatID, storageName) {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleErrorStorageNotFound, map[string]any{
|
||||
"Storage": storageName,
|
||||
})), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
basePath := ""
|
||||
if len(args) >= 4 {
|
||||
basePath = args[3]
|
||||
}
|
||||
presets := rule.PresetCategories(basePath)
|
||||
imported := 0
|
||||
for _, p := range presets {
|
||||
rd := &database.Rule{
|
||||
Type: rule.FileNameRegex.String(),
|
||||
Data: p.Regex,
|
||||
StorageName: storageName,
|
||||
DirPath: p.Dir,
|
||||
UserID: user.ID,
|
||||
}
|
||||
if err := database.CreateRule(ctx, rd); err != nil {
|
||||
logger.Errorf("failed to create preset rule %s: %s", p.Name, err)
|
||||
continue
|
||||
}
|
||||
imported++
|
||||
}
|
||||
if imported == 0 {
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleErrorCreateRuleFailed, nil)), nil)
|
||||
return dispatcher.EndGroups
|
||||
}
|
||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoPresetImported, map[string]any{
|
||||
"Count": imported,
|
||||
})), nil)
|
||||
case "del":
|
||||
// /rule del <id>
|
||||
if len(args) < 3 {
|
||||
|
||||
@@ -24,6 +24,8 @@ func BuildRuleHelpStyling(enabled bool, rules []database.Rule) []styling.StyledT
|
||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpSwitchSuffix, nil)),
|
||||
styling.Code("add"),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpAddSuffix, nil)),
|
||||
styling.Code("preset"),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpPresetSuffix, nil)),
|
||||
styling.Code("del"),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpDelSuffix, nil)),
|
||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpExistingRulesPrefix, nil)),
|
||||
|
||||
@@ -84,8 +84,8 @@ const (
|
||||
BotMsgCommonPromptSelectDefaultDir Key = "bot.msg.common.prompt_select_default_dir"
|
||||
BotMsgCommonPromptSelectDefaultStorage Key = "bot.msg.common.prompt_select_default_storage"
|
||||
BotMsgCommonPromptSelectDir Key = "bot.msg.common.prompt_select_dir"
|
||||
BotMsgConfigButtonFilenameStrategy Key = "bot.msg.config.button_filename_strategy"
|
||||
BotMsgConfigButtonConflictStrategy Key = "bot.msg.config.button_conflict_strategy"
|
||||
BotMsgConfigButtonFilenameStrategy Key = "bot.msg.config.button_filename_strategy"
|
||||
BotMsgConfigConflictStrategyAsk Key = "bot.msg.config.conflict_strategy_ask"
|
||||
BotMsgConfigConflictStrategyOverwrite Key = "bot.msg.config.conflict_strategy_overwrite"
|
||||
BotMsgConfigConflictStrategyRename Key = "bot.msg.config.conflict_strategy_rename"
|
||||
@@ -93,8 +93,8 @@ const (
|
||||
BotMsgConfigErrorInvalidCallbackData Key = "bot.msg.config.error_invalid_callback_data"
|
||||
BotMsgConfigErrorInvalidTemplate Key = "bot.msg.config.error_invalid_template"
|
||||
BotMsgConfigFnametmplHelp Key = "bot.msg.config.fnametmpl_help"
|
||||
BotMsgConfigInfoCurrentTemplatePrefix Key = "bot.msg.config.info_current_template_prefix"
|
||||
BotMsgConfigInfoConflictStrategySet Key = "bot.msg.config.info_conflict_strategy_set"
|
||||
BotMsgConfigInfoCurrentTemplatePrefix Key = "bot.msg.config.info_current_template_prefix"
|
||||
BotMsgConfigInfoFilenameStrategySet Key = "bot.msg.config.info_filename_strategy_set"
|
||||
BotMsgConfigInfoTemplateUpdated Key = "bot.msg.config.info_template_updated"
|
||||
BotMsgConfigPromptSelectConflictStrategy Key = "bot.msg.config.prompt_select_conflict_strategy"
|
||||
@@ -200,6 +200,7 @@ const (
|
||||
BotMsgRuleErrorGetUserRulesFailed Key = "bot.msg.rule.error_get_user_rules_failed"
|
||||
BotMsgRuleErrorInvalidRuleId Key = "bot.msg.rule.error_invalid_rule_id"
|
||||
BotMsgRuleErrorInvalidRuleType Key = "bot.msg.rule.error_invalid_rule_type"
|
||||
BotMsgRuleErrorStorageNotFound Key = "bot.msg.rule.error_storage_not_found"
|
||||
BotMsgRuleErrorUpdateUserFailed Key = "bot.msg.rule.error_update_user_failed"
|
||||
BotMsgRuleHelpAddSuffix Key = "bot.msg.rule.help_add_suffix"
|
||||
BotMsgRuleHelpAvailableOps Key = "bot.msg.rule.help_available_ops"
|
||||
@@ -207,13 +208,16 @@ const (
|
||||
BotMsgRuleHelpCurrentModeEnabled Key = "bot.msg.rule.help_current_mode_enabled"
|
||||
BotMsgRuleHelpDelSuffix Key = "bot.msg.rule.help_del_suffix"
|
||||
BotMsgRuleHelpExistingRulesPrefix Key = "bot.msg.rule.help_existing_rules_prefix"
|
||||
BotMsgRuleHelpPresetSuffix Key = "bot.msg.rule.help_preset_suffix"
|
||||
BotMsgRuleHelpSwitchSuffix Key = "bot.msg.rule.help_switch_suffix"
|
||||
BotMsgRuleHelpUsage Key = "bot.msg.rule.help_usage"
|
||||
BotMsgRuleInfoCreateRuleSuccess Key = "bot.msg.rule.info_create_rule_success"
|
||||
BotMsgRuleInfoDeleteRuleSuccess Key = "bot.msg.rule.info_delete_rule_success"
|
||||
BotMsgRuleInfoPresetImported Key = "bot.msg.rule.info_preset_imported"
|
||||
BotMsgRuleInfoRuleModeDisabled Key = "bot.msg.rule.info_rule_mode_disabled"
|
||||
BotMsgRuleInfoRuleModeEnabled Key = "bot.msg.rule.info_rule_mode_enabled"
|
||||
BotMsgRulePromptProvideRuleId Key = "bot.msg.rule.prompt_provide_rule_id"
|
||||
BotMsgRulePromptProvideStorageName Key = "bot.msg.rule.prompt_provide_storage_name"
|
||||
BotMsgSaveErrorInvalidIdOrUsername Key = "bot.msg.save.error_invalid_id_or_username"
|
||||
BotMsgSaveHelpText Key = "bot.msg.save_help_text"
|
||||
BotMsgStorageInfoFilenamePrefix Key = "bot.msg.storage.info_filename_prefix"
|
||||
|
||||
@@ -196,7 +196,11 @@ bot:
|
||||
help_switch_suffix: " - Toggle rule mode\n"
|
||||
help_add_suffix: " <type> <data> <storage_name> <path> - Add rule\n"
|
||||
help_del_suffix: " <rule_id> - Delete rule\n"
|
||||
help_preset_suffix: " <storage_name> [base_path] - Import built-in filetype rules (video/image/audio/document/archive)\n"
|
||||
help_existing_rules_prefix: "\nCurrent rules:\n"
|
||||
prompt_provide_storage_name: "Please provide a storage name"
|
||||
error_storage_not_found: "Storage not found: {{.Storage}}"
|
||||
info_preset_imported: "Imported {{.Count}} built-in classification rules into storage {{.Storage}}"
|
||||
dir:
|
||||
error_get_user_dirs_failed: "Failed to get user directories"
|
||||
error_get_user_failed: "Failed to get user"
|
||||
|
||||
@@ -197,7 +197,11 @@ bot:
|
||||
help_switch_suffix: " - 开关规则模式\n"
|
||||
help_add_suffix: " <类型> <数据> <存储名> <路径> - 添加规则\n"
|
||||
help_del_suffix: " <规则ID> - 删除规则\n"
|
||||
help_preset_suffix: " <存储名> [基础路径] - 导入内置文件类型分类规则(视频/图片/音频/文档/压缩包)\n"
|
||||
help_existing_rules_prefix: "\n当前已添加的规则:\n"
|
||||
prompt_provide_storage_name: "请提供存储名称"
|
||||
error_storage_not_found: "未找到存储: {{.Storage}}"
|
||||
info_preset_imported: "已导入 {{.Count}} 条内置分类规则到存储 {{.Storage}}"
|
||||
dir:
|
||||
error_get_user_dirs_failed: "获取用户文件夹失败"
|
||||
error_get_user_failed: "获取用户失败"
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||
"github.com/krau/SaveAny-Bot/pkg/queue"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
var queueInstance *queue.TaskQueue[Executable]
|
||||
@@ -30,11 +31,14 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
|
||||
break // queue closed and empty
|
||||
}
|
||||
exe := qtask.Data
|
||||
taskCtx := qtask.Context()
|
||||
logger.Infof("Processing task: %s", exe.TaskID())
|
||||
if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil {
|
||||
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseStart})
|
||||
if err := ExecCommandString(taskCtx, execHooks.TaskBeforeStart); err != nil {
|
||||
logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err)
|
||||
}
|
||||
if err := exe.Execute(qtask.Context()); err != nil {
|
||||
err = exe.Execute(taskCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.Infof("Task %s was canceled", exe.TaskID())
|
||||
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
|
||||
@@ -52,6 +56,7 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
|
||||
logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err)
|
||||
}
|
||||
}
|
||||
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseDone, Err: err})
|
||||
qe.Done(qtask.ID)
|
||||
<-semaphore
|
||||
}
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
// Execute implements core.Executable.
|
||||
@@ -77,6 +79,12 @@ func (t *Task) waitForDownload(ctx context.Context) error {
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t, status)
|
||||
}
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: parseInt64(status.TotalLength),
|
||||
DownloadedBytes: parseInt64(status.CompletedLength),
|
||||
})
|
||||
|
||||
// Check if download is complete
|
||||
if status.IsDownloadComplete() {
|
||||
@@ -248,3 +256,16 @@ func (t *Task) cancelAria2Download() {
|
||||
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseInt64 parses an aria2 status string (decimal bytes) into int64,
|
||||
// returning 0 on failure so it can be used directly in progress events.
|
||||
func parseInt64(s string) int64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
n, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -62,8 +63,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
return elem.Storage.Save(uploadCtx, pr, elem.Path)
|
||||
})
|
||||
wr := ioutil.NewProgressWriter(pw, func(n int) {
|
||||
t.downloaded.Add(int64(n))
|
||||
downloaded := t.downloaded.Add(int64(n))
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalSize,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
errg.Go(func() error {
|
||||
defer pw.Close()
|
||||
@@ -92,8 +99,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
}
|
||||
}()
|
||||
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
|
||||
t.downloaded.Add(int64(n))
|
||||
downloaded := t.downloaded.Add(int64(n))
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalSize,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
_, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -143,10 +144,16 @@ func (t *Task) processLink(ctx context.Context, file *File) error {
|
||||
}
|
||||
}()
|
||||
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
||||
t.downloadedBytes.Add(int64(n))
|
||||
downloaded := t.downloadedBytes.Add(int64(n))
|
||||
if t.Progress != nil {
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
}
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalBytes,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
|
||||
copyResultCh := make(chan error, 1)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/parser"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -107,10 +108,16 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er
|
||||
}
|
||||
}()
|
||||
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
||||
t.downloadedBytes.Add(int64(n))
|
||||
downloaded := t.downloadedBytes.Add(int64(n))
|
||||
if t.progress != nil {
|
||||
t.progress.OnProgress(ctx, t)
|
||||
}
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalBytes,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
})
|
||||
|
||||
copyResultCh := make(chan error, 1)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/duke-git/lancet/v2/retry"
|
||||
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -27,8 +28,14 @@ func (t *Task) Execute(ctx context.Context) error {
|
||||
logger.Errorf("Error processing picture %s: %v", pic, err)
|
||||
return fmt.Errorf("failed to process picture %s: %w", pic, err)
|
||||
}
|
||||
t.downloaded.Add(1)
|
||||
downloaded := t.downloaded.Add(1)
|
||||
t.progress.OnProgress(gctx, t)
|
||||
taskevent.Emit(gctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalFiles: t.totalpics,
|
||||
DownloadedFiles: int(downloaded),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
)
|
||||
|
||||
type ProgressWriterAt struct {
|
||||
@@ -20,9 +22,16 @@ func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
downloaded := w.downloaded.Add(int64(at))
|
||||
if w.progress != nil {
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
|
||||
}
|
||||
taskevent.Emit(w.ctx, taskevent.Event{
|
||||
TaskID: w.info.TaskID(),
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: w.total,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
return at, nil
|
||||
}
|
||||
|
||||
@@ -56,9 +65,16 @@ func (w *ProgressWriter) Write(p []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
downloaded := w.downloaded.Add(int64(at))
|
||||
if w.progress != nil {
|
||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
||||
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
|
||||
}
|
||||
taskevent.Emit(w.ctx, taskevent.Event{
|
||||
TaskID: w.info.TaskID(),
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: w.total,
|
||||
DownloadedBytes: downloaded,
|
||||
})
|
||||
return at, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||
"github.com/krau/SaveAny-Bot/storage"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
@@ -116,6 +117,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
||||
|
||||
t.uploaded.Add(size)
|
||||
t.Progress.OnProgress(ctx, t)
|
||||
taskevent.Emit(ctx, taskevent.Event{
|
||||
TaskID: t.ID,
|
||||
Phase: taskevent.PhaseProgress,
|
||||
TotalBytes: t.totalSize,
|
||||
DownloadedBytes: t.uploaded.Load(),
|
||||
})
|
||||
|
||||
logger.Info("File uploaded successfully")
|
||||
return nil
|
||||
|
||||
55
pkg/rule/preset.go
Normal file
55
pkg/rule/preset.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package rule
|
||||
|
||||
import "path"
|
||||
|
||||
// PresetCategory describes a built-in filetype classification: files whose name
|
||||
// matches Regex are routed into the Dir subdirectory (joined with a user base path).
|
||||
type PresetCategory struct {
|
||||
// Name is a stable identifier for the category (used in logs/messages).
|
||||
Name string
|
||||
// Regex is a FILENAME-REGEX rule data string matching this category's extensions.
|
||||
Regex string
|
||||
// Dir is the default subdirectory name for this category.
|
||||
Dir string
|
||||
}
|
||||
|
||||
// presetCategories holds the default filetype classification rules.
|
||||
// Regexes are case-insensitive and match common file extensions.
|
||||
var presetCategories = []PresetCategory{
|
||||
{
|
||||
Name: "video",
|
||||
Regex: `(?i)\.(mp4|mkv|ts|avi|flv|mov|webm|wmv|rmvb|m2ts)$`,
|
||||
Dir: "视频",
|
||||
},
|
||||
{
|
||||
Name: "image",
|
||||
Regex: `(?i)\.(jpg|jpeg|png|gif|webp|bmp)$`,
|
||||
Dir: "图片",
|
||||
},
|
||||
{
|
||||
Name: "audio",
|
||||
Regex: `(?i)\.(mp3|flac|wav|aac|m4a|ogg)$`,
|
||||
Dir: "音频",
|
||||
},
|
||||
{
|
||||
Name: "document",
|
||||
Regex: `(?i)\.(pdf|doc|docx|xls|xlsx|ppt|pptx|txt|md|csv|epub|mobi|azw3|chm)$`,
|
||||
Dir: "文档",
|
||||
},
|
||||
{
|
||||
Name: "archive",
|
||||
Regex: `(?i)\.(zip|rar|7z|tar|gz|bz2|xz|r\d{1,3}|z\d{1,3}|\d{3}|part\d+\.rar|7z\.\d{3})$`,
|
||||
Dir: "压缩包",
|
||||
},
|
||||
}
|
||||
|
||||
// PresetCategories returns the built-in filetype classification rules with each
|
||||
// category's directory joined under basePath. basePath may be empty.
|
||||
func PresetCategories(basePath string) []PresetCategory {
|
||||
out := make([]PresetCategory, len(presetCategories))
|
||||
for i, c := range presetCategories {
|
||||
c.Dir = path.Join(basePath, c.Dir)
|
||||
out[i] = c
|
||||
}
|
||||
return out
|
||||
}
|
||||
55
pkg/rule/preset_test.go
Normal file
55
pkg/rule/preset_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPresetCategoriesCompile(t *testing.T) {
|
||||
for _, c := range PresetCategories("") {
|
||||
if _, err := regexp.Compile(c.Regex); err != nil {
|
||||
t.Errorf("preset %q has invalid regex %q: %v", c.Name, c.Regex, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresetCategoriesMatch(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"video": "movie.MP4",
|
||||
"image": "photo.jpg",
|
||||
"audio": "song.flac",
|
||||
"document": "report.pdf",
|
||||
"archive": "backup.zip",
|
||||
}
|
||||
|
||||
byName := make(map[string]*regexp.Regexp)
|
||||
for _, c := range PresetCategories("") {
|
||||
byName[c.Name] = regexp.MustCompile(c.Regex)
|
||||
}
|
||||
|
||||
for name, filename := range cases {
|
||||
re, ok := byName[name]
|
||||
if !ok {
|
||||
t.Errorf("missing preset category %q", name)
|
||||
continue
|
||||
}
|
||||
if !re.MatchString(filename) {
|
||||
t.Errorf("preset %q did not match %q", name, filename)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresetCategoriesBasePath(t *testing.T) {
|
||||
presets := PresetCategories("/media")
|
||||
for _, c := range presets {
|
||||
if c.Dir == "" || c.Dir[0] != '/' {
|
||||
t.Errorf("preset %q dir %q not joined under base path", c.Name, c.Dir)
|
||||
}
|
||||
}
|
||||
// Empty base path must not prefix a separator.
|
||||
for _, c := range PresetCategories("") {
|
||||
if c.Dir == "" || c.Dir[0] == '/' {
|
||||
t.Errorf("preset %q dir %q should be relative when base path empty", c.Name, c.Dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
88
pkg/taskevent/taskevent.go
Normal file
88
pkg/taskevent/taskevent.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Package taskevent provides a decoupled, context-scoped event bus for task
|
||||
// lifecycle progress. Producers (task implementations) emit events via Emit;
|
||||
// consumers (e.g. the API progress store, the Telegram message editor) register
|
||||
// as Sinks and are injected through context. This keeps the task layer free of
|
||||
// any concrete progress-display dependency, so new task types gain progress
|
||||
// reporting for free and new observers can be added without touching tasks.
|
||||
package taskevent
|
||||
|
||||
import "context"
|
||||
|
||||
// Phase marks a stage in a task's lifecycle.
|
||||
type Phase int
|
||||
|
||||
const (
|
||||
PhaseStart Phase = iota
|
||||
PhaseProgress
|
||||
PhaseDone
|
||||
)
|
||||
|
||||
func (p Phase) String() string {
|
||||
switch p {
|
||||
case PhaseStart:
|
||||
return "start"
|
||||
case PhaseProgress:
|
||||
return "progress"
|
||||
case PhaseDone:
|
||||
return "done"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Event describes a single progress observation for a task. Byte fields are
|
||||
// populated by byte-stream tasks; file-count fields by count-based tasks. A
|
||||
// task may fill whichever subset it has; observers ignore zero values.
|
||||
type Event struct {
|
||||
TaskID string
|
||||
Phase Phase
|
||||
TotalBytes int64
|
||||
DownloadedBytes int64
|
||||
TotalFiles int
|
||||
DownloadedFiles int
|
||||
Err error
|
||||
}
|
||||
|
||||
// Sink receives task events. Implementations must be safe for concurrent use.
|
||||
type Sink interface {
|
||||
Emit(Event)
|
||||
}
|
||||
|
||||
// SinkFunc is a function adapter for Sink.
|
||||
type SinkFunc func(Event)
|
||||
|
||||
func (f SinkFunc) Emit(e Event) { f(e) }
|
||||
|
||||
type sinkKey struct{}
|
||||
|
||||
// WithSink returns a ctx carrying the given sinks. Multiple sinks can be passed
|
||||
// and all will receive every emitted event. Sinks already present in ctx are
|
||||
// preserved.
|
||||
func WithSink(ctx context.Context, sinks ...Sink) context.Context {
|
||||
if len(sinks) == 0 {
|
||||
return ctx
|
||||
}
|
||||
var existing []Sink
|
||||
if v, ok := ctx.Value(sinkKey{}).([]Sink); ok {
|
||||
existing = v
|
||||
}
|
||||
merged := make([]Sink, 0, len(existing)+len(sinks))
|
||||
merged = append(merged, existing...)
|
||||
merged = append(merged, sinks...)
|
||||
return context.WithValue(ctx, sinkKey{}, merged)
|
||||
}
|
||||
|
||||
// Emit broadcasts an event to all sinks carried by ctx. It is a no-op when no
|
||||
// sink is attached, so producers can call it unconditionally.
|
||||
func Emit(ctx context.Context, e Event) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
sinks, ok := ctx.Value(sinkKey{}).([]Sink)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, s := range sinks {
|
||||
s.Emit(e)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user