fix: bound long-lived cache state

This commit is contained in:
jxxghp
2026-05-24 18:03:42 +08:00
parent dc73d61682
commit 79539760da
15 changed files with 380 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
import asyncio
import unittest
from datetime import datetime
from datetime import datetime, timedelta
from types import SimpleNamespace
from unittest.mock import patch
@@ -8,11 +8,20 @@ from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage
from app.agent.middleware.usage import UsageMiddleware
from app.agent import AgentManager
from app.chain.message import MessageChain
from app.schemas.types import MessageChannel
class TestAgentSessionStatus(unittest.TestCase):
def setUp(self):
"""清理跨用例共享的用户会话状态。"""
MessageChain._user_sessions.clear()
def tearDown(self):
"""清理测试产生的用户会话状态。"""
MessageChain._user_sessions.clear()
def test_usage_middleware_records_usage_metadata(self):
snapshots = []
middleware = UsageMiddleware(on_usage=snapshots.append)
@@ -104,3 +113,34 @@ class TestAgentSessionStatus(unittest.TestCase):
notification = post_message.call_args.args[0]
self.assertEqual(notification.title, "您当前没有活跃的智能体会话")
def test_get_or_create_session_cleans_expired_session(self):
"""用户会话超过复用窗口时应调度清理旧 Agent 会话。"""
chain = MessageChain()
chain._user_sessions.clear()
chain._user_sessions["10001"] = (
"old-session",
datetime.now() - timedelta(minutes=chain._session_timeout_minutes + 1),
)
with patch.object(chain, "_schedule_agent_session_clear") as clear_session:
session_id = chain._get_or_create_session_id("10001")
self.assertNotEqual(session_id, "old-session")
self.assertEqual(chain._user_sessions["10001"][0], session_id)
clear_session.assert_called_once_with("old-session", "10001")
def test_agent_manager_collects_idle_sessions(self):
"""Agent 管理器应只回收超过空闲窗口且未忙碌的会话。"""
manager = AgentManager()
manager._idle_session_ttl = timedelta(seconds=1)
manager._session_last_used["idle-session"] = (
"10001",
datetime.now() - timedelta(seconds=2),
)
manager._session_last_used["fresh-session"] = ("10002", datetime.now())
self.assertEqual(
[("idle-session", "10001")],
manager._expired_idle_sessions(),
)

View File

@@ -39,11 +39,23 @@ _stub_module(
TEMP_PATH="/tmp",
PROXY_HOST=None,
LLM_MAX_CONTEXT_TOKENS=64,
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME=True,
RMT_MEDIAEXT=[".mkv", ".mp4"],
RMT_SUBEXT=[".srt"],
RMT_AUDIOEXT=[".flac"],
),
)
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_DummySystemConfigOper)
_stub_module("app.log", logger=_DummyLogger())
_stub_module("app.schemas.types", SystemConfigKey=SimpleNamespace(AIAgentConfig="agent"))
_stub_module(
"app.schemas.types",
SystemConfigKey=SimpleNamespace(
AIAgentConfig="agent",
CustomReleaseGroups="custom_release_groups",
Customization="customization",
CustomIdentifiers="custom_identifiers",
),
)
provider_path = Path(__file__).resolve().parents[1] / "app" / "agent" / "llm" / "provider.py"
spec = importlib.util.spec_from_file_location("test_llm_provider_module", provider_path)
@@ -54,6 +66,7 @@ spec.loader.exec_module(provider_module)
LLMProviderError = provider_module.LLMProviderError
LLMProviderManager = provider_module.LLMProviderManager
PendingAuthSession = provider_module.PendingAuthSession
class LlmProviderRegistryTest(unittest.TestCase):
@@ -612,6 +625,24 @@ class LlmProviderRegistryTest(unittest.TestCase):
self.assertEqual(models, [])
def test_expired_auth_session_cleanup_removes_state_index(self):
"""过期授权会话应同时移除 session 与 OAuth state 索引。"""
manager = LLMProviderManager()
manager._pending_sessions["session-old"] = PendingAuthSession(
session_id="session-old",
provider_id="chatgpt",
method_id="browser_oauth",
flow_type="oauth",
expires_at=100,
)
manager._oauth_state_index["state-old"] = "session-old"
with manager._lock:
manager._cleanup_auth_sessions_locked(now=101)
self.assertNotIn("session-old", manager._pending_sessions)
self.assertNotIn("state-old", manager._oauth_state_index)
if __name__ == "__main__":
unittest.main()

