perf(system): async SSRF check with DNS cache for image proxy (#5832)

This commit is contained in:
InfinityPacer
2026-05-25 15:54:02 +08:00
committed by GitHub
parent 98e3ea4e6f
commit 01c3451679
3 changed files with 633 additions and 113 deletions

View File

@@ -2,10 +2,23 @@ import socket
from unittest import TestCase
from unittest.mock import patch
from app.utils.security import SecurityUtils
from app.utils.security import (
SecurityUtils,
_dns_inflight_locks,
_dns_negative_cache,
_dns_positive_cache,
)
class SecurityUtilsTest(TestCase):
def setUp(self) -> None:
"""
每个用例前清空 DNS TTL 缓存与 in-flight 锁,避免跨用例状态污染。
"""
_dns_positive_cache.clear()
_dns_negative_cache.clear()
_dns_inflight_locks.clear()
def test_signed_url_roundtrip_returns_clean_url(self):
"""
URL 签名验证成功后返回不含签名片段的真实请求地址。
@@ -272,3 +285,268 @@ class SecurityUtilsTest(TestCase):
allowed_private_ranges=["198.18.0.0/15"],
)
)
def test_is_safe_url_async_uses_event_loop_resolver(self):
"""
异步版本通过事件循环的非阻塞 getaddrinfo 完成 SSRF 校验,
且语义与同步版本保持一致:解析到非公网地址时仍然拒绝。
"""
import asyncio
async def fake_getaddrinfo(host, *_args, **_kwargs):
self.assertEqual(host, "internal.example.com")
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("127.0.0.1", 0),
)
]
async def run() -> bool:
loop = asyncio.get_running_loop()
with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo):
return await SecurityUtils.is_safe_url_async(
"http://internal.example.com/secret.png",
{"example.com"},
block_private=True,
)
self.assertFalse(asyncio.run(run()))
def test_is_safe_url_async_hits_dns_cache(self):
"""
异步与同步版本共享 DNS TTL 缓存:同步预热后,异步版本不应再发起 DNS 查询。
"""
import asyncio
# 先用同步路径预热缓存
with patch(
"app.utils.security.socket.getaddrinfo",
return_value=[
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
],
):
self.assertTrue(
SecurityUtils.is_safe_url(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
async def run() -> bool:
loop = asyncio.get_running_loop()
with patch.object(
loop,
"getaddrinfo",
side_effect=AssertionError("缓存命中后不应再次发起 DNS 查询"),
):
return await SecurityUtils.is_safe_url_async(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
self.assertTrue(asyncio.run(run()))
def test_is_safe_url_async_allows_public_dns_result(self):
"""
异步版本对全公网解析结果且命中 allowlist 时放行。
"""
import asyncio
async def fake_getaddrinfo(host, *_args, **_kwargs):
self.assertEqual(host, "assets.example.com")
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
]
async def run() -> bool:
loop = asyncio.get_running_loop()
with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo):
return await SecurityUtils.is_safe_url_async(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
self.assertTrue(asyncio.run(run()))
def test_dns_resolution_failure_populates_negative_cache(self):
"""
DNS 解析失败应回填负向缓存,避免短期内对同一目标反复触发 `getaddrinfo`。
"""
from app.utils.security import _dns_negative_cache as neg_cache
with patch(
"app.utils.security.socket.getaddrinfo",
side_effect=socket.gaierror,
) as mock_resolve:
self.assertFalse(
SecurityUtils.is_safe_url(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
self.assertEqual(mock_resolve.call_count, 1)
self.assertIn("assets.example.com", neg_cache)
self.assertFalse(
SecurityUtils.is_safe_url(
"https://assets.example.com/another.jpg",
{"example.com"},
block_private=True,
)
)
self.assertEqual(
mock_resolve.call_count,
1,
"命中负向缓存后不应再次调用 getaddrinfo",
)
def test_literal_ip_skips_dns_cache(self):
"""
URL 中的字面量 IP 走快路径,不应进入 DNS 缓存或触发 `getaddrinfo`。
"""
from app.utils.security import (
_dns_negative_cache as neg_cache,
_dns_positive_cache as pos_cache,
)
with patch(
"app.utils.security.socket.getaddrinfo",
side_effect=AssertionError("字面量 IP 不应触发 getaddrinfo"),
):
self.assertFalse(
SecurityUtils.is_safe_url(
"http://10.0.0.5:8080/secret.png",
{"http://10.0.0.5:8080"},
block_private=True,
)
)
self.assertNotIn("10.0.0.5", pos_cache)
self.assertNotIn("10.0.0.5", neg_cache)
def test_literal_ipv6_in_brackets_is_recognized(self):
"""
`urlparse` 已为 IPv6 字面量脱壳,`_literal_ip` 兼容直接传入带方括号的形式。
"""
self.assertEqual(
str(SecurityUtils._literal_ip("[::1]")),
"::1",
)
self.assertEqual(
str(SecurityUtils._literal_ip("::1")),
"::1",
)
self.assertIsNone(SecurityUtils._literal_ip("not-an-ip"))
def test_is_safe_url_async_dedupes_concurrent_inflight_queries(self):
"""
同 hostname 的并发未命中请求应通过 in-flight 锁去重,只触发一次 DNS 查询。
"""
import asyncio
call_count = 0
async def run() -> None:
nonlocal call_count
loop = asyncio.get_running_loop()
release = asyncio.Event()
async def slow_getaddrinfo(host, *_args, **_kwargs):
nonlocal call_count
call_count += 1
await release.wait()
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
]
with patch.object(loop, "getaddrinfo", side_effect=slow_getaddrinfo):
tasks = [
asyncio.create_task(
SecurityUtils.is_safe_url_async(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
for _ in range(5)
]
# 让所有任务都进入 in-flight 等待状态
await asyncio.sleep(0)
await asyncio.sleep(0)
release.set()
results = await asyncio.gather(*tasks)
self.assertTrue(all(results))
self.assertEqual(call_count, 1, "并发未命中应去重为单次 DNS 查询")
asyncio.run(run())
def test_sync_cache_access_is_thread_safe(self):
"""
同步路径下并发线程访问 DNS 缓存不应触发异常或拿到不一致结果。
TTLCache 自身非线程安全,依赖模块级 `_dns_cache_lock` 串行化读写。
"""
import threading
with patch(
"app.utils.security.socket.getaddrinfo",
return_value=[
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
],
):
results: list = []
errors: list = []
def worker() -> None:
try:
for _ in range(50):
results.append(
SecurityUtils.is_safe_url(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
except Exception as exc: # noqa: BLE001 - 用例需捕获任意异常
errors.append(exc)
threads = [threading.Thread(target=worker) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(errors, [])
self.assertTrue(all(results))
self.assertEqual(len(results), 8 * 50)