feat(api): implement task management API with handlers for creating, listing, retrieving, and canceling tasks

- Added Handlers struct and methods for task operations
- Implemented task progress tracking and storage
- Created server setup with middleware for logging and recovery
- Added support for Telegram file extraction and Telegraph image extraction
- Introduced webhook functionality for task status updates
- Defined request and response types for API interactions
This commit is contained in:
krau
2026-03-05 19:11:30 +08:00
parent f377ee3ca4
commit 3eb3b6e3c8
13 changed files with 1612 additions and 1 deletions

48
api/auth.go Normal file
View File

@@ -0,0 +1,48 @@
package api
import (
"context"
"crypto/subtle"
"net/http"
"strings"
"github.com/krau/SaveAny-Bot/config"
)
// tokenContextKey 用于在 context 中存储 token
type tokenContextKey struct{}
// AuthMiddleware 返回认证中间件
func AuthMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cfg := config.C().API
// 从请求头获取 token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
WriteError(w, http.StatusUnauthorized, "unauthorized", "missing authorization header")
return
}
// 提取 Bearer token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid authorization header format")
return
}
token := parts[1]
// 验证 token
if subtle.ConstantTimeCompare([]byte(token), []byte(cfg.Token)) != 1 {
WriteError(w, http.StatusUnauthorized, "unauthorized", "invalid token")
return
}
// 将 token 添加到 context
ctx := context.WithValue(r.Context(), tokenContextKey{}, token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

355
api/factory.go Normal file
View File

@@ -0,0 +1,355 @@
package api
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/core/tasks/aria2dl"
"github.com/krau/SaveAny-Bot/core/tasks/batchtfile"
"github.com/krau/SaveAny-Bot/core/tasks/directlinks"
"github.com/krau/SaveAny-Bot/core/tasks/parsed"
tphtask "github.com/krau/SaveAny-Bot/core/tasks/telegraph"
"github.com/krau/SaveAny-Bot/core/tasks/tfile"
"github.com/krau/SaveAny-Bot/core/tasks/transfer"
"github.com/krau/SaveAny-Bot/core/tasks/ytdlp"
"github.com/krau/SaveAny-Bot/parsers/parsers"
"github.com/krau/SaveAny-Bot/pkg/aria2"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/pkg/parser"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
"github.com/krau/SaveAny-Bot/storage"
"github.com/rs/xid"
)
// TaskFactory 任务工厂
type TaskFactory struct {
ctx context.Context
}
// NewTaskFactory 创建任务工厂
func NewTaskFactory(ctx context.Context) *TaskFactory {
return &TaskFactory{ctx: ctx}
}
// CreateTask 创建任务
func (f *TaskFactory) CreateTask(req *CreateTaskRequest) (*CreateTaskResponse, error) {
// 验证存储
stor, ok := storage.Storages[req.Storage]
if !ok {
return nil, fmt.Errorf("storage not found: %s", req.Storage)
}
taskID := xid.New().String()
createdAt := time.Now()
switch req.Type {
case tasktype.TaskTypeDirectlinks:
return f.createDirectLinksTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeYtdlp:
return f.createYTDLPTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeAria2:
return f.createAria2Task(taskID, createdAt, req, stor)
case tasktype.TaskTypeParseditem:
return f.createParsedTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeTgfiles:
return f.createTGFilesTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeTphpics:
return f.createTPHPicsTask(taskID, createdAt, req, stor)
case tasktype.TaskTypeTransfer:
return f.createTransferTask(taskID, createdAt, req)
default:
return nil, fmt.Errorf("unsupported task type: %s", req.Type)
}
}
// createDirectLinksTask 创建直链下载任务
func (f *TaskFactory) createDirectLinksTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params DirectLinksParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.URLs) == 0 {
return nil, fmt.Errorf("no URLs provided")
}
task := directlinks.NewTask(taskID, f.ctx, params.URLs, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeDirectlinks,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createYTDLPTask 创建 yt-dlp 任务
func (f *TaskFactory) createYTDLPTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params YTDLPParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.URLs) == 0 {
return nil, fmt.Errorf("no URLs provided")
}
task := ytdlp.NewTask(taskID, f.ctx, params.URLs, params.Flags, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeYtdlp,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createAria2Task 创建 Aria2 任务
func (f *TaskFactory) createAria2Task(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params Aria2Params
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.URLs) == 0 {
return nil, fmt.Errorf("no URLs provided")
}
// 检查 Aria2 是否启用
cfg := config.C().Aria2
if !cfg.Enable {
return nil, fmt.Errorf("aria2 is not enabled")
}
aria2Client, err := aria2.NewClient(cfg.Url, cfg.Secret)
if err != nil {
return nil, fmt.Errorf("failed to create aria2 client: %w", err)
}
// 添加下载任务到 Aria2
gid, err := aria2Client.AddURI(f.ctx, params.URLs, nil)
if err != nil {
return nil, fmt.Errorf("failed to add aria2 task: %w", err)
}
task := aria2dl.NewTask(taskID, f.ctx, gid, params.URLs, aria2Client, stor, req.Path, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeAria2,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createParsedTask 创建解析任务
func (f *TaskFactory) createParsedTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params ParsedParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if params.URL == "" {
return nil, fmt.Errorf("no URL provided")
}
// 查找合适的解析器
var p parser.Parser
for _, parserItem := range parsers.Get() {
if parserItem.CanHandle(params.URL) {
p = parserItem
break
}
}
if p == nil {
return nil, fmt.Errorf("no parser found for URL: %s", params.URL)
}
// 解析 URL
item, err := p.Parse(f.ctx, params.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse URL: %w", err)
}
task := parsed.NewTask(taskID, f.ctx, stor, req.Path, item, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeParseditem,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createTGFilesTask 创建 Telegram 文件下载任务
func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params TGFilesParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if len(params.MessageLinks) == 0 {
return nil, fmt.Errorf("no message links provided")
}
// 提取文件
files, err := ExtractFilesFromLinks(f.ctx, params.MessageLinks)
if err != nil {
return nil, fmt.Errorf("failed to extract files: %w", err)
}
if len(files) == 0 {
return nil, fmt.Errorf("no files found in provided links")
}
if len(files) == 1 {
// 单个文件任务
tfileTask, err := tfile.NewTGFileTask(taskID, f.ctx, files[0], stor, req.Path, nil)
if err != nil {
return nil, fmt.Errorf("failed to create tfile task: %w", err)
}
if err := core.AddTask(f.ctx, tfileTask); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
} else {
// 批量文件任务
elems := make([]batchtfile.TaskElement, 0, len(files))
for _, file := range files {
elem, err := batchtfile.NewTaskElement(stor, req.Path, file)
if err != nil {
return nil, fmt.Errorf("failed to create task element: %w", err)
}
elems = append(elems, *elem)
}
task := batchtfile.NewBatchTGFileTask(taskID, f.ctx, elems, nil, true)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeTgfiles,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createTPHPicsTask 创建 Telegraph 图片下载任务
func (f *TaskFactory) createTPHPicsTask(taskID string, createdAt time.Time, req *CreateTaskRequest, stor storage.Storage) (*CreateTaskResponse, error) {
var params TPHPicsParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
if params.TelegraphURL == "" {
return nil, fmt.Errorf("no telegraph URL provided")
}
// 提取图片
pics, phPath, err := ExtractTelegraphImages(f.ctx, params.TelegraphURL)
if err != nil {
return nil, fmt.Errorf("failed to extract telegraph images: %w", err)
}
if len(pics) == 0 {
return nil, fmt.Errorf("no images found in telegraph page")
}
client := telegraph.NewClient()
task := tphtask.NewTask(taskID, f.ctx, phPath, pics, stor, req.Path, client, nil)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeTphpics,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}
// createTransferTask 创建存储间传输任务
func (f *TaskFactory) createTransferTask(taskID string, createdAt time.Time, req *CreateTaskRequest) (*CreateTaskResponse, error) {
var params TransferParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
// 验证源存储和目标存储
sourceStor, ok := storage.Storages[params.SourceStorage]
if !ok {
return nil, fmt.Errorf("source storage not found: %s", params.SourceStorage)
}
targetStor, ok := storage.Storages[params.TargetStorage]
if !ok {
return nil, fmt.Errorf("target storage not found: %s", params.TargetStorage)
}
// 检查源存储是否可读
sourceReadable, ok := sourceStor.(storage.StorageReadable)
if !ok {
return nil, fmt.Errorf("source storage does not support reading: %s", params.SourceStorage)
}
// 检查源存储是否可列
sourceListable, ok := sourceStor.(storage.StorageListable)
if !ok {
return nil, fmt.Errorf("source storage does not support listing: %s", params.SourceStorage)
}
// 列出源文件
files, err := sourceListable.ListFiles(f.ctx, params.SourcePath)
if err != nil {
return nil, fmt.Errorf("failed to list source files: %w", err)
}
if len(files) == 0 {
return nil, fmt.Errorf("no files found at source path: %s", params.SourcePath)
}
// 创建传输元素
elems := make([]transfer.TaskElement, 0, len(files))
for _, file := range files {
elem := transfer.NewTaskElement(sourceReadable, file, targetStor, params.TargetPath)
elems = append(elems, *elem)
}
task := transfer.NewTransferTask(taskID, f.ctx, elems, nil, true)
if err := core.AddTask(f.ctx, task); err != nil {
return nil, fmt.Errorf("failed to add task: %w", err)
}
return &CreateTaskResponse{
TaskID: taskID,
Type: tasktype.TaskTypeTransfer,
Status: TaskStatusQueued,
CreatedAt: createdAt,
}, nil
}

222
api/handlers.go Normal file
View File

@@ -0,0 +1,222 @@
package api
import (
"encoding/json"
"net/http"
"strings"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/storage"
)
// Handlers 处理器结构体
type Handlers struct {
factory *TaskFactory
}
// NewHandlers 创建处理器
func NewHandlers(factory *TaskFactory) *Handlers {
return &Handlers{factory: factory}
}
// CreateTaskHandler 创建任务处理器
func (h *Handlers) CreateTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only POST method is allowed")
return
}
var req CreateTaskRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error())
return
}
// 验证请求
if req.Type == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "task type is required")
return
}
if req.Storage == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "storage is required")
return
}
// 创建任务
resp, err := h.factory.CreateTask(&req)
if err != nil {
WriteError(w, http.StatusBadRequest, "task_creation_failed", err.Error())
return
}
WriteJSON(w, http.StatusCreated, resp)
}
// ListTasksHandler 列出任务处理器
func (h *Handlers) ListTasksHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
tasks := GetAllTasks()
response := make([]TaskInfoResponse, 0, len(tasks))
for _, task := range tasks {
info := convertTaskProgressToResponse(task)
response = append(response, info)
}
WriteJSON(w, http.StatusOK, TasksListResponse{
Tasks: response,
Total: len(response),
})
}
// GetTaskHandler 获取单个任务处理器
func (h *Handlers) GetTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
taskID := extractTaskIDFromPath(r.URL.Path)
if taskID == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
return
}
task, ok := GetTask(taskID)
if !ok {
WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
return
}
resp := convertTaskProgressToResponse(task)
WriteJSON(w, http.StatusOK, resp)
}
// CancelTaskHandler 取消任务处理器
func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only DELETE method is allowed")
return
}
taskID := extractTaskIDFromPath(r.URL.Path)
if taskID == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
return
}
task, ok := GetTask(taskID)
if !ok {
WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
return
}
// 取消任务
if err := core.CancelTask(r.Context(), taskID); err != nil {
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
return
}
task.UpdateStatus(TaskStatusCancelled)
WriteJSON(w, http.StatusOK, map[string]string{"message": "task cancelled successfully"})
}
// ListStoragesHandler 列出存储处理器
func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
storages := make([]StorageInfo, 0, len(storage.Storages))
for name, stor := range storage.Storages {
storages = append(storages, StorageInfo{
Name: name,
Type: string(stor.Type()),
})
}
WriteJSON(w, http.StatusOK, StoragesResponse{Storages: storages})
}
// GetTaskTypesHandler 获取支持的任务类型
func (h *Handlers) GetTaskTypesHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
types := []tasktype.TaskType{
tasktype.TaskTypeDirectlinks,
tasktype.TaskTypeYtdlp,
tasktype.TaskTypeAria2,
tasktype.TaskTypeParseditem,
tasktype.TaskTypeTgfiles,
tasktype.TaskTypeTphpics,
tasktype.TaskTypeTransfer,
}
WriteJSON(w, http.StatusOK, map[string]any{
"types": types,
})
}
// HealthCheckHandler 健康检查处理器
func (h *Handlers) HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
WriteJSON(w, http.StatusOK, map[string]string{
"status": "ok",
})
}
// extractTaskIDFromPath 从路径中提取任务 ID
// 路径格式: /api/v1/tasks/:id
func extractTaskIDFromPath(path string) string {
parts := strings.Split(strings.Trim(path, "/"), "/")
if len(parts) < 4 {
return ""
}
return parts[3]
}
// convertTaskProgressToResponse 将任务进度转换为响应格式
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
resp := TaskInfoResponse{
TaskID: task.TaskID,
Type: tasktype.TaskType(task.Type),
Status: task.Status,
Title: task.Title,
Storage: task.Storage,
Path: task.Path,
Error: task.Error,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
}
// 计算进度
if task.TotalBytes > 0 {
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
resp.Progress = &TaskProgress{
TotalBytes: task.TotalBytes,
DownloadedBytes: task.DownloadedBytes,
Percent: percent,
}
}
return resp
}
// NotFoundHandler 404 处理器
func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, http.StatusNotFound, "not_found", "endpoint not found: "+r.URL.Path)
}
// MethodNotAllowedHandler 405 处理器
func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "method not allowed: "+r.Method)
}

