mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-28 03:02:34 +08:00
fix 协程环境中调用插件同步函数处理
This commit is contained in:
@@ -220,7 +220,7 @@ async def detail(mediaid: str, type_name: str, title: Optional[str] = None, year
|
||||
mediaid=mediaid,
|
||||
convert_type=settings.RECOGNIZE_SOURCE
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
# 使用事件返回的上下文数据
|
||||
if event and event.event_data and event.event_data.media_dict:
|
||||
event_data: MediaRecognizeConvertEventData = event.event_data
|
||||
|
||||
@@ -145,7 +145,7 @@ async def update_subscribe_status(
|
||||
"state": state
|
||||
})
|
||||
# 发送订阅调整事件
|
||||
eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe.id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": subscribe.to_dict(),
|
||||
@@ -224,7 +224,7 @@ async def reset_subscribes(
|
||||
"state": "R"
|
||||
})
|
||||
# 发送订阅调整事件
|
||||
eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe.id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": subscribe.to_dict(),
|
||||
@@ -313,7 +313,7 @@ async def delete_subscribe_by_mediaid(
|
||||
for subscribe in delete_subscribes:
|
||||
await Subscribe.async_delete(db, subscribe.id)
|
||||
# 发送事件
|
||||
eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe.id,
|
||||
"subscribe_info": subscribe.to_dict()
|
||||
})
|
||||
@@ -596,7 +596,7 @@ async def delete_subscribe(
|
||||
if subscribe:
|
||||
await Subscribe.async_delete(db, subscribe_id)
|
||||
# 发送事件
|
||||
eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe.to_dict()
|
||||
})
|
||||
|
||||
@@ -7,6 +7,8 @@ from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Tuple, List, Set, Union, Dict
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
import aiofiles
|
||||
from aiopath import AsyncPath
|
||||
from qbittorrentapi import TorrentFilesList
|
||||
@@ -216,13 +218,15 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
result = await run_in_threadpool(func, *args, **kwargs)
|
||||
elif isinstance(result, list):
|
||||
# 返回为列表,有多个模块运行结果时进行合并
|
||||
if inspect.iscoroutinefunction(func):
|
||||
temp = await func(*args, **kwargs)
|
||||
else:
|
||||
temp = func(*args, **kwargs)
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
temp = await run_in_threadpool(func, *args, **kwargs)
|
||||
if isinstance(temp, list):
|
||||
result.extend(temp)
|
||||
else:
|
||||
|
||||
@@ -533,6 +533,7 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 512, ttl: Opti
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
# 异步函数的缓存装饰器
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# 获取缓存键
|
||||
@@ -554,13 +555,13 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 512, ttl: Opti
|
||||
"""
|
||||
清理缓存区
|
||||
"""
|
||||
# 清理缓存区
|
||||
cache_backend.clear(region=cache_region)
|
||||
|
||||
async_wrapper.cache_region = cache_region
|
||||
async_wrapper.cache_clear = cache_clear
|
||||
return async_wrapper
|
||||
else:
|
||||
# 同步函数的缓存装饰器
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# 获取缓存键
|
||||
@@ -582,7 +583,6 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 512, ttl: Opti
|
||||
"""
|
||||
清理缓存区
|
||||
"""
|
||||
# 清理缓存区
|
||||
cache_backend.clear(region=cache_region)
|
||||
|
||||
wrapper.cache_region = cache_region
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import importlib
|
||||
import inspect
|
||||
import random
|
||||
@@ -7,7 +6,9 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
@@ -440,12 +441,8 @@ class EventManager(metaclass=Singleton):
|
||||
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
|
||||
return
|
||||
|
||||
# 根据事件类型判断是否需要深复制
|
||||
is_broadcast_event = isinstance(event.event_type, EventType)
|
||||
event_to_process = copy.deepcopy(event) if is_broadcast_event else event
|
||||
|
||||
try:
|
||||
self.__invoke_handler_by_type_sync(handler, event_to_process, is_broadcast_event)
|
||||
self.__invoke_handler_by_type_sync(handler, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event, handler, e)
|
||||
|
||||
@@ -459,50 +456,67 @@ class EventManager(metaclass=Singleton):
|
||||
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
|
||||
return
|
||||
|
||||
# 链式事件不需要深复制
|
||||
event_to_process = event
|
||||
|
||||
try:
|
||||
await self.__invoke_handler_by_type_async(handler, event_to_process)
|
||||
await self.__invoke_handler_by_type_async(handler, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event, handler, e)
|
||||
|
||||
def __invoke_handler_by_type_sync(self, handler: Callable, event_to_process: Event, is_broadcast_event: bool):
|
||||
def __invoke_handler_by_type_sync(self, handler: Callable, event: Event):
|
||||
"""
|
||||
同步方式根据处理器类型调用相应的方法
|
||||
:param handler: 处理器
|
||||
:param event_to_process: 要处理的事件对象
|
||||
:param is_broadcast_event: 是否为广播事件
|
||||
:param event: 要处理的事件对象
|
||||
"""
|
||||
class_name, method_name = self.__parse_handler_names(handler)
|
||||
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.module import ModuleManager
|
||||
|
||||
if class_name in PluginManager().get_plugin_ids():
|
||||
self.__invoke_plugin_method_sync(class_name, method_name, event_to_process, is_broadcast_event)
|
||||
elif class_name in ModuleManager().get_module_ids():
|
||||
self.__invoke_module_method_sync(class_name, method_name, event_to_process, is_broadcast_event)
|
||||
else:
|
||||
self.__invoke_global_method_sync(class_name, method_name, event_to_process, is_broadcast_event)
|
||||
plugin_manager = PluginManager()
|
||||
module_manager = ModuleManager()
|
||||
|
||||
async def __invoke_handler_by_type_async(self, handler: Callable, event_to_process: Event):
|
||||
if class_name in plugin_manager.get_plugin_ids():
|
||||
# 插件处理器
|
||||
plugin_manager.run_plugin_method(class_name, method_name, event)
|
||||
elif class_name in module_manager.get_module_ids():
|
||||
# 模块处理器
|
||||
module = module_manager.get_running_module(class_name)
|
||||
if not module:
|
||||
return
|
||||
method = getattr(module, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
method(event)
|
||||
else:
|
||||
# 全局处理器
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if not class_obj or not hasattr(class_obj, method_name):
|
||||
return
|
||||
method = getattr(class_obj, method_name)
|
||||
if not method:
|
||||
return
|
||||
method(event)
|
||||
|
||||
async def __invoke_handler_by_type_async(self, handler: Callable, event: Event):
|
||||
"""
|
||||
异步方式根据处理器类型调用相应的方法
|
||||
:param handler: 处理器
|
||||
:param event_to_process: 要处理的事件对象
|
||||
:param event: 要处理的事件对象
|
||||
"""
|
||||
class_name, method_name = self.__parse_handler_names(handler)
|
||||
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.module import ModuleManager
|
||||
|
||||
if class_name in PluginManager().get_plugin_ids():
|
||||
await self.__invoke_plugin_method_async(class_name, method_name, event_to_process)
|
||||
elif class_name in ModuleManager().get_module_ids():
|
||||
await self.__invoke_module_method_async(class_name, method_name, event_to_process)
|
||||
plugin_manager = PluginManager()
|
||||
module_manager = ModuleManager()
|
||||
|
||||
if class_name in plugin_manager.get_plugin_ids():
|
||||
await self.__invoke_plugin_method_async(plugin_manager, class_name, method_name, event)
|
||||
elif class_name in module_manager.get_module_ids():
|
||||
await self.__invoke_module_method_async(module_manager, class_name, method_name, event)
|
||||
else:
|
||||
await self.__invoke_global_method_async(class_name, method_name, event_to_process)
|
||||
await self.__invoke_global_method_async(class_name, method_name, event)
|
||||
|
||||
@staticmethod
|
||||
def __parse_handler_names(handler: Callable) -> Tuple[str, str]:
|
||||
@@ -514,65 +528,26 @@ class EventManager(metaclass=Singleton):
|
||||
names = handler.__qualname__.split(".")
|
||||
return names[0], names[1]
|
||||
|
||||
def __invoke_plugin_method_sync(self, class_name: str, method_name: str, event_to_process: Event,
|
||||
is_broadcast_event: bool):
|
||||
"""
|
||||
同步调用插件方法
|
||||
"""
|
||||
from app.core.plugin import PluginManager
|
||||
|
||||
def plugin_callable():
|
||||
PluginManager().run_plugin_method(class_name, method_name, event_to_process)
|
||||
|
||||
if is_broadcast_event:
|
||||
self.__executor.submit(plugin_callable)
|
||||
else:
|
||||
plugin_callable()
|
||||
|
||||
@staticmethod
|
||||
async def __invoke_plugin_method_async(class_name: str, method_name: str, event_to_process: Event):
|
||||
async def __invoke_plugin_method_async(handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用插件方法
|
||||
"""
|
||||
from app.core.plugin import PluginManager
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
plugin = plugin_manager.running_plugins.get(class_name)
|
||||
plugin = handler.running_plugins.get(class_name)
|
||||
if plugin and hasattr(plugin, method_name):
|
||||
method = getattr(plugin, method_name)
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event_to_process)
|
||||
await method(event)
|
||||
else:
|
||||
plugin_manager.run_plugin_method(class_name, method_name, event_to_process)
|
||||
|
||||
def __invoke_module_method_sync(self, class_name: str, method_name: str, event_to_process: Event,
|
||||
is_broadcast_event: bool):
|
||||
"""
|
||||
同步调用模块方法
|
||||
"""
|
||||
from app.core.module import ModuleManager
|
||||
|
||||
module = ModuleManager().get_running_module(class_name)
|
||||
if not module:
|
||||
return
|
||||
|
||||
method = getattr(module, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
|
||||
if is_broadcast_event:
|
||||
self.__executor.submit(method, event_to_process)
|
||||
else:
|
||||
method(event_to_process)
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
await run_in_threadpool(method, event)
|
||||
|
||||
@staticmethod
|
||||
async def __invoke_module_method_async(class_name: str, method_name: str, event_to_process: Event):
|
||||
async def __invoke_module_method_async(handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用模块方法
|
||||
"""
|
||||
from app.core.module import ModuleManager
|
||||
|
||||
module = ModuleManager().get_running_module(class_name)
|
||||
module = handler.get_running_module(class_name)
|
||||
if not module:
|
||||
return
|
||||
|
||||
@@ -581,27 +556,11 @@ class EventManager(metaclass=Singleton):
|
||||
return
|
||||
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event_to_process)
|
||||
await method(event)
|
||||
else:
|
||||
method(event_to_process)
|
||||
method(event)
|
||||
|
||||
def __invoke_global_method_sync(self, class_name: str, method_name: str, event_to_process: Event,
|
||||
is_broadcast_event: bool):
|
||||
"""
|
||||
同步调用全局对象方法
|
||||
"""
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if not class_obj or not hasattr(class_obj, method_name):
|
||||
return
|
||||
|
||||
method = getattr(class_obj, method_name)
|
||||
|
||||
if is_broadcast_event:
|
||||
self.__executor.submit(method, event_to_process)
|
||||
else:
|
||||
method(event_to_process)
|
||||
|
||||
async def __invoke_global_method_async(self, class_name: str, method_name: str, event_to_process: Event):
|
||||
async def __invoke_global_method_async(self, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用全局对象方法
|
||||
"""
|
||||
@@ -612,9 +571,9 @@ class EventManager(metaclass=Singleton):
|
||||
method = getattr(class_obj, method_name)
|
||||
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event_to_process)
|
||||
await method(event)
|
||||
else:
|
||||
method(event_to_process)
|
||||
method(event)
|
||||
|
||||
@staticmethod
|
||||
def __get_class_instance(class_name: str):
|
||||
|
||||
Reference in New Issue
Block a user