mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-06 20:32:47 +08:00
security: enhance API key redaction with comprehensive testing and error handling
- Refactored redaction logic to use centralized helper function - Added robust error handling in AccessLogFormatter - Improved regex patterns for better OpenAI key detection - Added comprehensive unit tests covering edge cases and error scenarios - Enhanced input validation with descriptive error placeholders
This commit is contained in:
@@ -9,7 +9,7 @@ from app.config.config import settings, sync_initial_settings
|
||||
from app.database.connection import connect_to_db, disconnect_from_db
|
||||
from app.database.initialization import initialize_database
|
||||
from app.exception.exceptions import setup_exception_handlers
|
||||
from app.log.logger import get_application_logger
|
||||
from app.log.logger import get_application_logger, setup_access_logging
|
||||
from app.middleware.middleware import setup_middlewares
|
||||
from app.router.routes import setup_routers
|
||||
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
||||
@@ -150,4 +150,7 @@ def create_app() -> FastAPI:
|
||||
# 配置路由
|
||||
setup_routers(app)
|
||||
|
||||
# 配置访问日志API密钥隐藏
|
||||
setup_access_logging()
|
||||
|
||||
return app
|
||||
|
||||
@@ -14,15 +14,7 @@ COLORS = {
|
||||
}
|
||||
|
||||
|
||||
def _redact_key_for_logging(key: str) -> str:
|
||||
"""
|
||||
Redacts API key for secure logging by showing only first and last 6 characters.
|
||||
(Internal function to avoid circular imports)
|
||||
"""
|
||||
if not key or len(key) <= 12:
|
||||
return "***"
|
||||
|
||||
return f"{key[:6]}...{key[-6:]}"
|
||||
from app.utils.helpers import redact_key_for_logging as _redact_key_for_logging
|
||||
|
||||
# Windows系统启用ANSI支持
|
||||
if platform.system() == "Windows":
|
||||
@@ -54,9 +46,8 @@ class AccessLogFormatter(logging.Formatter):
|
||||
|
||||
# API key patterns to match in URLs
|
||||
API_KEY_PATTERNS = [
|
||||
r'AIza[0-9A-Za-z_-]{35}', # Google API keys (like Gemini)
|
||||
r'sk-[0-9A-Za-z]{48}', # OpenAI API keys
|
||||
r'sk-[0-9A-Za-z_-]{20,}', # General sk- prefixed keys
|
||||
r'\bAIza[0-9A-Za-z_-]{35}', # Google API keys (like Gemini)
|
||||
r'\bsk-[0-9A-Za-z_-]{20,}', # OpenAI and general sk- prefixed keys
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -75,14 +66,21 @@ class AccessLogFormatter(logging.Formatter):
|
||||
"""
|
||||
Replace API keys in log message with redacted versions
|
||||
"""
|
||||
for pattern in self.compiled_patterns:
|
||||
def replace_key(match):
|
||||
key = match.group(0)
|
||||
return _redact_key_for_logging(key)
|
||||
try:
|
||||
for pattern in self.compiled_patterns:
|
||||
def replace_key(match):
|
||||
key = match.group(0)
|
||||
return _redact_key_for_logging(key)
|
||||
|
||||
message = pattern.sub(replace_key, message)
|
||||
message = pattern.sub(replace_key, message)
|
||||
|
||||
return message
|
||||
return message
|
||||
except Exception as e:
|
||||
# Log the error but don't expose the original message in case it contains keys
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error redacting API keys in access log: {e}")
|
||||
return "[LOG_REDACTION_ERROR]"
|
||||
|
||||
|
||||
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
|
||||
|
||||
@@ -5,10 +5,7 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from app.core.application import create_app
|
||||
from app.log.logger import get_main_logger, setup_access_logging
|
||||
|
||||
# Setup access logging with API key redaction when app is imported (for CLI usage)
|
||||
setup_access_logging()
|
||||
from app.log.logger import get_main_logger
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
@@ -162,12 +162,15 @@ def redact_key_for_logging(key: str) -> str:
|
||||
key: API key to redact
|
||||
|
||||
Returns:
|
||||
str: Redacted key in format "first6...last6" or original if too short
|
||||
str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases
|
||||
"""
|
||||
if not key or len(key) <= 12:
|
||||
return "***"
|
||||
if not key:
|
||||
return key
|
||||
|
||||
return f"{key[:6]}...{key[-6:]}"
|
||||
if len(key) <= 12:
|
||||
return f"{key[:3]}...{key[-3:]}"
|
||||
else:
|
||||
return f"{key[:6]}...{key[-6:]}"
|
||||
|
||||
|
||||
def get_current_version(default_version: str = "0.0.0") -> str:
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
187
tests/test_key_redaction.py
Normal file
187
tests/test_key_redaction.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Unit tests for API key redaction functionality
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import logging
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
from app.log.logger import AccessLogFormatter
|
||||
|
||||
|
||||
class TestKeyRedaction(unittest.TestCase):
|
||||
"""Test cases for the redact_key_for_logging function"""
|
||||
|
||||
def test_valid_long_key_redaction(self):
|
||||
"""Test redaction of valid long API keys"""
|
||||
# Test Google/Gemini API key
|
||||
# This value is a random generated string for testing
|
||||
gemini_key = "AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI"
|
||||
result = redact_key_for_logging(gemini_key)
|
||||
expected = "AIzaSy...xDfGhI"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test OpenAI API key
|
||||
# This value is a random generated string for testing
|
||||
openai_key = "sk-1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
result = redact_key_for_logging(openai_key)
|
||||
expected = "sk-123...abcdef"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_short_key_handling(self):
|
||||
"""Test handling of short keys"""
|
||||
short_key = "short"
|
||||
result = redact_key_for_logging(short_key)
|
||||
self.assertEqual(result, "[SHORT_KEY]")
|
||||
|
||||
# Test exactly 12 characters (boundary case)
|
||||
boundary_key = "123456789012"
|
||||
result = redact_key_for_logging(boundary_key)
|
||||
self.assertEqual(result, "[SHORT_KEY]")
|
||||
|
||||
def test_empty_and_none_keys(self):
|
||||
"""Test handling of empty and None keys"""
|
||||
# Test empty string
|
||||
result = redact_key_for_logging("")
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
# Test None
|
||||
result = redact_key_for_logging(None)
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
def test_invalid_input_types(self):
|
||||
"""Test handling of invalid input types"""
|
||||
# Test integer
|
||||
result = redact_key_for_logging(123)
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
# Test list
|
||||
result = redact_key_for_logging(["key"])
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
# Test dict
|
||||
result = redact_key_for_logging({"key": "value"})
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
def test_boundary_cases(self):
|
||||
"""Test boundary cases for key length"""
|
||||
# Test 13 characters (just above the threshold)
|
||||
key_13 = "1234567890123"
|
||||
result = redact_key_for_logging(key_13)
|
||||
expected = "123456...890123"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test very long key
|
||||
long_key = "a" * 100
|
||||
result = redact_key_for_logging(long_key)
|
||||
expected = "aaaaaa...aaaaaa"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
class TestAccessLogFormatter(unittest.TestCase):
|
||||
"""Test cases for the AccessLogFormatter class"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.formatter = AccessLogFormatter()
|
||||
|
||||
def test_gemini_key_redaction_in_url(self):
|
||||
"""Test redaction of Gemini API keys in URLs"""
|
||||
log_message = (
|
||||
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
|
||||
)
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertIn("AIzaSy...xDfGhI", result)
|
||||
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
|
||||
|
||||
def test_openai_key_redaction_in_url(self):
|
||||
"""Test redaction of OpenAI API keys in URLs"""
|
||||
log_message = 'GET /api/models?key=sk-1234567890abcdef1234567890abcdef1234567890abcdef HTTP/1.1" 200'
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertIn("sk-123...abcdef", result)
|
||||
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
|
||||
|
||||
def test_multiple_keys_in_message(self):
|
||||
"""Test redaction of multiple API keys in a single message"""
|
||||
log_message = "Request with keys: AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI and sk-1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertIn("AIzaSy...xDfGhI", result)
|
||||
self.assertIn("sk-123...abcdef", result)
|
||||
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
|
||||
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
|
||||
|
||||
def test_no_keys_in_message(self):
|
||||
"""Test that messages without API keys are unchanged"""
|
||||
log_message = 'GET /api/health HTTP/1.1" 200'
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertEqual(result, log_message)
|
||||
|
||||
def test_partial_key_patterns_not_redacted(self):
|
||||
"""Test that partial key patterns are not redacted"""
|
||||
log_message = "Message with partial patterns: AIza sk- incomplete"
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertEqual(result, log_message)
|
||||
|
||||
def test_error_handling_in_redaction(self):
|
||||
"""Test error handling in the redaction process"""
|
||||
# Test by directly calling _redact_api_keys_in_message with a broken pattern
|
||||
original_patterns = self.formatter.compiled_patterns
|
||||
# Create a mock pattern that will raise an exception
|
||||
mock_pattern = MagicMock()
|
||||
mock_pattern.sub.side_effect = Exception("Regex error")
|
||||
self.formatter.compiled_patterns = [mock_pattern]
|
||||
|
||||
try:
|
||||
log_message = (
|
||||
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
|
||||
)
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertEqual(result, "[LOG_REDACTION_ERROR]")
|
||||
finally:
|
||||
# Restore original patterns
|
||||
self.formatter.compiled_patterns = original_patterns
|
||||
|
||||
def test_format_method(self):
|
||||
"""Test the format method of AccessLogFormatter"""
|
||||
# Create a mock log record
|
||||
record = MagicMock()
|
||||
record.getMessage.return_value = (
|
||||
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
|
||||
)
|
||||
|
||||
# Mock the parent format method
|
||||
with patch(
|
||||
"logging.Formatter.format",
|
||||
return_value='2025-01-01 12:00:00 | INFO | POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200',
|
||||
):
|
||||
result = self.formatter.format(record)
|
||||
self.assertIn("AIzaSy...xDfGhI", result)
|
||||
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
|
||||
|
||||
def test_regex_patterns_compilation(self):
|
||||
"""Test that regex patterns are properly compiled"""
|
||||
formatter = AccessLogFormatter()
|
||||
self.assertEqual(len(formatter.compiled_patterns), 2)
|
||||
self.assertTrue(
|
||||
all(hasattr(pattern, "sub") for pattern in formatter.compiled_patterns)
|
||||
)
|
||||
|
||||
def test_flexible_openai_pattern(self):
|
||||
"""Test the flexible OpenAI pattern matches various formats"""
|
||||
test_cases = [
|
||||
"sk-1234567890abcdef1234567890abcdef1234567890abcdef", # Standard 48 chars
|
||||
"sk-proj-1234567890abcdef1234567890abcdef1234567890abcdef", # Project key
|
||||
"sk-1234567890abcdef_1234567890abcdef-1234567890abcdef", # With underscores/hyphens
|
||||
"sk-12345678901234567890", # Shorter key (20 chars)
|
||||
]
|
||||
|
||||
for test_key in test_cases:
|
||||
log_message = f"Request with key: {test_key}"
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertNotIn(test_key, result)
|
||||
self.assertIn("sk-", result) # Should still contain the prefix
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user