From 01c345167952eea44f39e61295d02af9086c14c9 Mon Sep 17 00:00:00 2001 From: InfinityPacer <160988576+InfinityPacer@users.noreply.github.com> Date: Mon, 25 May 2026 15:54:02 +0800 Subject: [PATCH] perf(system): async SSRF check with DNS cache for image proxy (#5832) --- app/api/endpoints/system.py | 2 +- app/utils/security.py | 464 ++++++++++++++++++++++++++--------- tests/test_security_utils.py | 280 ++++++++++++++++++++- 3 files changed, 633 insertions(+), 113 deletions(-) diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index f33d4b9e..d323b4dd 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -360,7 +360,7 @@ async def fetch_image( fetch_url = SecurityUtils.strip_url_signature(url) # 验证URL安全性 - if not SecurityUtils.is_safe_url( + if not await SecurityUtils.is_safe_url_async( url, allowed_domains, block_private=True, diff --git a/app/utils/security.py b/app/utils/security.py index 41e5bee7..e9afbcd4 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -1,18 +1,60 @@ +import asyncio import hmac import ipaddress import socket +import threading import time from hashlib import sha256 from pathlib import Path -from typing import Iterable, List, Optional, Set, Union +from typing import Dict, Iterable, List, Optional, Set, Union from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse from anyio import Path as AsyncPath +from cachetools import TTLCache from app.core.config import settings from app.log import logger +# DNS 解析结果缓存。 +# 正向缓存 TTL 选择 120s,短于常见 CDN / fake-ip 的 DNS TTL,避免长期持有失效 IP; +# 负向缓存 TTL 选择 15s,避免临时解析失败把目标长时间拉黑。 +_DNS_CACHE_MAXSIZE = 1024 +_DNS_CACHE_TTL_POSITIVE = 120 +_DNS_CACHE_TTL_NEGATIVE = 15 +_dns_positive_cache: "TTLCache[str, List[ipaddress._BaseAddress]]" = TTLCache( + maxsize=_DNS_CACHE_MAXSIZE, ttl=_DNS_CACHE_TTL_POSITIVE +) +_dns_negative_cache: "TTLCache[str, bool]" = TTLCache( + maxsize=_DNS_CACHE_MAXSIZE, ttl=_DNS_CACHE_TTL_NEGATIVE +) +# 同步路径下保护 TTLCache 读写:`cachetools.TTLCache` 本身非线程安全。 +# 锁只覆盖缓存读写,不包 `getaddrinfo`,避免把 DNS 查询本身串行化。 +_dns_cache_lock = threading.Lock() +# 同 hostname 的并发异步解析去重:同一 hostname 首次未命中时建立锁, +# 后续并发请求 await 同一把锁,避免对同一目标重复发起 `getaddrinfo`。 +_dns_inflight_locks: Dict[str, asyncio.Lock] = {} +_dns_inflight_meta_lock = threading.Lock() + + +def _resolve_addrinfo_to_ips( + address_infos: Iterable, +) -> Optional[List[ipaddress._BaseAddress]]: + """ + 将 `socket.getaddrinfo` 返回的结果归一化为 IP 列表。 + + 任一条目无法解析为 IP 即视为异常情况,整体返回 None 让上层按"不安全目标" + 处理,避免出现"部分 IP 漏校验"的情况。 + """ + addresses: List[ipaddress._BaseAddress] = [] + for address_info in address_infos: + try: + addresses.append(ipaddress.ip_address(address_info[4][0])) + except ValueError: + return None + return addresses or None + + class SecurityUtils: _SIGNED_URL_PURPOSE = "image-proxy" _SIGNED_URL_EXPIRE_SECONDS = 86400 @@ -79,38 +121,176 @@ class SecurityUtils: logger.debug(f"Error occurred while validating paths: {e}") return False + @staticmethod + def _literal_ip(hostname: str) -> Optional[ipaddress._BaseAddress]: + """ + 若 hostname 是字面量 IP(含 IPv6 的 `[::1]` 形式)则返回 IP 对象,否则 None。 + """ + if not hostname: + return None + candidate = hostname + if candidate.startswith("[") and candidate.endswith("]"): + candidate = candidate[1:-1] + try: + return ipaddress.ip_address(candidate) + except ValueError: + return None + + @staticmethod + def _cache_lookup(hostname: str) -> tuple[bool, Optional[List[ipaddress._BaseAddress]]]: + """ + 在 TTL 缓存中查找 hostname,返回 (是否命中, 命中值)。 + + 命中值为 `None` 表示命中负向缓存(先前解析失败)。 + """ + with _dns_cache_lock: + cached = _dns_positive_cache.get(hostname) + if cached is not None: + return True, cached + if hostname in _dns_negative_cache: + return True, None + return False, None + + @staticmethod + def _cache_store( + hostname: str, addresses: Optional[List[ipaddress._BaseAddress]] + ) -> None: + """ + 将解析结果写入对应的正向/负向缓存。 + """ + with _dns_cache_lock: + if addresses is None: + _dns_negative_cache[hostname] = True + else: + _dns_positive_cache[hostname] = addresses + + @staticmethod + def _hostname_addresses(hostname: str) -> Optional[List[ipaddress._BaseAddress]]: + """ + 同步解析主机名并返回全部 IP 地址,结果走 TTL 缓存。 + + 字面量 IP 直接返回自身;DNS 解析失败或结果异常时返回 None,由上层按 + 不安全目标处理。async 调用方应使用 `_hostname_addresses_async`。 + """ + if not hostname: + return None + literal = SecurityUtils._literal_ip(hostname) + if literal is not None: + return [literal] + + hit, value = SecurityUtils._cache_lookup(hostname) + if hit: + return value + + try: + address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM) + except socket.gaierror: + SecurityUtils._cache_store(hostname, None) + return None + addresses = _resolve_addrinfo_to_ips(address_infos) + SecurityUtils._cache_store(hostname, addresses) + return addresses + + @staticmethod + def _get_inflight_lock(hostname: str) -> asyncio.Lock: + """ + 取得 hostname 对应的 in-flight 锁,不存在则按需创建。 + + 用 `threading.Lock` 保护字典写入,避免多个事件循环线程并发创建出多把锁 + 破坏去重语义;锁本身是 `asyncio.Lock`,归属当前事件循环。 + """ + with _dns_inflight_meta_lock: + lock = _dns_inflight_locks.get(hostname) + if lock is None: + lock = asyncio.Lock() + _dns_inflight_locks[hostname] = lock + return lock + + @staticmethod + def _release_inflight_lock(hostname: str, lock: asyncio.Lock) -> None: + """ + 请求结束后清理 in-flight 锁,避免长期持有大量已闲置的 `asyncio.Lock`。 + 只有当前 lock 仍是字典里登记的那把、且没有其它协程在等待时才删除。 + """ + with _dns_inflight_meta_lock: + current = _dns_inflight_locks.get(hostname) + if current is lock and not lock.locked(): + _dns_inflight_locks.pop(hostname, None) + + @staticmethod + async def _hostname_addresses_async( + hostname: str, + ) -> Optional[List[ipaddress._BaseAddress]]: + """ + 异步解析主机名并返回全部 IP 地址,与同步版本共用同一份 TTL 缓存。 + + 通过事件循环的默认线程池执行 `getaddrinfo`,不阻塞 asyncio 事件循环; + 同 hostname 的并发未命中请求通过 in-flight 锁去重,只发起一次 DNS 查询。 + """ + if not hostname: + return None + literal = SecurityUtils._literal_ip(hostname) + if literal is not None: + return [literal] + + hit, value = SecurityUtils._cache_lookup(hostname) + if hit: + return value + + lock = SecurityUtils._get_inflight_lock(hostname) + async with lock: + # 等到锁后再查一次缓存,前一个持锁者可能已经回填结果 + hit, value = SecurityUtils._cache_lookup(hostname) + if hit: + SecurityUtils._release_inflight_lock(hostname, lock) + return value + + loop = asyncio.get_running_loop() + try: + address_infos = await loop.getaddrinfo( + hostname, None, type=socket.SOCK_STREAM + ) + except socket.gaierror: + SecurityUtils._cache_store(hostname, None) + SecurityUtils._release_inflight_lock(hostname, lock) + return None + addresses = _resolve_addrinfo_to_ips(address_infos) + SecurityUtils._cache_store(hostname, addresses) + SecurityUtils._release_inflight_lock(hostname, lock) + return addresses + + @staticmethod + def _addresses_all_global( + addresses: Optional[List[ipaddress._BaseAddress]], + ) -> bool: + """ + 判断解析结果是否全部为公网地址(空列表/None 视为非公网)。 + """ + if not addresses: + return False + return all(address.is_global for address in addresses) + @staticmethod def _is_global_hostname(hostname: str) -> bool: """ - 判断主机名解析结果是否全部为公网地址。 + 判断主机名解析结果是否全部为公网地址(同步版本)。 图片代理会访问用户可控的 URL,这里必须在 allowlist 命中前后都排除 私有、回环、链路本地、保留地址等非公网目标,避免通过 DNS 或字面量 IP 绕过域名白名单访问内网服务。 """ - if not hostname: - return False - try: - return ipaddress.ip_address(hostname).is_global - except ValueError: - pass + return SecurityUtils._addresses_all_global( + SecurityUtils._hostname_addresses(hostname) + ) - try: - address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM) - except socket.gaierror: - return False - - if not address_infos: - return False - - for address_info in address_infos: - try: - address = ipaddress.ip_address(address_info[4][0]) - except ValueError: - return False - if not address.is_global: - return False - return True + @staticmethod + async def _is_global_hostname_async(hostname: str) -> bool: + """ + 判断主机名解析结果是否全部为公网地址(异步版本)。语义与 `_is_global_hostname` 一致。 + """ + return SecurityUtils._addresses_all_global( + await SecurityUtils._hostname_addresses_async(hostname) + ) @staticmethod def _parse_ip_networks(ranges: Optional[Iterable[str]]) -> List[ipaddress._BaseNetwork]: @@ -131,58 +311,22 @@ class SecurityUtils: return networks @staticmethod - def _hostname_addresses(hostname: str) -> Optional[List[ipaddress._BaseAddress]]: - """ - 解析主机名并返回全部 IP 地址。 - - 字面量 IP 直接返回自身;DNS 解析失败或结果异常时返回 None,让上层按 - 不安全目标处理。 - """ - if not hostname: - return None - try: - return [ipaddress.ip_address(hostname)] - except ValueError: - pass - - try: - address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM) - except socket.gaierror: - return None - - if not address_infos: - return None - - addresses = [] - for address_info in address_infos: - try: - addresses.append(ipaddress.ip_address(address_info[4][0])) - except ValueError: - return None - return addresses - - @staticmethod - def _is_allowed_private_hostname( - hostname: str, - allowed_private_ranges: Optional[Iterable[str]], + def _match_private_addresses( + addresses: Optional[List[ipaddress._BaseAddress]], + networks: List[ipaddress._BaseNetwork], ) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]: """ - 返回主机名命中的显式允许非公网地址和网段。 + 在已解析出的地址列表中匹配显式允许的非公网网段。 - 该能力只用于图片代理的受控例外,例如 TUN fake-ip 或内网 CDN。必须由 - `is_safe_url` 先完成域名 allowlist 校验后再调用,避免把任意用户 URL - 变成 SSRF 绕过入口。 + 所有解析地址都必须命中至少一个允许网段才放行;只要有一个 IP 落在允许 + 网段外(或解析结果是全公网),就视为不匹配私网放行规则。 """ - networks = SecurityUtils._parse_ip_networks(allowed_private_ranges) - if not networks: - return None - addresses = SecurityUtils._hostname_addresses(hostname) - if not addresses: + if not addresses or not networks: return None if all(address.is_global for address in addresses): return None - matched_networks = [] + matched_networks: List[ipaddress._BaseNetwork] = [] for address in addresses: matched_for_address = [ network for network in networks if address in network @@ -192,6 +336,40 @@ class SecurityUtils: matched_networks.extend(matched_for_address) return addresses, list(dict.fromkeys(matched_networks)) + @staticmethod + def _is_allowed_private_hostname( + hostname: str, + allowed_private_ranges: Optional[Iterable[str]], + ) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]: + """ + 返回主机名命中的显式允许非公网地址和网段(同步版本)。 + + 该能力只用于图片代理的受控例外,例如 TUN fake-ip 或内网 CDN。必须由 + `is_safe_url` 先完成域名 allowlist 校验后再调用,避免把任意用户 URL + 变成 SSRF 绕过入口。 + """ + networks = SecurityUtils._parse_ip_networks(allowed_private_ranges) + if not networks: + return None + return SecurityUtils._match_private_addresses( + SecurityUtils._hostname_addresses(hostname), networks + ) + + @staticmethod + async def _is_allowed_private_hostname_async( + hostname: str, + allowed_private_ranges: Optional[Iterable[str]], + ) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]: + """ + `_is_allowed_private_hostname` 的异步版本,语义保持一致。 + """ + networks = SecurityUtils._parse_ip_networks(allowed_private_ranges) + if not networks: + return None + return SecurityUtils._match_private_addresses( + await SecurityUtils._hostname_addresses_async(hostname), networks + ) + @staticmethod def _url_signature_payload(url: str, expires_at: int, purpose: str) -> bytes: """ @@ -289,6 +467,72 @@ class SecurityUtils: return None return clean_url + @staticmethod + def _check_url_allowlist( + url: str, + allowed_domains: Union[Set[str], List[str]], + strict: bool, + ) -> Optional[str]: + """ + 执行"协议 + netloc + 域名白名单"前置校验,命中返回 hostname,未命中返回 None。 + + DNS 校验(SSRF 防御)由调用方自行接续,本方法不发起 DNS 查询。 + """ + try: + parsed_url = urlparse(url) + except Exception as e: # noqa: BLE001 - 任何解析异常都视为不安全 URL + logger.debug(f"Error occurred while validating URL: {e}") + return None + + # 如果 URL 没有包含有效的 scheme,或者无法从中提取到有效的 netloc,则认为该 URL 是无效的 + if not parsed_url.scheme or not parsed_url.netloc: + return None + # 仅允许 http 或 https 协议 + if parsed_url.scheme not in {"http", "https"}: + return None + + # 获取完整的 netloc(包括 IP 和端口)并转换为小写 + netloc = parsed_url.netloc.lower() + if not netloc: + return None + + # 检查每个允许的域名 + normalized_allowed = {d.lower() for d in allowed_domains} + domain_allowed = False + for domain in normalized_allowed: + parsed_allowed_url = urlparse(domain) + allowed_netloc = parsed_allowed_url.netloc or parsed_allowed_url.path + + if strict: + # 严格模式下,要求完全匹配域名和端口 + if netloc == allowed_netloc: + domain_allowed = True + break + else: + # 非严格模式下,允许子域名匹配 + if netloc == allowed_netloc or netloc.endswith("." + allowed_netloc): + domain_allowed = True + break + + if not domain_allowed: + return None + return parsed_url.hostname or "" + + @staticmethod + def _log_private_range_allowed( + url: str, + match: tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]], + ) -> None: + """ + 记录"图片代理允许访问配置的非公网网段"放行日志,便于运维排查。 + """ + addresses, matched_networks = match + logger.debug( + "图片代理允许访问配置的非公网网段: " + f"url={url}, ips={','.join(map(str, addresses))}, " + f"ranges={','.join(map(str, matched_networks))}" + ) + @staticmethod def is_safe_url( url: str, @@ -298,7 +542,7 @@ class SecurityUtils: allowed_private_ranges: Optional[Iterable[str]] = None, ) -> bool: """ - 验证URL是否在允许的域名列表中,包括带有端口的域名 + 验证URL是否在允许的域名列表中,包括带有端口的域名(同步版本) :param url: 需要验证的 URL :param allowed_domains: 允许的域名集合,域名可以包含端口 @@ -306,57 +550,55 @@ class SecurityUtils: :param block_private: 是否拦截解析到非公网地址的 URL,防止 SSRF :param allowed_private_ranges: 域名命中后额外允许的非公网 IP/CIDR 网段 :return: 如果URL合法且在允许的域名列表中,返回 True;否则返回 False + + 注意:`block_private=True` 时会同步调用 `getaddrinfo`;async 上下文请改用 + `is_safe_url_async`。 """ try: - # 解析URL - parsed_url = urlparse(url) - - # 如果 URL 没有包含有效的 scheme,或者无法从中提取到有效的 netloc,则认为该 URL 是无效的 - if not parsed_url.scheme or not parsed_url.netloc: + hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict) + if hostname is None: return False - # 仅允许 http 或 https 协议 - if parsed_url.scheme not in {"http", "https"}: - return False - - # 获取完整的 netloc(包括 IP 和端口)并转换为小写 - netloc = parsed_url.netloc.lower() - if not netloc: - return False - - # 检查每个允许的域名 - allowed_domains = {d.lower() for d in allowed_domains} - domain_allowed = False - for domain in allowed_domains: - parsed_allowed_url = urlparse(domain) - allowed_netloc = parsed_allowed_url.netloc or parsed_allowed_url.path - - if strict: - # 严格模式下,要求完全匹配域名和端口 - if netloc == allowed_netloc: - domain_allowed = True - break - else: - # 非严格模式下,允许子域名匹配 - if netloc == allowed_netloc or netloc.endswith('.' + allowed_netloc): - domain_allowed = True - break - - if not domain_allowed: - return False - - hostname = parsed_url.hostname or "" if block_private and not SecurityUtils._is_global_hostname(hostname): private_match = SecurityUtils._is_allowed_private_hostname( hostname, allowed_private_ranges ) if private_match: - addresses, matched_networks = private_match - logger.debug( - "图片代理允许访问配置的非公网网段: " - f"url={url}, ips={','.join(map(str, addresses))}, " - f"ranges={','.join(map(str, matched_networks))}" - ) + SecurityUtils._log_private_range_allowed(url, private_match) + return True + return False + + return True + except Exception as e: + logger.debug(f"Error occurred while validating URL: {e}") + return False + + @staticmethod + async def is_safe_url_async( + url: str, + allowed_domains: Union[Set[str], List[str]], + strict: bool = False, + block_private: bool = False, + allowed_private_ranges: Optional[Iterable[str]] = None, + ) -> bool: + """ + `is_safe_url` 的异步版本,参数与返回值含义不变。 + + DNS 解析通过事件循环线程池执行,并复用 TTL 缓存。 + """ + try: + hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict) + if hostname is None: + return False + + if block_private and not await SecurityUtils._is_global_hostname_async( + hostname + ): + private_match = await SecurityUtils._is_allowed_private_hostname_async( + hostname, allowed_private_ranges + ) + if private_match: + SecurityUtils._log_private_range_allowed(url, private_match) return True return False diff --git a/tests/test_security_utils.py b/tests/test_security_utils.py index 222b1243..fd94f701 100644 --- a/tests/test_security_utils.py +++ b/tests/test_security_utils.py @@ -2,10 +2,23 @@ import socket from unittest import TestCase from unittest.mock import patch -from app.utils.security import SecurityUtils +from app.utils.security import ( + SecurityUtils, + _dns_inflight_locks, + _dns_negative_cache, + _dns_positive_cache, +) class SecurityUtilsTest(TestCase): + def setUp(self) -> None: + """ + 每个用例前清空 DNS TTL 缓存与 in-flight 锁,避免跨用例状态污染。 + """ + _dns_positive_cache.clear() + _dns_negative_cache.clear() + _dns_inflight_locks.clear() + def test_signed_url_roundtrip_returns_clean_url(self): """ URL 签名验证成功后返回不含签名片段的真实请求地址。 @@ -272,3 +285,268 @@ class SecurityUtilsTest(TestCase): allowed_private_ranges=["198.18.0.0/15"], ) ) + + def test_is_safe_url_async_uses_event_loop_resolver(self): + """ + 异步版本通过事件循环的非阻塞 getaddrinfo 完成 SSRF 校验, + 且语义与同步版本保持一致:解析到非公网地址时仍然拒绝。 + """ + import asyncio + + async def fake_getaddrinfo(host, *_args, **_kwargs): + self.assertEqual(host, "internal.example.com") + return [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("127.0.0.1", 0), + ) + ] + + async def run() -> bool: + loop = asyncio.get_running_loop() + with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo): + return await SecurityUtils.is_safe_url_async( + "http://internal.example.com/secret.png", + {"example.com"}, + block_private=True, + ) + + self.assertFalse(asyncio.run(run())) + + def test_is_safe_url_async_hits_dns_cache(self): + """ + 异步与同步版本共享 DNS TTL 缓存:同步预热后,异步版本不应再发起 DNS 查询。 + """ + import asyncio + + # 先用同步路径预热缓存 + with patch( + "app.utils.security.socket.getaddrinfo", + return_value=[ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("93.184.216.34", 0), + ) + ], + ): + self.assertTrue( + SecurityUtils.is_safe_url( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + ) + + async def run() -> bool: + loop = asyncio.get_running_loop() + with patch.object( + loop, + "getaddrinfo", + side_effect=AssertionError("缓存命中后不应再次发起 DNS 查询"), + ): + return await SecurityUtils.is_safe_url_async( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + + self.assertTrue(asyncio.run(run())) + + def test_is_safe_url_async_allows_public_dns_result(self): + """ + 异步版本对全公网解析结果且命中 allowlist 时放行。 + """ + import asyncio + + async def fake_getaddrinfo(host, *_args, **_kwargs): + self.assertEqual(host, "assets.example.com") + return [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("93.184.216.34", 0), + ) + ] + + async def run() -> bool: + loop = asyncio.get_running_loop() + with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo): + return await SecurityUtils.is_safe_url_async( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + + self.assertTrue(asyncio.run(run())) + + def test_dns_resolution_failure_populates_negative_cache(self): + """ + DNS 解析失败应回填负向缓存,避免短期内对同一目标反复触发 `getaddrinfo`。 + """ + from app.utils.security import _dns_negative_cache as neg_cache + + with patch( + "app.utils.security.socket.getaddrinfo", + side_effect=socket.gaierror, + ) as mock_resolve: + self.assertFalse( + SecurityUtils.is_safe_url( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + ) + self.assertEqual(mock_resolve.call_count, 1) + self.assertIn("assets.example.com", neg_cache) + + self.assertFalse( + SecurityUtils.is_safe_url( + "https://assets.example.com/another.jpg", + {"example.com"}, + block_private=True, + ) + ) + self.assertEqual( + mock_resolve.call_count, + 1, + "命中负向缓存后不应再次调用 getaddrinfo", + ) + + def test_literal_ip_skips_dns_cache(self): + """ + URL 中的字面量 IP 走快路径,不应进入 DNS 缓存或触发 `getaddrinfo`。 + """ + from app.utils.security import ( + _dns_negative_cache as neg_cache, + _dns_positive_cache as pos_cache, + ) + + with patch( + "app.utils.security.socket.getaddrinfo", + side_effect=AssertionError("字面量 IP 不应触发 getaddrinfo"), + ): + self.assertFalse( + SecurityUtils.is_safe_url( + "http://10.0.0.5:8080/secret.png", + {"http://10.0.0.5:8080"}, + block_private=True, + ) + ) + self.assertNotIn("10.0.0.5", pos_cache) + self.assertNotIn("10.0.0.5", neg_cache) + + def test_literal_ipv6_in_brackets_is_recognized(self): + """ + `urlparse` 已为 IPv6 字面量脱壳,`_literal_ip` 兼容直接传入带方括号的形式。 + """ + self.assertEqual( + str(SecurityUtils._literal_ip("[::1]")), + "::1", + ) + self.assertEqual( + str(SecurityUtils._literal_ip("::1")), + "::1", + ) + self.assertIsNone(SecurityUtils._literal_ip("not-an-ip")) + + def test_is_safe_url_async_dedupes_concurrent_inflight_queries(self): + """ + 同 hostname 的并发未命中请求应通过 in-flight 锁去重,只触发一次 DNS 查询。 + """ + import asyncio + + call_count = 0 + + async def run() -> None: + nonlocal call_count + loop = asyncio.get_running_loop() + release = asyncio.Event() + + async def slow_getaddrinfo(host, *_args, **_kwargs): + nonlocal call_count + call_count += 1 + await release.wait() + return [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("93.184.216.34", 0), + ) + ] + + with patch.object(loop, "getaddrinfo", side_effect=slow_getaddrinfo): + tasks = [ + asyncio.create_task( + SecurityUtils.is_safe_url_async( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + ) + for _ in range(5) + ] + # 让所有任务都进入 in-flight 等待状态 + await asyncio.sleep(0) + await asyncio.sleep(0) + release.set() + results = await asyncio.gather(*tasks) + + self.assertTrue(all(results)) + self.assertEqual(call_count, 1, "并发未命中应去重为单次 DNS 查询") + + asyncio.run(run()) + + def test_sync_cache_access_is_thread_safe(self): + """ + 同步路径下并发线程访问 DNS 缓存不应触发异常或拿到不一致结果。 + TTLCache 自身非线程安全,依赖模块级 `_dns_cache_lock` 串行化读写。 + """ + import threading + + with patch( + "app.utils.security.socket.getaddrinfo", + return_value=[ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("93.184.216.34", 0), + ) + ], + ): + results: list = [] + errors: list = [] + + def worker() -> None: + try: + for _ in range(50): + results.append( + SecurityUtils.is_safe_url( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + ) + except Exception as exc: # noqa: BLE001 - 用例需捕获任意异常 + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(errors, []) + self.assertTrue(all(results)) + self.assertEqual(len(results), 8 * 50)