perf(system): async SSRF check with DNS cache for image proxy (#5832)

This commit is contained in:
InfinityPacer
2026-05-25 15:54:02 +08:00
committed by GitHub
parent 98e3ea4e6f
commit 01c3451679
3 changed files with 633 additions and 113 deletions

View File

@@ -360,7 +360,7 @@ async def fetch_image(
fetch_url = SecurityUtils.strip_url_signature(url)
# 验证URL安全性
if not SecurityUtils.is_safe_url(
if not await SecurityUtils.is_safe_url_async(
url,
allowed_domains,
block_private=True,

View File

@@ -1,18 +1,60 @@
import asyncio
import hmac
import ipaddress
import socket
import threading
import time
from hashlib import sha256
from pathlib import Path
from typing import Iterable, List, Optional, Set, Union
from typing import Dict, Iterable, List, Optional, Set, Union
from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse
from anyio import Path as AsyncPath
from cachetools import TTLCache
from app.core.config import settings
from app.log import logger
# DNS 解析结果缓存。
# 正向缓存 TTL 选择 120s短于常见 CDN / fake-ip 的 DNS TTL避免长期持有失效 IP
# 负向缓存 TTL 选择 15s避免临时解析失败把目标长时间拉黑。
_DNS_CACHE_MAXSIZE = 1024
_DNS_CACHE_TTL_POSITIVE = 120
_DNS_CACHE_TTL_NEGATIVE = 15
_dns_positive_cache: "TTLCache[str, List[ipaddress._BaseAddress]]" = TTLCache(
maxsize=_DNS_CACHE_MAXSIZE, ttl=_DNS_CACHE_TTL_POSITIVE
)
_dns_negative_cache: "TTLCache[str, bool]" = TTLCache(
maxsize=_DNS_CACHE_MAXSIZE, ttl=_DNS_CACHE_TTL_NEGATIVE
)
# 同步路径下保护 TTLCache 读写:`cachetools.TTLCache` 本身非线程安全。
# 锁只覆盖缓存读写,不包 `getaddrinfo`,避免把 DNS 查询本身串行化。
_dns_cache_lock = threading.Lock()
# 同 hostname 的并发异步解析去重:同一 hostname 首次未命中时建立锁,
# 后续并发请求 await 同一把锁,避免对同一目标重复发起 `getaddrinfo`。
_dns_inflight_locks: Dict[str, asyncio.Lock] = {}
_dns_inflight_meta_lock = threading.Lock()
def _resolve_addrinfo_to_ips(
address_infos: Iterable,
) -> Optional[List[ipaddress._BaseAddress]]:
"""
将 `socket.getaddrinfo` 返回的结果归一化为 IP 列表。
任一条目无法解析为 IP 即视为异常情况,整体返回 None 让上层按"不安全目标"
处理,避免出现"部分 IP 漏校验"的情况。
"""
addresses: List[ipaddress._BaseAddress] = []
for address_info in address_infos:
try:
addresses.append(ipaddress.ip_address(address_info[4][0]))
except ValueError:
return None
return addresses or None
class SecurityUtils:
_SIGNED_URL_PURPOSE = "image-proxy"
_SIGNED_URL_EXPIRE_SECONDS = 86400
@@ -79,38 +121,176 @@ class SecurityUtils:
logger.debug(f"Error occurred while validating paths: {e}")
return False
@staticmethod
def _literal_ip(hostname: str) -> Optional[ipaddress._BaseAddress]:
"""
若 hostname 是字面量 IP含 IPv6 的 `[::1]` 形式)则返回 IP 对象,否则 None。
"""
if not hostname:
return None
candidate = hostname
if candidate.startswith("[") and candidate.endswith("]"):
candidate = candidate[1:-1]
try:
return ipaddress.ip_address(candidate)
except ValueError:
return None
@staticmethod
def _cache_lookup(hostname: str) -> tuple[bool, Optional[List[ipaddress._BaseAddress]]]:
"""
在 TTL 缓存中查找 hostname返回 (是否命中, 命中值)。
命中值为 `None` 表示命中负向缓存(先前解析失败)。
"""
with _dns_cache_lock:
cached = _dns_positive_cache.get(hostname)
if cached is not None:
return True, cached
if hostname in _dns_negative_cache:
return True, None
return False, None
@staticmethod
def _cache_store(
hostname: str, addresses: Optional[List[ipaddress._BaseAddress]]
) -> None:
"""
将解析结果写入对应的正向/负向缓存。
"""
with _dns_cache_lock:
if addresses is None:
_dns_negative_cache[hostname] = True
else:
_dns_positive_cache[hostname] = addresses
@staticmethod
def _hostname_addresses(hostname: str) -> Optional[List[ipaddress._BaseAddress]]:
"""
同步解析主机名并返回全部 IP 地址,结果走 TTL 缓存。
字面量 IP 直接返回自身DNS 解析失败或结果异常时返回 None由上层按
不安全目标处理。async 调用方应使用 `_hostname_addresses_async`。
"""
if not hostname:
return None
literal = SecurityUtils._literal_ip(hostname)
if literal is not None:
return [literal]
hit, value = SecurityUtils._cache_lookup(hostname)
if hit:
return value
try:
address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM)
except socket.gaierror:
SecurityUtils._cache_store(hostname, None)
return None
addresses = _resolve_addrinfo_to_ips(address_infos)
SecurityUtils._cache_store(hostname, addresses)
return addresses
@staticmethod
def _get_inflight_lock(hostname: str) -> asyncio.Lock:
"""
取得 hostname 对应的 in-flight 锁,不存在则按需创建。
用 `threading.Lock` 保护字典写入,避免多个事件循环线程并发创建出多把锁
破坏去重语义;锁本身是 `asyncio.Lock`,归属当前事件循环。
"""
with _dns_inflight_meta_lock:
lock = _dns_inflight_locks.get(hostname)
if lock is None:
lock = asyncio.Lock()
_dns_inflight_locks[hostname] = lock
return lock
@staticmethod
def _release_inflight_lock(hostname: str, lock: asyncio.Lock) -> None:
"""
请求结束后清理 in-flight 锁,避免长期持有大量已闲置的 `asyncio.Lock`。
只有当前 lock 仍是字典里登记的那把、且没有其它协程在等待时才删除。
"""
with _dns_inflight_meta_lock:
current = _dns_inflight_locks.get(hostname)
if current is lock and not lock.locked():
_dns_inflight_locks.pop(hostname, None)
@staticmethod
async def _hostname_addresses_async(
hostname: str,
) -> Optional[List[ipaddress._BaseAddress]]:
"""
异步解析主机名并返回全部 IP 地址,与同步版本共用同一份 TTL 缓存。
通过事件循环的默认线程池执行 `getaddrinfo`,不阻塞 asyncio 事件循环;
同 hostname 的并发未命中请求通过 in-flight 锁去重,只发起一次 DNS 查询。
"""
if not hostname:
return None
literal = SecurityUtils._literal_ip(hostname)
if literal is not None:
return [literal]
hit, value = SecurityUtils._cache_lookup(hostname)
if hit:
return value
lock = SecurityUtils._get_inflight_lock(hostname)
async with lock:
# 等到锁后再查一次缓存,前一个持锁者可能已经回填结果
hit, value = SecurityUtils._cache_lookup(hostname)
if hit:
SecurityUtils._release_inflight_lock(hostname, lock)
return value
loop = asyncio.get_running_loop()
try:
address_infos = await loop.getaddrinfo(
hostname, None, type=socket.SOCK_STREAM
)
except socket.gaierror:
SecurityUtils._cache_store(hostname, None)
SecurityUtils._release_inflight_lock(hostname, lock)
return None
addresses = _resolve_addrinfo_to_ips(address_infos)
SecurityUtils._cache_store(hostname, addresses)
SecurityUtils._release_inflight_lock(hostname, lock)
return addresses
@staticmethod
def _addresses_all_global(
addresses: Optional[List[ipaddress._BaseAddress]],
) -> bool:
"""
判断解析结果是否全部为公网地址(空列表/None 视为非公网)。
"""
if not addresses:
return False
return all(address.is_global for address in addresses)
@staticmethod
def _is_global_hostname(hostname: str) -> bool:
"""
判断主机名解析结果是否全部为公网地址。
判断主机名解析结果是否全部为公网地址(同步版本)
图片代理会访问用户可控的 URL这里必须在 allowlist 命中前后都排除
私有、回环、链路本地、保留地址等非公网目标,避免通过 DNS 或字面量 IP
绕过域名白名单访问内网服务。
"""
if not hostname:
return False
try:
return ipaddress.ip_address(hostname).is_global
except ValueError:
pass
return SecurityUtils._addresses_all_global(
SecurityUtils._hostname_addresses(hostname)
)
try:
address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM)
except socket.gaierror:
return False
if not address_infos:
return False
for address_info in address_infos:
try:
address = ipaddress.ip_address(address_info[4][0])
except ValueError:
return False
if not address.is_global:
return False
return True
@staticmethod
async def _is_global_hostname_async(hostname: str) -> bool:
"""
判断主机名解析结果是否全部为公网地址(异步版本)。语义与 `_is_global_hostname` 一致。
"""
return SecurityUtils._addresses_all_global(
await SecurityUtils._hostname_addresses_async(hostname)
)
@staticmethod
def _parse_ip_networks(ranges: Optional[Iterable[str]]) -> List[ipaddress._BaseNetwork]:
@@ -131,58 +311,22 @@ class SecurityUtils:
return networks
@staticmethod
def _hostname_addresses(hostname: str) -> Optional[List[ipaddress._BaseAddress]]:
"""
解析主机名并返回全部 IP 地址。
字面量 IP 直接返回自身DNS 解析失败或结果异常时返回 None让上层按
不安全目标处理。
"""
if not hostname:
return None
try:
return [ipaddress.ip_address(hostname)]
except ValueError:
pass
try:
address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM)
except socket.gaierror:
return None
if not address_infos:
return None
addresses = []
for address_info in address_infos:
try:
addresses.append(ipaddress.ip_address(address_info[4][0]))
except ValueError:
return None
return addresses
@staticmethod
def _is_allowed_private_hostname(
hostname: str,
allowed_private_ranges: Optional[Iterable[str]],
def _match_private_addresses(
addresses: Optional[List[ipaddress._BaseAddress]],
networks: List[ipaddress._BaseNetwork],
) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]:
"""
返回主机名命中的显式允许非公网地址和网段。
在已解析出的地址列表中匹配显式允许非公网网段。
该能力只用于图片代理的受控例外,例如 TUN fake-ip 或内网 CDN。必须由
`is_safe_url` 先完成域名 allowlist 校验后再调用,避免把任意用户 URL
变成 SSRF 绕过入口。
所有解析地址都必须命中至少一个允许网段才放行;只要有一个 IP 落在允许
网段外(或解析结果是全公网),就视为不匹配私网放行规则。
"""
networks = SecurityUtils._parse_ip_networks(allowed_private_ranges)
if not networks:
return None
addresses = SecurityUtils._hostname_addresses(hostname)
if not addresses:
if not addresses or not networks:
return None
if all(address.is_global for address in addresses):
return None
matched_networks = []
matched_networks: List[ipaddress._BaseNetwork] = []
for address in addresses:
matched_for_address = [
network for network in networks if address in network
@@ -192,6 +336,40 @@ class SecurityUtils:
matched_networks.extend(matched_for_address)
return addresses, list(dict.fromkeys(matched_networks))
@staticmethod
def _is_allowed_private_hostname(
hostname: str,
allowed_private_ranges: Optional[Iterable[str]],
) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]:
"""
返回主机名命中的显式允许非公网地址和网段(同步版本)。
该能力只用于图片代理的受控例外,例如 TUN fake-ip 或内网 CDN。必须由
`is_safe_url` 先完成域名 allowlist 校验后再调用,避免把任意用户 URL
变成 SSRF 绕过入口。
"""
networks = SecurityUtils._parse_ip_networks(allowed_private_ranges)
if not networks:
return None
return SecurityUtils._match_private_addresses(
SecurityUtils._hostname_addresses(hostname), networks
)
@staticmethod
async def _is_allowed_private_hostname_async(
hostname: str,
allowed_private_ranges: Optional[Iterable[str]],
) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]:
"""
`_is_allowed_private_hostname` 的异步版本,语义保持一致。
"""
networks = SecurityUtils._parse_ip_networks(allowed_private_ranges)
if not networks:
return None
return SecurityUtils._match_private_addresses(
await SecurityUtils._hostname_addresses_async(hostname), networks
)
@staticmethod
def _url_signature_payload(url: str, expires_at: int, purpose: str) -> bytes:
"""
@@ -289,6 +467,72 @@ class SecurityUtils:
return None
return clean_url
@staticmethod
def _check_url_allowlist(
url: str,
allowed_domains: Union[Set[str], List[str]],
strict: bool,
) -> Optional[str]:
"""
执行"协议 + netloc + 域名白名单"前置校验,命中返回 hostname未命中返回 None。
DNS 校验SSRF 防御)由调用方自行接续,本方法不发起 DNS 查询。
"""
try:
parsed_url = urlparse(url)
except Exception as e: # noqa: BLE001 - 任何解析异常都视为不安全 URL
logger.debug(f"Error occurred while validating URL: {e}")
return None
# 如果 URL 没有包含有效的 scheme或者无法从中提取到有效的 netloc则认为该 URL 是无效的
if not parsed_url.scheme or not parsed_url.netloc:
return None
# 仅允许 http 或 https 协议
if parsed_url.scheme not in {"http", "https"}:
return None
# 获取完整的 netloc包括 IP 和端口)并转换为小写
netloc = parsed_url.netloc.lower()
if not netloc:
return None
# 检查每个允许的域名
normalized_allowed = {d.lower() for d in allowed_domains}
domain_allowed = False
for domain in normalized_allowed:
parsed_allowed_url = urlparse(domain)
allowed_netloc = parsed_allowed_url.netloc or parsed_allowed_url.path
if strict:
# 严格模式下,要求完全匹配域名和端口
if netloc == allowed_netloc:
domain_allowed = True
break
else:
# 非严格模式下,允许子域名匹配
if netloc == allowed_netloc or netloc.endswith("." + allowed_netloc):
domain_allowed = True
break
if not domain_allowed:
return None
return parsed_url.hostname or ""
@staticmethod
def _log_private_range_allowed(
url: str,
match: tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]],
) -> None:
"""
记录"图片代理允许访问配置的非公网网段"放行日志,便于运维排查。
"""
addresses, matched_networks = match
logger.debug(
"图片代理允许访问配置的非公网网段: "
f"url={url}, ips={','.join(map(str, addresses))}, "
f"ranges={','.join(map(str, matched_networks))}"
)
@staticmethod
def is_safe_url(
url: str,
@@ -298,7 +542,7 @@ class SecurityUtils:
allowed_private_ranges: Optional[Iterable[str]] = None,
) -> bool:
"""
验证URL是否在允许的域名列表中包括带有端口的域名
验证URL是否在允许的域名列表中包括带有端口的域名(同步版本)
:param url: 需要验证的 URL
:param allowed_domains: 允许的域名集合,域名可以包含端口
@@ -306,57 +550,55 @@ class SecurityUtils:
:param block_private: 是否拦截解析到非公网地址的 URL防止 SSRF
:param allowed_private_ranges: 域名命中后额外允许的非公网 IP/CIDR 网段
:return: 如果URL合法且在允许的域名列表中返回 True否则返回 False
注意:`block_private=True` 时会同步调用 `getaddrinfo`async 上下文请改用
`is_safe_url_async`。
"""
try:
# 解析URL
parsed_url = urlparse(url)
# 如果 URL 没有包含有效的 scheme或者无法从中提取到有效的 netloc则认为该 URL 是无效的
if not parsed_url.scheme or not parsed_url.netloc:
hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict)
if hostname is None:
return False
# 仅允许 http 或 https 协议
if parsed_url.scheme not in {"http", "https"}:
return False
# 获取完整的 netloc包括 IP 和端口)并转换为小写
netloc = parsed_url.netloc.lower()
if not netloc:
return False
# 检查每个允许的域名
allowed_domains = {d.lower() for d in allowed_domains}
domain_allowed = False
for domain in allowed_domains:
parsed_allowed_url = urlparse(domain)
allowed_netloc = parsed_allowed_url.netloc or parsed_allowed_url.path
if strict:
# 严格模式下,要求完全匹配域名和端口
if netloc == allowed_netloc:
domain_allowed = True
break
else:
# 非严格模式下,允许子域名匹配
if netloc == allowed_netloc or netloc.endswith('.' + allowed_netloc):
domain_allowed = True
break
if not domain_allowed:
return False
hostname = parsed_url.hostname or ""
if block_private and not SecurityUtils._is_global_hostname(hostname):
private_match = SecurityUtils._is_allowed_private_hostname(
hostname, allowed_private_ranges
)
if private_match:
addresses, matched_networks = private_match
logger.debug(
"图片代理允许访问配置的非公网网段: "
f"url={url}, ips={','.join(map(str, addresses))}, "
f"ranges={','.join(map(str, matched_networks))}"
)
SecurityUtils._log_private_range_allowed(url, private_match)
return True
return False
return True
except Exception as e:
logger.debug(f"Error occurred while validating URL: {e}")
return False
@staticmethod
async def is_safe_url_async(
url: str,
allowed_domains: Union[Set[str], List[str]],
strict: bool = False,
block_private: bool = False,
allowed_private_ranges: Optional[Iterable[str]] = None,
) -> bool:
"""
`is_safe_url` 的异步版本,参数与返回值含义不变。
DNS 解析通过事件循环线程池执行,并复用 TTL 缓存。
"""
try:
hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict)
if hostname is None:
return False
if block_private and not await SecurityUtils._is_global_hostname_async(
hostname
):
private_match = await SecurityUtils._is_allowed_private_hostname_async(
hostname, allowed_private_ranges
)
if private_match:
SecurityUtils._log_private_range_allowed(url, private_match)
return True
return False

View File

@@ -2,10 +2,23 @@ import socket
from unittest import TestCase
from unittest.mock import patch
from app.utils.security import SecurityUtils
from app.utils.security import (
SecurityUtils,
_dns_inflight_locks,
_dns_negative_cache,
_dns_positive_cache,
)
class SecurityUtilsTest(TestCase):
def setUp(self) -> None:
"""
每个用例前清空 DNS TTL 缓存与 in-flight 锁,避免跨用例状态污染。
"""
_dns_positive_cache.clear()
_dns_negative_cache.clear()
_dns_inflight_locks.clear()
def test_signed_url_roundtrip_returns_clean_url(self):
"""
URL 签名验证成功后返回不含签名片段的真实请求地址。
@@ -272,3 +285,268 @@ class SecurityUtilsTest(TestCase):
allowed_private_ranges=["198.18.0.0/15"],
)
)
def test_is_safe_url_async_uses_event_loop_resolver(self):
"""
异步版本通过事件循环的非阻塞 getaddrinfo 完成 SSRF 校验,
且语义与同步版本保持一致:解析到非公网地址时仍然拒绝。
"""
import asyncio
async def fake_getaddrinfo(host, *_args, **_kwargs):
self.assertEqual(host, "internal.example.com")
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("127.0.0.1", 0),
)
]
async def run() -> bool:
loop = asyncio.get_running_loop()
with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo):
return await SecurityUtils.is_safe_url_async(
"http://internal.example.com/secret.png",
{"example.com"},
block_private=True,
)
self.assertFalse(asyncio.run(run()))
def test_is_safe_url_async_hits_dns_cache(self):
"""
异步与同步版本共享 DNS TTL 缓存:同步预热后,异步版本不应再发起 DNS 查询。
"""
import asyncio
# 先用同步路径预热缓存
with patch(
"app.utils.security.socket.getaddrinfo",
return_value=[
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
],
):
self.assertTrue(
SecurityUtils.is_safe_url(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
async def run() -> bool:
loop = asyncio.get_running_loop()
with patch.object(
loop,
"getaddrinfo",
side_effect=AssertionError("缓存命中后不应再次发起 DNS 查询"),
):
return await SecurityUtils.is_safe_url_async(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
self.assertTrue(asyncio.run(run()))
def test_is_safe_url_async_allows_public_dns_result(self):
"""
异步版本对全公网解析结果且命中 allowlist 时放行。
"""
import asyncio
async def fake_getaddrinfo(host, *_args, **_kwargs):
self.assertEqual(host, "assets.example.com")
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
]
async def run() -> bool:
loop = asyncio.get_running_loop()
with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo):
return await SecurityUtils.is_safe_url_async(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
self.assertTrue(asyncio.run(run()))
def test_dns_resolution_failure_populates_negative_cache(self):
"""
DNS 解析失败应回填负向缓存,避免短期内对同一目标反复触发 `getaddrinfo`。
"""
from app.utils.security import _dns_negative_cache as neg_cache
with patch(
"app.utils.security.socket.getaddrinfo",
side_effect=socket.gaierror,
) as mock_resolve:
self.assertFalse(
SecurityUtils.is_safe_url(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
self.assertEqual(mock_resolve.call_count, 1)
self.assertIn("assets.example.com", neg_cache)
self.assertFalse(
SecurityUtils.is_safe_url(
"https://assets.example.com/another.jpg",
{"example.com"},
block_private=True,
)
)
self.assertEqual(
mock_resolve.call_count,
1,
"命中负向缓存后不应再次调用 getaddrinfo",
)
def test_literal_ip_skips_dns_cache(self):
"""
URL 中的字面量 IP 走快路径,不应进入 DNS 缓存或触发 `getaddrinfo`。
"""
from app.utils.security import (
_dns_negative_cache as neg_cache,
_dns_positive_cache as pos_cache,
)
with patch(
"app.utils.security.socket.getaddrinfo",
side_effect=AssertionError("字面量 IP 不应触发 getaddrinfo"),
):
self.assertFalse(
SecurityUtils.is_safe_url(
"http://10.0.0.5:8080/secret.png",
{"http://10.0.0.5:8080"},
block_private=True,
)
)
self.assertNotIn("10.0.0.5", pos_cache)
self.assertNotIn("10.0.0.5", neg_cache)
def test_literal_ipv6_in_brackets_is_recognized(self):
"""
`urlparse` 已为 IPv6 字面量脱壳,`_literal_ip` 兼容直接传入带方括号的形式。
"""
self.assertEqual(
str(SecurityUtils._literal_ip("[::1]")),
"::1",
)
self.assertEqual(
str(SecurityUtils._literal_ip("::1")),
"::1",
)
self.assertIsNone(SecurityUtils._literal_ip("not-an-ip"))
def test_is_safe_url_async_dedupes_concurrent_inflight_queries(self):
"""
同 hostname 的并发未命中请求应通过 in-flight 锁去重,只触发一次 DNS 查询。
"""
import asyncio
call_count = 0
async def run() -> None:
nonlocal call_count
loop = asyncio.get_running_loop()
release = asyncio.Event()
async def slow_getaddrinfo(host, *_args, **_kwargs):
nonlocal call_count
call_count += 1
await release.wait()
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
]
with patch.object(loop, "getaddrinfo", side_effect=slow_getaddrinfo):
tasks = [
asyncio.create_task(
SecurityUtils.is_safe_url_async(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
for _ in range(5)
]
# 让所有任务都进入 in-flight 等待状态
await asyncio.sleep(0)
await asyncio.sleep(0)
release.set()
results = await asyncio.gather(*tasks)
self.assertTrue(all(results))
self.assertEqual(call_count, 1, "并发未命中应去重为单次 DNS 查询")
asyncio.run(run())
def test_sync_cache_access_is_thread_safe(self):
"""
同步路径下并发线程访问 DNS 缓存不应触发异常或拿到不一致结果。
TTLCache 自身非线程安全,依赖模块级 `_dns_cache_lock` 串行化读写。
"""
import threading
with patch(
"app.utils.security.socket.getaddrinfo",
return_value=[
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
],
):
results: list = []
errors: list = []
def worker() -> None:
try:
for _ in range(50):
results.append(
SecurityUtils.is_safe_url(
"https://assets.example.com/poster.jpg",
{"example.com"},
block_private=True,
)
)
except Exception as exc: # noqa: BLE001 - 用例需捕获任意异常
errors.append(exc)
threads = [threading.Thread(target=worker) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(errors, [])
self.assertTrue(all(results))
self.assertEqual(len(results), 8 * 50)