diff --git a/api/handlers.go b/api/handlers.go index 017ac62..1d9ae98 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -89,8 +89,8 @@ func handleCreateTask(w http.ResponseWriter, r *http.Request) { respondError(w, "telegram_url is required", http.StatusBadRequest) return } - if req.UserID == 0 { - respondError(w, "user_id is required", http.StatusBadRequest) + if req.UserID <= 0 { + respondError(w, "user_id is required and must be positive", http.StatusBadRequest) return } @@ -329,6 +329,8 @@ func sendWebhook(taskID, status, errorMsg string) { return } + logger := log.WithPrefix("webhook") + payload := TaskStatusResponse{ TaskID: ts.ID, Status: status, @@ -339,7 +341,7 @@ func sendWebhook(taskID, status, errorMsg string) { body, err := json.Marshal(payload) if err != nil { - log.Errorf("Failed to marshal webhook payload: %v", err) + logger.Errorf("Failed to marshal webhook payload: %v", err) return } @@ -348,7 +350,7 @@ func sendWebhook(taskID, status, errorMsg string) { req, err := http.NewRequestWithContext(ctx, "POST", cfg.API.WebhookURL, bytes.NewReader(body)) if err != nil { - log.Errorf("Failed to create webhook request: %v", err) + logger.Errorf("Failed to create webhook request: %v", err) return } @@ -359,7 +361,7 @@ func sendWebhook(taskID, status, errorMsg string) { resp, err := http.DefaultClient.Do(req) if err != nil { - log.Errorf("Failed to send webhook: %v", err) + logger.Errorf("Failed to send webhook: %v", err) return } defer resp.Body.Close() @@ -367,9 +369,9 @@ func sendWebhook(taskID, status, errorMsg string) { if resp.StatusCode >= 400 { body, err := io.ReadAll(resp.Body) if err != nil { - log.Errorf("Webhook returned error status %d, failed to read response body: %v", resp.StatusCode, err) + logger.Errorf("Webhook returned error status %d, failed to read response body: %v", resp.StatusCode, err) } else { - log.Errorf("Webhook returned error status %d: %s", resp.StatusCode, string(body)) + logger.Errorf("Webhook returned error status %d: %s", resp.StatusCode, string(body)) } } } diff --git a/api/middleware.go b/api/middleware.go index feeb533..6545764 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -93,9 +93,12 @@ func getClientIP(r *http.Request) string { // Fall back to RemoteAddr ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { - // If SplitHostPort fails, RemoteAddr might not have a port - // In this case, just return RemoteAddr as is - return r.RemoteAddr + // If SplitHostPort fails, try to parse RemoteAddr as IP directly + if parsedIP := net.ParseIP(r.RemoteAddr); parsedIP != nil { + return r.RemoteAddr + } + // If all else fails, return empty string (will fail IP check) + return "" } return ip } @@ -112,7 +115,8 @@ func isIPAllowed(clientIP string, allowedIPs []string) bool { if err != nil { continue } - if ipNet.Contains(net.ParseIP(clientIP)) { + ip := net.ParseIP(clientIP) + if ip != nil && ipNet.Contains(ip) { return true } } diff --git a/test_api.sh b/test_api.sh index aca1b7c..3c7dc7b 100755 --- a/test_api.sh +++ b/test_api.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash -euo pipefail # API Test Script for SaveAny-Bot HTTP API