150
api/progress.go Normal file
View File

@@ -0,0 +1,150 @@
package api
import (
"sync"
"sync/atomic"
"time"
)
// TaskProgressInfo 存储任务的进度信息
type TaskProgressInfo struct {
TaskID string
Type string
Status TaskStatus
Title string
TotalBytes int64
DownloadedBytes int64
TotalFiles int
DownloadedFiles int
Storage string
Path string
Error string
CreatedAt time.Time
UpdatedAt time.Time
Webhook string
}
// progressStore 存储所有 API 任务的进度信息
type progressStore struct {
mu sync.RWMutex
tasks map[string]*TaskProgressInfo
}
var store = &progressStore{
tasks: make(map[string]*TaskProgressInfo),
}
// RegisterTask 注册一个新的 API 任务
func RegisterTask(taskID, taskType, storage, path, title, webhook string) *TaskProgressInfo {
info := &TaskProgressInfo{
TaskID: taskID,
Type: taskType,
Status: TaskStatusQueued,
Title: title,
Storage: storage,
Path: path,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Webhook: webhook,
}
store.mu.Lock()
store.tasks[taskID] = info
store.mu.Unlock()
return info
}
// GetTask 获取任务进度信息
func GetTask(taskID string) (*TaskProgressInfo, bool) {
store.mu.RLock()
defer store.mu.RUnlock()
info, ok := store.tasks[taskID]
return info, ok
}
// GetAllTasks 获取所有任务
func GetAllTasks() []*TaskProgressInfo {
store.mu.RLock()
defer store.mu.RUnlock()
tasks := make([]*TaskProgressInfo, 0, len(store.tasks))
for _, info := range store.tasks {
tasks = append(tasks, info)
}
return tasks
}
// DeleteTask 删除任务记录
func DeleteTask(taskID string) {
store.mu.Lock()
defer store.mu.Unlock()
delete(store.tasks, taskID)
}
// UpdateStatus 更新任务状态
func (t *TaskProgressInfo) UpdateStatus(status TaskStatus) {
t.Status = status
t.UpdatedAt = time.Now()
}
// SetError 设置错误信息
func (t *TaskProgressInfo) SetError(err string) {
t.Error = err
t.Status = TaskStatusFailed
t.UpdatedAt = time.Now()
}
// ProgressTracker 用于 API 任务的进度追踪
type ProgressTracker struct {
info *TaskProgressInfo
}
// NewProgressTracker 创建新的进度追踪器
func NewProgressTracker(taskID, taskType, storage, path, title, webhook string) *ProgressTracker {
info := RegisterTask(taskID, taskType, storage, path, title, webhook)
return &ProgressTracker{info: info}
}
// OnStart 任务开始
func (p *ProgressTracker) OnStart(totalBytes int64, totalFiles int) {
p.info.Status = TaskStatusRunning
p.info.TotalBytes = totalBytes
p.info.TotalFiles = totalFiles
p.info.UpdatedAt = time.Now()
}
// OnProgress 进度更新
func (p *ProgressTracker) OnProgress(downloadedBytes int64, downloadedFiles int) {
atomic.StoreInt64(&p.info.DownloadedBytes, downloadedBytes)
p.info.DownloadedFiles = downloadedFiles
p.info.UpdatedAt = time.Now()
}
// OnDone 任务完成
func (p *ProgressTracker) OnDone(err error) {
if err != nil {
p.info.Status = TaskStatusFailed
p.info.Error = err.Error()
} else {
p.info.Status = TaskStatusCompleted
}
p.info.UpdatedAt = time.Now()
}
// GetInfo 获取任务信息
func (p *ProgressTracker) GetInfo() *TaskProgressInfo {
return p.info
}
// UpdateProgressBytes 更新下载字节数
func (p *ProgressTracker) UpdateProgressBytes(bytes int64) {
atomic.StoreInt64(&p.info.DownloadedBytes, bytes)
p.info.UpdatedAt = time.Now()
}
// UpdateProgressFiles 更新下载文件数
func (p *ProgressTracker) UpdateProgressFiles(files int) {
p.info.DownloadedFiles = files
p.info.UpdatedAt = time.Now()
}