View File

@@ -8,7 +8,7 @@ sys.modules['app.db.systemconfig_oper'] = MagicMock()
sys.modules['app.db.systemconfig_oper'].SystemConfigOper.return_value.get.return_value = None
from app import schemas
from app.chain.media import MediaChain, ScrapingOption
from app.chain.media import MediaChain, ScrapingConfig, ScrapingOption
from app.core.context import MediaInfo
from app.core.event import Event
from app.core.metainfo import MetaInfo
@@ -42,6 +42,20 @@ class TestMediaScrapingPaths(unittest.TestCase):
self.assertEqual(target_item, parent_item)
self.assertEqual(target_path, Path("/movies/avatar.nfo"))
def test_scraping_config_does_not_share_policy_state_between_instances(self):
"""刮削配置实例之间不应共享已删除或覆盖过的策略。"""
first_config = ScrapingConfig({"movie_nfo": ScrapingPolicy.SKIP})
second_config = ScrapingConfig({})
self.assertEqual(
ScrapingPolicy.SKIP,
first_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy,
)
self.assertEqual(
ScrapingPolicy.MISSINGONLY,
second_config.option(ScrapingTarget.MOVIE, ScrapingMetadata.NFO).policy,
)
def test_movie_dir_nfo_path(self):
fileitem = schemas.FileItem(path="/movies/Avatar (2009)", name="Avatar (2009)", type="dir", storage="local")

View File

@@ -200,6 +200,19 @@ class RcloneStorageTest(unittest.TestCase):
self.assertEqual("/Show/", folder.path)
run_mock.assert_called_once()
def test_folder_lock_table_evicts_old_unlocked_paths(self):
"""路径锁表超过上限时应优先淘汰未占用的旧锁。"""
with patch.object(rclone_module, "_MAX_FOLDER_LOCKS", 2):
first_lock = Rclone._Rclone__get_path_lock(Path("/A"))
second_lock = Rclone._Rclone__get_path_lock(Path("/B"))
third_lock = Rclone._Rclone__get_path_lock(Path("/C"))
self.assertNotIn("/A", rclone_module._folder_locks)
self.assertIn("/B", rclone_module._folder_locks)
self.assertIn("/C", rclone_module._folder_locks)
self.assertIsNot(first_lock, third_lock)
self.assertIs(second_lock, rclone_module._folder_locks["/B"])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,61 @@
import unittest
from types import SimpleNamespace
from app.core.config import global_vars
from app.helper.webpush import is_webpush_subscription_gone
class WebPushSubscriptionTest(unittest.TestCase):
def setUp(self):
"""清理跨用例共享的 WebPush 订阅。"""
with global_vars.SUBSCRIPTIONS_LOCK:
global_vars.SUBSCRIPTIONS.clear()
def tearDown(self):
"""清理测试产生的 WebPush 订阅。"""
with global_vars.SUBSCRIPTIONS_LOCK:
global_vars.SUBSCRIPTIONS.clear()
def test_push_subscription_upserts_by_endpoint(self):
"""相同 endpoint 的 WebPush 订阅应更新而不是重复追加。"""
global_vars.push_subscription(
{"endpoint": "https://push.example/a", "keys": {"p256dh": "old"}}
)
global_vars.push_subscription(
{"endpoint": "https://push.example/a", "keys": {"p256dh": "new"}}
)
subscriptions = global_vars.get_subscriptions()
self.assertEqual(1, len(subscriptions))
self.assertEqual("new", subscriptions[0]["keys"]["p256dh"])
def test_remove_subscription_deletes_by_endpoint(self):
"""失效订阅应能按 endpoint 从全局订阅表删除。"""
subscription = {"endpoint": "https://push.example/a", "keys": {}}
global_vars.push_subscription(subscription)
self.assertTrue(global_vars.remove_subscription(subscription))
self.assertEqual([], global_vars.get_subscriptions())
def test_is_webpush_subscription_gone_matches_404_and_410(self):
"""推送服务返回 404/410 时应识别为订阅已失效。"""
self.assertTrue(
is_webpush_subscription_gone(
SimpleNamespace(response=SimpleNamespace(status_code=410))
)
)
self.assertTrue(
is_webpush_subscription_gone(
SimpleNamespace(response=SimpleNamespace(status=404))
)
)
self.assertFalse(
is_webpush_subscription_gone(
SimpleNamespace(response=SimpleNamespace(status_code=500))
)
)
if __name__ == "__main__":
unittest.main()