diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 77d7af3d..3ce1b3f8 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -360,7 +360,7 @@ async def fetch_image( allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) # 验证URL安全性 - if not SecurityUtils.is_safe_url(url, allowed_domains): + if not SecurityUtils.is_safe_url(url, allowed_domains, block_private=True): logger.warn(f"Blocked unsafe image URL: {url}") return None diff --git a/app/utils/security.py b/app/utils/security.py index 025fc5f3..e41a3f2f 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -1,3 +1,5 @@ +import ipaddress +import socket from hashlib import sha256 from pathlib import Path from typing import List, Optional, Set, Union @@ -73,13 +75,52 @@ class SecurityUtils: return False @staticmethod - def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: bool = False) -> bool: + def _is_global_hostname(hostname: str) -> bool: + """ + 判断主机名解析结果是否全部为公网地址。 + + 图片代理会访问用户可控的 URL,这里必须在 allowlist 命中前后都排除 + 私有、回环、链路本地、保留地址等非公网目标,避免通过 DNS 或字面量 IP + 绕过域名白名单访问内网服务。 + """ + if not hostname: + return False + try: + return ipaddress.ip_address(hostname).is_global + except ValueError: + pass + + try: + address_infos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM) + except socket.gaierror: + return False + + if not address_infos: + return False + + for address_info in address_infos: + try: + address = ipaddress.ip_address(address_info[4][0]) + except ValueError: + return False + if not address.is_global: + return False + return True + + @staticmethod + def is_safe_url( + url: str, + allowed_domains: Union[Set[str], List[str]], + strict: bool = False, + block_private: bool = False, + ) -> bool: """ 验证URL是否在允许的域名列表中,包括带有端口的域名 :param url: 需要验证的 URL :param allowed_domains: 允许的域名集合,域名可以包含端口 :param strict: 是否严格匹配一级域名(默认为 False,允许多级域名) + :param block_private: 是否拦截解析到非公网地址的 URL,防止 SSRF :return: 如果URL合法且在允许的域名列表中,返回 True;否则返回 False """ try: @@ -99,6 +140,9 @@ class SecurityUtils: if not netloc: return False + if block_private and not SecurityUtils._is_global_hostname(parsed_url.hostname or ""): + return False + # 检查每个允许的域名 allowed_domains = {d.lower() for d in allowed_domains} for domain in allowed_domains: diff --git a/tests/test_security_utils.py b/tests/test_security_utils.py new file mode 100644 index 00000000..61ae2d53 --- /dev/null +++ b/tests/test_security_utils.py @@ -0,0 +1,126 @@ +import socket +from unittest import TestCase +from unittest.mock import patch + +from app.utils.security import SecurityUtils + + +class SecurityUtilsTest(TestCase): + + def test_is_safe_url_keeps_default_allowlist_behavior(self): + """ + 默认 URL 校验保持历史 allowlist 行为,避免影响非代理调用方。 + """ + self.assertTrue( + SecurityUtils.is_safe_url( + "http://192.168.1.50:8096/secret.png", + {"http://192.168.1.50:8096"}, + ) + ) + + def test_is_safe_url_blocks_private_literal_ip_when_enabled(self): + """ + 启用 SSRF 防护时,即使内网 IP 命中 allowlist 也不能放行。 + """ + self.assertFalse( + SecurityUtils.is_safe_url( + "http://192.168.1.50:8096/secret.png", + {"http://192.168.1.50:8096"}, + block_private=True, + ) + ) + + def test_is_safe_url_blocks_loopback_dns_result_when_enabled(self): + """ + 主机名解析到回环地址时必须拒绝,防止通过域名绕过内网地址拦截。 + """ + with patch( + "app.utils.security.socket.getaddrinfo", + return_value=[ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("127.0.0.1", 0), + ) + ], + ): + self.assertFalse( + SecurityUtils.is_safe_url( + "http://internal.example.com/secret.png", + {"example.com"}, + block_private=True, + ) + ) + + def test_is_safe_url_blocks_mixed_public_and_private_dns_results(self): + """ + 同一域名只要存在任一非公网解析结果,就不能作为图片代理目标。 + """ + with patch( + "app.utils.security.socket.getaddrinfo", + return_value=[ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("93.184.216.34", 0), + ), + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("10.0.0.8", 0), + ), + ], + ): + self.assertFalse( + SecurityUtils.is_safe_url( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + ) + + def test_is_safe_url_allows_public_dns_result_when_enabled(self): + """ + 域名解析结果全部为公网地址且命中 allowlist 时继续允许访问。 + """ + with patch( + "app.utils.security.socket.getaddrinfo", + return_value=[ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 0, + "", + ("93.184.216.34", 0), + ) + ], + ): + self.assertTrue( + SecurityUtils.is_safe_url( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + ) + + def test_is_safe_url_rejects_dns_resolution_failure_when_enabled(self): + """ + SSRF 防护无法确认目标地址时按失败处理,避免解析异常时继续请求。 + """ + with patch( + "app.utils.security.socket.getaddrinfo", + side_effect=socket.gaierror, + ): + self.assertFalse( + SecurityUtils.is_safe_url( + "https://assets.example.com/poster.jpg", + {"example.com"}, + block_private=True, + ) + ) diff --git a/tests/test_system_nettest.py b/tests/test_system_nettest.py index 3d62b425..b9dd01fd 100644 --- a/tests/test_system_nettest.py +++ b/tests/test_system_nettest.py @@ -33,6 +33,7 @@ for _module_name in ("pillow_avif", "aiofiles", "psutil"): _stub_module(_module_name) _stub_module("app.helper.sites", SitesHelper=_Dummy) +_stub_module("app.chain.media", MediaChain=_Dummy) _stub_module("app.chain.mediaserver", MediaServerChain=_Dummy) _stub_module("app.chain.search", SearchChain=_Dummy) _stub_module("app.chain.system", SystemChain=_Dummy) @@ -81,6 +82,24 @@ from app.api.endpoints import system as system_endpoint class NettestSecurityTest(unittest.TestCase): + def test_fetch_image_blocks_private_allowed_url_before_request(self): + """ + 图片代理即使拿到内网 allowlist 项,也必须在发起请求前拦截。 + """ + class FailIfCalled: + def __init__(self, *args, **kwargs): + raise AssertionError("fetch_image should block private URLs before fetching") + + with patch.object(system_endpoint, "ImageHelper", FailIfCalled): + resp = asyncio.run( + system_endpoint.fetch_image( + url="http://127.0.0.1:8096/secret.png", + allowed_domains={"http://127.0.0.1:8096"}, + ) + ) + + self.assertIsNone(resp) + def test_nettest_targets_are_served_by_backend(self): resp = asyncio.run(system_endpoint.nettest_targets(_="token"))