163
api/server.go Normal file
View File

@@ -0,0 +1,163 @@
package api
import (
"context"
"fmt"
"net/http"
"time"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/config"
)
// Server API 服务器
type Server struct {
httpServer *http.Server
factory *TaskFactory
}
// NewServer 创建新的 API 服务器
func NewServer(ctx context.Context) *Server {
cfg := config.C().API
factory := NewTaskFactory(ctx)
handlers := NewHandlers(factory)
// 设置路由
mux := http.NewServeMux()
// 健康检查
mux.HandleFunc("/health", handlers.HealthCheckHandler)
// API v1 路由
mux.HandleFunc("/api/v1/tasks", handlers.CreateTaskHandler)
mux.HandleFunc("/api/v1/tasks/", func(w http.ResponseWriter, r *http.Request) {
// 根据方法和路径分发
switch r.Method {
case http.MethodGet:
if r.URL.Path == "/api/v1/tasks" {
handlers.ListTasksHandler(w, r)
} else {
handlers.GetTaskHandler(w, r)
}
case http.MethodDelete:
handlers.CancelTaskHandler(w, r)
default:
MethodNotAllowedHandler(w, r)
}
})
mux.HandleFunc("/api/v1/storages", handlers.ListStoragesHandler)
mux.HandleFunc("/api/v1/task-types", handlers.GetTaskTypesHandler)
// 404 处理
mux.HandleFunc("/", NotFoundHandler)
// 应用中间件
var handler http.Handler = mux
// 添加认证中间件
token := cfg.Token
if token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
}
if token != "" {
handler = AuthMiddleware()(handler)
}
// 添加日志中间件
handler = loggingMiddleware(handler)
// 添加恢复中间件
handler = recoveryMiddleware(handler)
return &Server{
httpServer: &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: handler,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
},
factory: factory,
}
}
// Start 启动服务器
func (s *Server) Start(ctx context.Context) error {
logger := log.FromContext(ctx).With("module", "api")
logger.Infof("Starting API server on %s", s.httpServer.Addr)
// 在 goroutine 中启动服务器
go func() {
if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Errorf("API server error: %v", err)
}
}()
// 监听 context 取消
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.httpServer.Shutdown(shutdownCtx); err != nil {
logger.Errorf("API server shutdown error: %v", err)
}
}()
return nil
}
// loggingMiddleware 日志中间件
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装 ResponseWriter 以获取状态码
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapped, r)
log.Infof("%s %s %d %s", r.Method, r.URL.Path, wrapped.statusCode, time.Since(start))
})
}
// recoveryMiddleware 恢复中间件
func recoveryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
log.Errorf("Panic recovered: %v", err)
WriteError(w, http.StatusInternalServerError, "internal_error", "internal server error")
}
}()
next.ServeHTTP(w, r)
})
}
// responseWriter 包装 http.ResponseWriter 以捕获状态码
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// Start 初始化并启动 API 服务器
func Start(ctx context.Context) error {
cfg := config.C().API
if !cfg.Enable {
return nil
}
if cfg.Token == "" {
log.FromContext(ctx).Warn("API server is enabled but no token is set, this is insecure!")
}
server := NewServer(ctx)
return server.Start(ctx)
}

