diff --git a/.gitignore b/.gitignore index f1ebe3d..412da25 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ temp/ playwright/ testplugins/ *.exe -tmp-* \ No newline at end of file +tmp-* +saveany-bot \ No newline at end of file diff --git a/api/auth.go b/api/auth.go new file mode 100644 index 0000000..236bec3 --- /dev/null +++ b/api/auth.go @@ -0,0 +1,48 @@ +package api + +import ( + "context" + "crypto/subtle" + "net/http" + "strings" + + "github.com/krau/SaveAny-Bot/config" +) + +// tokenContextKey 用于在 context 中存储 token +type tokenContextKey struct{} + +// AuthMiddleware 返回认证中间件 +func AuthMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cfg := config.C().API + + // 从请求头获取 token + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + WriteError(w, http.StatusUnauthorized, "unauthorized", "missing authorization header") + return + } + + // 提取 Bearer token + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid authorization header format") + return + } + + token := parts[1] + + // 验证 token + if subtle.ConstantTimeCompare([]byte(token), []byte(cfg.Token)) != 1 { + WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid token") + return + } + + // 将 token 添加到 context + ctx := context.WithValue(r.Context(), tokenContextKey{}, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/api/factory.go b/api/factory.go new file mode 100644 index 0000000..7020173 --- /dev/null +++ b/api/factory.go @@ -0,0 +1,355 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/krau/SaveAny-Bot/config" + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/core/tasks/aria2dl" + "github.com/krau/SaveAny-Bot/core/tasks/batchtfile" + "github.com/krau/SaveAny-Bot/core/tasks/directlinks" + "github.com/krau/SaveAny-Bot/core/tasks/parsed" + tphtask "github.com/krau/SaveAny-Bot/core/tasks/telegraph" + "github.com/krau/SaveAny-Bot/core/tasks/tfile" + "github.com/krau/SaveAny-Bot/core/tasks/transfer" + "github.com/krau/SaveAny-Bot/core/tasks/ytdlp" + "github.com/krau/SaveAny-Bot/parsers/parsers" + "github.com/krau/SaveAny-Bot/pkg/aria2" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/pkg/parser" + "github.com/krau/SaveAny-Bot/pkg/telegraph" + "github.com/krau/SaveAny-Bot/storage" + "github.com/rs/xid" +) + +// TaskFactory 任务工厂 +type TaskFactory struct { + ctx context.Context +} + +// NewTaskFactory 创建任务工厂 +func NewTaskFactory(ctx context.Context) *TaskFactory { + return &TaskFactory{ctx: ctx} +} + +// CreateTask 创建任务 +func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, error) { + // 验证存储 + stor, ok := storage.Storages[req.Storage] + if !ok { + return nil, fmt.Errorf("storage not found: %s", req.Storage) + } + + taskID := xid.New().String() + createdAt := time.Now() + + switch req.Type { + case tasktype.TaskTypeDirectlinks: + return f.createDirectLinksTask(taskID, createdAt, req, stor) + case tasktype.TaskTypeYtdlp: + return f.createYTDLPTask(taskID, createdAt, req, stor) + case tasktype.TaskTypeAria2: + return f.createAria2Task(taskID, createdAt, req, stor) + case tasktype.TaskTypeParseditem: + return f.createParsedTask(taskID, createdAt, req, stor) + case tasktype.TaskTypeTgfiles: + return f.createTGFilesTask(taskID, createdAt, req, stor) + case tasktype.TaskTypeTphpics: + return f.createTPHPicsTask(taskID, createdAt, req, stor) + case tasktype.TaskTypeTransfer: + return f.createTransferTask(taskID, createdAt, req) + default: + return nil, fmt.Errorf("unsupported task type: %s", req.Type) + } +} + +// createDirectLinksTask 创建直链下载任务 +func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { + var params DirectLinksParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + if len(params.URLs) == 0 { + return nil, fmt.Errorf("no URLs provided") + } + + task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil) + + if err := core.AddTask(f.ctx, task); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + + return &CreateTaskResponse{ + TaskID: taskID, + Type: tasktype.TaskTypeDirectlinks, + Status: TaskStatusQueued, + CreatedAt: createdAt, + }, nil +} + +// createYTDLPTask 创建 yt-dlp 任务 +func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { + var params YTDLPParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + if len(params.URLs) == 0 { + return nil, fmt.Errorf("no URLs provided") + } + + task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil) + + if err := core.AddTask(f.ctx, task); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + + return &CreateTaskResponse{ + TaskID: taskID, + Type: tasktype.TaskTypeYtdlp, + Status: TaskStatusQueued, + CreatedAt: createdAt, + }, nil +} + +// createAria2Task 创建 Aria2 任务 +func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { + var params Aria2Params + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + if len(params.URLs) == 0 { + return nil, fmt.Errorf("no URLs provided") + } + + // 检查 Aria2 是否启用 + cfg := config.C().Aria2 + if !cfg.Enable { + return nil, fmt.Errorf("aria2 is not enabled") + } + + aria2Client, err := aria2.NewClient(cfg.Url, cfg.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create aria2 client: %w", err) + } + + // 添加下载任务到 Aria2 + gid, err := aria2Client.AddURI(f.ctx, params.URLs, nil) + if err != nil { + return nil, fmt.Errorf("failed to add aria2 task: %w", err) + } + + task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil) + + if err := core.AddTask(f.ctx, task); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + + return &CreateTaskResponse{ + TaskID: taskID, + Type: tasktype.TaskTypeAria2, + Status: TaskStatusQueued, + CreatedAt: createdAt, + }, nil +} + +// createParsedTask 创建解析任务 +func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { + var params ParsedParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + if params.URL == "" { + return nil, fmt.Errorf("no URL provided") + } + + // 查找合适的解析器 + var p parser.Parser + for _, parserItem := range parsers.Get() { + if parserItem.CanHandle(params.URL) { + p = parserItem + break + } + } + + if p == nil { + return nil, fmt.Errorf("no parser found for URL: %s", params.URL) + } + + // 解析 URL + item, err := p.Parse(f.ctx, params.URL) + if err != nil { + return nil, fmt.Errorf("failed to parse URL: %w", err) + } + + task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil) + + if err := core.AddTask(f.ctx, task); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + + return &CreateTaskResponse{ + TaskID: taskID, + Type: tasktype.TaskTypeParseditem, + Status: TaskStatusQueued, + CreatedAt: createdAt, + }, nil +} + +// createTGFilesTask 创建 Telegram 文件下载任务 +func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { + var params TGFilesParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + if len(params.MessageLinks) == 0 { + return nil, fmt.Errorf("no message links provided") + } + + // 提取文件 + files, err := ExtractFilesFromLinks(f.ctx, params.MessageLinks) + if err != nil { + return nil, fmt.Errorf("failed to extract files: %w", err) + } + + if len(files) == 0 { + return nil, fmt.Errorf("no files found in provided links") + } + + if len(files) == 1 { + // 单个文件任务 + tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil) + if err != nil { + return nil, fmt.Errorf("failed to create tfile task: %w", err) + } + if err := core.AddTask(f.ctx, tfileTask); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + } else { + // 批量文件任务 + elems := make([]batchtfile.TaskElement, 0, len(files)) + for _, file := range files { + elem, err := batchtfile.NewTaskElement(stor, req.Path, file) + if err != nil { + return nil, fmt.Errorf("failed to create task element: %w", err) + } + elems = append(elems, *elem) + } + + task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true) + if err := core.AddTask(f.ctx, task); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + } + + return &CreateTaskResponse{ + TaskID: taskID, + Type: tasktype.TaskTypeTgfiles, + Status: TaskStatusQueued, + CreatedAt: createdAt, + }, nil +} + +// createTPHPicsTask 创建 Telegraph 图片下载任务 +func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) { + var params TPHPicsParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + if params.TelegraphURL == "" { + return nil, fmt.Errorf("no telegraph URL provided") + } + + // 提取图片 + pics, phPath, err := ExtractTelegraphImages(f.ctx, params.TelegraphURL) + if err != nil { + return nil, fmt.Errorf("failed to extract telegraph images: %w", err) + } + + if len(pics) == 0 { + return nil, fmt.Errorf("no images found in telegraph page") + } + + client := telegraph.NewClient() + task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil) + + if err := core.AddTask(f.ctx, task); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + + return &CreateTaskResponse{ + TaskID: taskID, + Type: tasktype.TaskTypeTphpics, + Status: TaskStatusQueued, + CreatedAt: createdAt, + }, nil +} + +// createTransferTask 创建存储间传输任务 +func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req *CreateTaskRequest) (*CreateTaskResponse, error) { + var params TransferParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, fmt.Errorf("invalid params: %w", err) + } + + // 验证源存储和目标存储 + sourceStor, ok := storage.Storages[params.SourceStorage] + if !ok { + return nil, fmt.Errorf("source storage not found: %s", params.SourceStorage) + } + + targetStor, ok := storage.Storages[params.TargetStorage] + if !ok { + return nil, fmt.Errorf("target storage not found: %s", params.TargetStorage) + } + + // 检查源存储是否可读 + sourceReadable, ok := sourceStor.(storage.StorageReadable) + if !ok { + return nil, fmt.Errorf("source storage does not support reading: %s", params.SourceStorage) + } + + // 检查源存储是否可列 + sourceListable, ok := sourceStor.(storage.StorageListable) + if !ok { + return nil, fmt.Errorf("source storage does not support listing: %s", params.SourceStorage) + } + + // 列出源文件 + files, err := sourceListable.ListFiles(f.ctx, params.SourcePath) + if err != nil { + return nil, fmt.Errorf("failed to list source files: %w", err) + } + + if len(files) == 0 { + return nil, fmt.Errorf("no files found at source path: %s", params.SourcePath) + } + + // 创建传输元素 + elems := make([]transfer.TaskElement, 0, len(files)) + for _, file := range files { + elem := transfer.NewTaskElement(sourceReadable, file, targetStor, params.TargetPath) + elems = append(elems, *elem) + } + + task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true) + + if err := core.AddTask(f.ctx, task); err != nil { + return nil, fmt.Errorf("failed to add task: %w", err) + } + + return &CreateTaskResponse{ + TaskID: taskID, + Type: tasktype.TaskTypeTransfer, + Status: TaskStatusQueued, + CreatedAt: createdAt, + }, nil +} diff --git a/api/handlers.go b/api/handlers.go new file mode 100644 index 0000000..81946d0 --- /dev/null +++ b/api/handlers.go @@ -0,0 +1,222 @@ +package api + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/krau/SaveAny-Bot/core" + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" + "github.com/krau/SaveAny-Bot/storage" +) + +// Handlers 处理器结构体 +type Handlers struct { + factory *TaskFactory +} + +// NewHandlers 创建处理器 +func NewHandlers(factory *TaskFactory) *Handlers { + return &Handlers{factory: factory} +} + +// CreateTaskHandler 创建任务处理器 +func (h *Handlers) CreateTaskHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only POST method is allowed") + return + } + + var req CreateTaskRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error()) + return + } + + // 验证请求 + if req.Type == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "task type is required") + return + } + + if req.Storage == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "storage is required") + return + } + + // 创建任务 + resp, err := h.factory.CreateTask(&req) + if err != nil { + WriteError(w, http.StatusBadRequest, "task_creation_failed", err.Error()) + return + } + + WriteJSON(w, http.StatusCreated, resp) +} + +// ListTasksHandler 列出任务处理器 +func (h *Handlers) ListTasksHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed") + return + } + + tasks := GetAllTasks() + response := make([]TaskInfoResponse, 0, len(tasks)) + + for _, task := range tasks { + info := convertTaskProgressToResponse(task) + response = append(response, info) + } + + WriteJSON(w, http.StatusOK, TasksListResponse{ + Tasks: response, + Total: len(response), + }) +} + +// GetTaskHandler 获取单个任务处理器 +func (h *Handlers) GetTaskHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed") + return + } + + taskID := extractTaskIDFromPath(r.URL.Path) + if taskID == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required") + return + } + + task, ok := GetTask(taskID) + if !ok { + WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID) + return + } + + resp := convertTaskProgressToResponse(task) + WriteJSON(w, http.StatusOK, resp) +} + +// CancelTaskHandler 取消任务处理器 +func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only DELETE method is allowed") + return + } + + taskID := extractTaskIDFromPath(r.URL.Path) + if taskID == "" { + WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required") + return + } + + task, ok := GetTask(taskID) + if !ok { + WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID) + return + } + + // 取消任务 + if err := core.CancelTask(r.Context(), taskID); err != nil { + WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error()) + return + } + + task.UpdateStatus(TaskStatusCancelled) + WriteJSON(w, http.StatusOK, map[string]string{"message": "task cancelled successfully"}) +} + +// ListStoragesHandler 列出存储处理器 +func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed") + return + } + + storages := make([]StorageInfo, 0, len(storage.Storages)) + for name, stor := range storage.Storages { + storages = append(storages, StorageInfo{ + Name: name, + Type: string(stor.Type()), + }) + } + + WriteJSON(w, http.StatusOK, StoragesResponse{Storages: storages}) +} + +// GetTaskTypesHandler 获取支持的任务类型 +func (h *Handlers) GetTaskTypesHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed") + return + } + + types := []tasktype.TaskType{ + tasktype.TaskTypeDirectlinks, + tasktype.TaskTypeYtdlp, + tasktype.TaskTypeAria2, + tasktype.TaskTypeParseditem, + tasktype.TaskTypeTgfiles, + tasktype.TaskTypeTphpics, + tasktype.TaskTypeTransfer, + } + + WriteJSON(w, http.StatusOK, map[string]any{ + "types": types, + }) +} + +// HealthCheckHandler 健康检查处理器 +func (h *Handlers) HealthCheckHandler(w http.ResponseWriter, r *http.Request) { + WriteJSON(w, http.StatusOK, map[string]string{ + "status": "ok", + }) +} + +// extractTaskIDFromPath 从路径中提取任务 ID +// 路径格式: /api/v1/tasks/:id +func extractTaskIDFromPath(path string) string { + parts := strings.Split(strings.Trim(path, "/"), "/") + if len(parts) < 4 { + return "" + } + return parts[3] +} + +// convertTaskProgressToResponse 将任务进度转换为响应格式 +func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse { + resp := TaskInfoResponse{ + TaskID: task.TaskID, + Type: tasktype.TaskType(task.Type), + Status: task.Status, + Title: task.Title, + Storage: task.Storage, + Path: task.Path, + Error: task.Error, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, + } + + // 计算进度 + if task.TotalBytes > 0 { + percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes) + resp.Progress = &TaskProgress{ + TotalBytes: task.TotalBytes, + DownloadedBytes: task.DownloadedBytes, + Percent: percent, + } + } + + return resp +} + +// NotFoundHandler 404 处理器 +func NotFoundHandler(w http.ResponseWriter, r *http.Request) { + WriteError(w, http.StatusNotFound, "not_found", "endpoint not found: "+r.URL.Path) +} + +// MethodNotAllowedHandler 405 处理器 +func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) { + WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "method not allowed: "+r.Method) +} diff --git a/api/progress.go b/api/progress.go new file mode 100644 index 0000000..adac371 --- /dev/null +++ b/api/progress.go @@ -0,0 +1,150 @@ +package api + +import ( + "sync" + "sync/atomic" + "time" +) + +// TaskProgressInfo 存储任务的进度信息 +type TaskProgressInfo struct { + TaskID string + Type string + Status TaskStatus + Title string + TotalBytes int64 + DownloadedBytes int64 + TotalFiles int + DownloadedFiles int + Storage string + Path string + Error string + CreatedAt time.Time + UpdatedAt time.Time + Webhook string +} + +// progressStore 存储所有 API 任务的进度信息 +type progressStore struct { + mu sync.RWMutex + tasks map[string]*TaskProgressInfo +} + +var store = &progressStore{ + tasks: make(map[string]*TaskProgressInfo), +} + +// RegisterTask 注册一个新的 API 任务 +func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo { + info := &TaskProgressInfo{ + TaskID: taskID, + Type: taskType, + Status: TaskStatusQueued, + Title: title, + Storage: storage, + Path: path, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + Webhook: webhook, + } + + store.mu.Lock() + store.tasks[taskID] = info + store.mu.Unlock() + + return info +} + +// GetTask 获取任务进度信息 +func GetTask(taskID string) (*TaskProgressInfo, bool) { + store.mu.RLock() + defer store.mu.RUnlock() + info, ok := store.tasks[taskID] + return info, ok +} + +// GetAllTasks 获取所有任务 +func GetAllTasks() []*TaskProgressInfo { + store.mu.RLock() + defer store.mu.RUnlock() + + tasks := make([]*TaskProgressInfo, 0, len(store.tasks)) + for _, info := range store.tasks { + tasks = append(tasks, info) + } + return tasks +} + +// DeleteTask 删除任务记录 +func DeleteTask(taskID string) { + store.mu.Lock() + defer store.mu.Unlock() + delete(store.tasks, taskID) +} + +// UpdateStatus 更新任务状态 +func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) { + t.Status = status + t.UpdatedAt = time.Now() +} + +// SetError 设置错误信息 +func (t *TaskProgressInfo) SetError(err string) { + t.Error = err + t.Status = TaskStatusFailed + t.UpdatedAt = time.Now() +} + +// ProgressTracker 用于 API 任务的进度追踪 +type ProgressTracker struct { + info *TaskProgressInfo +} + +// NewProgressTracker 创建新的进度追踪器 +func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker { + info := RegisterTask(taskID, taskType, storage, path, title, webhook) + return &ProgressTracker{info: info} +} + +// OnStart 任务开始 +func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) { + p.info.Status = TaskStatusRunning + p.info.TotalBytes = totalBytes + p.info.TotalFiles = totalFiles + p.info.UpdatedAt = time.Now() +} + +// OnProgress 进度更新 +func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) { + atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes) + p.info.DownloadedFiles = downloadedFiles + p.info.UpdatedAt = time.Now() +} + +// OnDone 任务完成 +func (p *ProgressTracker) OnDone(err error) { + if err != nil { + p.info.Status = TaskStatusFailed + p.info.Error = err.Error() + } else { + p.info.Status = TaskStatusCompleted + } + p.info.UpdatedAt = time.Now() +} + +// GetInfo 获取任务信息 +func (p *ProgressTracker) GetInfo() *TaskProgressInfo { + return p.info +} + +// UpdateProgressBytes 更新下载字节数 +func (p *ProgressTracker) UpdateProgressBytes(bytes int64) { + atomic.StoreInt64(&p.info.DownloadedBytes, bytes) + p.info.UpdatedAt = time.Now() +} + +// UpdateProgressFiles 更新下载文件数 +func (p *ProgressTracker) UpdateProgressFiles(files int) { + p.info.DownloadedFiles = files + p.info.UpdatedAt = time.Now() +} diff --git a/api/server.go b/api/server.go new file mode 100644 index 0000000..5e65e29 --- /dev/null +++ b/api/server.go @@ -0,0 +1,163 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/config" +) + +// Server API 服务器 +type Server struct { + httpServer *http.Server + factory *TaskFactory +} + +// NewServer 创建新的 API 服务器 +func NewServer(ctx context.Context) *Server { + cfg := config.C().API + + factory := NewTaskFactory(ctx) + handlers := NewHandlers(factory) + + // 设置路由 + mux := http.NewServeMux() + + // 健康检查 + mux.HandleFunc("/health", handlers.HealthCheckHandler) + + // API v1 路由 + mux.HandleFunc("/api/v1/tasks", handlers.CreateTaskHandler) + mux.HandleFunc("/api/v1/tasks/", func(w http.ResponseWriter, r *http.Request) { + // 根据方法和路径分发 + switch r.Method { + case http.MethodGet: + if r.URL.Path == "/api/v1/tasks" { + handlers.ListTasksHandler(w, r) + } else { + handlers.GetTaskHandler(w, r) + } + case http.MethodDelete: + handlers.CancelTaskHandler(w, r) + default: + MethodNotAllowedHandler(w, r) + } + }) + mux.HandleFunc("/api/v1/storages", handlers.ListStoragesHandler) + mux.HandleFunc("/api/v1/task-types", handlers.GetTaskTypesHandler) + + // 404 处理 + mux.HandleFunc("/", NotFoundHandler) + + // 应用中间件 + var handler http.Handler = mux + + // 添加认证中间件 + token := cfg.Token + if token == "" { + log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!") + } + if token != "" { + handler = AuthMiddleware()(handler) + } + + // 添加日志中间件 + handler = loggingMiddleware(handler) + + // 添加恢复中间件 + handler = recoveryMiddleware(handler) + + return &Server{ + httpServer: &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Handler: handler, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + }, + factory: factory, + } +} + +// Start 启动服务器 +func (s *Server) Start(ctx context.Context) error { + logger := log.FromContext(ctx).With("module", "api") + + logger.Infof("Starting API server on %s", s.httpServer.Addr) + + // 在 goroutine 中启动服务器 + go func() { + if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Errorf("API server error: %v", err) + } + }() + + // 监听 context 取消 + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.httpServer.Shutdown(shutdownCtx); err != nil { + logger.Errorf("API server shutdown error: %v", err) + } + }() + + return nil +} + +// loggingMiddleware 日志中间件 +func loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // 包装 ResponseWriter 以获取状态码 + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + next.ServeHTTP(wrapped, r) + + log.Infof("%s %s %d %s", r.Method, r.URL.Path, wrapped.statusCode, time.Since(start)) + }) +} + +// recoveryMiddleware 恢复中间件 +func recoveryMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Errorf("Panic recovered: %v", err) + WriteError(w, http.StatusInternalServerError, "internal_error", "internal server error") + } + }() + next.ServeHTTP(w, r) + }) +} + +// responseWriter 包装 http.ResponseWriter 以捕获状态码 +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +// Start 初始化并启动 API 服务器 +func Start(ctx context.Context) error { + cfg := config.C().API + + if !cfg.Enable { + return nil + } + + if cfg.Token == "" { + log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!") + } + + server := NewServer(ctx) + return server.Start(ctx) +} diff --git a/api/tgfiles.go b/api/tgfiles.go new file mode 100644 index 0000000..0338457 --- /dev/null +++ b/api/tgfiles.go @@ -0,0 +1,272 @@ +package api + +import ( + "context" + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/client/bot" + userclient "github.com/krau/SaveAny-Bot/client/user" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" + "github.com/krau/SaveAny-Bot/pkg/tfile" +) + +// MessageContext 保存消息和获取它所用的 context +type MessageContext struct { + Message *tg.Message + Client *ext.Context +} + +// getClientContext 获取可用的客户端上下文 +// 优先使用 Bot,失败后回退到 Userbot +func getClientContext() (*ext.Context, error) { + // 首先尝试获取 Bot context + if botCtx := bot.ExtContext(); botCtx != nil { + return botCtx, nil + } + + // 回退到 Userbot + if uc := userclient.GetCtx(); uc != nil { + return uc, nil + } + + return nil, fmt.Errorf("no client available (bot and userbot are not initialized)") +} + +// resolveChatID 解析聊天 ID +func resolveChatID(_ context.Context, idOrUsername string) (int64, error) { + // 如果是数字 ID + if id, err := strconv.ParseInt(idOrUsername, 10, 64); err == nil { + // 私有频道 ID 需要加上 -100 前缀 + if id > 0 { + return -1000000000000 - id, nil + } + return id, nil + } + + // 获取可用的客户端上下文 + clientCtx, err := getClientContext() + if err != nil { + return 0, err + } + + // 使用 tgutil 的 ParseChatID + return tgutil.ParseChatID(clientCtx, idOrUsername) +} + +// ParseMessageLink 解析 Telegram 消息链接 +// 支持格式: +// - https://t.me/username/123 +// - https://t.me/c/123456789/123 +// - https://t.me/c/123456789/111/456 (topic id) +// - https://t.me/username/123?comment=2 (评论) +func ParseMessageLink(ctx context.Context, link string) (int64, int, error) { + u, err := url.Parse(link) + if err != nil { + return 0, 0, fmt.Errorf("invalid URL: %w", err) + } + paths := strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + + if cmt := u.Query().Get("comment"); cmt != "" { + // 频道评论的消息链接 + if len(paths) < 1 { + return 0, 0, fmt.Errorf("invalid message link format: %s", link) + } + // 简化处理:返回错误,提示不支持评论链接 + return 0, 0, fmt.Errorf("comment links are not supported") + } + + switch len(paths) { + case 2: // https://t.me/username/123 + chatID, err := resolveChatID(ctx, paths[0]) + if err != nil { + return 0, 0, fmt.Errorf("failed to resolve chat ID: %w", err) + } + msgID, err := strconv.Atoi(paths[1]) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse message ID: %w", err) + } + return chatID, msgID, nil + case 3: + // https://t.me/c/123456789/123 + // https://t.me/username/123/456 , 123: topic id + chatPart, msgPart := paths[1], paths[2] + if paths[0] != "c" { + chatPart = paths[0] + } + chatID, err := resolveChatID(ctx, chatPart) + if err != nil { + return 0, 0, fmt.Errorf("failed to resolve chat ID: %w", err) + } + msgID, err := strconv.Atoi(msgPart) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse message ID: %w", err) + } + return chatID, msgID, nil + case 4: + // https://t.me/c/123456789/111/456 111: topic id + if paths[0] != "c" { + return 0, 0, fmt.Errorf("invalid message link format: %s", link) + } + chatID, err := resolveChatID(ctx, paths[1]) + if err != nil { + return 0, 0, fmt.Errorf("failed to resolve chat ID: %w", err) + } + msgID, err := strconv.Atoi(paths[3]) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse message ID: %w", err) + } + return chatID, msgID, nil + } + return 0, 0, fmt.Errorf("invalid message link format: %s", link) +} + +// getMessageWithContext 通过 ID 获取消息,返回消息和使用的 context +// 确保消息获取和后续文件创建使用同一个 context +func getMessageWithContext(_ context.Context, chatID int64, msgID int) (*MessageContext, error) { + // 首先尝试使用 Bot + if botCtx := bot.ExtContext(); botCtx != nil { + msg, err := tgutil.GetMessageByID(botCtx, chatID, msgID) + if err == nil { + return &MessageContext{Message: msg, Client: botCtx}, nil + } + } + + // 回退到 Userbot + uc := userclient.GetCtx() + if uc == nil { + return nil, fmt.Errorf("userbot not initialized and bot cannot access this message") + } + + msg, err := tgutil.GetMessageByID(uc, chatID, msgID) + if err != nil { + return nil, err + } + + return &MessageContext{Message: msg, Client: uc}, nil +} + +// getGroupedMessagesWithContext 获取媒体组消息,返回消息列表和使用的 context +// 确保消息获取和后续文件创建使用同一个 context +func getGroupedMessagesWithContext(ctx *MessageContext, chatID int64) ([]*tg.Message, error) { + msg := ctx.Message + clientCtx := ctx.Client + + groupID, ok := msg.GetGroupedID() + if !ok || groupID == 0 { + return []*tg.Message{msg}, nil + } + + // 使用获取原始消息的同一个 client 获取媒体组 + msgs, err := tgutil.GetGroupedMessages(clientCtx, chatID, msg) + if err != nil || len(msgs) == 0 { + // 如果获取失败,至少返回原始消息 + return []*tg.Message{msg}, nil + } + + return msgs, nil +} + +// ExtractFilesFromLinks 从消息链接中提取文件 +// 每个文件的处理流程:解析链接 -> 获取消息 -> 获取媒体组 -> 创建文件对象 +// 对于单个文件,全程使用同一个 client context,不会交叉 +func ExtractFilesFromLinks(ctx context.Context, links []string) ([]tfile.TGFileMessage, error) { + logger := log.FromContext(ctx) + var files []tfile.TGFileMessage + + for _, link := range links { + link = strings.TrimSpace(link) + if link == "" { + continue + } + + // 验证链接格式 + if !isValidMessageLink(link) { + logger.Errorf("Invalid message link format: %s", link) + continue + } + + chatID, msgID, err := ParseMessageLink(ctx, link) + if err != nil { + logger.Errorf("Failed to parse message link %s: %v", link, err) + continue + } + + // 解析链接 URL 检查是否有 single 参数 + u, _ := url.Parse(link) + single := u != nil && u.Query().Has("single") + + // 获取消息和使用的 context(Bot 优先,失败回退 Userbot) + msgCtx, err := getMessageWithContext(ctx, chatID, msgID) + if err != nil { + logger.Errorf("Failed to get message %d from chat %d: %v", msgID, chatID, err) + continue + } + + msg := msgCtx.Message + clientCtx := msgCtx.Client + + if msg.Media == nil { + logger.Warnf("Message %d has no media", msgID) + continue + } + + media, ok := msg.GetMedia() + if !ok { + logger.Warnf("Failed to get media from message %d", msgID) + continue + } + + // 检查是否是媒体组 + groupID, isGroup := msg.GetGroupedID() + if isGroup && groupID != 0 && !single { + // 使用同一个 client context 获取媒体组 + groupMsgs, err := getGroupedMessagesWithContext(msgCtx, chatID) + if err != nil { + logger.Errorf("Failed to get grouped messages: %v", err) + } else { + for _, gmsg := range groupMsgs { + if gmsg.Media == nil { + continue + } + gmedia, ok := gmsg.GetMedia() + if !ok { + continue + } + // 使用获取消息时使用的同一个 client context 创建文件 + file, err := tfile.FromMediaMessage(gmedia, clientCtx.Raw, gmsg) + if err != nil { + logger.Errorf("Failed to create file from media: %v", err) + continue + } + files = append(files, file) + } + continue + } + } + + // 单个文件 - 使用获取消息时使用的同一个 client context 创建文件 + file, err := tfile.FromMediaMessage(media, clientCtx.Raw, msg) + if err != nil { + logger.Errorf("Failed to create file from media: %v", err) + continue + } + files = append(files, file) + } + + if len(files) == 0 { + return nil, fmt.Errorf("no files found in provided links") + } + + return files, nil +} + +// isValidMessageLink 检查是否是有效的 Telegram 消息链接 +func isValidMessageLink(link string) bool { + return strings.HasPrefix(link, "https://t.me/") || strings.HasPrefix(link, "http://t.me/") +} diff --git a/api/tphpics.go b/api/tphpics.go new file mode 100644 index 0000000..3e95537 --- /dev/null +++ b/api/tphpics.go @@ -0,0 +1,80 @@ +package api + +import ( + "context" + "fmt" + "net/url" + "strings" + + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/common/utils/tphutil" + "github.com/krau/SaveAny-Bot/pkg/telegraph" +) + +// ExtractTelegraphImages 从 Telegraph URL 提取图片 +func ExtractTelegraphImages(ctx context.Context, pageURL string) ([]string, string, error) { + logger := log.FromContext(ctx) + + // 验证 URL 格式 + if !isValidTelegraphURL(pageURL) { + return nil, "", fmt.Errorf("invalid telegraph URL format: %s", pageURL) + } + + // 解析 URL 获取页面路径 + pagepath, err := parseTelegraphPath(pageURL) + if err != nil { + return nil, "", err + } + + logger.Debugf("Fetching telegraph page: %s", pagepath) + + client := telegraph.NewClient() + page, err := client.GetPage(ctx, pagepath) + if err != nil { + return nil, "", fmt.Errorf("failed to get telegraph page: %w", err) + } + + var imgs []string + for _, elem := range page.Content { + imgs = append(imgs, tphutil.GetNodeImages(elem)...) + } + + if len(imgs) == 0 { + return nil, "", fmt.Errorf("no images found in telegraph page") + } + + return imgs, pagepath, nil +} + +// parseTelegraphPath 解析 Telegraph URL 获取页面路径 +func parseTelegraphPath(pageURL string) (string, error) { + u, err := url.Parse(pageURL) + if err != nil { + return "", fmt.Errorf("invalid telegraph URL: %w", err) + } + + if !strings.HasSuffix(u.Host, "telegra.ph") && !strings.HasSuffix(u.Host, "telegraph.co") { + return "", fmt.Errorf("invalid telegraph URL host: %s", u.Host) + } + + paths := strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + if len(paths) == 0 || paths[0] == "" { + return "", fmt.Errorf("invalid telegraph URL path: %s", u.Path) + } + + pagepath := paths[len(paths)-1] + pagepath, err = url.PathUnescape(pagepath) + if err != nil { + return "", fmt.Errorf("failed to unescape telegraph path: %w", err) + } + + return strings.TrimSpace(pagepath), nil +} + +// isValidTelegraphURL 检查是否是有效的 Telegraph URL +func isValidTelegraphURL(url string) bool { + return strings.HasPrefix(url, "https://telegra.ph/") || + strings.HasPrefix(url, "http://telegra.ph/") || + strings.HasPrefix(url, "https://telegraph.co/") || + strings.HasPrefix(url, "http://telegraph.co/") +} diff --git a/api/types.go b/api/types.go new file mode 100644 index 0000000..5462f6c --- /dev/null +++ b/api/types.go @@ -0,0 +1,161 @@ +package api + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" +) + +// TaskStatus 表示任务状态 +type TaskStatus string + +const ( + TaskStatusQueued TaskStatus = "queued" + TaskStatusRunning TaskStatus = "running" + TaskStatusCompleted TaskStatus = "completed" + TaskStatusFailed TaskStatus = "failed" + TaskStatusCancelled TaskStatus = "cancelled" +) + +// CreateTaskRequest 创建任务请求 +type CreateTaskRequest struct { + Type tasktype.TaskType `json:"type"` + Storage string `json:"storage"` + Path string `json:"path"` + Webhook string `json:"webhook,omitempty"` + Params json.RawMessage `json:"params"` +} + +// CreateTaskResponse 创建任务响应 +type CreateTaskResponse struct { + TaskID string `json:"task_id"` + Type tasktype.TaskType `json:"type"` + Status TaskStatus `json:"status"` + CreatedAt time.Time `json:"created_at"` +} + +// TaskProgress 任务进度 +type TaskProgress struct { + TotalBytes int64 `json:"total_bytes,omitempty"` + DownloadedBytes int64 `json:"downloaded_bytes,omitempty"` + Percent float64 `json:"percent,omitempty"` + SpeedMBPS float64 `json:"speed_mbps,omitempty"` +} + +// TaskInfoResponse 任务信息响应 +type TaskInfoResponse struct { + TaskID string `json:"task_id"` + Type tasktype.TaskType `json:"type"` + Status TaskStatus `json:"status"` + Title string `json:"title"` + Progress *TaskProgress `json:"progress,omitempty"` + Storage string `json:"storage"` + Path string `json:"path"` + Error string `json:"error,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TasksListResponse 任务列表响应 +type TasksListResponse struct { + Tasks []TaskInfoResponse `json:"tasks"` + Total int `json:"total"` +} + +// StoragesResponse 存储列表响应 +type StoragesResponse struct { + Storages []StorageInfo `json:"storages"` +} + +// StorageInfo 存储信息 +type StorageInfo struct { + Name string `json:"name"` + Type string `json:"type"` +} + +// WebhookPayload Webhook 回调负载 +type WebhookPayload struct { + TaskID string `json:"task_id"` + Type string `json:"type"` + Status TaskStatus `json:"status"` + Storage string `json:"storage"` + Path string `json:"path"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + Error string `json:"error,omitempty"` +} + +// ErrorResponse 错误响应 +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message,omitempty"` +} + +// APIError API 错误 +type APIError struct { + StatusCode int + ErrorCode string + Message string +} + +func (e *APIError) Error() string { + return e.Message +} + +// WriteJSON 写入 JSON 响应 +func WriteJSON(w http.ResponseWriter, statusCode int, data any) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + return json.NewEncoder(w).Encode(data) +} + +// WriteError 写入错误响应 +func WriteError(w http.ResponseWriter, statusCode int, errCode, message string) error { + return WriteJSON(w, statusCode, ErrorResponse{ + Error: errCode, + Message: message, + }) +} + +// Task 参数结构体 + +// DirectLinksParams directlinks 任务参数 +type DirectLinksParams struct { + URLs []string `json:"urls"` +} + +// YTDLPParams ytdlp 任务参数 +type YTDLPParams struct { + URLs []string `json:"urls"` + Flags []string `json:"flags,omitempty"` +} + +// Aria2Params aria2 任务参数 +type Aria2Params struct { + URLs []string `json:"urls"` + Options map[string]string `json:"options,omitempty"` +} + +// ParsedParams parsed 任务参数 +type ParsedParams struct { + URL string `json:"url"` +} + +// TransferParams transfer 任务参数 +type TransferParams struct { + SourceStorage string `json:"source_storage"` + SourcePath string `json:"source_path"` + TargetStorage string `json:"target_storage"` + TargetPath string `json:"target_path"` +} + +// TGFilesParams tgfiles 任务参数 +type TGFilesParams struct { + MessageLinks []string `json:"message_links"` +} + +// TPHPicsParams tphpics 任务参数 +type TPHPicsParams struct { + TelegraphURL string `json:"telegraph_url"` +} diff --git a/api/webhook.go b/api/webhook.go new file mode 100644 index 0000000..16a50c4 --- /dev/null +++ b/api/webhook.go @@ -0,0 +1,130 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/charmbracelet/log" +) + +// webhookClient Webhook 客户端 +var webhookClient = &http.Client{ + Timeout: 30 * time.Second, +} + +// SendWebhook 发送 Webhook 回调 +func SendWebhook(ctx context.Context, payload *WebhookPayload) { + if payload == nil || payload.TaskID == "" { + return + } + + // 获取任务信息以获取 webhook URL + info, ok := GetTask(payload.TaskID) + if !ok || info.Webhook == "" { + return + } + + webhookURL := info.Webhook + + // 异步发送 webhook + go func() { + logger := log.FromContext(ctx).With("task_id", payload.TaskID) + + payloadBytes, err := json.Marshal(payload) + if err != nil { + logger.Errorf("Failed to marshal webhook payload: %v", err) + return + } + + // 重试 3 次 + for i := range 3 { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes)) + if err != nil { + logger.Errorf("Failed to create webhook request: %v", err) + return + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "SaveAny-Bot/1.0") + + resp, err := webhookClient.Do(req) + if err != nil { + logger.Warnf("Webhook request failed (attempt %d/3): %v", i+1, err) + time.Sleep(time.Second * time.Duration(i+1)) + continue + } + resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + logger.Debugf("Webhook sent successfully: %s", webhookURL) + return + } + + logger.Warnf("Webhook returned non-2xx status (attempt %d/3): %d", i+1, resp.StatusCode) + time.Sleep(time.Second * time.Duration(i+1)) + } + + logger.Errorf("Failed to send webhook after 3 attempts") + }() +} + +// CreateWebhookPayload 创建 Webhook 负载 +func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload { + payload := &WebhookPayload{ + TaskID: taskID, + Type: taskType, + Status: status, + Storage: storage, + Path: path, + } + + if status == TaskStatusCompleted || status == TaskStatusFailed { + now := time.Now() + payload.CompletedAt = &now + } + + if err != nil { + payload.Error = err.Error() + } + + return payload +} + +// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调 +func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error { + info, ok := GetTask(taskID) + if !ok { + return fmt.Errorf("task not found: %s", taskID) + } + + err := fn() + + // 确定任务状态 + status := TaskStatusCompleted + if err != nil { + if err == context.Canceled { + status = TaskStatusCancelled + } else { + status = TaskStatusFailed + } + } + + // 更新任务状态 + if err != nil { + info.SetError(err.Error()) + } else { + info.UpdateStatus(TaskStatusCompleted) + } + + // 发送 webhook + if info.Webhook != "" { + payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err) + SendWebhook(ctx, payload) + } + + return err +} diff --git a/cmd/run.go b/cmd/run.go index a0537de..3030411 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -10,6 +10,7 @@ import ( "slices" "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/api" "github.com/krau/SaveAny-Bot/client/bot" userclient "github.com/krau/SaveAny-Bot/client/user" "github.com/krau/SaveAny-Bot/common/cache" @@ -76,6 +77,9 @@ func initAll(ctx context.Context, cmd *cobra.Command) (<-chan struct{}, error) { logger.Fatal("User login failed", "error", err) } } + if err := api.Start(ctx); err != nil { + logger.Error("Failed to start API server", "error", err) + } return bot.Init(ctx), nil } diff --git a/config.example.toml b/config.example.toml index bf83165..9abc9c3 100644 --- a/config.example.toml +++ b/config.example.toml @@ -29,6 +29,17 @@ secret = "" # 转存完成后删除 Aria2 下载的本地文件 remove_after_transfer = true +# HTTP API 配置 +[api] +# 启用 HTTP API +enable = false +# 监听地址 +host = "0.0.0.0" +# 监听端口 +port = 8080 +# 认证 Token (必需) +token = "" + # 存储列表 [[storages]] # 标识名, 需要唯一 diff --git a/config/viper.go b/config/viper.go index c7756da..ad0c0f3 100644 --- a/config/viper.go +++ b/config/viper.go @@ -24,6 +24,7 @@ type Config struct { Stream bool `toml:"stream" mapstructure:"stream" json:"stream"` Proxy string `toml:"proxy" mapstructure:"proxy" json:"proxy"` Aria2 aria2Config `toml:"aria2" mapstructure:"aria2" json:"aria2"` + API apiConfig `toml:"api" mapstructure:"api" json:"api"` Cache cacheConfig `toml:"cache" mapstructure:"cache" json:"cache"` Users []userConfig `toml:"users" mapstructure:"users" json:"users"` @@ -42,6 +43,13 @@ type aria2Config struct { KeepFile bool `toml:"keep_file" mapstructure:"keep_file" json:"keep_file"` } +type apiConfig struct { + Enable bool `toml:"enable" mapstructure:"enable" json:"enable"` + Host string `toml:"host" mapstructure:"host" json:"host"` + Port int `toml:"port" mapstructure:"port" json:"port"` + Token string `toml:"token" mapstructure:"token" json:"token"` +} + var cfg = &Config{} func C() Config { @@ -115,6 +123,12 @@ func Init(ctx context.Context, configFile ...string) error { // 数据库 "db.path": "data/saveany.db", "db.session": "data/session.db", + + // API + "api.enable": false, + "api.host": "0.0.0.0", + "api.port": 8080, + "api.token": "", } for key, value := range defaultConfigs {