mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-07 16:50:18 +08:00
fix: bound long-lived cache state
This commit is contained in:
@@ -4,7 +4,7 @@ import re
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
@@ -966,6 +966,11 @@ class AgentManager:
|
||||
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||
# 每个会话的worker任务
|
||||
self._session_workers: Dict[str, asyncio.Task] = {}
|
||||
# 每个会话最后活动时间,用于回收空闲 Agent 实例
|
||||
self._session_last_used: Dict[str, tuple[str, datetime]] = {}
|
||||
self._idle_cleanup_task: Optional[asyncio.Task] = None
|
||||
self._idle_session_ttl = timedelta(hours=24)
|
||||
self._idle_cleanup_interval = 60 * 60
|
||||
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话当前模型与 token 使用状态。"""
|
||||
@@ -998,33 +1003,85 @@ class AgentManager:
|
||||
)
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
memory_manager.initialize()
|
||||
if self._idle_cleanup_task and not self._idle_cleanup_task.done():
|
||||
return
|
||||
self._idle_cleanup_task = asyncio.create_task(self._cleanup_idle_sessions())
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
if self._idle_cleanup_task:
|
||||
self._idle_cleanup_task.cancel()
|
||||
try:
|
||||
await self._idle_cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._idle_cleanup_task = None
|
||||
await memory_manager.close()
|
||||
# 取消所有会话worker
|
||||
for task in self._session_workers.values():
|
||||
for task in list(self._session_workers.values()):
|
||||
task.cancel()
|
||||
# 等待所有worker结束
|
||||
for session_id, task in self._session_workers.items():
|
||||
for session_id, task in list(self._session_workers.items()):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._session_workers.clear()
|
||||
self._session_queues.clear()
|
||||
for agent in self.active_agents.values():
|
||||
self._session_last_used.clear()
|
||||
for agent in list(self.active_agents.values()):
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
def _record_session_activity(self, session_id: str, user_id: str) -> None:
|
||||
"""
|
||||
记录会话最近活动时间,供空闲会话清理任务判断是否可释放资源。
|
||||
"""
|
||||
self._session_last_used[session_id] = (user_id, datetime.now())
|
||||
|
||||
def _is_session_busy(self, session_id: str) -> bool:
|
||||
"""
|
||||
判断会话是否仍有正在执行的 worker 或待处理消息,避免误清理活跃会话。
|
||||
"""
|
||||
worker = self._session_workers.get(session_id)
|
||||
if worker and not worker.done():
|
||||
return True
|
||||
queue = self._session_queues.get(session_id)
|
||||
return bool(queue and not queue.empty())
|
||||
|
||||
def _expired_idle_sessions(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
收集已经超过空闲时间且当前不忙的会话。
|
||||
"""
|
||||
expire_before = datetime.now() - self._idle_session_ttl
|
||||
expired = []
|
||||
for session_id, (user_id, last_used) in list(self._session_last_used.items()):
|
||||
if last_used < expire_before and not self._is_session_busy(session_id):
|
||||
expired.append((session_id, user_id))
|
||||
return expired
|
||||
|
||||
async def _cleanup_idle_sessions(self) -> None:
|
||||
"""
|
||||
周期性清理长时间没有新消息的 Agent 会话,避免长期运行后实例持续累积。
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._idle_cleanup_interval)
|
||||
for session_id, user_id in self._expired_idle_sessions():
|
||||
await self.clear_session(session_id=session_id, user_id=user_id)
|
||||
logger.info(f"已清理空闲Agent会话: session_id={session_id}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理空闲Agent会话失败: {e}")
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -1056,6 +1113,7 @@ class AgentManager:
|
||||
original_chat_id=original_chat_id,
|
||||
reply_mode=reply_mode,
|
||||
)
|
||||
self._record_session_activity(session_id, user_id)
|
||||
|
||||
# 获取或创建会话队列
|
||||
if session_id not in self._session_queues:
|
||||
@@ -1221,6 +1279,7 @@ class AgentManager:
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
self._session_last_used.pop(session_id, None)
|
||||
# 取消该会话的worker
|
||||
if session_id in self._session_workers:
|
||||
self._session_workers[session_id].cancel()
|
||||
@@ -1228,7 +1287,7 @@ class AgentManager:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self._session_workers.pop(session_id, None)
|
||||
self._session_workers.pop(session_id, None)
|
||||
|
||||
# 清理队列
|
||||
self._session_queues.pop(session_id, None)
|
||||
|
||||
@@ -105,6 +105,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
_MODELS_DEV_URL = "https://models.dev/api.json"
|
||||
_MODELS_DEV_BUNDLED_PATH = Path(__file__).with_name("models.json")
|
||||
_MODELS_DEV_CACHE_TTL = 7 * 24 * 60 * 60
|
||||
_AUTH_SESSION_DONE_RETENTION = 300
|
||||
_CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
_CHATGPT_ISSUER = "https://auth.openai.com"
|
||||
_CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
@@ -183,6 +184,33 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json"
|
||||
)
|
||||
|
||||
def _cleanup_auth_sessions_locked(self, now: Optional[float] = None) -> None:
|
||||
"""
|
||||
清理过期或已完成一段时间的临时授权会话。
|
||||
|
||||
调用方必须已经持有 `_lock`,这样 `_pending_sessions` 与
|
||||
`_oauth_state_index` 能保持一致,避免 state 残留。
|
||||
"""
|
||||
now = time.time() if now is None else now
|
||||
expired_session_ids = []
|
||||
for session_id, session in self._pending_sessions.items():
|
||||
expires_at = session.expires_at or session.created_at + 600
|
||||
if session.status == "pending":
|
||||
if expires_at <= now:
|
||||
expired_session_ids.append(session_id)
|
||||
elif expires_at + self._AUTH_SESSION_DONE_RETENTION <= now:
|
||||
expired_session_ids.append(session_id)
|
||||
|
||||
if not expired_session_ids:
|
||||
return
|
||||
|
||||
expired_session_ids_set = set(expired_session_ids)
|
||||
for session_id in expired_session_ids:
|
||||
self._pending_sessions.pop(session_id, None)
|
||||
for state, session_id in list(self._oauth_state_index.items()):
|
||||
if session_id in expired_session_ids_set:
|
||||
self._oauth_state_index.pop(state, None)
|
||||
|
||||
@staticmethod
|
||||
def _builtin_provider_specs() -> tuple[ProviderSpec, ...]:
|
||||
"""
|
||||
@@ -2001,6 +2029,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
self._oauth_state_index[state] = session.session_id
|
||||
return {
|
||||
@@ -2035,6 +2064,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
@@ -2073,6 +2103,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
self._pending_sessions[session.session_id] = session
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
@@ -2089,6 +2120,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""读取临时授权会话状态。"""
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session = self._pending_sessions.get(session_id)
|
||||
if not session:
|
||||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||||
@@ -2135,6 +2167,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
if error:
|
||||
message = error_description or error
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session_id = self._oauth_state_index.pop(state or "", None)
|
||||
if session_id and session_id in self._pending_sessions:
|
||||
self._mark_session_error(self._pending_sessions[session_id], message)
|
||||
@@ -2144,6 +2177,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return False, "缺少授权码或 state 参数"
|
||||
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session_id = self._oauth_state_index.pop(state, None)
|
||||
session = self._pending_sessions.get(session_id or "")
|
||||
|
||||
@@ -2186,6 +2220,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_auth_sessions_locked()
|
||||
session = self._pending_sessions.get(session_id)
|
||||
if not session:
|
||||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||||
|
||||
@@ -27,6 +27,8 @@ class MemoryManager:
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
return
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_memories()
|
||||
@@ -46,6 +48,7 @@ class MemoryManager:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.cleanup_task = None
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.db.models import User
|
||||
from app.db.models.message import Message
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.helper.webpush import is_webpush_subscription_gone
|
||||
from app.log import logger
|
||||
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
from app.schemas.types import MessageChannel
|
||||
@@ -218,8 +219,7 @@ async def subscribe(
|
||||
客户端webpush通知订阅
|
||||
"""
|
||||
subinfo = subscription.model_dump()
|
||||
if subinfo not in global_vars.get_subscriptions():
|
||||
global_vars.push_subscription(subinfo)
|
||||
global_vars.push_subscription(subinfo)
|
||||
logger.debug(f"通知订阅成功: {subinfo}")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -244,5 +244,7 @@ def send_notification(
|
||||
)
|
||||
except WebPushException as err:
|
||||
logger.error(f"WebPush发送失败: {str(err)}")
|
||||
if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub):
|
||||
logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}")
|
||||
continue
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -84,13 +84,12 @@ class ScrapingOption:
|
||||
class ScrapingConfig:
|
||||
"""媒体刮削配置"""
|
||||
|
||||
_policies: dict[tuple[str], ScrapingOption] = {}
|
||||
|
||||
def __init__(self, config_dict: dict[str, str] = None):
|
||||
"""
|
||||
初始化配置对象
|
||||
:param config_dict: 用户配置字典(扁平化格式),为 None 时使用默认配置
|
||||
"""
|
||||
self._policies: dict[tuple[str, str], ScrapingOption] = {}
|
||||
# 合并用户配置和默认配置
|
||||
if config_dict is None:
|
||||
config_dict = {}
|
||||
|
||||
@@ -47,6 +47,36 @@ class MessageChain(ChainBase):
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 24 * 60
|
||||
|
||||
@staticmethod
|
||||
def _schedule_agent_session_clear(session_id: str, userid: Union[str, int]) -> None:
|
||||
"""
|
||||
异步调度 Agent 会话清理,避免同步消息链阻塞在模型资源释放上。
|
||||
"""
|
||||
if not session_id:
|
||||
return
|
||||
clear_task = None
|
||||
try:
|
||||
clear_task = agent_manager.clear_session(session_id=session_id, user_id=str(userid))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
clear_task,
|
||||
global_vars.loop,
|
||||
)
|
||||
except Exception as e:
|
||||
if clear_task:
|
||||
clear_task.close()
|
||||
logger.warning(f"调度清理智能体会话失败: {e}")
|
||||
|
||||
def _cleanup_expired_user_sessions(self, current_time: datetime) -> None:
|
||||
"""
|
||||
清理超过复用窗口的用户会话映射,并同步释放旧 Agent 实例。
|
||||
"""
|
||||
timeout = timedelta(minutes=self._session_timeout_minutes)
|
||||
for userid, (session_id, last_time) in list(self._user_sessions.items()):
|
||||
if current_time - last_time <= timeout:
|
||||
continue
|
||||
self._user_sessions.pop(userid, None)
|
||||
self._schedule_agent_session_clear(session_id, userid)
|
||||
|
||||
@dataclass
|
||||
class _ProcessingStatus:
|
||||
channel: MessageChannel
|
||||
@@ -919,6 +949,7 @@ class MessageChain(ChainBase):
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
self._cleanup_expired_user_sessions(current_time)
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in self._user_sessions:
|
||||
@@ -946,6 +977,9 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
将用户会话绑定到指定的 session_id,并刷新最后活动时间。
|
||||
"""
|
||||
old_session = self._user_sessions.get(userid)
|
||||
if old_session and old_session[0] != session_id:
|
||||
self._schedule_agent_session_clear(old_session[0], userid)
|
||||
self._user_sessions[userid] = (session_id, datetime.now())
|
||||
|
||||
def _record_user_message(
|
||||
@@ -1005,14 +1039,18 @@ class MessageChain(ChainBase):
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
if session_id:
|
||||
clear_task = None
|
||||
try:
|
||||
clear_task = agent_manager.clear_session(
|
||||
session_id=session_id, user_id=str(userid)
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id, user_id=str(userid)
|
||||
),
|
||||
clear_task,
|
||||
global_vars.loop,
|
||||
)
|
||||
except Exception as e:
|
||||
if clear_task:
|
||||
clear_task.close()
|
||||
logger.warning(f"清除智能体会话记忆失败: {e}")
|
||||
|
||||
self.post_message(
|
||||
|
||||
@@ -1130,6 +1130,8 @@ class GlobalVar(object):
|
||||
STOP_EVENT: threading.Event = threading.Event()
|
||||
# webpush订阅
|
||||
SUBSCRIPTIONS: List[dict] = []
|
||||
# webpush订阅读写锁
|
||||
SUBSCRIPTIONS_LOCK: threading.Lock = threading.Lock()
|
||||
# 需应急停止的工作流
|
||||
EMERGENCY_STOP_WORKFLOWS: List[int] = []
|
||||
# 需应急停止文件整理
|
||||
@@ -1169,13 +1171,37 @@ class GlobalVar(object):
|
||||
"""
|
||||
获取webpush订阅
|
||||
"""
|
||||
return self.SUBSCRIPTIONS
|
||||
with self.SUBSCRIPTIONS_LOCK:
|
||||
return list(self.SUBSCRIPTIONS)
|
||||
|
||||
def push_subscription(self, subscription: dict):
|
||||
"""
|
||||
添加webpush订阅
|
||||
添加或更新webpush订阅。
|
||||
"""
|
||||
self.SUBSCRIPTIONS.append(subscription)
|
||||
endpoint = subscription.get("endpoint") if subscription else None
|
||||
if not endpoint:
|
||||
return
|
||||
with self.SUBSCRIPTIONS_LOCK:
|
||||
for index, current in enumerate(self.SUBSCRIPTIONS):
|
||||
if current.get("endpoint") == endpoint:
|
||||
self.SUBSCRIPTIONS[index] = subscription
|
||||
return
|
||||
self.SUBSCRIPTIONS.append(subscription)
|
||||
|
||||
def remove_subscription(self, subscription: dict) -> bool:
|
||||
"""
|
||||
根据 endpoint 移除webpush订阅,返回是否实际删除。
|
||||
"""
|
||||
endpoint = subscription.get("endpoint") if subscription else None
|
||||
if not endpoint:
|
||||
return False
|
||||
with self.SUBSCRIPTIONS_LOCK:
|
||||
before_count = len(self.SUBSCRIPTIONS)
|
||||
self.SUBSCRIPTIONS[:] = [
|
||||
current for current in self.SUBSCRIPTIONS
|
||||
if current.get("endpoint") != endpoint
|
||||
]
|
||||
return len(self.SUBSCRIPTIONS) != before_count
|
||||
|
||||
def stop_workflow(self, workflow_id: int):
|
||||
"""
|
||||
|
||||
12
app/helper/webpush.py
Normal file
12
app/helper/webpush.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from pywebpush import WebPushException
|
||||
|
||||
|
||||
def is_webpush_subscription_gone(error: WebPushException) -> bool:
|
||||
"""
|
||||
判断 WebPush 订阅是否已经在浏览器或推送服务侧失效。
|
||||
"""
|
||||
response: Any = getattr(error, "response", None)
|
||||
status_code = getattr(response, "status_code", None) or getattr(response, "status", None)
|
||||
return status_code in {404, 410}
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
@@ -13,10 +14,24 @@ from app.schemas.types import StorageSchema
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
_folder_locks: dict[str, threading.Lock] = {}
|
||||
_MAX_FOLDER_LOCKS = 4096
|
||||
_folder_locks: OrderedDict[str, threading.Lock] = OrderedDict()
|
||||
_folder_locks_guard = threading.Lock()
|
||||
|
||||
|
||||
def _evict_unused_folder_locks_locked() -> None:
|
||||
"""
|
||||
在持有全局锁表互斥锁时淘汰旧路径锁,避免大量不同目录导致锁表无限增长。
|
||||
"""
|
||||
while len(_folder_locks) >= _MAX_FOLDER_LOCKS:
|
||||
for key, lock in list(_folder_locks.items()):
|
||||
if not lock.locked():
|
||||
_folder_locks.pop(key, None)
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class Rclone(StorageBase):
|
||||
"""
|
||||
rclone相关操作
|
||||
@@ -144,9 +159,14 @@ class Rclone(StorageBase):
|
||||
"""
|
||||
normalized = Rclone.__normalize_remote_path(path)
|
||||
with _folder_locks_guard:
|
||||
if normalized not in _folder_locks:
|
||||
_folder_locks[normalized] = threading.Lock()
|
||||
return _folder_locks[normalized]
|
||||
lock = _folder_locks.get(normalized)
|
||||
if lock:
|
||||
_folder_locks.move_to_end(normalized)
|
||||
return lock
|
||||
_evict_unused_folder_locks_locked()
|
||||
lock = threading.Lock()
|
||||
_folder_locks[normalized] = lock
|
||||
return lock
|
||||
|
||||
def __wait_for_item(
|
||||
self, path: Path, retries: int = 3, delay: float = 0.2
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Union, Tuple
|
||||
from pywebpush import webpush, WebPushException
|
||||
|
||||
from app.core.config import global_vars, settings
|
||||
from app.helper.webpush import is_webpush_subscription_gone
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import Notification
|
||||
@@ -97,6 +98,8 @@ class WebPushModule(_ModuleBase, _MessageBase):
|
||||
)
|
||||
except WebPushException as err:
|
||||
logger.error(f"WebPush发送失败: {str(err)}")
|
||||
if is_webpush_subscription_gone(err) and global_vars.remove_subscription(sub):
|
||||
logger.info(f"已移除失效WebPush订阅: {sub.get('endpoint')}")
|
||||
|
||||
except Exception as msg_e:
|
||||
logger.error(f"发送消息失败:{msg_e}")
|
||||
|
||||
Reference in New Issue
Block a user