272
api/tgfiles.go Normal file
View File

@@ -0,0 +1,272 @@
package api
import (
"context"
"fmt"
"net/url"
"strconv"
"strings"
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
"github.com/krau/SaveAny-Bot/client/bot"
userclient "github.com/krau/SaveAny-Bot/client/user"
"github.com/krau/SaveAny-Bot/common/utils/tgutil"
"github.com/krau/SaveAny-Bot/pkg/tfile"
)
// MessageContext 保存消息和获取它所用的 context
type MessageContext struct {
Message *tg.Message
Client *ext.Context
}
// getClientContext 获取可用的客户端上下文
// 优先使用 Bot失败后回退到 Userbot
func getClientContext() (*ext.Context, error) {
// 首先尝试获取 Bot context
if botCtx := bot.ExtContext(); botCtx != nil {
return botCtx, nil
}
// 回退到 Userbot
if uc := userclient.GetCtx(); uc != nil {
return uc, nil
}
return nil, fmt.Errorf("no client available (bot and userbot are not initialized)")
}
// resolveChatID 解析聊天 ID
func resolveChatID(_ context.Context, idOrUsername string) (int64, error) {
// 如果是数字 ID
if id, err := strconv.ParseInt(idOrUsername, 10, 64); err == nil {
// 私有频道 ID 需要加上 -100 前缀
if id > 0 {
return -1000000000000 - id, nil
}
return id, nil
}
// 获取可用的客户端上下文
clientCtx, err := getClientContext()
if err != nil {
return 0, err
}
// 使用 tgutil 的 ParseChatID
return tgutil.ParseChatID(clientCtx, idOrUsername)
}
// ParseMessageLink 解析 Telegram 消息链接
// 支持格式:
// - https://t.me/username/123
// - https://t.me/c/123456789/123
// - https://t.me/c/123456789/111/456 (topic id)
// - https://t.me/username/123?comment=2 (评论)
func ParseMessageLink(ctx context.Context, link string) (int64, int, error) {
u, err := url.Parse(link)
if err != nil {
return 0, 0, fmt.Errorf("invalid URL: %w", err)
}
paths := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if cmt := u.Query().Get("comment"); cmt != "" {
// 频道评论的消息链接
if len(paths) < 1 {
return 0, 0, fmt.Errorf("invalid message link format: %s", link)
}
// 简化处理:返回错误,提示不支持评论链接
return 0, 0, fmt.Errorf("comment links are not supported")
}
switch len(paths) {
case 2: // https://t.me/username/123
chatID, err := resolveChatID(ctx, paths[0])
if err != nil {
return 0, 0, fmt.Errorf("failed to resolve chat ID: %w", err)
}
msgID, err := strconv.Atoi(paths[1])
if err != nil {
return 0, 0, fmt.Errorf("failed to parse message ID: %w", err)
}
return chatID, msgID, nil
case 3:
// https://t.me/c/123456789/123
// https://t.me/username/123/456 , 123: topic id
chatPart, msgPart := paths[1], paths[2]
if paths[0] != "c" {
chatPart = paths[0]
}
chatID, err := resolveChatID(ctx, chatPart)
if err != nil {
return 0, 0, fmt.Errorf("failed to resolve chat ID: %w", err)
}
msgID, err := strconv.Atoi(msgPart)
if err != nil {
return 0, 0, fmt.Errorf("failed to parse message ID: %w", err)
}
return chatID, msgID, nil
case 4:
// https://t.me/c/123456789/111/456 111: topic id
if paths[0] != "c" {
return 0, 0, fmt.Errorf("invalid message link format: %s", link)
}
chatID, err := resolveChatID(ctx, paths[1])
if err != nil {
return 0, 0, fmt.Errorf("failed to resolve chat ID: %w", err)
}
msgID, err := strconv.Atoi(paths[3])
if err != nil {
return 0, 0, fmt.Errorf("failed to parse message ID: %w", err)
}
return chatID, msgID, nil
}
return 0, 0, fmt.Errorf("invalid message link format: %s", link)
}
// getMessageWithContext 通过 ID 获取消息,返回消息和使用的 context
// 确保消息获取和后续文件创建使用同一个 context
func getMessageWithContext(_ context.Context, chatID int64, msgID int) (*MessageContext, error) {
// 首先尝试使用 Bot
if botCtx := bot.ExtContext(); botCtx != nil {
msg, err := tgutil.GetMessageByID(botCtx, chatID, msgID)
if err == nil {
return &MessageContext{Message: msg, Client: botCtx}, nil
}
}
// 回退到 Userbot
uc := userclient.GetCtx()
if uc == nil {
return nil, fmt.Errorf("userbot not initialized and bot cannot access this message")
}
msg, err := tgutil.GetMessageByID(uc, chatID, msgID)
if err != nil {
return nil, err
}
return &MessageContext{Message: msg, Client: uc}, nil
}
// getGroupedMessagesWithContext 获取媒体组消息,返回消息列表和使用的 context
// 确保消息获取和后续文件创建使用同一个 context
func getGroupedMessagesWithContext(ctx *MessageContext, chatID int64) ([]*tg.Message, error) {
msg := ctx.Message
clientCtx := ctx.Client
groupID, ok := msg.GetGroupedID()
if !ok || groupID == 0 {
return []*tg.Message{msg}, nil
}
// 使用获取原始消息的同一个 client 获取媒体组
msgs, err := tgutil.GetGroupedMessages(clientCtx, chatID, msg)
if err != nil || len(msgs) == 0 {
// 如果获取失败,至少返回原始消息
return []*tg.Message{msg}, nil
}
return msgs, nil
}
// ExtractFilesFromLinks 从消息链接中提取文件
// 每个文件的处理流程:解析链接 -> 获取消息 -> 获取媒体组 -> 创建文件对象
// 对于单个文件,全程使用同一个 client context不会交叉
func ExtractFilesFromLinks(ctx context.Context, links []string) ([]tfile.TGFileMessage, error) {
logger := log.FromContext(ctx)
var files []tfile.TGFileMessage
for _, link := range links {
link = strings.TrimSpace(link)
if link == "" {
continue
}
// 验证链接格式
if !isValidMessageLink(link) {
logger.Errorf("Invalid message link format: %s", link)
continue
}
chatID, msgID, err := ParseMessageLink(ctx, link)
if err != nil {
logger.Errorf("Failed to parse message link %s: %v", link, err)
continue
}
// 解析链接 URL 检查是否有 single 参数
u, _ := url.Parse(link)
single := u != nil && u.Query().Has("single")
// 获取消息和使用的 contextBot 优先,失败回退 Userbot
msgCtx, err := getMessageWithContext(ctx, chatID, msgID)
if err != nil {
logger.Errorf("Failed to get message %d from chat %d: %v", msgID, chatID, err)
continue
}
msg := msgCtx.Message
clientCtx := msgCtx.Client
if msg.Media == nil {
logger.Warnf("Message %d has no media", msgID)
continue
}
media, ok := msg.GetMedia()
if !ok {
logger.Warnf("Failed to get media from message %d", msgID)
continue
}
// 检查是否是媒体组
groupID, isGroup := msg.GetGroupedID()
if isGroup && groupID != 0 && !single {
// 使用同一个 client context 获取媒体组
groupMsgs, err := getGroupedMessagesWithContext(msgCtx, chatID)
if err != nil {
logger.Errorf("Failed to get grouped messages: %v", err)
} else {
for _, gmsg := range groupMsgs {
if gmsg.Media == nil {
continue
}
gmedia, ok := gmsg.GetMedia()
if !ok {
continue
}
// 使用获取消息时使用的同一个 client context 创建文件
file, err := tfile.FromMediaMessage(gmedia, clientCtx.Raw, gmsg)
if err != nil {
logger.Errorf("Failed to create file from media: %v", err)
continue
}
files = append(files, file)
}
continue
}
}
// 单个文件 - 使用获取消息时使用的同一个 client context 创建文件
file, err := tfile.FromMediaMessage(media, clientCtx.Raw, msg)
if err != nil {
logger.Errorf("Failed to create file from media: %v", err)
continue
}
files = append(files, file)
}
if len(files) == 0 {
return nil, fmt.Errorf("no files found in provided links")
}
return files, nil
}
// isValidMessageLink 检查是否是有效的 Telegram 消息链接
func isValidMessageLink(link string) bool {
return strings.HasPrefix(link, "https://t.me/") || strings.HasPrefix(link, "http://t.me/")
}

