mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-03 14:39:56 +08:00
perf(system): async SSRF check with DNS cache for image proxy (#5832)
This commit is contained in:
@@ -360,7 +360,7 @@ async def fetch_image(
|
||||
|
||||
fetch_url = SecurityUtils.strip_url_signature(url)
|
||||
# 验证URL安全性
|
||||
if not SecurityUtils.is_safe_url(
|
||||
if not await SecurityUtils.is_safe_url_async(
|
||||
url,
|
||||
allowed_domains,
|
||||
block_private=True,
|
||||
|
||||
@@ -1,18 +1,60 @@
|
||||
import asyncio
|
||||
import hmac
|
||||
import ipaddress
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Set, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Union
|
||||
from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from cachetools import TTLCache
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
# DNS 解析结果缓存。
|
||||
# 正向缓存 TTL 选择 120s,短于常见 CDN / fake-ip 的 DNS TTL,避免长期持有失效 IP;
|
||||
# 负向缓存 TTL 选择 15s,避免临时解析失败把目标长时间拉黑。
|
||||
_DNS_CACHE_MAXSIZE = 1024
|
||||
_DNS_CACHE_TTL_POSITIVE = 120
|
||||
_DNS_CACHE_TTL_NEGATIVE = 15
|
||||
_dns_positive_cache: "TTLCache[str, List[ipaddress._BaseAddress]]" = TTLCache(
|
||||
maxsize=_DNS_CACHE_MAXSIZE, ttl=_DNS_CACHE_TTL_POSITIVE
|
||||
)
|
||||
_dns_negative_cache: "TTLCache[str, bool]" = TTLCache(
|
||||
maxsize=_DNS_CACHE_MAXSIZE, ttl=_DNS_CACHE_TTL_NEGATIVE
|
||||
)
|
||||
# 同步路径下保护 TTLCache 读写:`cachetools.TTLCache` 本身非线程安全。
|
||||
# 锁只覆盖缓存读写,不包 `getaddrinfo`,避免把 DNS 查询本身串行化。
|
||||
_dns_cache_lock = threading.Lock()
|
||||
# 同 hostname 的并发异步解析去重:同一 hostname 首次未命中时建立锁,
|
||||
# 后续并发请求 await 同一把锁,避免对同一目标重复发起 `getaddrinfo`。
|
||||
_dns_inflight_locks: Dict[str, asyncio.Lock] = {}
|
||||
_dns_inflight_meta_lock = threading.Lock()
|
||||
|
||||
|
||||
def _resolve_addrinfo_to_ips(
|
||||
address_infos: Iterable,
|
||||
) -> Optional[List[ipaddress._BaseAddress]]:
|
||||
"""
|
||||
将 `socket.getaddrinfo` 返回的结果归一化为 IP 列表。
|
||||
|
||||
任一条目无法解析为 IP 即视为异常情况,整体返回 None 让上层按"不安全目标"
|
||||
处理,避免出现"部分 IP 漏校验"的情况。
|
||||
"""
|
||||
addresses: List[ipaddress._BaseAddress] = []
|
||||
for address_info in address_infos:
|
||||
try:
|
||||
addresses.append(ipaddress.ip_address(address_info[4][0]))
|
||||
except ValueError:
|
||||
return None
|
||||
return addresses or None
|
||||
|
||||
|
||||
class SecurityUtils:
|
||||
_SIGNED_URL_PURPOSE = "image-proxy"
|
||||
_SIGNED_URL_EXPIRE_SECONDS = 86400
|
||||
@@ -79,38 +121,176 @@ class SecurityUtils:
|
||||
logger.debug(f"Error occurred while validating paths: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _literal_ip(hostname: str) -> Optional[ipaddress._BaseAddress]:
|
||||
"""
|
||||
若 hostname 是字面量 IP(含 IPv6 的 `[::1]` 形式)则返回 IP 对象,否则 None。
|
||||
"""
|
||||
if not hostname:
|
||||
return None
|
||||
candidate = hostname
|
||||
if candidate.startswith("[") and candidate.endswith("]"):
|
||||
candidate = candidate[1:-1]
|
||||
try:
|
||||
return ipaddress.ip_address(candidate)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _cache_lookup(hostname: str) -> tuple[bool, Optional[List[ipaddress._BaseAddress]]]:
|
||||
"""
|
||||
在 TTL 缓存中查找 hostname,返回 (是否命中, 命中值)。
|
||||
|
||||
命中值为 `None` 表示命中负向缓存(先前解析失败)。
|
||||
"""
|
||||
with _dns_cache_lock:
|
||||
cached = _dns_positive_cache.get(hostname)
|
||||
if cached is not None:
|
||||
return True, cached
|
||||
if hostname in _dns_negative_cache:
|
||||
return True, None
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def _cache_store(
|
||||
hostname: str, addresses: Optional[List[ipaddress._BaseAddress]]
|
||||
) -> None:
|
||||
"""
|
||||
将解析结果写入对应的正向/负向缓存。
|
||||
"""
|
||||
with _dns_cache_lock:
|
||||
if addresses is None:
|
||||
_dns_negative_cache[hostname] = True
|
||||
else:
|
||||
_dns_positive_cache[hostname] = addresses
|
||||
|
||||
@staticmethod
|
||||
def _hostname_addresses(hostname: str) -> Optional[List[ipaddress._BaseAddress]]:
|
||||
"""
|
||||
同步解析主机名并返回全部 IP 地址,结果走 TTL 缓存。
|
||||
|
||||
字面量 IP 直接返回自身;DNS 解析失败或结果异常时返回 None,由上层按
|
||||
不安全目标处理。async 调用方应使用 `_hostname_addresses_async`。
|
||||
"""
|
||||
if not hostname:
|
||||
return None
|
||||
literal = SecurityUtils._literal_ip(hostname)
|
||||
if literal is not None:
|
||||
return [literal]
|
||||
|
||||
hit, value = SecurityUtils._cache_lookup(hostname)
|
||||
if hit:
|
||||
return value
|
||||
|
||||
try:
|
||||
address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
SecurityUtils._cache_store(hostname, None)
|
||||
return None
|
||||
addresses = _resolve_addrinfo_to_ips(address_infos)
|
||||
SecurityUtils._cache_store(hostname, addresses)
|
||||
return addresses
|
||||
|
||||
@staticmethod
|
||||
def _get_inflight_lock(hostname: str) -> asyncio.Lock:
|
||||
"""
|
||||
取得 hostname 对应的 in-flight 锁,不存在则按需创建。
|
||||
|
||||
用 `threading.Lock` 保护字典写入,避免多个事件循环线程并发创建出多把锁
|
||||
破坏去重语义;锁本身是 `asyncio.Lock`,归属当前事件循环。
|
||||
"""
|
||||
with _dns_inflight_meta_lock:
|
||||
lock = _dns_inflight_locks.get(hostname)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_dns_inflight_locks[hostname] = lock
|
||||
return lock
|
||||
|
||||
@staticmethod
|
||||
def _release_inflight_lock(hostname: str, lock: asyncio.Lock) -> None:
|
||||
"""
|
||||
请求结束后清理 in-flight 锁,避免长期持有大量已闲置的 `asyncio.Lock`。
|
||||
只有当前 lock 仍是字典里登记的那把、且没有其它协程在等待时才删除。
|
||||
"""
|
||||
with _dns_inflight_meta_lock:
|
||||
current = _dns_inflight_locks.get(hostname)
|
||||
if current is lock and not lock.locked():
|
||||
_dns_inflight_locks.pop(hostname, None)
|
||||
|
||||
@staticmethod
|
||||
async def _hostname_addresses_async(
|
||||
hostname: str,
|
||||
) -> Optional[List[ipaddress._BaseAddress]]:
|
||||
"""
|
||||
异步解析主机名并返回全部 IP 地址,与同步版本共用同一份 TTL 缓存。
|
||||
|
||||
通过事件循环的默认线程池执行 `getaddrinfo`,不阻塞 asyncio 事件循环;
|
||||
同 hostname 的并发未命中请求通过 in-flight 锁去重,只发起一次 DNS 查询。
|
||||
"""
|
||||
if not hostname:
|
||||
return None
|
||||
literal = SecurityUtils._literal_ip(hostname)
|
||||
if literal is not None:
|
||||
return [literal]
|
||||
|
||||
hit, value = SecurityUtils._cache_lookup(hostname)
|
||||
if hit:
|
||||
return value
|
||||
|
||||
lock = SecurityUtils._get_inflight_lock(hostname)
|
||||
async with lock:
|
||||
# 等到锁后再查一次缓存,前一个持锁者可能已经回填结果
|
||||
hit, value = SecurityUtils._cache_lookup(hostname)
|
||||
if hit:
|
||||
SecurityUtils._release_inflight_lock(hostname, lock)
|
||||
return value
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
address_infos = await loop.getaddrinfo(
|
||||
hostname, None, type=socket.SOCK_STREAM
|
||||
)
|
||||
except socket.gaierror:
|
||||
SecurityUtils._cache_store(hostname, None)
|
||||
SecurityUtils._release_inflight_lock(hostname, lock)
|
||||
return None
|
||||
addresses = _resolve_addrinfo_to_ips(address_infos)
|
||||
SecurityUtils._cache_store(hostname, addresses)
|
||||
SecurityUtils._release_inflight_lock(hostname, lock)
|
||||
return addresses
|
||||
|
||||
@staticmethod
|
||||
def _addresses_all_global(
|
||||
addresses: Optional[List[ipaddress._BaseAddress]],
|
||||
) -> bool:
|
||||
"""
|
||||
判断解析结果是否全部为公网地址(空列表/None 视为非公网)。
|
||||
"""
|
||||
if not addresses:
|
||||
return False
|
||||
return all(address.is_global for address in addresses)
|
||||
|
||||
@staticmethod
|
||||
def _is_global_hostname(hostname: str) -> bool:
|
||||
"""
|
||||
判断主机名解析结果是否全部为公网地址。
|
||||
判断主机名解析结果是否全部为公网地址(同步版本)。
|
||||
|
||||
图片代理会访问用户可控的 URL,这里必须在 allowlist 命中前后都排除
|
||||
私有、回环、链路本地、保留地址等非公网目标,避免通过 DNS 或字面量 IP
|
||||
绕过域名白名单访问内网服务。
|
||||
"""
|
||||
if not hostname:
|
||||
return False
|
||||
try:
|
||||
return ipaddress.ip_address(hostname).is_global
|
||||
except ValueError:
|
||||
pass
|
||||
return SecurityUtils._addresses_all_global(
|
||||
SecurityUtils._hostname_addresses(hostname)
|
||||
)
|
||||
|
||||
try:
|
||||
address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
return False
|
||||
|
||||
if not address_infos:
|
||||
return False
|
||||
|
||||
for address_info in address_infos:
|
||||
try:
|
||||
address = ipaddress.ip_address(address_info[4][0])
|
||||
except ValueError:
|
||||
return False
|
||||
if not address.is_global:
|
||||
return False
|
||||
return True
|
||||
@staticmethod
|
||||
async def _is_global_hostname_async(hostname: str) -> bool:
|
||||
"""
|
||||
判断主机名解析结果是否全部为公网地址(异步版本)。语义与 `_is_global_hostname` 一致。
|
||||
"""
|
||||
return SecurityUtils._addresses_all_global(
|
||||
await SecurityUtils._hostname_addresses_async(hostname)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_ip_networks(ranges: Optional[Iterable[str]]) -> List[ipaddress._BaseNetwork]:
|
||||
@@ -131,58 +311,22 @@ class SecurityUtils:
|
||||
return networks
|
||||
|
||||
@staticmethod
|
||||
def _hostname_addresses(hostname: str) -> Optional[List[ipaddress._BaseAddress]]:
|
||||
"""
|
||||
解析主机名并返回全部 IP 地址。
|
||||
|
||||
字面量 IP 直接返回自身;DNS 解析失败或结果异常时返回 None,让上层按
|
||||
不安全目标处理。
|
||||
"""
|
||||
if not hostname:
|
||||
return None
|
||||
try:
|
||||
return [ipaddress.ip_address(hostname)]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
return None
|
||||
|
||||
if not address_infos:
|
||||
return None
|
||||
|
||||
addresses = []
|
||||
for address_info in address_infos:
|
||||
try:
|
||||
addresses.append(ipaddress.ip_address(address_info[4][0]))
|
||||
except ValueError:
|
||||
return None
|
||||
return addresses
|
||||
|
||||
@staticmethod
|
||||
def _is_allowed_private_hostname(
|
||||
hostname: str,
|
||||
allowed_private_ranges: Optional[Iterable[str]],
|
||||
def _match_private_addresses(
|
||||
addresses: Optional[List[ipaddress._BaseAddress]],
|
||||
networks: List[ipaddress._BaseNetwork],
|
||||
) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]:
|
||||
"""
|
||||
返回主机名命中的显式允许非公网地址和网段。
|
||||
在已解析出的地址列表中匹配显式允许的非公网网段。
|
||||
|
||||
该能力只用于图片代理的受控例外,例如 TUN fake-ip 或内网 CDN。必须由
|
||||
`is_safe_url` 先完成域名 allowlist 校验后再调用,避免把任意用户 URL
|
||||
变成 SSRF 绕过入口。
|
||||
所有解析地址都必须命中至少一个允许网段才放行;只要有一个 IP 落在允许
|
||||
网段外(或解析结果是全公网),就视为不匹配私网放行规则。
|
||||
"""
|
||||
networks = SecurityUtils._parse_ip_networks(allowed_private_ranges)
|
||||
if not networks:
|
||||
return None
|
||||
addresses = SecurityUtils._hostname_addresses(hostname)
|
||||
if not addresses:
|
||||
if not addresses or not networks:
|
||||
return None
|
||||
if all(address.is_global for address in addresses):
|
||||
return None
|
||||
|
||||
matched_networks = []
|
||||
matched_networks: List[ipaddress._BaseNetwork] = []
|
||||
for address in addresses:
|
||||
matched_for_address = [
|
||||
network for network in networks if address in network
|
||||
@@ -192,6 +336,40 @@ class SecurityUtils:
|
||||
matched_networks.extend(matched_for_address)
|
||||
return addresses, list(dict.fromkeys(matched_networks))
|
||||
|
||||
@staticmethod
|
||||
def _is_allowed_private_hostname(
|
||||
hostname: str,
|
||||
allowed_private_ranges: Optional[Iterable[str]],
|
||||
) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]:
|
||||
"""
|
||||
返回主机名命中的显式允许非公网地址和网段(同步版本)。
|
||||
|
||||
该能力只用于图片代理的受控例外,例如 TUN fake-ip 或内网 CDN。必须由
|
||||
`is_safe_url` 先完成域名 allowlist 校验后再调用,避免把任意用户 URL
|
||||
变成 SSRF 绕过入口。
|
||||
"""
|
||||
networks = SecurityUtils._parse_ip_networks(allowed_private_ranges)
|
||||
if not networks:
|
||||
return None
|
||||
return SecurityUtils._match_private_addresses(
|
||||
SecurityUtils._hostname_addresses(hostname), networks
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _is_allowed_private_hostname_async(
|
||||
hostname: str,
|
||||
allowed_private_ranges: Optional[Iterable[str]],
|
||||
) -> Optional[tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]]]:
|
||||
"""
|
||||
`_is_allowed_private_hostname` 的异步版本,语义保持一致。
|
||||
"""
|
||||
networks = SecurityUtils._parse_ip_networks(allowed_private_ranges)
|
||||
if not networks:
|
||||
return None
|
||||
return SecurityUtils._match_private_addresses(
|
||||
await SecurityUtils._hostname_addresses_async(hostname), networks
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _url_signature_payload(url: str, expires_at: int, purpose: str) -> bytes:
|
||||
"""
|
||||
@@ -289,6 +467,72 @@ class SecurityUtils:
|
||||
return None
|
||||
return clean_url
|
||||
|
||||
@staticmethod
|
||||
def _check_url_allowlist(
|
||||
url: str,
|
||||
allowed_domains: Union[Set[str], List[str]],
|
||||
strict: bool,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
执行"协议 + netloc + 域名白名单"前置校验,命中返回 hostname,未命中返回 None。
|
||||
|
||||
DNS 校验(SSRF 防御)由调用方自行接续,本方法不发起 DNS 查询。
|
||||
"""
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
except Exception as e: # noqa: BLE001 - 任何解析异常都视为不安全 URL
|
||||
logger.debug(f"Error occurred while validating URL: {e}")
|
||||
return None
|
||||
|
||||
# 如果 URL 没有包含有效的 scheme,或者无法从中提取到有效的 netloc,则认为该 URL 是无效的
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
return None
|
||||
# 仅允许 http 或 https 协议
|
||||
if parsed_url.scheme not in {"http", "https"}:
|
||||
return None
|
||||
|
||||
# 获取完整的 netloc(包括 IP 和端口)并转换为小写
|
||||
netloc = parsed_url.netloc.lower()
|
||||
if not netloc:
|
||||
return None
|
||||
|
||||
# 检查每个允许的域名
|
||||
normalized_allowed = {d.lower() for d in allowed_domains}
|
||||
domain_allowed = False
|
||||
for domain in normalized_allowed:
|
||||
parsed_allowed_url = urlparse(domain)
|
||||
allowed_netloc = parsed_allowed_url.netloc or parsed_allowed_url.path
|
||||
|
||||
if strict:
|
||||
# 严格模式下,要求完全匹配域名和端口
|
||||
if netloc == allowed_netloc:
|
||||
domain_allowed = True
|
||||
break
|
||||
else:
|
||||
# 非严格模式下,允许子域名匹配
|
||||
if netloc == allowed_netloc or netloc.endswith("." + allowed_netloc):
|
||||
domain_allowed = True
|
||||
break
|
||||
|
||||
if not domain_allowed:
|
||||
return None
|
||||
return parsed_url.hostname or ""
|
||||
|
||||
@staticmethod
|
||||
def _log_private_range_allowed(
|
||||
url: str,
|
||||
match: tuple[List[ipaddress._BaseAddress], List[ipaddress._BaseNetwork]],
|
||||
) -> None:
|
||||
"""
|
||||
记录"图片代理允许访问配置的非公网网段"放行日志,便于运维排查。
|
||||
"""
|
||||
addresses, matched_networks = match
|
||||
logger.debug(
|
||||
"图片代理允许访问配置的非公网网段: "
|
||||
f"url={url}, ips={','.join(map(str, addresses))}, "
|
||||
f"ranges={','.join(map(str, matched_networks))}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_safe_url(
|
||||
url: str,
|
||||
@@ -298,7 +542,7 @@ class SecurityUtils:
|
||||
allowed_private_ranges: Optional[Iterable[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证URL是否在允许的域名列表中,包括带有端口的域名
|
||||
验证URL是否在允许的域名列表中,包括带有端口的域名(同步版本)
|
||||
|
||||
:param url: 需要验证的 URL
|
||||
:param allowed_domains: 允许的域名集合,域名可以包含端口
|
||||
@@ -306,57 +550,55 @@ class SecurityUtils:
|
||||
:param block_private: 是否拦截解析到非公网地址的 URL,防止 SSRF
|
||||
:param allowed_private_ranges: 域名命中后额外允许的非公网 IP/CIDR 网段
|
||||
:return: 如果URL合法且在允许的域名列表中,返回 True;否则返回 False
|
||||
|
||||
注意:`block_private=True` 时会同步调用 `getaddrinfo`;async 上下文请改用
|
||||
`is_safe_url_async`。
|
||||
"""
|
||||
try:
|
||||
# 解析URL
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
# 如果 URL 没有包含有效的 scheme,或者无法从中提取到有效的 netloc,则认为该 URL 是无效的
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict)
|
||||
if hostname is None:
|
||||
return False
|
||||
|
||||
# 仅允许 http 或 https 协议
|
||||
if parsed_url.scheme not in {"http", "https"}:
|
||||
return False
|
||||
|
||||
# 获取完整的 netloc(包括 IP 和端口)并转换为小写
|
||||
netloc = parsed_url.netloc.lower()
|
||||
if not netloc:
|
||||
return False
|
||||
|
||||
# 检查每个允许的域名
|
||||
allowed_domains = {d.lower() for d in allowed_domains}
|
||||
domain_allowed = False
|
||||
for domain in allowed_domains:
|
||||
parsed_allowed_url = urlparse(domain)
|
||||
allowed_netloc = parsed_allowed_url.netloc or parsed_allowed_url.path
|
||||
|
||||
if strict:
|
||||
# 严格模式下,要求完全匹配域名和端口
|
||||
if netloc == allowed_netloc:
|
||||
domain_allowed = True
|
||||
break
|
||||
else:
|
||||
# 非严格模式下,允许子域名匹配
|
||||
if netloc == allowed_netloc or netloc.endswith('.' + allowed_netloc):
|
||||
domain_allowed = True
|
||||
break
|
||||
|
||||
if not domain_allowed:
|
||||
return False
|
||||
|
||||
hostname = parsed_url.hostname or ""
|
||||
if block_private and not SecurityUtils._is_global_hostname(hostname):
|
||||
private_match = SecurityUtils._is_allowed_private_hostname(
|
||||
hostname, allowed_private_ranges
|
||||
)
|
||||
if private_match:
|
||||
addresses, matched_networks = private_match
|
||||
logger.debug(
|
||||
"图片代理允许访问配置的非公网网段: "
|
||||
f"url={url}, ips={','.join(map(str, addresses))}, "
|
||||
f"ranges={','.join(map(str, matched_networks))}"
|
||||
)
|
||||
SecurityUtils._log_private_range_allowed(url, private_match)
|
||||
return True
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Error occurred while validating URL: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def is_safe_url_async(
|
||||
url: str,
|
||||
allowed_domains: Union[Set[str], List[str]],
|
||||
strict: bool = False,
|
||||
block_private: bool = False,
|
||||
allowed_private_ranges: Optional[Iterable[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
`is_safe_url` 的异步版本,参数与返回值含义不变。
|
||||
|
||||
DNS 解析通过事件循环线程池执行,并复用 TTL 缓存。
|
||||
"""
|
||||
try:
|
||||
hostname = SecurityUtils._check_url_allowlist(url, allowed_domains, strict)
|
||||
if hostname is None:
|
||||
return False
|
||||
|
||||
if block_private and not await SecurityUtils._is_global_hostname_async(
|
||||
hostname
|
||||
):
|
||||
private_match = await SecurityUtils._is_allowed_private_hostname_async(
|
||||
hostname, allowed_private_ranges
|
||||
)
|
||||
if private_match:
|
||||
SecurityUtils._log_private_range_allowed(url, private_match)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -2,10 +2,23 @@ import socket
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.security import (
|
||||
SecurityUtils,
|
||||
_dns_inflight_locks,
|
||||
_dns_negative_cache,
|
||||
_dns_positive_cache,
|
||||
)
|
||||
|
||||
|
||||
class SecurityUtilsTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
每个用例前清空 DNS TTL 缓存与 in-flight 锁,避免跨用例状态污染。
|
||||
"""
|
||||
_dns_positive_cache.clear()
|
||||
_dns_negative_cache.clear()
|
||||
_dns_inflight_locks.clear()
|
||||
|
||||
def test_signed_url_roundtrip_returns_clean_url(self):
|
||||
"""
|
||||
URL 签名验证成功后返回不含签名片段的真实请求地址。
|
||||
@@ -272,3 +285,268 @@ class SecurityUtilsTest(TestCase):
|
||||
allowed_private_ranges=["198.18.0.0/15"],
|
||||
)
|
||||
)
|
||||
|
||||
def test_is_safe_url_async_uses_event_loop_resolver(self):
|
||||
"""
|
||||
异步版本通过事件循环的非阻塞 getaddrinfo 完成 SSRF 校验,
|
||||
且语义与同步版本保持一致:解析到非公网地址时仍然拒绝。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
async def fake_getaddrinfo(host, *_args, **_kwargs):
|
||||
self.assertEqual(host, "internal.example.com")
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
"",
|
||||
("127.0.0.1", 0),
|
||||
)
|
||||
]
|
||||
|
||||
async def run() -> bool:
|
||||
loop = asyncio.get_running_loop()
|
||||
with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo):
|
||||
return await SecurityUtils.is_safe_url_async(
|
||||
"http://internal.example.com/secret.png",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
|
||||
self.assertFalse(asyncio.run(run()))
|
||||
|
||||
def test_is_safe_url_async_hits_dns_cache(self):
|
||||
"""
|
||||
异步与同步版本共享 DNS TTL 缓存:同步预热后,异步版本不应再发起 DNS 查询。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 先用同步路径预热缓存
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
return_value=[
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
"",
|
||||
("93.184.216.34", 0),
|
||||
)
|
||||
],
|
||||
):
|
||||
self.assertTrue(
|
||||
SecurityUtils.is_safe_url(
|
||||
"https://assets.example.com/poster.jpg",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
)
|
||||
|
||||
async def run() -> bool:
|
||||
loop = asyncio.get_running_loop()
|
||||
with patch.object(
|
||||
loop,
|
||||
"getaddrinfo",
|
||||
side_effect=AssertionError("缓存命中后不应再次发起 DNS 查询"),
|
||||
):
|
||||
return await SecurityUtils.is_safe_url_async(
|
||||
"https://assets.example.com/poster.jpg",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
|
||||
self.assertTrue(asyncio.run(run()))
|
||||
|
||||
def test_is_safe_url_async_allows_public_dns_result(self):
|
||||
"""
|
||||
异步版本对全公网解析结果且命中 allowlist 时放行。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
async def fake_getaddrinfo(host, *_args, **_kwargs):
|
||||
self.assertEqual(host, "assets.example.com")
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
"",
|
||||
("93.184.216.34", 0),
|
||||
)
|
||||
]
|
||||
|
||||
async def run() -> bool:
|
||||
loop = asyncio.get_running_loop()
|
||||
with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo):
|
||||
return await SecurityUtils.is_safe_url_async(
|
||||
"https://assets.example.com/poster.jpg",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
|
||||
self.assertTrue(asyncio.run(run()))
|
||||
|
||||
def test_dns_resolution_failure_populates_negative_cache(self):
|
||||
"""
|
||||
DNS 解析失败应回填负向缓存,避免短期内对同一目标反复触发 `getaddrinfo`。
|
||||
"""
|
||||
from app.utils.security import _dns_negative_cache as neg_cache
|
||||
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
side_effect=socket.gaierror,
|
||||
) as mock_resolve:
|
||||
self.assertFalse(
|
||||
SecurityUtils.is_safe_url(
|
||||
"https://assets.example.com/poster.jpg",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(mock_resolve.call_count, 1)
|
||||
self.assertIn("assets.example.com", neg_cache)
|
||||
|
||||
self.assertFalse(
|
||||
SecurityUtils.is_safe_url(
|
||||
"https://assets.example.com/another.jpg",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
mock_resolve.call_count,
|
||||
1,
|
||||
"命中负向缓存后不应再次调用 getaddrinfo",
|
||||
)
|
||||
|
||||
def test_literal_ip_skips_dns_cache(self):
|
||||
"""
|
||||
URL 中的字面量 IP 走快路径,不应进入 DNS 缓存或触发 `getaddrinfo`。
|
||||
"""
|
||||
from app.utils.security import (
|
||||
_dns_negative_cache as neg_cache,
|
||||
_dns_positive_cache as pos_cache,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
side_effect=AssertionError("字面量 IP 不应触发 getaddrinfo"),
|
||||
):
|
||||
self.assertFalse(
|
||||
SecurityUtils.is_safe_url(
|
||||
"http://10.0.0.5:8080/secret.png",
|
||||
{"http://10.0.0.5:8080"},
|
||||
block_private=True,
|
||||
)
|
||||
)
|
||||
self.assertNotIn("10.0.0.5", pos_cache)
|
||||
self.assertNotIn("10.0.0.5", neg_cache)
|
||||
|
||||
def test_literal_ipv6_in_brackets_is_recognized(self):
|
||||
"""
|
||||
`urlparse` 已为 IPv6 字面量脱壳,`_literal_ip` 兼容直接传入带方括号的形式。
|
||||
"""
|
||||
self.assertEqual(
|
||||
str(SecurityUtils._literal_ip("[::1]")),
|
||||
"::1",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(SecurityUtils._literal_ip("::1")),
|
||||
"::1",
|
||||
)
|
||||
self.assertIsNone(SecurityUtils._literal_ip("not-an-ip"))
|
||||
|
||||
def test_is_safe_url_async_dedupes_concurrent_inflight_queries(self):
|
||||
"""
|
||||
同 hostname 的并发未命中请求应通过 in-flight 锁去重,只触发一次 DNS 查询。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def run() -> None:
|
||||
nonlocal call_count
|
||||
loop = asyncio.get_running_loop()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_getaddrinfo(host, *_args, **_kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await release.wait()
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
"",
|
||||
("93.184.216.34", 0),
|
||||
)
|
||||
]
|
||||
|
||||
with patch.object(loop, "getaddrinfo", side_effect=slow_getaddrinfo):
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
SecurityUtils.is_safe_url_async(
|
||||
"https://assets.example.com/poster.jpg",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
# 让所有任务都进入 in-flight 等待状态
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
release.set()
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
self.assertTrue(all(results))
|
||||
self.assertEqual(call_count, 1, "并发未命中应去重为单次 DNS 查询")
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_sync_cache_access_is_thread_safe(self):
|
||||
"""
|
||||
同步路径下并发线程访问 DNS 缓存不应触发异常或拿到不一致结果。
|
||||
TTLCache 自身非线程安全,依赖模块级 `_dns_cache_lock` 串行化读写。
|
||||
"""
|
||||
import threading
|
||||
|
||||
with patch(
|
||||
"app.utils.security.socket.getaddrinfo",
|
||||
return_value=[
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
"",
|
||||
("93.184.216.34", 0),
|
||||
)
|
||||
],
|
||||
):
|
||||
results: list = []
|
||||
errors: list = []
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
for _ in range(50):
|
||||
results.append(
|
||||
SecurityUtils.is_safe_url(
|
||||
"https://assets.example.com/poster.jpg",
|
||||
{"example.com"},
|
||||
block_private=True,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - 用例需捕获任意异常
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(8)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(errors, [])
|
||||
self.assertTrue(all(results))
|
||||
self.assertEqual(len(results), 8 * 50)
|
||||
|
||||
Reference in New Issue
Block a user