Remove IP whitelist mechanism, keep only token authentication
Co-authored-by: krau <71133316+krau@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -10,7 +9,7 @@ import (
|
||||
"github.com/krau/SaveAny-Bot/config"
|
||||
)
|
||||
|
||||
// authMiddleware validates API token and IP restrictions
|
||||
// authMiddleware validates API token
|
||||
func authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip auth for health check
|
||||
@@ -21,15 +20,6 @@ func authMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
cfg := config.C()
|
||||
|
||||
// Check IP whitelist if configured
|
||||
if len(cfg.API.TrustedIPs) > 0 {
|
||||
clientIP := getClientIP(r)
|
||||
if !isIPAllowed(clientIP, cfg.API.TrustedIPs) {
|
||||
http.Error(w, `{"error":"forbidden: IP not allowed"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check token if configured
|
||||
if cfg.API.Token != "" {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
@@ -74,52 +64,3 @@ func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// getClientIP extracts the real client IP from the request
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
// If SplitHostPort fails, try to parse RemoteAddr as IP directly
|
||||
if parsedIP := net.ParseIP(r.RemoteAddr); parsedIP != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
// If all else fails, return empty string (will fail IP check)
|
||||
return ""
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isIPAllowed checks if the client IP is in the allowed list
|
||||
func isIPAllowed(clientIP string, allowedIPs []string) bool {
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if clientIP == allowedIP || allowedIP == "*" {
|
||||
return true
|
||||
}
|
||||
// Support CIDR notation
|
||||
if strings.Contains(allowedIP, "/") {
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip != nil && ipNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -6,68 +6,6 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsIPAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientIP string
|
||||
allowedIPs []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
clientIP: "192.168.1.100",
|
||||
allowedIPs: []string{"192.168.1.100"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
clientIP: "192.168.1.100",
|
||||
allowedIPs: []string{"192.168.1.101"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard",
|
||||
clientIP: "192.168.1.100",
|
||||
allowedIPs: []string{"*"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR match",
|
||||
clientIP: "192.168.1.100",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR no match",
|
||||
clientIP: "192.168.2.100",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple IPs with match",
|
||||
clientIP: "192.168.1.100",
|
||||
allowedIPs: []string{"10.0.0.1", "192.168.1.100", "172.16.0.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
clientIP: "127.0.0.1",
|
||||
allowedIPs: []string{"127.0.0.1"},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isIPAllowed(tt.clientIP, tt.allowedIPs)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isIPAllowed(%q, %v) = %v, want %v",
|
||||
tt.clientIP, tt.allowedIPs, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_NoAuth(t *testing.T) {
|
||||
// Create a test handler
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -107,62 +45,3 @@ func TestAuthMiddleware_HealthCheck(t *testing.T) {
|
||||
t.Errorf("Expected status 200 for health check, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
xRealIP string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "RemoteAddr only",
|
||||
remoteAddr: "192.168.1.100:12345",
|
||||
expectedIP: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For single",
|
||||
remoteAddr: "192.168.1.100:12345",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
expectedIP: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For multiple",
|
||||
remoteAddr: "192.168.1.100:12345",
|
||||
xForwardedFor: "10.0.0.1, 10.0.0.2, 10.0.0.3",
|
||||
expectedIP: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP",
|
||||
remoteAddr: "192.168.1.100:12345",
|
||||
xRealIP: "10.0.0.1",
|
||||
expectedIP: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For takes precedence",
|
||||
remoteAddr: "192.168.1.100:12345",
|
||||
xForwardedFor: "10.0.0.1",
|
||||
xRealIP: "10.0.0.2",
|
||||
expectedIP: "10.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
if tt.xRealIP != "" {
|
||||
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||
}
|
||||
|
||||
result := getClientIP(req)
|
||||
if result != tt.expectedIP {
|
||||
t.Errorf("getClientIP() = %q, want %q", result, tt.expectedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,8 +39,6 @@ port = 8080
|
||||
token = ""
|
||||
# 任务完成回调 Webhook URL (留空则不回调)
|
||||
webhook_url = ""
|
||||
# 可信任的 IP 地址列表 (留空则不限制), 支持单个 IP 或 CIDR 格式
|
||||
# trusted_ips = ["127.0.0.1", "192.168.1.0/24"]
|
||||
|
||||
# 存储列表
|
||||
[[storages]]
|
||||
|
||||
@@ -44,11 +44,10 @@ type aria2Config struct {
|
||||
}
|
||||
|
||||
type apiConfig struct {
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
Port int `toml:"port" mapstructure:"port" json:"port"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
WebhookURL string `toml:"webhook_url" mapstructure:"webhook_url" json:"webhook_url"`
|
||||
TrustedIPs []string `toml:"trusted_ips" mapstructure:"trusted_ips" json:"trusted_ips"`
|
||||
Enable bool `toml:"enable" mapstructure:"enable" json:"enable"`
|
||||
Port int `toml:"port" mapstructure:"port" json:"port"`
|
||||
Token string `toml:"token" mapstructure:"token" json:"token"`
|
||||
WebhookURL string `toml:"webhook_url" mapstructure:"webhook_url" json:"webhook_url"`
|
||||
}
|
||||
|
||||
var cfg = &Config{}
|
||||
|
||||
@@ -21,8 +21,6 @@ port = 8080
|
||||
token = "your-secret-token-here"
|
||||
# Task completion callback webhook URL (leave empty to disable)
|
||||
webhook_url = "https://your-server.com/webhook"
|
||||
# Trusted IP addresses (leave empty to allow all), supports single IP or CIDR notation
|
||||
trusted_ips = ["127.0.0.1", "192.168.1.0/24"]
|
||||
```
|
||||
|
||||
## Authentication
|
||||
@@ -33,8 +31,6 @@ If `token` is configured, all API requests (except `/health`) must include an `A
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
If `trusted_ips` is configured, requests will only be accepted from specified IP addresses.
|
||||
|
||||
## Endpoints
|
||||
|
||||
### Health Check
|
||||
@@ -256,7 +252,6 @@ print(f"Task status: {status['status']}")
|
||||
## Security Recommendations
|
||||
|
||||
1. **Always use a strong token** for production environments
|
||||
2. **Enable IP whitelist** (`trusted_ips`) to restrict access
|
||||
3. **Use HTTPS** in production by placing the API behind a reverse proxy (e.g., Nginx, Caddy)
|
||||
4. **Keep logs secure** as they may contain sensitive information
|
||||
5. **Validate user permissions** - ensure `user_id` in requests corresponds to authorized users in your config
|
||||
2. **Use HTTPS** in production by placing the API behind a reverse proxy (e.g., Nginx, Caddy)
|
||||
3. **Keep logs secure** as they may contain sensitive information
|
||||
4. **Validate user permissions** - ensure `user_id` in requests corresponds to authorized users in your config
|
||||
|
||||
@@ -21,8 +21,6 @@ port = 8080
|
||||
token = "your-secret-token-here"
|
||||
# 任务完成回调 Webhook URL (留空则不回调)
|
||||
webhook_url = "https://your-server.com/webhook"
|
||||
# 可信任的 IP 地址列表 (留空则不限制), 支持单个 IP 或 CIDR 格式
|
||||
trusted_ips = ["127.0.0.1", "192.168.1.0/24"]
|
||||
```
|
||||
|
||||
## 认证
|
||||
@@ -33,8 +31,6 @@ trusted_ips = ["127.0.0.1", "192.168.1.0/24"]
|
||||
Authorization: Bearer your-secret-token-here
|
||||
```
|
||||
|
||||
如果配置了 `trusted_ips`,请求只会从指定的 IP 地址被接受。
|
||||
|
||||
## 端点
|
||||
|
||||
### 健康检查
|
||||
@@ -256,7 +252,6 @@ print(f"任务状态: {status['status']}")
|
||||
## 安全建议
|
||||
|
||||
1. **生产环境始终使用强令牌**
|
||||
2. **启用 IP 白名单** (`trusted_ips`) 限制访问
|
||||
3. **生产环境使用 HTTPS**,通过反向代理(如 Nginx、Caddy)放置 API
|
||||
4. **保护日志安全**,因为它们可能包含敏感信息
|
||||
5. **验证用户权限** - 确保请求中的 `user_id` 对应于配置中的授权用户
|
||||
2. **生产环境使用 HTTPS**,通过反向代理(如 Nginx、Caddy)放置 API
|
||||
3. **保护日志安全**,因为它们可能包含敏感信息
|
||||
4. **验证用户权限** - 确保请求中的 `user_id` 对应于配置中的授权用户
|
||||
|
||||
Reference in New Issue
Block a user