80
api/tphpics.go Normal file
View File

@@ -0,0 +1,80 @@
package api
import (
"context"
"fmt"
"net/url"
"strings"
"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/common/utils/tphutil"
"github.com/krau/SaveAny-Bot/pkg/telegraph"
)
// ExtractTelegraphImages 从 Telegraph URL 提取图片
func ExtractTelegraphImages(ctx context.Context, pageURL string) ([]string, string, error) {
logger := log.FromContext(ctx)
// 验证 URL 格式
if !isValidTelegraphURL(pageURL) {
return nil, "", fmt.Errorf("invalid telegraph URL format: %s", pageURL)
}
// 解析 URL 获取页面路径
pagepath, err := parseTelegraphPath(pageURL)
if err != nil {
return nil, "", err
}
logger.Debugf("Fetching telegraph page: %s", pagepath)
client := telegraph.NewClient()
page, err := client.GetPage(ctx, pagepath)
if err != nil {
return nil, "", fmt.Errorf("failed to get telegraph page: %w", err)
}
var imgs []string
for _, elem := range page.Content {
imgs = append(imgs, tphutil.GetNodeImages(elem)...)
}
if len(imgs) == 0 {
return nil, "", fmt.Errorf("no images found in telegraph page")
}
return imgs, pagepath, nil
}
// parseTelegraphPath 解析 Telegraph URL 获取页面路径
func parseTelegraphPath(pageURL string) (string, error) {
u, err := url.Parse(pageURL)
if err != nil {
return "", fmt.Errorf("invalid telegraph URL: %w", err)
}
if !strings.HasSuffix(u.Host, "telegra.ph") && !strings.HasSuffix(u.Host, "telegraph.co") {
return "", fmt.Errorf("invalid telegraph URL host: %s", u.Host)
}
paths := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(paths) == 0 || paths[0] == "" {
return "", fmt.Errorf("invalid telegraph URL path: %s", u.Path)
}
pagepath := paths[len(paths)-1]
pagepath, err = url.PathUnescape(pagepath)
if err != nil {
return "", fmt.Errorf("failed to unescape telegraph path: %w", err)
}
return strings.TrimSpace(pagepath), nil
}
// isValidTelegraphURL 检查是否是有效的 Telegraph URL
func isValidTelegraphURL(url string) bool {
return strings.HasPrefix(url, "https://telegra.ph/") ||
strings.HasPrefix(url, "http://telegra.ph/") ||
strings.HasPrefix(url, "https://telegraph.co/") ||
strings.HasPrefix(url, "http://telegraph.co/")
}

