Files
SaveAny-Bot/api/factory.go
krau 3eb3b6e3c8 feat(api): implement task management API with handlers for creating, listing, retrieving, and canceling tasks
- Added Handlers struct and methods for task operations
- Implemented task progress tracking and storage
- Created server setup with middleware for logging and recovery
- Added support for Telegram file extraction and Telegraph image extraction
- Introduced webhook functionality for task status updates
- Defined request and response types for API interactions
2026-03-05 19:11:30 +08:00

356 lines
10 KiB
Go

package api
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
"github.com/krau/SaveAny-Bot/core/tasks/batchtfile"
"github.com/krau/SaveAny-Bot/core/tasks/directlinks"
"github.com/krau/SaveAny-Bot/core/tasks/parsed"
tphtask "github.com/krau/SaveAny-Bot/core/tasks/telegraph"
"github.com/krau/SaveAny-Bot/core/tasks/tfile"
"github.com/krau/SaveAny-Bot/core/tasks/transfer"
"github.com/krau/SaveAny-Bot/core/tasks/ytdlp"
"github.com/krau/SaveAny-Bot/parsers/parsers"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/parser"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
// TaskFactory 任务工厂
type TaskFactory struct {
ctx context.Context
}
// NewTaskFactory 创建任务工厂
func NewTaskFactory(ctx context.Context) *TaskFactory {
return &TaskFactory{ctx: ctx}
}
// CreateTask 创建任务
func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, error) {
// 验证存储
stor, ok := storage.Storages[req.Storage]
if !ok {
return nil, fmt.Errorf("storage not found: %s", req.Storage)
}
taskID := xid.New().String()
createdAt := time.Now()
switch req.Type {
case tasktype.TaskTypeDirectlinks:
return f.createDirectLinksTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeYtdlp:
return f.createYTDLPTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeAria2:
return f.createAria2Task(taskID, createdAt, req, stor)
case tasktype.TaskTypeParseditem:
return f.createParsedTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeTgfiles:
return f.createTGFilesTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeTphpics:
return f.createTPHPicsTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeTransfer:
return f.createTransferTask(taskID, createdAt, req)
default:
return nil, fmt.Errorf("unsupported task type: %s", req.Type)
}
}
// createDirectLinksTask 创建直链下载任务
func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params DirectLinksParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.URLs) == 0 {
return nil, fmt.Errorf("no URLs provided")
}
task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeDirectlinks,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createYTDLPTask 创建 yt-dlp 任务
func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params YTDLPParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.URLs) == 0 {
return nil, fmt.Errorf("no URLs provided")
}
task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeYtdlp,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createAria2Task 创建 Aria2 任务
func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params Aria2Params
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.URLs) == 0 {
return nil, fmt.Errorf("no URLs provided")
}
// 检查 Aria2 是否启用
cfg := config.C().Aria2
if !cfg.Enable {
return nil, fmt.Errorf("aria2 is not enabled")
}
aria2Client, err := aria2.NewClient(cfg.Url, cfg.Secret)
if err != nil {
return nil, fmt.Errorf("failed to create aria2 client: %w", err)
}
// 添加下载任务到 Aria2
gid, err := aria2Client.AddURI(f.ctx, params.URLs, nil)
if err != nil {
return nil, fmt.Errorf("failed to add aria2 task: %w", err)
}
task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeAria2,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createParsedTask 创建解析任务
func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params ParsedParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if params.URL == "" {
return nil, fmt.Errorf("no URL provided")
}
// 查找合适的解析器
var p parser.Parser
for _, parserItem := range parsers.Get() {
if parserItem.CanHandle(params.URL) {
p = parserItem
break
}
}
if p == nil {
return nil, fmt.Errorf("no parser found for URL: %s", params.URL)
}
// 解析 URL
item, err := p.Parse(f.ctx, params.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse URL: %w", err)
}
task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeParseditem,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createTGFilesTask 创建 Telegram 文件下载任务
func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params TGFilesParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.MessageLinks) == 0 {
return nil, fmt.Errorf("no message links provided")
}
// 提取文件
files, err := ExtractFilesFromLinks(f.ctx, params.MessageLinks)
if err != nil {
return nil, fmt.Errorf("failed to extract files: %w", err)
}
if len(files) == 0 {
return nil, fmt.Errorf("no files found in provided links")
}
if len(files) == 1 {
// 单个文件任务
tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil)
if err != nil {
return nil, fmt.Errorf("failed to create tfile task: %w", err)
}
if err := core.AddTask(f.ctx, tfileTask); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
} else {
// 批量文件任务
elems := make([]batchtfile.TaskElement, 0, len(files))
for _, file := range files {
elem, err := batchtfile.NewTaskElement(stor, req.Path, file)
if err != nil {
return nil, fmt.Errorf("failed to create task element: %w", err)
}
elems = append(elems, *elem)
}
task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeTgfiles,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createTPHPicsTask 创建 Telegraph 图片下载任务
func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params TPHPicsParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if params.TelegraphURL == "" {
return nil, fmt.Errorf("no telegraph URL provided")
}
// 提取图片
pics, phPath, err := ExtractTelegraphImages(f.ctx, params.TelegraphURL)
if err != nil {
return nil, fmt.Errorf("failed to extract telegraph images: %w", err)
}
if len(pics) == 0 {
return nil, fmt.Errorf("no images found in telegraph page")
}
client := telegraph.NewClient()
task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeTphpics,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createTransferTask 创建存储间传输任务
func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req *CreateTaskRequest) (*CreateTaskResponse, error) {
var params TransferParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
// 验证源存储和目标存储
sourceStor, ok := storage.Storages[params.SourceStorage]
if !ok {
return nil, fmt.Errorf("source storage not found: %s", params.SourceStorage)
}
targetStor, ok := storage.Storages[params.TargetStorage]
if !ok {
return nil, fmt.Errorf("target storage not found: %s", params.TargetStorage)
}
// 检查源存储是否可读
sourceReadable, ok := sourceStor.(storage.StorageReadable)
if !ok {
return nil, fmt.Errorf("source storage does not support reading: %s", params.SourceStorage)
}
// 检查源存储是否可列
sourceListable, ok := sourceStor.(storage.StorageListable)
if !ok {
return nil, fmt.Errorf("source storage does not support listing: %s", params.SourceStorage)
}
// 列出源文件
files, err := sourceListable.ListFiles(f.ctx, params.SourcePath)
if err != nil {
return nil, fmt.Errorf("failed to list source files: %w", err)
}
if len(files) == 0 {
return nil, fmt.Errorf("no files found at source path: %s", params.SourcePath)
}
// 创建传输元素
elems := make([]transfer.TaskElement, 0, len(files))
for _, file := range files {
elem := transfer.NewTaskElement(sourceReadable, file, targetStor, params.TargetPath)
elems = append(elems, *elem)
}
task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeTransfer,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}