- Added Handlers struct and methods for task operations - Implemented task progress tracking and storage - Created server setup with middleware for logging and recovery - Added support for Telegram file extraction and Telegraph image extraction - Introduced webhook functionality for task status updates - Defined request and response types for API interactions
223 lines
5.9 KiB
Go
223 lines
5.9 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/krau/SaveAny-Bot/core"
|
|
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
|
|
"github.com/krau/SaveAny-Bot/storage"
|
|
)
|
|
|
|
// Handlers 处理器结构体
|
|
type Handlers struct {
|
|
factory *TaskFactory
|
|
}
|
|
|
|
// NewHandlers 创建处理器
|
|
func NewHandlers(factory *TaskFactory) *Handlers {
|
|
return &Handlers{factory: factory}
|
|
}
|
|
|
|
// CreateTaskHandler 创建任务处理器
|
|
func (h *Handlers) CreateTaskHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only POST method is allowed")
|
|
return
|
|
}
|
|
|
|
var req CreateTaskRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
WriteError(w, http.StatusBadRequest, "invalid_request", "failed to decode request body: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// 验证请求
|
|
if req.Type == "" {
|
|
WriteError(w, http.StatusBadRequest, "invalid_request", "task type is required")
|
|
return
|
|
}
|
|
|
|
if req.Storage == "" {
|
|
WriteError(w, http.StatusBadRequest, "invalid_request", "storage is required")
|
|
return
|
|
}
|
|
|
|
// 创建任务
|
|
resp, err := h.factory.CreateTask(&req)
|
|
if err != nil {
|
|
WriteError(w, http.StatusBadRequest, "task_creation_failed", err.Error())
|
|
return
|
|
}
|
|
|
|
WriteJSON(w, http.StatusCreated, resp)
|
|
}
|
|
|
|
// ListTasksHandler 列出任务处理器
|
|
func (h *Handlers) ListTasksHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
|
|
return
|
|
}
|
|
|
|
tasks := GetAllTasks()
|
|
response := make([]TaskInfoResponse, 0, len(tasks))
|
|
|
|
for _, task := range tasks {
|
|
info := convertTaskProgressToResponse(task)
|
|
response = append(response, info)
|
|
}
|
|
|
|
WriteJSON(w, http.StatusOK, TasksListResponse{
|
|
Tasks: response,
|
|
Total: len(response),
|
|
})
|
|
}
|
|
|
|
// GetTaskHandler 获取单个任务处理器
|
|
func (h *Handlers) GetTaskHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
|
|
return
|
|
}
|
|
|
|
taskID := extractTaskIDFromPath(r.URL.Path)
|
|
if taskID == "" {
|
|
WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
|
|
return
|
|
}
|
|
|
|
task, ok := GetTask(taskID)
|
|
if !ok {
|
|
WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
|
|
return
|
|
}
|
|
|
|
resp := convertTaskProgressToResponse(task)
|
|
WriteJSON(w, http.StatusOK, resp)
|
|
}
|
|
|
|
// CancelTaskHandler 取消任务处理器
|
|
func (h *Handlers) CancelTaskHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodDelete {
|
|
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only DELETE method is allowed")
|
|
return
|
|
}
|
|
|
|
taskID := extractTaskIDFromPath(r.URL.Path)
|
|
if taskID == "" {
|
|
WriteError(w, http.StatusBadRequest, "invalid_request", "task ID is required")
|
|
return
|
|
}
|
|
|
|
task, ok := GetTask(taskID)
|
|
if !ok {
|
|
WriteError(w, http.StatusNotFound, "task_not_found", "task not found: "+taskID)
|
|
return
|
|
}
|
|
|
|
// 取消任务
|
|
if err := core.CancelTask(r.Context(), taskID); err != nil {
|
|
WriteError(w, http.StatusInternalServerError, "cancel_failed", "failed to cancel task: "+err.Error())
|
|
return
|
|
}
|
|
|
|
task.UpdateStatus(TaskStatusCancelled)
|
|
WriteJSON(w, http.StatusOK, map[string]string{"message": "task cancelled successfully"})
|
|
}
|
|
|
|
// ListStoragesHandler 列出存储处理器
|
|
func (h *Handlers) ListStoragesHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
|
|
return
|
|
}
|
|
|
|
storages := make([]StorageInfo, 0, len(storage.Storages))
|
|
for name, stor := range storage.Storages {
|
|
storages = append(storages, StorageInfo{
|
|
Name: name,
|
|
Type: string(stor.Type()),
|
|
})
|
|
}
|
|
|
|
WriteJSON(w, http.StatusOK, StoragesResponse{Storages: storages})
|
|
}
|
|
|
|
// GetTaskTypesHandler 获取支持的任务类型
|
|
func (h *Handlers) GetTaskTypesHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET method is allowed")
|
|
return
|
|
}
|
|
|
|
types := []tasktype.TaskType{
|
|
tasktype.TaskTypeDirectlinks,
|
|
tasktype.TaskTypeYtdlp,
|
|
tasktype.TaskTypeAria2,
|
|
tasktype.TaskTypeParseditem,
|
|
tasktype.TaskTypeTgfiles,
|
|
tasktype.TaskTypeTphpics,
|
|
tasktype.TaskTypeTransfer,
|
|
}
|
|
|
|
WriteJSON(w, http.StatusOK, map[string]any{
|
|
"types": types,
|
|
})
|
|
}
|
|
|
|
// HealthCheckHandler 健康检查处理器
|
|
func (h *Handlers) HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
|
WriteJSON(w, http.StatusOK, map[string]string{
|
|
"status": "ok",
|
|
})
|
|
}
|
|
|
|
// extractTaskIDFromPath 从路径中提取任务 ID
|
|
// 路径格式: /api/v1/tasks/:id
|
|
func extractTaskIDFromPath(path string) string {
|
|
parts := strings.Split(strings.Trim(path, "/"), "/")
|
|
if len(parts) < 4 {
|
|
return ""
|
|
}
|
|
return parts[3]
|
|
}
|
|
|
|
// convertTaskProgressToResponse 将任务进度转换为响应格式
|
|
func convertTaskProgressToResponse(task *TaskProgressInfo) TaskInfoResponse {
|
|
resp := TaskInfoResponse{
|
|
TaskID: task.TaskID,
|
|
Type: tasktype.TaskType(task.Type),
|
|
Status: task.Status,
|
|
Title: task.Title,
|
|
Storage: task.Storage,
|
|
Path: task.Path,
|
|
Error: task.Error,
|
|
CreatedAt: task.CreatedAt,
|
|
UpdatedAt: task.UpdatedAt,
|
|
}
|
|
|
|
// 计算进度
|
|
if task.TotalBytes > 0 {
|
|
percent := float64(task.DownloadedBytes) * 100 / float64(task.TotalBytes)
|
|
resp.Progress = &TaskProgress{
|
|
TotalBytes: task.TotalBytes,
|
|
DownloadedBytes: task.DownloadedBytes,
|
|
Percent: percent,
|
|
}
|
|
}
|
|
|
|
return resp
|
|
}
|
|
|
|
// NotFoundHandler 404 处理器
|
|
func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
|
|
WriteError(w, http.StatusNotFound, "not_found", "endpoint not found: "+r.URL.Path)
|
|
}
|
|
|
|
// MethodNotAllowedHandler 405 处理器
|
|
func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
|
|
WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "method not allowed: "+r.Method)
|
|
}
|