Files
SaveAny-Bot/api/handlers.go
krau 3eb3b6e3c8 feat(api): implement task management API with handlers for creating, listing, retrieving, and canceling tasks
- Added Handlers struct and methods for task operations
- Implemented task progress tracking and storage
- Created server setup with middleware for logging and recovery
- Added support for Telegram file extraction and Telegraph image extraction
- Introduced webhook functionality for task status updates
- Defined request and response types for API interactions
2026-03-05 19:11:30 +08:00

223 lines
5.9 KiB
Go

package api
import (
"encoding/json"
"net/http"
"strings"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
"github.com/krau/SaveAny-Bot/storage"
)
// Handlers 处理器结构体
type Handlers struct {
factory *TaskFactory
}
// NewHandlers 创建处理器
func NewHandlers(factory *TaskFactory) *Handlers {
return &Handlers{factory: factory}
}
// CreateTaskHandler 创建任务处理器
func (h *Handlers) CreateTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only POST method is allowed")
return
}
var req CreateTaskRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error())
return
}
// 验证请求
if req.Type == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "task type is required")
return
}
if req.Storage == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "storage is required")
return
}
// 创建任务
resp, err := h.factory.CreateTask(&req)
if err != nil {
WriteError(w, http.StatusBadRequest, "task_creation_failed", err.Error())
return
}
WriteJSON(w, http.StatusCreated, resp)
}
// ListTasksHandler 列出任务处理器
func (h *Handlers) ListTasksHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
tasks := GetAllTasks()
response := make([]TaskInfoResponse, 0, len(tasks))
for _, task := range tasks {
info := convertTaskProgressToResponse(task)
response = append(response, info)
}
WriteJSON(w, http.StatusOK, TasksListResponse{
Tasks: response,
Total: len(response),
})
}
// GetTaskHandler 获取单个任务处理器
func (h *Handlers) GetTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
taskID := extractTaskIDFromPath(r.URL.Path)
if taskID == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
return
}
task, ok := GetTask(taskID)
if !ok {
WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
return
}
resp := convertTaskProgressToResponse(task)
WriteJSON(w, http.StatusOK, resp)
}
// CancelTaskHandler 取消任务处理器
func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only DELETE method is allowed")
return
}
taskID := extractTaskIDFromPath(r.URL.Path)
if taskID == "" {
WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
return
}
task, ok := GetTask(taskID)
if !ok {
WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
return
}
// 取消任务
if err := core.CancelTask(r.Context(), taskID); err != nil {
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
return
}
task.UpdateStatus(TaskStatusCancelled)
WriteJSON(w, http.StatusOK, map[string]string{"message": "task cancelled successfully"})
}
// ListStoragesHandler 列出存储处理器
func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
storages := make([]StorageInfo, 0, len(storage.Storages))
for name, stor := range storage.Storages {
storages = append(storages, StorageInfo{
Name: name,
Type: string(stor.Type()),
})
}
WriteJSON(w, http.StatusOK, StoragesResponse{Storages: storages})
}
// GetTaskTypesHandler 获取支持的任务类型
func (h *Handlers) GetTaskTypesHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
return
}
types := []tasktype.TaskType{
tasktype.TaskTypeDirectlinks,
tasktype.TaskTypeYtdlp,
tasktype.TaskTypeAria2,
tasktype.TaskTypeParseditem,
tasktype.TaskTypeTgfiles,
tasktype.TaskTypeTphpics,
tasktype.TaskTypeTransfer,
}
WriteJSON(w, http.StatusOK, map[string]any{
"types": types,
})
}
// HealthCheckHandler 健康检查处理器
func (h *Handlers) HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
WriteJSON(w, http.StatusOK, map[string]string{
"status": "ok",
})
}
// extractTaskIDFromPath 从路径中提取任务 ID
// 路径格式: /api/v1/tasks/:id
func extractTaskIDFromPath(path string) string {
parts := strings.Split(strings.Trim(path, "/"), "/")
if len(parts) < 4 {
return ""
}
return parts[3]
}
// convertTaskProgressToResponse 将任务进度转换为响应格式
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
resp := TaskInfoResponse{
TaskID: task.TaskID,
Type: tasktype.TaskType(task.Type),
Status: task.Status,
Title: task.Title,
Storage: task.Storage,
Path: task.Path,
Error: task.Error,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
}
// 计算进度
if task.TotalBytes > 0 {
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
resp.Progress = &TaskProgress{
TotalBytes: task.TotalBytes,
DownloadedBytes: task.DownloadedBytes,
Percent: percent,
}
}
return resp
}
// NotFoundHandler 404 处理器
func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, http.StatusNotFound, "not_found", "endpoint not found: "+r.URL.Path)
}
// MethodNotAllowedHandler 405 处理器
func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "method not allowed: "+r.Method)
}