Files
SaveAny-Bot/api/handlers_test.go
2026-03-05 19:40:25 +08:00

690 lines
17 KiB
Go

package api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
)
// setupTestServer creates a test server with handlers
func setupTestServer(t *testing.T) (*Handlers, *TaskFactory) {
factory := NewTaskFactory(t.Context())
handlers := NewHandlers(factory)
return handlers, factory
}
// TestCreateTaskHandler tests the create task endpoint
func TestCreateTaskHandler(t *testing.T) {
handlers, _ := setupTestServer(t)
tests := []struct {
name string
method string
body any
wantStatus int
wantErr bool
}{
{
name: "Method not allowed",
method: http.MethodGet,
body: nil,
wantStatus: http.StatusMethodNotAllowed,
wantErr: true,
},
{
name: "Invalid JSON body",
method: http.MethodPost,
body: "invalid json",
wantStatus: http.StatusBadRequest,
wantErr: true,
},
{
name: "Empty request body",
method: http.MethodPost,
body: CreateTaskRequest{},
// Will fail validation for missing type
wantStatus: http.StatusBadRequest,
wantErr: true,
},
{
name: "Missing type",
method: http.MethodPost,
body: CreateTaskRequest{
Storage: "test-storage",
Path: "downloads",
},
wantStatus: http.StatusBadRequest,
wantErr: true,
},
{
name: "Missing storage",
method: http.MethodPost,
body: CreateTaskRequest{
Type: tasktype.TaskTypeDirectlinks,
Path: "downloads",
},
wantStatus: http.StatusBadRequest,
wantErr: true,
},
{
name: "Storage not found",
method: http.MethodPost,
body: CreateTaskRequest{
Type: tasktype.TaskTypeDirectlinks,
Storage: "non-existent-storage",
Path: "downloads",
Params: json.RawMessage(`{"urls":["https://example.com/file.zip"]}`),
},
wantStatus: http.StatusBadRequest,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var bodyBytes []byte
var err error
if tt.body != nil {
switch v := tt.body.(type) {
case string:
bodyBytes = []byte(v)
default:
bodyBytes, err = json.Marshal(tt.body)
if err != nil {
t.Fatalf("failed to marshal body: %v", err)
}
}
}
req := httptest.NewRequest(tt.method, "/api/v1/tasks", bytes.NewReader(bodyBytes))
if tt.body != nil {
req.Header.Set("Content-Type", "application/json")
}
rr := httptest.NewRecorder()
handlers.CreateTaskHandler(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
if tt.wantErr {
var errResp ErrorResponse
if err := json.Unmarshal(rr.Body.Bytes(), &errResp); err != nil {
t.Errorf("expected error response, got: %s", rr.Body.String())
}
}
})
}
}
// TestListTasksHandler tests the list tasks endpoint
func TestListTasksHandler(t *testing.T) {
handlers, _ := setupTestServer(t)
tests := []struct {
name string
method string
wantStatus int
}{
{
name: "Method not allowed",
method: http.MethodPost,
wantStatus: http.StatusMethodNotAllowed,
},
{
name: "Success",
method: http.MethodGet,
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/api/v1/tasks", nil)
rr := httptest.NewRecorder()
handlers.ListTasksHandler(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
if tt.method == http.MethodGet {
var resp TasksListResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if resp.Tasks == nil {
t.Error("expected non-nil tasks array")
}
}
})
}
}
// TestGetTaskHandler tests the get task endpoint
func TestGetTaskHandler(t *testing.T) {
handlers, _ := setupTestServer(t)
// Register a test task
testTaskID := "test-get-task"
RegisterTask(testTaskID, "directlinks", "local", "downloads", "Test", "")
defer DeleteTask(testTaskID)
tests := []struct {
name string
method string
path string
wantStatus int
wantFound bool
}{
{
name: "Method not allowed",
method: http.MethodPost,
path: "/api/v1/tasks/test-id",
wantStatus: http.StatusMethodNotAllowed,
},
{
name: "Missing task ID",
method: http.MethodGet,
path: "/api/v1/tasks",
wantStatus: http.StatusBadRequest,
},
{
name: "Task not found",
method: http.MethodGet,
path: "/api/v1/tasks/non-existent-task",
wantStatus: http.StatusNotFound,
},
{
name: "Task found",
method: http.MethodGet,
path: "/api/v1/tasks/" + testTaskID,
wantStatus: http.StatusOK,
wantFound: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
rr := httptest.NewRecorder()
handlers.GetTaskHandler(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
if tt.wantFound {
var resp TaskInfoResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if resp.TaskID != testTaskID {
t.Errorf("expected task ID %s, got %s", testTaskID, resp.TaskID)
}
}
})
}
}
// TestCancelTaskHandler tests the cancel task endpoint
func TestCancelTaskHandler(t *testing.T) {
handlers, _ := setupTestServer(t)
// Register a test task
testTaskID := "test-cancel-task"
RegisterTask(testTaskID, "directlinks", "local", "downloads", "Test", "")
defer DeleteTask(testTaskID)
tests := []struct {
name string
method string
path string
wantStatus int
skipCore bool // Skip if core is not initialized
}{
{
name: "Method not allowed",
method: http.MethodGet,
path: "/api/v1/tasks/test-id",
wantStatus: http.StatusMethodNotAllowed,
},
{
name: "Missing task ID",
method: http.MethodDelete,
path: "/api/v1/tasks",
wantStatus: http.StatusBadRequest,
},
{
name: "Task not found",
method: http.MethodDelete,
path: "/api/v1/tasks/non-existent-task",
wantStatus: http.StatusNotFound,
},
{
name: "Cancel task",
method: http.MethodDelete,
path: "/api/v1/tasks/" + testTaskID,
wantStatus: http.StatusOK,
skipCore: true, // Requires initialized core task queue
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skipCore {
t.Skip("Skipping test: requires initialized core")
}
req := httptest.NewRequest(tt.method, tt.path, nil)
rr := httptest.NewRecorder()
handlers.CancelTaskHandler(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
})
}
}
// TestListStoragesHandler tests the list storages endpoint
func TestListStoragesHandler(t *testing.T) {
handlers, _ := setupTestServer(t)
tests := []struct {
name string
method string
wantStatus int
}{
{
name: "Method not allowed",
method: http.MethodPost,
wantStatus: http.StatusMethodNotAllowed,
},
{
name: "Success",
method: http.MethodGet,
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/api/v1/storages", nil)
rr := httptest.NewRecorder()
handlers.ListStoragesHandler(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
if tt.method == http.MethodGet {
var resp StoragesResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if resp.Storages == nil {
t.Error("expected non-nil storages array")
}
}
})
}
}
// TestConcurrentProgressStore tests concurrent access to progress store
func TestConcurrentProgressStore(t *testing.T) {
// Clear store before test
t.Cleanup(func() {
tasks := GetAllTasks()
for _, task := range tasks {
if strings.HasPrefix(task.TaskID, "concurrent-test-") {
DeleteTask(task.TaskID)
}
}
})
var wg sync.WaitGroup
numGoroutines := 100
// Concurrent registrations
for i := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
taskID := fmt.Sprintf("concurrent-test-%d", id)
RegisterTask(taskID, "directlinks", "local", "downloads", "Test", "")
}(i)
}
// Concurrent reads
for i := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
taskID := fmt.Sprintf("concurrent-test-%d", id)
GetTask(taskID)
}(i)
}
// Concurrent updates
for i := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
taskID := fmt.Sprintf("concurrent-test-%d", id)
info, ok := GetTask(taskID)
if ok {
info.UpdateStatus(TaskStatusRunning)
}
}(i)
}
wg.Wait()
// Verify all tasks exist
for i := range numGoroutines {
taskID := fmt.Sprintf("concurrent-test-%d", i)
if _, ok := GetTask(taskID); !ok {
t.Errorf("task %s not found after concurrent operations", taskID)
}
}
}
// TestProgressTrackerConcurrentUpdates tests concurrent progress updates
func TestProgressTrackerConcurrentUpdates(t *testing.T) {
tracker := NewProgressTracker("concurrent-progress", "directlinks", "local", "downloads", "Test", "")
tracker.OnStart(10000, 10)
var wg sync.WaitGroup
numGoroutines := 50
updatesPerGoroutine := 100
// Concurrent progress updates
for i := range numGoroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := range updatesPerGoroutine {
tracker.OnProgress(int64(id*updatesPerGoroutine+j), j)
}
}(i)
}
wg.Wait()
info := tracker.GetInfo()
if info.Status != TaskStatusRunning {
t.Errorf("expected status Running after concurrent updates, got %s", info.Status)
}
// Note: Due to race conditions in the simple implementation,
// we can't reliably check exact values without proper synchronization
}
// TestTaskFactoryValidation tests TaskFactory parameter validation
func TestTaskFactoryValidation(t *testing.T) {
factory := NewTaskFactory(context.Background())
tests := []struct {
name string
request *CreateTaskRequest
wantErr bool
errMsg string
}{
{
name: "Storage not found",
request: &CreateTaskRequest{
Type: tasktype.TaskTypeDirectlinks,
Storage: "non-existent",
Path: "downloads",
Params: json.RawMessage(`{"urls":["https://example.com/file.zip"]}`),
},
wantErr: true,
errMsg: "storage not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := factory.CreateTask(tt.request)
if (err != nil) != tt.wantErr {
t.Errorf("CreateTask() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && tt.errMsg != "" {
if !strings.Contains(strings.ToLower(err.Error()), tt.errMsg) {
t.Errorf("error message %q does not contain %q", err.Error(), tt.errMsg)
}
}
})
}
}
// TestEdgeCases tests various edge cases
func TestEdgeCases(t *testing.T) {
tests := []struct {
name string
fn func(t *testing.T)
}{
{
name: "Empty request body",
fn: func(t *testing.T) {
handlers, _ := setupTestServer(t)
req := httptest.NewRequest(http.MethodPost, "/api/v1/tasks", nil)
rr := httptest.NewRecorder()
handlers.CreateTaskHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected %d, got %d", http.StatusBadRequest, rr.Code)
}
},
},
{
name: "Very long task ID in path",
fn: func(t *testing.T) {
handlers, _ := setupTestServer(t)
longID := strings.Repeat("a", 1000)
req := httptest.NewRequest(http.MethodGet, "/api/v1/tasks/"+longID, nil)
rr := httptest.NewRecorder()
handlers.GetTaskHandler(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("expected %d, got %d", http.StatusNotFound, rr.Code)
}
},
},
{
name: "Path with special characters",
fn: func(t *testing.T) {
path := "/api/v1/tasks/test%20id/with/slashes"
got := extractTaskIDFromPath(path)
expected := "test%20id"
if got != expected {
t.Errorf("expected %q, got %q", expected, got)
}
},
},
{
name: "Double slashes in path",
fn: func(t *testing.T) {
path := "/api/v1/tasks//task-id"
got := extractTaskIDFromPath(path)
expected := ""
if got != expected {
t.Errorf("expected %q, got %q", expected, got)
}
},
},
{
name: "Progress tracker with empty webhook",
fn: func(t *testing.T) {
tracker := NewProgressTracker("test", "type", "storage", "path", "title", "")
info := tracker.GetInfo()
if info.Webhook != "" {
t.Error("expected empty webhook")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, tt.fn)
}
}
// TestHealthCheckHandler tests the health check endpoint
func TestHealthCheckHandler(t *testing.T) {
handlers, _ := setupTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
handlers.HealthCheckHandler(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, rr.Code)
}
var resp map[string]string
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["status"] != "ok" {
t.Errorf("expected status 'ok', got %q", resp["status"])
}
}
// TestGetTaskTypesHandler tests the task types endpoint
func TestGetTaskTypesHandler(t *testing.T) {
handlers, _ := setupTestServer(t)
tests := []struct {
name string
method string
wantStatus int
}{
{
name: "Method not allowed",
method: http.MethodPost,
wantStatus: http.StatusMethodNotAllowed,
},
{
name: "Success",
method: http.MethodGet,
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/api/v1/task-types", nil)
rr := httptest.NewRecorder()
handlers.GetTaskTypesHandler(rr, req)
if rr.Code != tt.wantStatus {
t.Errorf("expected status %d, got %d", tt.wantStatus, rr.Code)
}
if tt.method == http.MethodGet {
var resp map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if _, ok := resp["types"]; !ok {
t.Error("expected 'types' field in response")
}
}
})
}
}
// TestNotFoundHandler tests the 404 handler
func TestNotFoundHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/non-existent-path", nil)
rr := httptest.NewRecorder()
NotFoundHandler(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("expected status %d, got %d", http.StatusNotFound, rr.Code)
}
var resp ErrorResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Error != "not_found" {
t.Errorf("expected error 'not_found', got %q", resp.Error)
}
}
// TestMethodNotAllowedHandler tests the 405 handler
func TestMethodNotAllowedHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/v1/tasks", nil)
rr := httptest.NewRecorder()
MethodNotAllowedHandler(rr, req)
if rr.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
var resp ErrorResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Error != "method_not_allowed" {
t.Errorf("expected error 'method_not_allowed', got %q", resp.Error)
}
}
// TestTaskProgressInfoTimeUpdate tests that timestamps are updated correctly
func TestTaskProgressInfoTimeUpdate(t *testing.T) {
info := RegisterTask("time-test", "directlinks", "local", "downloads", "Test", "")
defer DeleteTask("time-test")
originalTime := info.UpdatedAt
time.Sleep(10 * time.Millisecond) // Ensure time difference
info.UpdateStatus(TaskStatusRunning)
if !info.UpdatedAt.After(originalTime) {
t.Error("expected UpdatedAt to be updated")
}
}
// TestWebhookPayloadWithNilCompletedAt tests webhook payload with nil completed_at
func TestWebhookPayloadWithNilCompletedAt(t *testing.T) {
payload := WebhookPayload{
TaskID: "test-id",
Type: "directlinks",
Status: TaskStatusRunning,
Storage: "local",
Path: "downloads/file.zip",
CompletedAt: nil,
Error: "",
}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var decoded map[string]any
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
// completed_at should be omitted when nil
if _, ok := decoded["completed_at"]; ok {
t.Error("expected completed_at to be omitted when nil")
}
}