- Added Handlers struct and methods for task operations - Implemented task progress tracking and storage - Created server setup with middleware for logging and recovery - Added support for Telegram file extraction and Telegraph image extraction - Introduced webhook functionality for task status updates - Defined request and response types for API interactions
356 lines
10 KiB
Go
356 lines
10 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/krau/SaveAny-Bot/config"
|
|
"github.com/krau/SaveAny-Bot/core"
|
|
"github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
|
|
"github.com/krau/SaveAny-Bot/core/tasks/batchtfile"
|
|
"github.com/krau/SaveAny-Bot/core/tasks/directlinks"
|
|
"github.com/krau/SaveAny-Bot/core/tasks/parsed"
|
|
tphtask "github.com/krau/SaveAny-Bot/core/tasks/telegraph"
|
|
"github.com/krau/SaveAny-Bot/core/tasks/tfile"
|
|
"github.com/krau/SaveAny-Bot/core/tasks/transfer"
|
|
"github.com/krau/SaveAny-Bot/core/tasks/ytdlp"
|
|
"github.com/krau/SaveAny-Bot/parsers/parsers"
|
|
"github.com/krau/SaveAny-Bot/pkg/aria2"
|
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
|
"github.com/krau/SaveAny-Bot/pkg/parser"
|
|
"github.com/krau/SaveAny-Bot/pkg/telegraph"
|
|
"github.com/krau/SaveAny-Bot/storage"
|
|
"github.com/rs/xid"
|
|
)
|
|
|
|
// TaskFactory 任务工厂
|
|
type TaskFactory struct {
|
|
ctx context.Context
|
|
}
|
|
|
|
// NewTaskFactory 创建任务工厂
|
|
func NewTaskFactory(ctx context.Context) *TaskFactory {
|
|
return &TaskFactory{ctx: ctx}
|
|
}
|
|
|
|
// CreateTask 创建任务
|
|
func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, error) {
|
|
// 验证存储
|
|
stor, ok := storage.Storages[req.Storage]
|
|
if !ok {
|
|
return nil, fmt.Errorf("storage not found: %s", req.Storage)
|
|
}
|
|
|
|
taskID := xid.New().String()
|
|
createdAt := time.Now()
|
|
|
|
switch req.Type {
|
|
case tasktype.TaskTypeDirectlinks:
|
|
return f.createDirectLinksTask(taskID, createdAt, req, stor)
|
|
case tasktype.TaskTypeYtdlp:
|
|
return f.createYTDLPTask(taskID, createdAt, req, stor)
|
|
case tasktype.TaskTypeAria2:
|
|
return f.createAria2Task(taskID, createdAt, req, stor)
|
|
case tasktype.TaskTypeParseditem:
|
|
return f.createParsedTask(taskID, createdAt, req, stor)
|
|
case tasktype.TaskTypeTgfiles:
|
|
return f.createTGFilesTask(taskID, createdAt, req, stor)
|
|
case tasktype.TaskTypeTphpics:
|
|
return f.createTPHPicsTask(taskID, createdAt, req, stor)
|
|
case tasktype.TaskTypeTransfer:
|
|
return f.createTransferTask(taskID, createdAt, req)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported task type: %s", req.Type)
|
|
}
|
|
}
|
|
|
|
// createDirectLinksTask 创建直链下载任务
|
|
func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
|
|
var params DirectLinksParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
return nil, fmt.Errorf("invalid params: %w", err)
|
|
}
|
|
|
|
if len(params.URLs) == 0 {
|
|
return nil, fmt.Errorf("no URLs provided")
|
|
}
|
|
|
|
task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil)
|
|
|
|
if err := core.AddTask(f.ctx, task); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
|
|
return &CreateTaskResponse{
|
|
TaskID: taskID,
|
|
Type: tasktype.TaskTypeDirectlinks,
|
|
Status: TaskStatusQueued,
|
|
CreatedAt: createdAt,
|
|
}, nil
|
|
}
|
|
|
|
// createYTDLPTask 创建 yt-dlp 任务
|
|
func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
|
|
var params YTDLPParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
return nil, fmt.Errorf("invalid params: %w", err)
|
|
}
|
|
|
|
if len(params.URLs) == 0 {
|
|
return nil, fmt.Errorf("no URLs provided")
|
|
}
|
|
|
|
task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil)
|
|
|
|
if err := core.AddTask(f.ctx, task); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
|
|
return &CreateTaskResponse{
|
|
TaskID: taskID,
|
|
Type: tasktype.TaskTypeYtdlp,
|
|
Status: TaskStatusQueued,
|
|
CreatedAt: createdAt,
|
|
}, nil
|
|
}
|
|
|
|
// createAria2Task 创建 Aria2 任务
|
|
func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
|
|
var params Aria2Params
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
return nil, fmt.Errorf("invalid params: %w", err)
|
|
}
|
|
|
|
if len(params.URLs) == 0 {
|
|
return nil, fmt.Errorf("no URLs provided")
|
|
}
|
|
|
|
// 检查 Aria2 是否启用
|
|
cfg := config.C().Aria2
|
|
if !cfg.Enable {
|
|
return nil, fmt.Errorf("aria2 is not enabled")
|
|
}
|
|
|
|
aria2Client, err := aria2.NewClient(cfg.Url, cfg.Secret)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create aria2 client: %w", err)
|
|
}
|
|
|
|
// 添加下载任务到 Aria2
|
|
gid, err := aria2Client.AddURI(f.ctx, params.URLs, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to add aria2 task: %w", err)
|
|
}
|
|
|
|
task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil)
|
|
|
|
if err := core.AddTask(f.ctx, task); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
|
|
return &CreateTaskResponse{
|
|
TaskID: taskID,
|
|
Type: tasktype.TaskTypeAria2,
|
|
Status: TaskStatusQueued,
|
|
CreatedAt: createdAt,
|
|
}, nil
|
|
}
|
|
|
|
// createParsedTask 创建解析任务
|
|
func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
|
|
var params ParsedParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
return nil, fmt.Errorf("invalid params: %w", err)
|
|
}
|
|
|
|
if params.URL == "" {
|
|
return nil, fmt.Errorf("no URL provided")
|
|
}
|
|
|
|
// 查找合适的解析器
|
|
var p parser.Parser
|
|
for _, parserItem := range parsers.Get() {
|
|
if parserItem.CanHandle(params.URL) {
|
|
p = parserItem
|
|
break
|
|
}
|
|
}
|
|
|
|
if p == nil {
|
|
return nil, fmt.Errorf("no parser found for URL: %s", params.URL)
|
|
}
|
|
|
|
// 解析 URL
|
|
item, err := p.Parse(f.ctx, params.URL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse URL: %w", err)
|
|
}
|
|
|
|
task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil)
|
|
|
|
if err := core.AddTask(f.ctx, task); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
|
|
return &CreateTaskResponse{
|
|
TaskID: taskID,
|
|
Type: tasktype.TaskTypeParseditem,
|
|
Status: TaskStatusQueued,
|
|
CreatedAt: createdAt,
|
|
}, nil
|
|
}
|
|
|
|
// createTGFilesTask 创建 Telegram 文件下载任务
|
|
func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
|
|
var params TGFilesParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
return nil, fmt.Errorf("invalid params: %w", err)
|
|
}
|
|
|
|
if len(params.MessageLinks) == 0 {
|
|
return nil, fmt.Errorf("no message links provided")
|
|
}
|
|
|
|
// 提取文件
|
|
files, err := ExtractFilesFromLinks(f.ctx, params.MessageLinks)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to extract files: %w", err)
|
|
}
|
|
|
|
if len(files) == 0 {
|
|
return nil, fmt.Errorf("no files found in provided links")
|
|
}
|
|
|
|
if len(files) == 1 {
|
|
// 单个文件任务
|
|
tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create tfile task: %w", err)
|
|
}
|
|
if err := core.AddTask(f.ctx, tfileTask); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
} else {
|
|
// 批量文件任务
|
|
elems := make([]batchtfile.TaskElement, 0, len(files))
|
|
for _, file := range files {
|
|
elem, err := batchtfile.NewTaskElement(stor, req.Path, file)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create task element: %w", err)
|
|
}
|
|
elems = append(elems, *elem)
|
|
}
|
|
|
|
task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
|
|
if err := core.AddTask(f.ctx, task); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
}
|
|
|
|
return &CreateTaskResponse{
|
|
TaskID: taskID,
|
|
Type: tasktype.TaskTypeTgfiles,
|
|
Status: TaskStatusQueued,
|
|
CreatedAt: createdAt,
|
|
}, nil
|
|
}
|
|
|
|
// createTPHPicsTask 创建 Telegraph 图片下载任务
|
|
func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
|
|
var params TPHPicsParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
return nil, fmt.Errorf("invalid params: %w", err)
|
|
}
|
|
|
|
if params.TelegraphURL == "" {
|
|
return nil, fmt.Errorf("no telegraph URL provided")
|
|
}
|
|
|
|
// 提取图片
|
|
pics, phPath, err := ExtractTelegraphImages(f.ctx, params.TelegraphURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to extract telegraph images: %w", err)
|
|
}
|
|
|
|
if len(pics) == 0 {
|
|
return nil, fmt.Errorf("no images found in telegraph page")
|
|
}
|
|
|
|
client := telegraph.NewClient()
|
|
task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil)
|
|
|
|
if err := core.AddTask(f.ctx, task); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
|
|
return &CreateTaskResponse{
|
|
TaskID: taskID,
|
|
Type: tasktype.TaskTypeTphpics,
|
|
Status: TaskStatusQueued,
|
|
CreatedAt: createdAt,
|
|
}, nil
|
|
}
|
|
|
|
// createTransferTask 创建存储间传输任务
|
|
func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req *CreateTaskRequest) (*CreateTaskResponse, error) {
|
|
var params TransferParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
return nil, fmt.Errorf("invalid params: %w", err)
|
|
}
|
|
|
|
// 验证源存储和目标存储
|
|
sourceStor, ok := storage.Storages[params.SourceStorage]
|
|
if !ok {
|
|
return nil, fmt.Errorf("source storage not found: %s", params.SourceStorage)
|
|
}
|
|
|
|
targetStor, ok := storage.Storages[params.TargetStorage]
|
|
if !ok {
|
|
return nil, fmt.Errorf("target storage not found: %s", params.TargetStorage)
|
|
}
|
|
|
|
// 检查源存储是否可读
|
|
sourceReadable, ok := sourceStor.(storage.StorageReadable)
|
|
if !ok {
|
|
return nil, fmt.Errorf("source storage does not support reading: %s", params.SourceStorage)
|
|
}
|
|
|
|
// 检查源存储是否可列
|
|
sourceListable, ok := sourceStor.(storage.StorageListable)
|
|
if !ok {
|
|
return nil, fmt.Errorf("source storage does not support listing: %s", params.SourceStorage)
|
|
}
|
|
|
|
// 列出源文件
|
|
files, err := sourceListable.ListFiles(f.ctx, params.SourcePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list source files: %w", err)
|
|
}
|
|
|
|
if len(files) == 0 {
|
|
return nil, fmt.Errorf("no files found at source path: %s", params.SourcePath)
|
|
}
|
|
|
|
// 创建传输元素
|
|
elems := make([]transfer.TaskElement, 0, len(files))
|
|
for _, file := range files {
|
|
elem := transfer.NewTaskElement(sourceReadable, file, targetStor, params.TargetPath)
|
|
elems = append(elems, *elem)
|
|
}
|
|
|
|
task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true)
|
|
|
|
if err := core.AddTask(f.ctx, task); err != nil {
|
|
return nil, fmt.Errorf("failed to add task: %w", err)
|
|
}
|
|
|
|
return &CreateTaskResponse{
|
|
TaskID: taskID,
|
|
Type: tasktype.TaskTypeTransfer,
|
|
Status: TaskStatusQueued,
|
|
CreatedAt: createdAt,
|
|
}, nil
|
|
}
|