161
api/types.go Normal file
View File

@@ -0,0 +1,161 @@
package api
import (
"encoding/json"
"net/http"
"time"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
// TaskStatus 表示任务状态
type TaskStatus string
const (
TaskStatusQueued TaskStatus = "queued"
TaskStatusRunning TaskStatus = "running"
TaskStatusCompleted TaskStatus = "completed"
TaskStatusFailed TaskStatus = "failed"
TaskStatusCancelled TaskStatus = "cancelled"
)
// CreateTaskRequest 创建任务请求
type CreateTaskRequest struct {
Type tasktype.TaskType `json:"type"`
Storage string `json:"storage"`
Path string `json:"path"`
Webhook string `json:"webhook,omitempty"`
Params json.RawMessage `json:"params"`
}
// CreateTaskResponse 创建任务响应
type CreateTaskResponse struct {
TaskID string `json:"task_id"`
Type tasktype.TaskType `json:"type"`
Status TaskStatus `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// TaskProgress 任务进度
type TaskProgress struct {
TotalBytes int64 `json:"total_bytes,omitempty"`
DownloadedBytes int64 `json:"downloaded_bytes,omitempty"`
Percent float64 `json:"percent,omitempty"`
SpeedMBPS float64 `json:"speed_mbps,omitempty"`
}
// TaskInfoResponse 任务信息响应
type TaskInfoResponse struct {
TaskID string `json:"task_id"`
Type tasktype.TaskType `json:"type"`
Status TaskStatus `json:"status"`
Title string `json:"title"`
Progress *TaskProgress `json:"progress,omitempty"`
Storage string `json:"storage"`
Path string `json:"path"`
Error string `json:"error,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TasksListResponse 任务列表响应
type TasksListResponse struct {
Tasks []TaskInfoResponse `json:"tasks"`
Total int `json:"total"`
}
// StoragesResponse 存储列表响应
type StoragesResponse struct {
Storages []StorageInfo `json:"storages"`
}
// StorageInfo 存储信息
type StorageInfo struct {
Name string `json:"name"`
Type string `json:"type"`
}
// WebhookPayload Webhook 回调负载
type WebhookPayload struct {
TaskID string `json:"task_id"`
Type string `json:"type"`
Status TaskStatus `json:"status"`
Storage string `json:"storage"`
Path string `json:"path"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
Error string `json:"error,omitempty"`
}
// ErrorResponse 错误响应
type ErrorResponse struct {
Error string `json:"error"`
Message string `json:"message,omitempty"`
}
// APIError API 错误
type APIError struct {
StatusCode int
ErrorCode string
Message string
}
func (e *APIError) Error() string {
return e.Message
}
// WriteJSON 写入 JSON 响应
func WriteJSON(w http.ResponseWriter, statusCode int, data any) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
return json.NewEncoder(w).Encode(data)
}
// WriteError 写入错误响应
func WriteError(w http.ResponseWriter, statusCode int, errCode, message string) error {
return WriteJSON(w, statusCode, ErrorResponse{
Error: errCode,
Message: message,
})
}
// Task 参数结构体
// DirectLinksParams directlinks 任务参数
type DirectLinksParams struct {
URLs []string `json:"urls"`
}
// YTDLPParams ytdlp 任务参数
type YTDLPParams struct {
URLs []string `json:"urls"`
Flags []string `json:"flags,omitempty"`
}
// Aria2Params aria2 任务参数
type Aria2Params struct {
URLs []string `json:"urls"`
Options map[string]string `json:"options,omitempty"`
}
// ParsedParams parsed 任务参数
type ParsedParams struct {
URL string `json:"url"`
}
// TransferParams transfer 任务参数
type TransferParams struct {
SourceStorage string `json:"source_storage"`
SourcePath string `json:"source_path"`
TargetStorage string `json:"target_storage"`
TargetPath string `json:"target_path"`
}
// TGFilesParams tgfiles 任务参数
type TGFilesParams struct {
MessageLinks []string `json:"message_links"`
}
// TPHPicsParams tphpics 任务参数
type TPHPicsParams struct {
TelegraphURL string `json:"telegraph_url"`
}

130
api/webhook.go Normal file
View File

@@ -0,0 +1,130 @@
package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/charmbracelet/log"
)
// webhookClient Webhook 客户端
var webhookClient = &http.Client{
Timeout: 30 * time.Second,
}
// SendWebhook 发送 Webhook 回调
func SendWebhook(ctx context.Context, payload *WebhookPayload) {
if payload == nil || payload.TaskID == "" {
return
}
// 获取任务信息以获取 webhook URL
info, ok := GetTask(payload.TaskID)
if !ok || info.Webhook == "" {
return
}
webhookURL := info.Webhook
// 异步发送 webhook
go func() {
logger := log.FromContext(ctx).With("task_id", payload.TaskID)
payloadBytes, err := json.Marshal(payload)
if err != nil {
logger.Errorf("Failed to marshal webhook payload: %v", err)
return
}
// 重试 3 次
for i := range 3 {
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
if err != nil {
logger.Errorf("Failed to create webhook request: %v", err)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "SaveAny-Bot/1.0")
resp, err := webhookClient.Do(req)
if err != nil {
logger.Warnf("Webhook request failed (attempt %d/3): %v", i+1, err)
time.Sleep(time.Second * time.Duration(i+1))
continue
}
resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
logger.Debugf("Webhook sent successfully: %s", webhookURL)
return
}
logger.Warnf("Webhook returned non-2xx status (attempt %d/3): %d", i+1, resp.StatusCode)
time.Sleep(time.Second * time.Duration(i+1))
}
logger.Errorf("Failed to send webhook after 3 attempts")
}()
}
// CreateWebhookPayload 创建 Webhook 负载
func CreateWebhookPayload(taskID string, taskType string, status TaskStatus, storage, path string, err error) *WebhookPayload {
payload := &WebhookPayload{
TaskID: taskID,
Type: taskType,
Status: status,
Storage: storage,
Path: path,
}
if status == TaskStatusCompleted || status == TaskStatusFailed {
now := time.Now()
payload.CompletedAt = &now
}
if err != nil {
payload.Error = err.Error()
}
return payload
}
// WrapTaskWithWebhook 包装任务执行,添加 webhook 回调
func WrapTaskWithWebhook(ctx context.Context, taskID string, fn func() error) error {
info, ok := GetTask(taskID)
if !ok {
return fmt.Errorf("task not found: %s", taskID)
}
err := fn()
// 确定任务状态
status := TaskStatusCompleted
if err != nil {
if err == context.Canceled {
status = TaskStatusCancelled
} else {
status = TaskStatusFailed
}
}
// 更新任务状态
if err != nil {
info.SetError(err.Error())
} else {
info.UpdateStatus(TaskStatusCompleted)
}
// 发送 webhook
if info.Webhook != "" {
payload := CreateWebhookPayload(taskID, info.Type, status, info.Storage, info.Path, err)
SendWebhook(ctx, payload)
}
return err
}