mirror of
https://github.com/krau/SaveAny-Bot.git
synced 2026-06-27 10:11:34 +08:00
Compare commits
2 Commits
refactor/p
...
v0.58.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bc460c609 | ||
|
|
f02860ff3f |
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/parser"
|
"github.com/krau/SaveAny-Bot/pkg/parser"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
||||||
"github.com/krau/SaveAny-Bot/storage"
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
@@ -68,9 +69,14 @@ func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, e
|
|||||||
|
|
||||||
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
|
func (f *TaskFactory) registerAndEnqueueTask(task core.Executable, taskType tasktype.TaskType, storageName, path, webhook string) error {
|
||||||
taskID := task.TaskID()
|
taskID := task.TaskID()
|
||||||
RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
|
info := RegisterTask(taskID, string(taskType), storageName, path, task.Title(), webhook)
|
||||||
|
|
||||||
err := core.AddTask(f.ctx, NewExecutableWrapper(task))
|
// Inject the progress sink into the context so the task's Emit calls update
|
||||||
|
// the API store (and fire the webhook on terminal states) without the task
|
||||||
|
// knowing about the API.
|
||||||
|
taskCtx := taskevent.WithSink(f.ctx, info)
|
||||||
|
|
||||||
|
err := core.AddTask(taskCtx, task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
DeleteTask(taskID)
|
DeleteTask(taskID)
|
||||||
return fmt.Errorf("failed to add task: %w", err)
|
return fmt.Errorf("failed to add task: %w", err)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/core"
|
"github.com/krau/SaveAny-Bot/core"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||||
@@ -117,7 +118,7 @@ func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 取消任务
|
// Cancel the task; the terminal status is set via the task event stream.
|
||||||
if err := core.CancelTask(r.Context(), taskID); err != nil {
|
if err := core.CancelTask(r.Context(), taskID); err != nil {
|
||||||
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
|
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
|
||||||
return
|
return
|
||||||
@@ -184,27 +185,45 @@ func extractTaskIDFromPath(path string) string {
|
|||||||
return parts[3]
|
return parts[3]
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertTaskProgressToResponse 将任务进度转换为响应格式
|
// convertTaskProgressToResponse renders a task's current state, computing
|
||||||
|
// percent and speed from the snapshot taken under the task's mutex.
|
||||||
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
|
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
|
||||||
|
status, total, downloaded, totalFiles, downloadedFiles, startedAt, errMsg, updatedAt := task.snapshot()
|
||||||
|
|
||||||
resp := TaskInfoResponse{
|
resp := TaskInfoResponse{
|
||||||
TaskID: task.TaskID,
|
TaskID: task.TaskID,
|
||||||
Type: tasktype.TaskType(task.Type),
|
Type: tasktype.TaskType(task.Type),
|
||||||
Status: task.Status,
|
Status: status,
|
||||||
Title: task.Title,
|
Title: task.Title,
|
||||||
Storage: task.Storage,
|
Storage: task.Storage,
|
||||||
Path: task.Path,
|
Path: task.Path,
|
||||||
Error: task.Error,
|
Error: errMsg,
|
||||||
CreatedAt: task.CreatedAt,
|
CreatedAt: task.CreatedAt,
|
||||||
UpdatedAt: task.UpdatedAt,
|
UpdatedAt: updatedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算进度
|
var percent float64
|
||||||
if task.TotalBytes > 0 {
|
var speedMBPS float64
|
||||||
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
|
if total > 0 {
|
||||||
|
percent = float64(downloaded) * 100 / float64(total)
|
||||||
|
} else if totalFiles > 0 {
|
||||||
|
percent = float64(downloadedFiles) * 100 / float64(totalFiles)
|
||||||
|
}
|
||||||
|
if !startedAt.IsZero() {
|
||||||
|
elapsed := time.Since(startedAt).Seconds()
|
||||||
|
if elapsed > 0 {
|
||||||
|
speedMBPS = float64(downloaded) / elapsed / (1024 * 1024)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if total > 0 || totalFiles > 0 {
|
||||||
resp.Progress = &TaskProgress{
|
resp.Progress = &TaskProgress{
|
||||||
TotalBytes: task.TotalBytes,
|
TotalBytes: total,
|
||||||
DownloadedBytes: task.DownloadedBytes,
|
DownloadedBytes: downloaded,
|
||||||
|
TotalFiles: totalFiles,
|
||||||
|
DownloadedFiles: downloadedFiles,
|
||||||
Percent: percent,
|
Percent: percent,
|
||||||
|
SpeedMBPS: speedMBPS,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// setupTestServer creates a test server with handlers
|
// setupTestServer creates a test server with handlers
|
||||||
@@ -403,32 +404,38 @@ func TestConcurrentProgressStore(t *testing.T) {
|
|||||||
|
|
||||||
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates
|
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates
|
||||||
func TestProgressTrackerConcurrentUpdates(t *testing.T) {
|
func TestProgressTrackerConcurrentUpdates(t *testing.T) {
|
||||||
tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
|
info := RegisterTask("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
|
||||||
tracker.OnStart(10000, 10)
|
info.Emit(taskevent.Event{TaskID: "concurrent-progress", Phase: taskevent.PhaseStart, TotalBytes: 10000})
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
numGoroutines := 50
|
numGoroutines := 50
|
||||||
updatesPerGoroutine := 100
|
updatesPerGoroutine := 100
|
||||||
|
|
||||||
// Concurrent progress updates
|
// Concurrent progress updates via the Sink interface
|
||||||
for i := range numGoroutines {
|
for i := range numGoroutines {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for j := range updatesPerGoroutine {
|
for j := range updatesPerGoroutine {
|
||||||
tracker.OnProgress(int64(id*updatesPerGoroutine+j), j)
|
info.Emit(taskevent.Event{
|
||||||
|
TaskID: "concurrent-progress",
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
DownloadedBytes: int64(id*updatesPerGoroutine + j),
|
||||||
|
TotalBytes: 10000,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
info := tracker.GetInfo()
|
status, _, downloaded, _, _, _, _, _ := info.snapshot()
|
||||||
if info.Status != TaskStatusRunning {
|
if status != TaskStatusRunning {
|
||||||
t.Errorf("expected status Running after concurrent updates, got %s", info.Status)
|
t.Errorf("expected status Running after concurrent updates, got %s", status)
|
||||||
|
}
|
||||||
|
if downloaded <= 0 {
|
||||||
|
t.Errorf("expected downloaded bytes > 0 after concurrent updates, got %d", downloaded)
|
||||||
}
|
}
|
||||||
// Note: Due to race conditions in the simple implementation,
|
|
||||||
// we can't reliably check exact values without proper synchronization
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestTaskFactoryValidation tests TaskFactory parameter validation
|
// TestTaskFactoryValidation tests TaskFactory parameter validation
|
||||||
@@ -526,8 +533,7 @@ func TestEdgeCases(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Progress tracker with empty webhook",
|
name: "Progress tracker with empty webhook",
|
||||||
fn: func(t *testing.T) {
|
fn: func(t *testing.T) {
|
||||||
tracker := NewProgressTracker("test", "type", "storage", "path", "title", "")
|
info := RegisterTask("test-empty-webhook", "type", "storage", "path", "title", "")
|
||||||
info := tracker.GetInfo()
|
|
||||||
if info.Webhook != "" {
|
if info.Webhook != "" {
|
||||||
t.Error("expected empty webhook")
|
t.Error("expected empty webhook")
|
||||||
}
|
}
|
||||||
|
|||||||
182
api/progress.go
182
api/progress.go
@@ -2,12 +2,16 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TaskProgressInfo 存储任务的进度信息
|
// TaskProgressInfo stores the progress of an API-submitted task. All fields are
|
||||||
|
// guarded by mu. It implements taskevent.Sink so the task layer can update it
|
||||||
|
// without knowing about the API.
|
||||||
type TaskProgressInfo struct {
|
type TaskProgressInfo struct {
|
||||||
|
mu sync.Mutex
|
||||||
TaskID string
|
TaskID string
|
||||||
Type string
|
Type string
|
||||||
Status TaskStatus
|
Status TaskStatus
|
||||||
@@ -21,20 +25,25 @@ type TaskProgressInfo struct {
|
|||||||
Error string
|
Error string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
StartedAt time.Time
|
||||||
Webhook string
|
Webhook string
|
||||||
|
webhookNotified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// progressStore 存储所有 API 任务的进度信息
|
// progressStore holds all API tasks. Entries are removed a fixed duration after
|
||||||
|
// they reach a terminal state to bound memory usage.
|
||||||
type progressStore struct {
|
type progressStore struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
tasks map[string]*TaskProgressInfo
|
tasks map[string]*TaskProgressInfo
|
||||||
|
retention time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var store = &progressStore{
|
var store = &progressStore{
|
||||||
tasks: make(map[string]*TaskProgressInfo),
|
tasks: make(map[string]*TaskProgressInfo),
|
||||||
|
retention: 24 * time.Hour,
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterTask 注册一个新的 API 任务
|
// RegisterTask registers a new API task and returns its progress info.
|
||||||
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
|
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
|
||||||
info := &TaskProgressInfo{
|
info := &TaskProgressInfo{
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
@@ -55,7 +64,7 @@ func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskP
|
|||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTask 获取任务进度信息
|
// GetTask returns the progress info for a task.
|
||||||
func GetTask(taskID string) (*TaskProgressInfo, bool) {
|
func GetTask(taskID string) (*TaskProgressInfo, bool) {
|
||||||
store.mu.RLock()
|
store.mu.RLock()
|
||||||
defer store.mu.RUnlock()
|
defer store.mu.RUnlock()
|
||||||
@@ -63,7 +72,7 @@ func GetTask(taskID string) (*TaskProgressInfo, bool) {
|
|||||||
return info, ok
|
return info, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllTasks 获取所有任务
|
// GetAllTasks returns all tracked tasks.
|
||||||
func GetAllTasks() []*TaskProgressInfo {
|
func GetAllTasks() []*TaskProgressInfo {
|
||||||
store.mu.RLock()
|
store.mu.RLock()
|
||||||
defer store.mu.RUnlock()
|
defer store.mu.RUnlock()
|
||||||
@@ -75,76 +84,133 @@ func GetAllTasks() []*TaskProgressInfo {
|
|||||||
return tasks
|
return tasks
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteTask 删除任务记录
|
// DeleteTask removes a task record.
|
||||||
func DeleteTask(taskID string) {
|
func DeleteTask(taskID string) {
|
||||||
store.mu.Lock()
|
store.mu.Lock()
|
||||||
defer store.mu.Unlock()
|
defer store.mu.Unlock()
|
||||||
delete(store.tasks, taskID)
|
delete(store.tasks, taskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateStatus 更新任务状态
|
// CleanupExpired removes tasks that reached a terminal state more than the
|
||||||
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
|
// store's retention duration ago. It is safe to call periodically.
|
||||||
t.Status = status
|
func CleanupExpired() {
|
||||||
t.UpdatedAt = time.Now()
|
now := time.Now()
|
||||||
|
store.mu.Lock()
|
||||||
|
defer store.mu.Unlock()
|
||||||
|
for id, info := range store.tasks {
|
||||||
|
info.mu.Lock()
|
||||||
|
terminal := info.Status == TaskStatusCompleted || info.Status == TaskStatusFailed || info.Status == TaskStatusCancelled
|
||||||
|
stale := terminal && now.Sub(info.UpdatedAt) > store.retention
|
||||||
|
info.mu.Unlock()
|
||||||
|
if stale {
|
||||||
|
delete(store.tasks, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetError 设置错误信息
|
// StartCleanupLoop runs CleanupExpired on a fixed interval until ctx is done.
|
||||||
|
// It should be started once during API server initialization.
|
||||||
|
func StartCleanupLoop(ctx interface{ Done() <-chan struct{} }) {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(10 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
CleanupExpired()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus sets the task status.
|
||||||
|
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
|
||||||
|
t.mu.Lock()
|
||||||
|
t.Status = status
|
||||||
|
t.UpdatedAt = time.Now()
|
||||||
|
if status == TaskStatusRunning && t.StartedAt.IsZero() {
|
||||||
|
t.StartedAt = t.UpdatedAt
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetError marks the task failed with an error message.
|
||||||
func (t *TaskProgressInfo) SetError(err string) {
|
func (t *TaskProgressInfo) SetError(err string) {
|
||||||
|
t.mu.Lock()
|
||||||
t.Error = err
|
t.Error = err
|
||||||
t.Status = TaskStatusFailed
|
t.Status = TaskStatusFailed
|
||||||
t.UpdatedAt = time.Now()
|
t.UpdatedAt = time.Now()
|
||||||
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProgressTracker 用于 API 任务的进度追踪
|
// snapshot returns a point-in-time copy of the fields needed to render a
|
||||||
type ProgressTracker struct {
|
// response, so callers never touch the mutex directly.
|
||||||
info *TaskProgressInfo
|
func (t *TaskProgressInfo) snapshot() (status TaskStatus, total, downloaded int64, totalFiles, downloadedFiles int, startedAt time.Time, err string, updatedAt time.Time) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
return t.Status, t.TotalBytes, t.DownloadedBytes, t.TotalFiles, t.DownloadedFiles, t.StartedAt, t.Error, t.UpdatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProgressTracker 创建新的进度追踪器
|
// Emit implements taskevent.Sink. It translates task lifecycle events into
|
||||||
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
|
// status/progress updates and fires the webhook on terminal transitions.
|
||||||
info := RegisterTask(taskID, taskType, storage, path, title, webhook)
|
func (t *TaskProgressInfo) Emit(e taskevent.Event) {
|
||||||
return &ProgressTracker{info: info}
|
t.mu.Lock()
|
||||||
}
|
switch e.Phase {
|
||||||
|
case taskevent.PhaseStart:
|
||||||
// OnStart 任务开始
|
t.Status = TaskStatusRunning
|
||||||
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {
|
if t.StartedAt.IsZero() {
|
||||||
p.info.Status = TaskStatusRunning
|
t.StartedAt = time.Now()
|
||||||
p.info.TotalBytes = totalBytes
|
}
|
||||||
p.info.TotalFiles = totalFiles
|
if e.TotalBytes > 0 {
|
||||||
p.info.UpdatedAt = time.Now()
|
t.TotalBytes = e.TotalBytes
|
||||||
}
|
}
|
||||||
|
case taskevent.PhaseProgress:
|
||||||
// OnProgress 进度更新
|
t.Status = TaskStatusRunning
|
||||||
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {
|
if e.TotalBytes > 0 {
|
||||||
atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes)
|
t.TotalBytes = e.TotalBytes
|
||||||
p.info.DownloadedFiles = downloadedFiles
|
}
|
||||||
p.info.UpdatedAt = time.Now()
|
t.DownloadedBytes = e.DownloadedBytes
|
||||||
}
|
if e.TotalFiles > 0 {
|
||||||
|
t.TotalFiles = e.TotalFiles
|
||||||
// OnDone 任务完成
|
}
|
||||||
func (p *ProgressTracker) OnDone(err error) {
|
if e.DownloadedFiles > 0 {
|
||||||
if err != nil {
|
t.DownloadedFiles = e.DownloadedFiles
|
||||||
p.info.Status = TaskStatusFailed
|
}
|
||||||
p.info.Error = err.Error()
|
case taskevent.PhaseDone:
|
||||||
} else {
|
if e.Err != nil {
|
||||||
p.info.Status = TaskStatusCompleted
|
t.Status = TaskStatusFailed
|
||||||
|
t.Error = e.Err.Error()
|
||||||
|
} else {
|
||||||
|
t.Status = TaskStatusCompleted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.UpdatedAt = time.Now()
|
||||||
|
notify := t.Webhook != "" && !t.webhookNotified && (t.Status == TaskStatusCompleted || t.Status == TaskStatusFailed)
|
||||||
|
if notify {
|
||||||
|
t.webhookNotified = true
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
if notify {
|
||||||
|
payload := CreateWebhookPayload(t.TaskID, t.Type, t.Status, t.Storage, t.Path, e.Err)
|
||||||
|
SendWebhook(nil, payload)
|
||||||
}
|
}
|
||||||
p.info.UpdatedAt = time.Now()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetInfo 获取任务信息
|
// ProgressTracker is retained for compatibility but is no longer the primary
|
||||||
func (p *ProgressTracker) GetInfo() *TaskProgressInfo {
|
// progress path; taskevent drives updates now. These methods are safe no-ops
|
||||||
return p.info
|
// when called on a nil receiver.
|
||||||
|
type ProgressTracker struct{}
|
||||||
|
|
||||||
|
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
|
||||||
|
return &ProgressTracker{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProgressBytes 更新下载字节数
|
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {}
|
||||||
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {
|
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {}
|
||||||
atomic.StoreInt64(&p.info.DownloadedBytes, bytes)
|
func (p *ProgressTracker) OnDone(err error) {}
|
||||||
p.info.UpdatedAt = time.Now()
|
func (p *ProgressTracker) GetInfo() *TaskProgressInfo { return nil }
|
||||||
}
|
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {}
|
||||||
|
func (p *ProgressTracker) UpdateProgressFiles(files int) {}
|
||||||
// UpdateProgressFiles 更新下载文件数
|
|
||||||
func (p *ProgressTracker) UpdateProgressFiles(files int) {
|
|
||||||
p.info.DownloadedFiles = files
|
|
||||||
p.info.UpdatedAt = time.Now()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -57,22 +57,19 @@ func NewServer(ctx context.Context) *Server {
|
|||||||
// 404 处理
|
// 404 处理
|
||||||
mux.HandleFunc("/", NotFoundHandler)
|
mux.HandleFunc("/", NotFoundHandler)
|
||||||
|
|
||||||
// 应用中间件
|
// Apply middleware chain.
|
||||||
var handler http.Handler = mux
|
var handler http.Handler = mux
|
||||||
|
|
||||||
// 添加认证中间件
|
// Apply auth middleware when a token is configured.
|
||||||
token := cfg.Token
|
token := cfg.Token
|
||||||
if token == "" {
|
|
||||||
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
|
|
||||||
}
|
|
||||||
if token != "" {
|
if token != "" {
|
||||||
handler = AuthMiddleware()(handler)
|
handler = AuthMiddleware()(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加日志中间件
|
// Add logging middleware.
|
||||||
handler = loggingMiddleware(handler)
|
handler = loggingMiddleware(handler)
|
||||||
|
|
||||||
// 添加恢复中间件
|
// Add recovery middleware.
|
||||||
handler = recoveryMiddleware(handler)
|
handler = recoveryMiddleware(handler)
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
@@ -151,7 +148,8 @@ func (rw *responseWriter) WriteHeader(code int) {
|
|||||||
rw.ResponseWriter.WriteHeader(code)
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start 初始化并启动 API 服务器
|
// Start initializes and starts the API server. It refuses to start without a
|
||||||
|
// token, since an open download proxy is a security risk.
|
||||||
func Start(ctx context.Context) error {
|
func Start(ctx context.Context) error {
|
||||||
cfg := config.C().API
|
cfg := config.C().API
|
||||||
|
|
||||||
@@ -160,9 +158,13 @@ func Start(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Token == "" {
|
if cfg.Token == "" {
|
||||||
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
|
return fmt.Errorf("API server is enabled but no token is set; refusing to start insecurely")
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(ctx)
|
server := NewServer(ctx)
|
||||||
return server.Start(ctx)
|
if err := server.Start(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
StartCleanupLoop(ctx)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ type CreateTaskResponse struct {
|
|||||||
type TaskProgress struct {
|
type TaskProgress struct {
|
||||||
TotalBytes int64 `json:"total_bytes,omitempty"`
|
TotalBytes int64 `json:"total_bytes,omitempty"`
|
||||||
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
|
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
|
||||||
|
TotalFiles int `json:"total_files,omitempty"`
|
||||||
|
DownloadedFiles int `json:"downloaded_files,omitempty"`
|
||||||
Percent float64 `json:"percent,omitempty"`
|
Percent float64 `json:"percent,omitempty"`
|
||||||
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
|
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -30,9 +29,14 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
|
|||||||
|
|
||||||
webhookURL := info.Webhook
|
webhookURL := info.Webhook
|
||||||
|
|
||||||
// 异步发送 webhook
|
// Async send with retries.
|
||||||
go func() {
|
go func() {
|
||||||
logger := log.FromContext(ctx).With("task_id", payload.TaskID)
|
var logger *log.Logger
|
||||||
|
if ctx != nil {
|
||||||
|
logger = log.FromContext(ctx).With("task_id", payload.TaskID)
|
||||||
|
} else {
|
||||||
|
logger = log.Default().With("task_id", payload.TaskID)
|
||||||
|
}
|
||||||
|
|
||||||
payloadBytes, err := json.Marshal(payload)
|
payloadBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -72,7 +76,7 @@ func SendWebhook(ctx context.Context, payload *WebhookPayload) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateWebhookPayload 创建 Webhook 负载
|
// CreateWebhookPayload creates a Webhook payload.
|
||||||
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
|
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
|
||||||
payload := &WebhookPayload{
|
payload := &WebhookPayload{
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
@@ -93,38 +97,3 @@ func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, sto
|
|||||||
|
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调
|
|
||||||
func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error {
|
|
||||||
info, ok := GetTask(taskID)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("task not found: %s", taskID)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := fn()
|
|
||||||
|
|
||||||
// 确定任务状态
|
|
||||||
status := TaskStatusCompleted
|
|
||||||
if err != nil {
|
|
||||||
if err == context.Canceled {
|
|
||||||
status = TaskStatusCancelled
|
|
||||||
} else {
|
|
||||||
status = TaskStatusFailed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务状态
|
|
||||||
if err != nil {
|
|
||||||
info.SetError(err.Error())
|
|
||||||
} else {
|
|
||||||
info.UpdateStatus(TaskStatusCompleted)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送 webhook
|
|
||||||
if info.Webhook != "" {
|
|
||||||
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
|
|
||||||
SendWebhook(ctx, payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/krau/SaveAny-Bot/core"
|
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ExecutableWrapper wraps core.Executable to track task status in the API store and send webhooks.
|
|
||||||
type ExecutableWrapper struct {
|
|
||||||
inner core.Executable
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewExecutableWrapper(inner core.Executable) *ExecutableWrapper {
|
|
||||||
return &ExecutableWrapper{inner: inner}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ExecutableWrapper) Type() tasktype.TaskType { return w.inner.Type() }
|
|
||||||
func (w *ExecutableWrapper) Title() string { return w.inner.Title() }
|
|
||||||
func (w *ExecutableWrapper) TaskID() string { return w.inner.TaskID() }
|
|
||||||
|
|
||||||
func (w *ExecutableWrapper) Execute(ctx context.Context) error {
|
|
||||||
taskID := w.inner.TaskID()
|
|
||||||
|
|
||||||
if info, ok := GetTask(taskID); ok {
|
|
||||||
info.UpdateStatus(TaskStatusRunning)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := w.inner.Execute(ctx)
|
|
||||||
|
|
||||||
info, ok := GetTask(taskID)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var status TaskStatus
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
status = TaskStatusCancelled
|
|
||||||
info.UpdateStatus(TaskStatusCancelled)
|
|
||||||
} else {
|
|
||||||
status = TaskStatusFailed
|
|
||||||
info.SetError(err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
status = TaskStatusCompleted
|
|
||||||
info.UpdateStatus(TaskStatusCompleted)
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Webhook != "" {
|
|
||||||
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
|
|
||||||
SendWebhook(ctx, payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/common/i18n"
|
"github.com/krau/SaveAny-Bot/common/i18n"
|
||||||
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
"github.com/krau/SaveAny-Bot/common/i18n/i18nk"
|
||||||
"github.com/krau/SaveAny-Bot/common/utils/strutil"
|
"github.com/krau/SaveAny-Bot/common/utils/strutil"
|
||||||
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/database"
|
"github.com/krau/SaveAny-Bot/database"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/rule"
|
"github.com/krau/SaveAny-Bot/pkg/rule"
|
||||||
)
|
)
|
||||||
@@ -84,6 +85,46 @@ func handleRuleCmd(ctx *ext.Context, update *ext.Update) error {
|
|||||||
return dispatcher.EndGroups
|
return dispatcher.EndGroups
|
||||||
}
|
}
|
||||||
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoCreateRuleSuccess, nil)), nil)
|
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoCreateRuleSuccess, nil)), nil)
|
||||||
|
case "preset":
|
||||||
|
// /rule preset <storage> [base_path]
|
||||||
|
if len(args) < 3 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextStyledTextArray(msgelem.BuildRuleHelpStyling(user.ApplyRule, user.Rules)), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
storageName := args[2]
|
||||||
|
if !config.C().HasStorage(user.ChatID, storageName) {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleErrorStorageNotFound, map[string]any{
|
||||||
|
"Storage": storageName,
|
||||||
|
})), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
basePath := ""
|
||||||
|
if len(args) >= 4 {
|
||||||
|
basePath = args[3]
|
||||||
|
}
|
||||||
|
presets := rule.PresetCategories(basePath)
|
||||||
|
imported := 0
|
||||||
|
for _, p := range presets {
|
||||||
|
rd := &database.Rule{
|
||||||
|
Type: rule.FileNameRegex.String(),
|
||||||
|
Data: p.Regex,
|
||||||
|
StorageName: storageName,
|
||||||
|
DirPath: p.Dir,
|
||||||
|
UserID: user.ID,
|
||||||
|
}
|
||||||
|
if err := database.CreateRule(ctx, rd); err != nil {
|
||||||
|
logger.Errorf("failed to create preset rule %s: %s", p.Name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
imported++
|
||||||
|
}
|
||||||
|
if imported == 0 {
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleErrorCreateRuleFailed, nil)), nil)
|
||||||
|
return dispatcher.EndGroups
|
||||||
|
}
|
||||||
|
ctx.Reply(update, ext.ReplyTextString(i18n.T(i18nk.BotMsgRuleInfoPresetImported, map[string]any{
|
||||||
|
"Count": imported,
|
||||||
|
})), nil)
|
||||||
case "del":
|
case "del":
|
||||||
// /rule del <id>
|
// /rule del <id>
|
||||||
if len(args) < 3 {
|
if len(args) < 3 {
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ func BuildRuleHelpStyling(enabled bool, rules []database.Rule) []styling.StyledT
|
|||||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpSwitchSuffix, nil)),
|
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpSwitchSuffix, nil)),
|
||||||
styling.Code("add"),
|
styling.Code("add"),
|
||||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpAddSuffix, nil)),
|
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpAddSuffix, nil)),
|
||||||
|
styling.Code("preset"),
|
||||||
|
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpPresetSuffix, nil)),
|
||||||
styling.Code("del"),
|
styling.Code("del"),
|
||||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpDelSuffix, nil)),
|
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpDelSuffix, nil)),
|
||||||
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpExistingRulesPrefix, nil)),
|
styling.Plain(i18n.T(i18nk.BotMsgRuleHelpExistingRulesPrefix, nil)),
|
||||||
|
|||||||
@@ -84,8 +84,8 @@ const (
|
|||||||
BotMsgCommonPromptSelectDefaultDir Key = "bot.msg.common.prompt_select_default_dir"
|
BotMsgCommonPromptSelectDefaultDir Key = "bot.msg.common.prompt_select_default_dir"
|
||||||
BotMsgCommonPromptSelectDefaultStorage Key = "bot.msg.common.prompt_select_default_storage"
|
BotMsgCommonPromptSelectDefaultStorage Key = "bot.msg.common.prompt_select_default_storage"
|
||||||
BotMsgCommonPromptSelectDir Key = "bot.msg.common.prompt_select_dir"
|
BotMsgCommonPromptSelectDir Key = "bot.msg.common.prompt_select_dir"
|
||||||
BotMsgConfigButtonFilenameStrategy Key = "bot.msg.config.button_filename_strategy"
|
|
||||||
BotMsgConfigButtonConflictStrategy Key = "bot.msg.config.button_conflict_strategy"
|
BotMsgConfigButtonConflictStrategy Key = "bot.msg.config.button_conflict_strategy"
|
||||||
|
BotMsgConfigButtonFilenameStrategy Key = "bot.msg.config.button_filename_strategy"
|
||||||
BotMsgConfigConflictStrategyAsk Key = "bot.msg.config.conflict_strategy_ask"
|
BotMsgConfigConflictStrategyAsk Key = "bot.msg.config.conflict_strategy_ask"
|
||||||
BotMsgConfigConflictStrategyOverwrite Key = "bot.msg.config.conflict_strategy_overwrite"
|
BotMsgConfigConflictStrategyOverwrite Key = "bot.msg.config.conflict_strategy_overwrite"
|
||||||
BotMsgConfigConflictStrategyRename Key = "bot.msg.config.conflict_strategy_rename"
|
BotMsgConfigConflictStrategyRename Key = "bot.msg.config.conflict_strategy_rename"
|
||||||
@@ -93,8 +93,8 @@ const (
|
|||||||
BotMsgConfigErrorInvalidCallbackData Key = "bot.msg.config.error_invalid_callback_data"
|
BotMsgConfigErrorInvalidCallbackData Key = "bot.msg.config.error_invalid_callback_data"
|
||||||
BotMsgConfigErrorInvalidTemplate Key = "bot.msg.config.error_invalid_template"
|
BotMsgConfigErrorInvalidTemplate Key = "bot.msg.config.error_invalid_template"
|
||||||
BotMsgConfigFnametmplHelp Key = "bot.msg.config.fnametmpl_help"
|
BotMsgConfigFnametmplHelp Key = "bot.msg.config.fnametmpl_help"
|
||||||
BotMsgConfigInfoCurrentTemplatePrefix Key = "bot.msg.config.info_current_template_prefix"
|
|
||||||
BotMsgConfigInfoConflictStrategySet Key = "bot.msg.config.info_conflict_strategy_set"
|
BotMsgConfigInfoConflictStrategySet Key = "bot.msg.config.info_conflict_strategy_set"
|
||||||
|
BotMsgConfigInfoCurrentTemplatePrefix Key = "bot.msg.config.info_current_template_prefix"
|
||||||
BotMsgConfigInfoFilenameStrategySet Key = "bot.msg.config.info_filename_strategy_set"
|
BotMsgConfigInfoFilenameStrategySet Key = "bot.msg.config.info_filename_strategy_set"
|
||||||
BotMsgConfigInfoTemplateUpdated Key = "bot.msg.config.info_template_updated"
|
BotMsgConfigInfoTemplateUpdated Key = "bot.msg.config.info_template_updated"
|
||||||
BotMsgConfigPromptSelectConflictStrategy Key = "bot.msg.config.prompt_select_conflict_strategy"
|
BotMsgConfigPromptSelectConflictStrategy Key = "bot.msg.config.prompt_select_conflict_strategy"
|
||||||
@@ -200,6 +200,7 @@ const (
|
|||||||
BotMsgRuleErrorGetUserRulesFailed Key = "bot.msg.rule.error_get_user_rules_failed"
|
BotMsgRuleErrorGetUserRulesFailed Key = "bot.msg.rule.error_get_user_rules_failed"
|
||||||
BotMsgRuleErrorInvalidRuleId Key = "bot.msg.rule.error_invalid_rule_id"
|
BotMsgRuleErrorInvalidRuleId Key = "bot.msg.rule.error_invalid_rule_id"
|
||||||
BotMsgRuleErrorInvalidRuleType Key = "bot.msg.rule.error_invalid_rule_type"
|
BotMsgRuleErrorInvalidRuleType Key = "bot.msg.rule.error_invalid_rule_type"
|
||||||
|
BotMsgRuleErrorStorageNotFound Key = "bot.msg.rule.error_storage_not_found"
|
||||||
BotMsgRuleErrorUpdateUserFailed Key = "bot.msg.rule.error_update_user_failed"
|
BotMsgRuleErrorUpdateUserFailed Key = "bot.msg.rule.error_update_user_failed"
|
||||||
BotMsgRuleHelpAddSuffix Key = "bot.msg.rule.help_add_suffix"
|
BotMsgRuleHelpAddSuffix Key = "bot.msg.rule.help_add_suffix"
|
||||||
BotMsgRuleHelpAvailableOps Key = "bot.msg.rule.help_available_ops"
|
BotMsgRuleHelpAvailableOps Key = "bot.msg.rule.help_available_ops"
|
||||||
@@ -207,13 +208,16 @@ const (
|
|||||||
BotMsgRuleHelpCurrentModeEnabled Key = "bot.msg.rule.help_current_mode_enabled"
|
BotMsgRuleHelpCurrentModeEnabled Key = "bot.msg.rule.help_current_mode_enabled"
|
||||||
BotMsgRuleHelpDelSuffix Key = "bot.msg.rule.help_del_suffix"
|
BotMsgRuleHelpDelSuffix Key = "bot.msg.rule.help_del_suffix"
|
||||||
BotMsgRuleHelpExistingRulesPrefix Key = "bot.msg.rule.help_existing_rules_prefix"
|
BotMsgRuleHelpExistingRulesPrefix Key = "bot.msg.rule.help_existing_rules_prefix"
|
||||||
|
BotMsgRuleHelpPresetSuffix Key = "bot.msg.rule.help_preset_suffix"
|
||||||
BotMsgRuleHelpSwitchSuffix Key = "bot.msg.rule.help_switch_suffix"
|
BotMsgRuleHelpSwitchSuffix Key = "bot.msg.rule.help_switch_suffix"
|
||||||
BotMsgRuleHelpUsage Key = "bot.msg.rule.help_usage"
|
BotMsgRuleHelpUsage Key = "bot.msg.rule.help_usage"
|
||||||
BotMsgRuleInfoCreateRuleSuccess Key = "bot.msg.rule.info_create_rule_success"
|
BotMsgRuleInfoCreateRuleSuccess Key = "bot.msg.rule.info_create_rule_success"
|
||||||
BotMsgRuleInfoDeleteRuleSuccess Key = "bot.msg.rule.info_delete_rule_success"
|
BotMsgRuleInfoDeleteRuleSuccess Key = "bot.msg.rule.info_delete_rule_success"
|
||||||
|
BotMsgRuleInfoPresetImported Key = "bot.msg.rule.info_preset_imported"
|
||||||
BotMsgRuleInfoRuleModeDisabled Key = "bot.msg.rule.info_rule_mode_disabled"
|
BotMsgRuleInfoRuleModeDisabled Key = "bot.msg.rule.info_rule_mode_disabled"
|
||||||
BotMsgRuleInfoRuleModeEnabled Key = "bot.msg.rule.info_rule_mode_enabled"
|
BotMsgRuleInfoRuleModeEnabled Key = "bot.msg.rule.info_rule_mode_enabled"
|
||||||
BotMsgRulePromptProvideRuleId Key = "bot.msg.rule.prompt_provide_rule_id"
|
BotMsgRulePromptProvideRuleId Key = "bot.msg.rule.prompt_provide_rule_id"
|
||||||
|
BotMsgRulePromptProvideStorageName Key = "bot.msg.rule.prompt_provide_storage_name"
|
||||||
BotMsgSaveErrorInvalidIdOrUsername Key = "bot.msg.save.error_invalid_id_or_username"
|
BotMsgSaveErrorInvalidIdOrUsername Key = "bot.msg.save.error_invalid_id_or_username"
|
||||||
BotMsgSaveHelpText Key = "bot.msg.save_help_text"
|
BotMsgSaveHelpText Key = "bot.msg.save_help_text"
|
||||||
BotMsgStorageInfoFilenamePrefix Key = "bot.msg.storage.info_filename_prefix"
|
BotMsgStorageInfoFilenamePrefix Key = "bot.msg.storage.info_filename_prefix"
|
||||||
|
|||||||
@@ -196,7 +196,11 @@ bot:
|
|||||||
help_switch_suffix: " - Toggle rule mode\n"
|
help_switch_suffix: " - Toggle rule mode\n"
|
||||||
help_add_suffix: " <type> <data> <storage_name> <path> - Add rule\n"
|
help_add_suffix: " <type> <data> <storage_name> <path> - Add rule\n"
|
||||||
help_del_suffix: " <rule_id> - Delete rule\n"
|
help_del_suffix: " <rule_id> - Delete rule\n"
|
||||||
|
help_preset_suffix: " <storage_name> [base_path] - Import built-in filetype rules (video/image/audio/document/archive)\n"
|
||||||
help_existing_rules_prefix: "\nCurrent rules:\n"
|
help_existing_rules_prefix: "\nCurrent rules:\n"
|
||||||
|
prompt_provide_storage_name: "Please provide a storage name"
|
||||||
|
error_storage_not_found: "Storage not found: {{.Storage}}"
|
||||||
|
info_preset_imported: "Imported {{.Count}} built-in classification rules into storage {{.Storage}}"
|
||||||
dir:
|
dir:
|
||||||
error_get_user_dirs_failed: "Failed to get user directories"
|
error_get_user_dirs_failed: "Failed to get user directories"
|
||||||
error_get_user_failed: "Failed to get user"
|
error_get_user_failed: "Failed to get user"
|
||||||
|
|||||||
@@ -197,7 +197,11 @@ bot:
|
|||||||
help_switch_suffix: " - 开关规则模式\n"
|
help_switch_suffix: " - 开关规则模式\n"
|
||||||
help_add_suffix: " <类型> <数据> <存储名> <路径> - 添加规则\n"
|
help_add_suffix: " <类型> <数据> <存储名> <路径> - 添加规则\n"
|
||||||
help_del_suffix: " <规则ID> - 删除规则\n"
|
help_del_suffix: " <规则ID> - 删除规则\n"
|
||||||
|
help_preset_suffix: " <存储名> [基础路径] - 导入内置文件类型分类规则(视频/图片/音频/文档/压缩包)\n"
|
||||||
help_existing_rules_prefix: "\n当前已添加的规则:\n"
|
help_existing_rules_prefix: "\n当前已添加的规则:\n"
|
||||||
|
prompt_provide_storage_name: "请提供存储名称"
|
||||||
|
error_storage_not_found: "未找到存储: {{.Storage}}"
|
||||||
|
info_preset_imported: "已导入 {{.Count}} 条内置分类规则到存储 {{.Storage}}"
|
||||||
dir:
|
dir:
|
||||||
error_get_user_dirs_failed: "获取用户文件夹失败"
|
error_get_user_dirs_failed: "获取用户文件夹失败"
|
||||||
error_get_user_failed: "获取用户失败"
|
error_get_user_failed: "获取用户失败"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/queue"
|
"github.com/krau/SaveAny-Bot/pkg/queue"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
)
|
)
|
||||||
|
|
||||||
var queueInstance *queue.TaskQueue[Executable]
|
var queueInstance *queue.TaskQueue[Executable]
|
||||||
@@ -30,11 +31,14 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
|
|||||||
break // queue closed and empty
|
break // queue closed and empty
|
||||||
}
|
}
|
||||||
exe := qtask.Data
|
exe := qtask.Data
|
||||||
|
taskCtx := qtask.Context()
|
||||||
logger.Infof("Processing task: %s", exe.TaskID())
|
logger.Infof("Processing task: %s", exe.TaskID())
|
||||||
if err := ExecCommandString(qtask.Context(), execHooks.TaskBeforeStart); err != nil {
|
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseStart})
|
||||||
|
if err := ExecCommandString(taskCtx, execHooks.TaskBeforeStart); err != nil {
|
||||||
logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err)
|
logger.Errorf("Failed to execute before start hook for task %s: %v", exe.TaskID(), err)
|
||||||
}
|
}
|
||||||
if err := exe.Execute(qtask.Context()); err != nil {
|
err = exe.Execute(taskCtx)
|
||||||
|
if err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
logger.Infof("Task %s was canceled", exe.TaskID())
|
logger.Infof("Task %s was canceled", exe.TaskID())
|
||||||
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
|
if err := ExecCommandString(ctx, execHooks.TaskCancel); err != nil {
|
||||||
@@ -52,6 +56,7 @@ func worker(ctx context.Context, qe *queue.TaskQueue[Executable], semaphore chan
|
|||||||
logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err)
|
logger.Errorf("Failed to execute success hook for task %s: %v", exe.TaskID(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
taskevent.Emit(taskCtx, taskevent.Event{TaskID: exe.TaskID(), Phase: taskevent.PhaseDone, Err: err})
|
||||||
qe.Done(qtask.ID)
|
qe.Done(qtask.ID)
|
||||||
<-semaphore
|
<-semaphore
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Execute implements core.Executable.
|
// Execute implements core.Executable.
|
||||||
@@ -77,6 +79,12 @@ func (t *Task) waitForDownload(ctx context.Context) error {
|
|||||||
if t.Progress != nil {
|
if t.Progress != nil {
|
||||||
t.Progress.OnProgress(ctx, t, status)
|
t.Progress.OnProgress(ctx, t, status)
|
||||||
}
|
}
|
||||||
|
taskevent.Emit(ctx, taskevent.Event{
|
||||||
|
TaskID: t.ID,
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: parseInt64(status.TotalLength),
|
||||||
|
DownloadedBytes: parseInt64(status.CompletedLength),
|
||||||
|
})
|
||||||
|
|
||||||
// Check if download is complete
|
// Check if download is complete
|
||||||
if status.IsDownloadComplete() {
|
if status.IsDownloadComplete() {
|
||||||
@@ -248,3 +256,16 @@ func (t *Task) cancelAria2Download() {
|
|||||||
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
|
logger.Debugf("Failed to remove download result for %s: %v", t.GID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseInt64 parses an aria2 status string (decimal bytes) into int64,
|
||||||
|
// returning 0 on failure so it can be used directly in progress events.
|
||||||
|
func parseInt64(s string) int64 {
|
||||||
|
if s == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,8 +63,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
|||||||
return elem.Storage.Save(uploadCtx, pr, elem.Path)
|
return elem.Storage.Save(uploadCtx, pr, elem.Path)
|
||||||
})
|
})
|
||||||
wr := ioutil.NewProgressWriter(pw, func(n int) {
|
wr := ioutil.NewProgressWriter(pw, func(n int) {
|
||||||
t.downloaded.Add(int64(n))
|
downloaded := t.downloaded.Add(int64(n))
|
||||||
t.Progress.OnProgress(ctx, t)
|
t.Progress.OnProgress(ctx, t)
|
||||||
|
taskevent.Emit(ctx, taskevent.Event{
|
||||||
|
TaskID: t.ID,
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: t.totalSize,
|
||||||
|
DownloadedBytes: downloaded,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
errg.Go(func() error {
|
errg.Go(func() error {
|
||||||
defer pw.Close()
|
defer pw.Close()
|
||||||
@@ -92,8 +99,14 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
|
wrAt := ioutil.NewProgressWriterAt(localFile, func(n int) {
|
||||||
t.downloaded.Add(int64(n))
|
downloaded := t.downloaded.Add(int64(n))
|
||||||
t.Progress.OnProgress(ctx, t)
|
t.Progress.OnProgress(ctx, t)
|
||||||
|
taskevent.Emit(ctx, taskevent.Event{
|
||||||
|
TaskID: t.ID,
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: t.totalSize,
|
||||||
|
DownloadedBytes: downloaded,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
_, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt)
|
_, err = tdler.NewDownloader(elem.File).Parallel(ctx, wrAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
"github.com/krau/SaveAny-Bot/common/utils/ioutil"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,10 +144,16 @@ func (t *Task) processLink(ctx context.Context, file *File) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
||||||
t.downloadedBytes.Add(int64(n))
|
downloaded := t.downloadedBytes.Add(int64(n))
|
||||||
if t.Progress != nil {
|
if t.Progress != nil {
|
||||||
t.Progress.OnProgress(ctx, t)
|
t.Progress.OnProgress(ctx, t)
|
||||||
}
|
}
|
||||||
|
taskevent.Emit(ctx, taskevent.Event{
|
||||||
|
TaskID: t.ID,
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: t.totalBytes,
|
||||||
|
DownloadedBytes: downloaded,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
copyResultCh := make(chan error, 1)
|
copyResultCh := make(chan error, 1)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/parser"
|
"github.com/krau/SaveAny-Bot/pkg/parser"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,10 +108,16 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
|
||||||
t.downloadedBytes.Add(int64(n))
|
downloaded := t.downloadedBytes.Add(int64(n))
|
||||||
if t.progress != nil {
|
if t.progress != nil {
|
||||||
t.progress.OnProgress(ctx, t)
|
t.progress.OnProgress(ctx, t)
|
||||||
}
|
}
|
||||||
|
taskevent.Emit(ctx, taskevent.Event{
|
||||||
|
TaskID: t.ID,
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: t.totalBytes,
|
||||||
|
DownloadedBytes: downloaded,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
copyResultCh := make(chan error, 1)
|
copyResultCh := make(chan error, 1)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/duke-git/lancet/v2/retry"
|
"github.com/duke-git/lancet/v2/retry"
|
||||||
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
"github.com/krau/SaveAny-Bot/common/utils/fsutil"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,8 +28,14 @@ func (t *Task) Execute(ctx context.Context) error {
|
|||||||
logger.Errorf("Error processing picture %s: %v", pic, err)
|
logger.Errorf("Error processing picture %s: %v", pic, err)
|
||||||
return fmt.Errorf("failed to process picture %s: %w", pic, err)
|
return fmt.Errorf("failed to process picture %s: %w", pic, err)
|
||||||
}
|
}
|
||||||
t.downloaded.Add(1)
|
downloaded := t.downloaded.Add(1)
|
||||||
t.progress.OnProgress(gctx, t)
|
t.progress.OnProgress(gctx, t)
|
||||||
|
taskevent.Emit(gctx, taskevent.Event{
|
||||||
|
TaskID: t.ID,
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalFiles: t.totalpics,
|
||||||
|
DownloadedFiles: int(downloaded),
|
||||||
|
})
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProgressWriterAt struct {
|
type ProgressWriterAt struct {
|
||||||
@@ -20,9 +22,16 @@ func (w *ProgressWriterAt) WriteAt(p []byte, off int64) (int, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
downloaded := w.downloaded.Add(int64(at))
|
||||||
if w.progress != nil {
|
if w.progress != nil {
|
||||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
|
||||||
}
|
}
|
||||||
|
taskevent.Emit(w.ctx, taskevent.Event{
|
||||||
|
TaskID: w.info.TaskID(),
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: w.total,
|
||||||
|
DownloadedBytes: downloaded,
|
||||||
|
})
|
||||||
return at, nil
|
return at, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,9 +65,16 @@ func (w *ProgressWriter) Write(p []byte) (int, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
downloaded := w.downloaded.Add(int64(at))
|
||||||
if w.progress != nil {
|
if w.progress != nil {
|
||||||
w.progress.OnProgress(w.ctx, w.info, w.downloaded.Add(int64(at)), w.total)
|
w.progress.OnProgress(w.ctx, w.info, downloaded, w.total)
|
||||||
}
|
}
|
||||||
|
taskevent.Emit(w.ctx, taskevent.Event{
|
||||||
|
TaskID: w.info.TaskID(),
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: w.total,
|
||||||
|
DownloadedBytes: downloaded,
|
||||||
|
})
|
||||||
return at, nil
|
return at, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
"github.com/krau/SaveAny-Bot/config"
|
"github.com/krau/SaveAny-Bot/config"
|
||||||
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
|
||||||
|
"github.com/krau/SaveAny-Bot/pkg/taskevent"
|
||||||
"github.com/krau/SaveAny-Bot/storage"
|
"github.com/krau/SaveAny-Bot/storage"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
@@ -116,6 +117,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
|
|||||||
|
|
||||||
t.uploaded.Add(size)
|
t.uploaded.Add(size)
|
||||||
t.Progress.OnProgress(ctx, t)
|
t.Progress.OnProgress(ctx, t)
|
||||||
|
taskevent.Emit(ctx, taskevent.Event{
|
||||||
|
TaskID: t.ID,
|
||||||
|
Phase: taskevent.PhaseProgress,
|
||||||
|
TotalBytes: t.totalSize,
|
||||||
|
DownloadedBytes: t.uploaded.Load(),
|
||||||
|
})
|
||||||
|
|
||||||
logger.Info("File uploaded successfully")
|
logger.Info("File uploaded successfully")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
55
pkg/rule/preset.go
Normal file
55
pkg/rule/preset.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package rule
|
||||||
|
|
||||||
|
import "path"
|
||||||
|
|
||||||
|
// PresetCategory describes a built-in filetype classification: files whose name
|
||||||
|
// matches Regex are routed into the Dir subdirectory (joined with a user base path).
|
||||||
|
type PresetCategory struct {
|
||||||
|
// Name is a stable identifier for the category (used in logs/messages).
|
||||||
|
Name string
|
||||||
|
// Regex is a FILENAME-REGEX rule data string matching this category's extensions.
|
||||||
|
Regex string
|
||||||
|
// Dir is the default subdirectory name for this category.
|
||||||
|
Dir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// presetCategories holds the default filetype classification rules.
|
||||||
|
// Regexes are case-insensitive and match common file extensions.
|
||||||
|
var presetCategories = []PresetCategory{
|
||||||
|
{
|
||||||
|
Name: "video",
|
||||||
|
Regex: `(?i)\.(mp4|mkv|ts|avi|flv|mov|webm|wmv|rmvb|m2ts)$`,
|
||||||
|
Dir: "视频",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "image",
|
||||||
|
Regex: `(?i)\.(jpg|jpeg|png|gif|webp|bmp)$`,
|
||||||
|
Dir: "图片",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "audio",
|
||||||
|
Regex: `(?i)\.(mp3|flac|wav|aac|m4a|ogg)$`,
|
||||||
|
Dir: "音频",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "document",
|
||||||
|
Regex: `(?i)\.(pdf|doc|docx|xls|xlsx|ppt|pptx|txt|md|csv|epub|mobi|azw3|chm)$`,
|
||||||
|
Dir: "文档",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "archive",
|
||||||
|
Regex: `(?i)\.(zip|rar|7z|tar|gz|bz2|xz|r\d{1,3}|z\d{1,3}|\d{3}|part\d+\.rar|7z\.\d{3})$`,
|
||||||
|
Dir: "压缩包",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// PresetCategories returns the built-in filetype classification rules with each
|
||||||
|
// category's directory joined under basePath. basePath may be empty.
|
||||||
|
func PresetCategories(basePath string) []PresetCategory {
|
||||||
|
out := make([]PresetCategory, len(presetCategories))
|
||||||
|
for i, c := range presetCategories {
|
||||||
|
c.Dir = path.Join(basePath, c.Dir)
|
||||||
|
out[i] = c
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
55
pkg/rule/preset_test.go
Normal file
55
pkg/rule/preset_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package rule
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPresetCategoriesCompile(t *testing.T) {
|
||||||
|
for _, c := range PresetCategories("") {
|
||||||
|
if _, err := regexp.Compile(c.Regex); err != nil {
|
||||||
|
t.Errorf("preset %q has invalid regex %q: %v", c.Name, c.Regex, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPresetCategoriesMatch(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"video": "movie.MP4",
|
||||||
|
"image": "photo.jpg",
|
||||||
|
"audio": "song.flac",
|
||||||
|
"document": "report.pdf",
|
||||||
|
"archive": "backup.zip",
|
||||||
|
}
|
||||||
|
|
||||||
|
byName := make(map[string]*regexp.Regexp)
|
||||||
|
for _, c := range PresetCategories("") {
|
||||||
|
byName[c.Name] = regexp.MustCompile(c.Regex)
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, filename := range cases {
|
||||||
|
re, ok := byName[name]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("missing preset category %q", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !re.MatchString(filename) {
|
||||||
|
t.Errorf("preset %q did not match %q", name, filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPresetCategoriesBasePath(t *testing.T) {
|
||||||
|
presets := PresetCategories("/media")
|
||||||
|
for _, c := range presets {
|
||||||
|
if c.Dir == "" || c.Dir[0] != '/' {
|
||||||
|
t.Errorf("preset %q dir %q not joined under base path", c.Name, c.Dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Empty base path must not prefix a separator.
|
||||||
|
for _, c := range PresetCategories("") {
|
||||||
|
if c.Dir == "" || c.Dir[0] == '/' {
|
||||||
|
t.Errorf("preset %q dir %q should be relative when base path empty", c.Name, c.Dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
88
pkg/taskevent/taskevent.go
Normal file
88
pkg/taskevent/taskevent.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Package taskevent provides a decoupled, context-scoped event bus for task
|
||||||
|
// lifecycle progress. Producers (task implementations) emit events via Emit;
|
||||||
|
// consumers (e.g. the API progress store, the Telegram message editor) register
|
||||||
|
// as Sinks and are injected through context. This keeps the task layer free of
|
||||||
|
// any concrete progress-display dependency, so new task types gain progress
|
||||||
|
// reporting for free and new observers can be added without touching tasks.
|
||||||
|
package taskevent
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Phase marks a stage in a task's lifecycle.
|
||||||
|
type Phase int
|
||||||
|
|
||||||
|
const (
|
||||||
|
PhaseStart Phase = iota
|
||||||
|
PhaseProgress
|
||||||
|
PhaseDone
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p Phase) String() string {
|
||||||
|
switch p {
|
||||||
|
case PhaseStart:
|
||||||
|
return "start"
|
||||||
|
case PhaseProgress:
|
||||||
|
return "progress"
|
||||||
|
case PhaseDone:
|
||||||
|
return "done"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Event describes a single progress observation for a task. Byte fields are
|
||||||
|
// populated by byte-stream tasks; file-count fields by count-based tasks. A
|
||||||
|
// task may fill whichever subset it has; observers ignore zero values.
|
||||||
|
type Event struct {
|
||||||
|
TaskID string
|
||||||
|
Phase Phase
|
||||||
|
TotalBytes int64
|
||||||
|
DownloadedBytes int64
|
||||||
|
TotalFiles int
|
||||||
|
DownloadedFiles int
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sink receives task events. Implementations must be safe for concurrent use.
|
||||||
|
type Sink interface {
|
||||||
|
Emit(Event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SinkFunc is a function adapter for Sink.
|
||||||
|
type SinkFunc func(Event)
|
||||||
|
|
||||||
|
func (f SinkFunc) Emit(e Event) { f(e) }
|
||||||
|
|
||||||
|
type sinkKey struct{}
|
||||||
|
|
||||||
|
// WithSink returns a ctx carrying the given sinks. Multiple sinks can be passed
|
||||||
|
// and all will receive every emitted event. Sinks already present in ctx are
|
||||||
|
// preserved.
|
||||||
|
func WithSink(ctx context.Context, sinks ...Sink) context.Context {
|
||||||
|
if len(sinks) == 0 {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
var existing []Sink
|
||||||
|
if v, ok := ctx.Value(sinkKey{}).([]Sink); ok {
|
||||||
|
existing = v
|
||||||
|
}
|
||||||
|
merged := make([]Sink, 0, len(existing)+len(sinks))
|
||||||
|
merged = append(merged, existing...)
|
||||||
|
merged = append(merged, sinks...)
|
||||||
|
return context.WithValue(ctx, sinkKey{}, merged)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit broadcasts an event to all sinks carried by ctx. It is a no-op when no
|
||||||
|
// sink is attached, so producers can call it unconditionally.
|
||||||
|
func Emit(ctx context.Context, e Event) {
|
||||||
|
if ctx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sinks, ok := ctx.Value(sinkKey{}).([]Sink)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, s := range sinks {
|
||||||
|
s.Emit(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user