fix(security): release SSRF DNS inflight lock outside async with block (#5834)

This commit is contained in:
InfinityPacer
2026-05-25 16:45:32 +08:00
committed by GitHub
parent d940373f6b
commit d57deb1df1
2 changed files with 125 additions and 20 deletions

View File

@@ -210,7 +210,11 @@ class SecurityUtils:
def _release_inflight_lock(hostname: str, lock: asyncio.Lock) -> None:
"""
请求结束后清理 in-flight 锁,避免长期持有大量已闲置的 `asyncio.Lock`。
只有当前 lock 仍是字典里登记的那把、且没有其它协程在等待时才删除。
仅当字典中登记的仍是当前 lock且 `lock.locked()` 为 False 时才删除。
`asyncio.Lock` 公平 FIFO持有者释放后若仍有等待者锁会立刻被下一个
等待者接走、`locked()` 重新变为 True因此该守卫可同时排除"仍有持有者"
"刚被等待者接走"两种情况,避免误删后续协程仍在使用的字典条目。
"""
with _dns_inflight_meta_lock:
current = _dns_inflight_locks.get(hostname)
@@ -238,26 +242,28 @@ class SecurityUtils:
return value
lock = SecurityUtils._get_inflight_lock(hostname)
async with lock:
# 等到锁后再查一次缓存,前一个持锁者可能已经回填结果
hit, value = SecurityUtils._cache_lookup(hostname)
if hit:
SecurityUtils._release_inflight_lock(hostname, lock)
return value
try:
async with lock:
# 等到锁后再查一次缓存,前一个持锁者可能已经回填结果
hit, value = SecurityUtils._cache_lookup(hostname)
if hit:
return value
loop = asyncio.get_running_loop()
try:
address_infos = await loop.getaddrinfo(
hostname, None, type=socket.SOCK_STREAM
)
except socket.gaierror:
SecurityUtils._cache_store(hostname, None)
SecurityUtils._release_inflight_lock(hostname, lock)
return None
addresses = _resolve_addrinfo_to_ips(address_infos)
SecurityUtils._cache_store(hostname, addresses)
SecurityUtils._release_inflight_lock(hostname, lock)
return addresses
loop = asyncio.get_running_loop()
try:
address_infos = await loop.getaddrinfo(
hostname, None, type=socket.SOCK_STREAM
)
except socket.gaierror:
SecurityUtils._cache_store(hostname, None)
return None
addresses = _resolve_addrinfo_to_ips(address_infos)
SecurityUtils._cache_store(hostname, addresses)
return addresses
finally:
# 必须在 `async with` 释放锁之后再清理字典:`_release_inflight_lock`
# 以 `not lock.locked()` 为清理守卫,持锁状态下调用会跳过 pop。
SecurityUtils._release_inflight_lock(hostname, lock)
@staticmethod
def _addresses_all_global(

View File

@@ -550,3 +550,102 @@ class SecurityUtilsTest(TestCase):
self.assertEqual(errors, [])
self.assertTrue(all(results))
self.assertEqual(len(results), 8 * 50)
def test_async_dns_resolution_failure_releases_inflight_lock(self):
"""
DNS 解析失败后 in-flight 锁字典中必须被清理,避免每个解析失败的 hostname
都在 `_dns_inflight_locks` 里残留一把 `asyncio.Lock`。
"""
import asyncio
async def fail_getaddrinfo(*_args, **_kwargs):
raise socket.gaierror()
async def run() -> None:
loop = asyncio.get_running_loop()
with patch.object(loop, "getaddrinfo", side_effect=fail_getaddrinfo):
result = await SecurityUtils._hostname_addresses_async(
"bad-host.example"
)
self.assertIsNone(result)
asyncio.run(run())
self.assertNotIn(
"bad-host.example",
_dns_inflight_locks,
"解析失败路径必须释放 in-flight 锁字典条目",
)
def test_async_dns_resolution_success_releases_inflight_lock(self):
"""
正常解析完成后 in-flight 锁字典也必须被清理,避免 hostname 累积。
"""
import asyncio
async def fake_getaddrinfo(*_args, **_kwargs):
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
]
async def run() -> None:
loop = asyncio.get_running_loop()
with patch.object(loop, "getaddrinfo", side_effect=fake_getaddrinfo):
result = await SecurityUtils._hostname_addresses_async(
"ok-host.example"
)
self.assertIsNotNone(result)
asyncio.run(run())
self.assertNotIn(
"ok-host.example",
_dns_inflight_locks,
"正常解析路径必须释放 in-flight 锁字典条目",
)
def test_async_dns_concurrent_waiters_release_inflight_lock(self):
"""
并发未命中场景下,所有等待者完成后 in-flight 锁字典也必须被清理,
覆盖"等到锁但缓存已被前一个协程回填"的二次返回路径。
"""
import asyncio
async def run() -> None:
loop = asyncio.get_running_loop()
release = asyncio.Event()
async def slow_getaddrinfo(*_args, **_kwargs):
await release.wait()
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
0,
"",
("93.184.216.34", 0),
)
]
with patch.object(loop, "getaddrinfo", side_effect=slow_getaddrinfo):
tasks = [
asyncio.create_task(
SecurityUtils._hostname_addresses_async("multi-host.example")
)
for _ in range(5)
]
await asyncio.sleep(0)
await asyncio.sleep(0)
release.set()
await asyncio.gather(*tasks)
asyncio.run(run())
self.assertNotIn(
"multi-host.example",
_dns_inflight_locks,
"并发等待者全部退出后必须释放 in-flight 锁字典条目",
)