mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-25 09:34:19 +08:00
Compare commits
209 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c055e2482 | ||
|
|
28718094e4 | ||
|
|
9b23265c3b | ||
|
|
9f61bce039 | ||
|
|
1f49f9b454 | ||
|
|
51229204c9 | ||
|
|
2831eecbeb | ||
|
|
b2a18f9ae4 | ||
|
|
5a06e7b8bc | ||
|
|
f303d9e576 | ||
|
|
b76c4edc4a | ||
|
|
41da9b62c2 | ||
|
|
9128955bf9 | ||
|
|
f50773711e | ||
|
|
23784f614b | ||
|
|
7b27b7fd16 | ||
|
|
6834d8b2c7 | ||
|
|
4322f8a3c1 | ||
|
|
0f3a4e4c15 | ||
|
|
f4423e121e | ||
|
|
e5b67438d9 | ||
|
|
7b1ece8b83 | ||
|
|
6d5cda5d51 | ||
|
|
1af3a0ef59 | ||
|
|
5a585839ba | ||
|
|
fcf6e14ac9 | ||
|
|
0959c4ace4 | ||
|
|
f0bc1bd681 | ||
|
|
f8d096f476 | ||
|
|
b24127e66f | ||
|
|
35eb8c51a9 | ||
|
|
669ca713cf | ||
|
|
f2fd28bf4d | ||
|
|
3852c0e43e | ||
|
|
6fb6996d81 | ||
|
|
4c16704ca2 | ||
|
|
f017eaedcc | ||
|
|
19526146c5 | ||
|
|
7b4cb2097b | ||
|
|
b6062a9ce2 | ||
|
|
ea8a90aa0a | ||
|
|
fa939dfbe6 | ||
|
|
77aa65bfdc | ||
|
|
d86d24fc4f | ||
|
|
0989439d25 | ||
|
|
a46ce24691 | ||
|
|
57bb67e547 | ||
|
|
5e5c257b75 | ||
|
|
624862dfc6 | ||
|
|
b172a6d08f | ||
|
|
116465b6d8 | ||
|
|
cfb6448060 | ||
|
|
10a9e7293a | ||
|
|
fc2c77fbf1 | ||
|
|
e4721fef0c | ||
|
|
2c45831714 | ||
|
|
9068280f6d | ||
|
|
ea88f272a6 | ||
|
|
ac090af606 | ||
|
|
1c17c0b07e | ||
|
|
db6321d032 | ||
|
|
d6270dfb81 | ||
|
|
cc52bdaaf3 | ||
|
|
cbc8592b49 | ||
|
|
14d648445e | ||
|
|
87777343d2 | ||
|
|
26aa49f323 | ||
|
|
ad8b6473fc | ||
|
|
c32df7446d | ||
|
|
05b34b9c26 | ||
|
|
99fbeecfa1 | ||
|
|
41477601c7 | ||
|
|
a6ab9b76c1 | ||
|
|
a62b6b6fd5 | ||
|
|
75a52ad751 | ||
|
|
a2fa8d6f28 | ||
|
|
ed9116d81e | ||
|
|
6db1dd2067 | ||
|
|
0fb11880a4 | ||
|
|
b7fc5b0203 | ||
|
|
1b2433f7c2 | ||
|
|
c745616495 | ||
|
|
888ccfcfc2 | ||
|
|
3c9228c2f8 | ||
|
|
3776422634 | ||
|
|
5021b2c86f | ||
|
|
412e10972f | ||
|
|
d0b1b3d7f0 | ||
|
|
f5fea25b41 | ||
|
|
68706d3d5b | ||
|
|
b768ed8fed | ||
|
|
c4d3d28491 | ||
|
|
1862a7ab4b | ||
|
|
adb7aa6aa9 | ||
|
|
79eb128196 | ||
|
|
4d132c424a | ||
|
|
c52327c248 | ||
|
|
1d97f2e043 | ||
|
|
ee9ea54ab7 | ||
|
|
4027ae2641 | ||
|
|
bc6c61bc45 | ||
|
|
cd5e693302 | ||
|
|
ac11b303b3 | ||
|
|
a7823fb4d1 | ||
|
|
45d47d32f8 | ||
|
|
893b8eba86 | ||
|
|
f9b987c3ef | ||
|
|
4ef8b0ba99 | ||
|
|
268414fb11 | ||
|
|
bedab9ab92 | ||
|
|
94d7e4385e | ||
|
|
64b4de3900 | ||
|
|
a59afe4cc9 | ||
|
|
7b6047accf | ||
|
|
e217d1aa05 | ||
|
|
52e15b51db | ||
|
|
0dab3f087d | ||
|
|
e4c5a4f232 | ||
|
|
a729307d30 | ||
|
|
98347669ea | ||
|
|
9e4020c617 | ||
|
|
2f231fe632 | ||
|
|
14b366a648 | ||
|
|
0a0d5e6da2 | ||
|
|
3dbb68627f | ||
|
|
f157b61dfa | ||
|
|
44f975baf4 | ||
|
|
28ec4a6ac0 | ||
|
|
1140a85402 | ||
|
|
c6d95cd006 | ||
|
|
c9931aa948 | ||
|
|
ec4f13dd79 | ||
|
|
d43ef610c7 | ||
|
|
05d720d81f | ||
|
|
2d2c2a01eb | ||
|
|
226f9c9318 | ||
|
|
b77b5a21c5 | ||
|
|
82b637532e | ||
|
|
c2c9950bb1 | ||
|
|
ffbe348d66 | ||
|
|
6d7b0733af | ||
|
|
49a51cca25 | ||
|
|
06197144c0 | ||
|
|
62541ffe43 | ||
|
|
c762628217 | ||
|
|
caf615f3bd | ||
|
|
27436757a0 | ||
|
|
924d54dfd3 | ||
|
|
39f9550f86 | ||
|
|
367ecafbbb | ||
|
|
10467244e0 | ||
|
|
cb6dcc6a2e | ||
|
|
43c421b0bb | ||
|
|
45d0891502 | ||
|
|
76c5f54465 | ||
|
|
bcf8116172 | ||
|
|
1f889596b7 | ||
|
|
04443fcfba | ||
|
|
5d7a7fd301 | ||
|
|
4d0a722b09 | ||
|
|
db6dc926cf | ||
|
|
4bb4f5aeb5 | ||
|
|
58e25fe900 | ||
|
|
03f6b9bc96 | ||
|
|
6fdda3a570 | ||
|
|
100eaec38f | ||
|
|
b129508304 | ||
|
|
53bf81aede | ||
|
|
afcc071d07 | ||
|
|
2ea617655c | ||
|
|
0583495548 | ||
|
|
516aea6312 | ||
|
|
2d412cae1c | ||
|
|
45f5326fb4 | ||
|
|
2ccea2da39 | ||
|
|
53f6897d62 | ||
|
|
28a2386f2f | ||
|
|
abda9d3212 | ||
|
|
34e7c4ac14 | ||
|
|
b228107a25 | ||
|
|
2375508616 | ||
|
|
baebd0ed1a | ||
|
|
6532c60a3c | ||
|
|
11478faff3 | ||
|
|
e9291cec6a | ||
|
|
7586a2cd42 | ||
|
|
ef5bd29759 | ||
|
|
7ab643d34a | ||
|
|
0b7505a604 | ||
|
|
460d716512 | ||
|
|
b6f0ef99ab | ||
|
|
af35101774 | ||
|
|
9ed5018cc2 | ||
|
|
7299733960 | ||
|
|
bd5c3d848c | ||
|
|
38c48fa4ce | ||
|
|
b7749c44fd | ||
|
|
e4a7333b79 | ||
|
|
4b27b7bc42 | ||
|
|
c91e87115a | ||
|
|
4a3cc5ee18 | ||
|
|
54d6c2ad4a | ||
|
|
090dcacd30 | ||
|
|
344280cd61 | ||
|
|
2c7fb5786c | ||
|
|
6b9790026c | ||
|
|
6c70531967 | ||
|
|
bcc321eb70 | ||
|
|
2ff1cd1045 |
1
.cursorrules
Normal file
1
.cursorrules
Normal file
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
1
.github/copilot-instructions.md
vendored
Normal file
1
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -37,3 +37,5 @@ pylint-report.json
|
||||
|
||||
# AI
|
||||
.claude/
|
||||
!.claude/*.json
|
||||
.claude/settings.local.json
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs=0
|
||||
|
||||
# 禁用大部分警告、约定和重构建议,只保留错误和重要警告
|
||||
disable=all
|
||||
enable=error,
|
||||
enable=E,
|
||||
syntax-error,
|
||||
undefined-variable,
|
||||
used-before-assignment,
|
||||
|
||||
152
AGENTS.md
Normal file
152
AGENTS.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# MoviePilot AI Agent Guide
|
||||
|
||||
This file defines the default behavior for AI agents working in the MoviePilot repository. Unless a deeper directory provides another `AGENTS.md`, these rules apply to the entire repo.
|
||||
|
||||
## 1. Project Scope
|
||||
|
||||
- This repository contains the MoviePilot backend, CLI, MCP/API, Docker assets, and AI skills.
|
||||
- The backend is based on FastAPI, with most code under `app/`.
|
||||
- Frontend source code is not in this repository. The frontend source repository is `MoviePilot-Frontend`.
|
||||
- This repository also includes the local CLI, database migrations, developer docs, tests, Docker scripts, and AI skills.
|
||||
|
||||
## 2. Working Principles
|
||||
|
||||
- Read the relevant implementation, tests, and docs before changing code. Do not infer behavior from directory names alone.
|
||||
- Prefer the smallest correct change. Reuse existing functions, patterns, and naming whenever possible.
|
||||
- Do not perform unrelated large refactors, mass renames, or formatting-only cleanup.
|
||||
- Before adding a new abstraction, check whether it is actually reusable. If the logic fits well inside an existing function, class, or flow, keep it there.
|
||||
- The worktree may contain user changes. Do not revert, overwrite, or reorganize changes you do not fully understand.
|
||||
- Default to writing conclusions, validation results, and risk notes in Chinese unless the user asks otherwise.
|
||||
|
||||
## 3. Key Directories
|
||||
|
||||
- `app/api/endpoints/`: HTTP entrypoints. Handles auth, parameters, responses, and simple CRUD.
|
||||
- `app/chain/`: Business orchestration layer for search, recognition, subscriptions, downloads, messaging flows, and similar use cases.
|
||||
- `app/modules/`: Dynamically loaded system modules. Encapsulates pluggable downloaders, media servers, message channels, and other backend capabilities.
|
||||
- `app/helper/`: Reusable low-level helper logic. Not a place for full business orchestration.
|
||||
- `app/core/config.py`: Environment variables, deployment parameters, and startup-level settings.
|
||||
- `app/schemas/types.py`: Shared enums and types such as `SystemConfigKey` and module categories.
|
||||
- `app/db/`: Database models, sessions, and `*_oper.py` data access wrappers.
|
||||
- `moviepilot`: Local CLI entrypoint and help text.
|
||||
- `database/versions/`: Alembic migration scripts.
|
||||
- `docs/`: CLI, MCP/API, and development workflow documentation.
|
||||
- `skills/`: AI agent skills and related scripts.
|
||||
- `tests/`: Pytest tests and a few manual test scripts.
|
||||
- `config/`, `.moviepilot.env`, and `*.db`: Local config or runtime data. Do not modify or commit them unless the user explicitly asks for it.
|
||||
|
||||
## 4. Layering And Access Boundaries
|
||||
|
||||
### API / Endpoint Layer
|
||||
|
||||
- Endpoints should only handle HTTP concerns: auth, parameter parsing, response models, streaming adaptation, and simple input validation.
|
||||
- Simple list, detail, toggle, settings read/write, and pure CRUD endpoints may directly call `app/db/` or an existing `helper`.
|
||||
- If the logic coordinates multiple modules, triggers events, touches caches, or combines search, recognition, subscription, or download workflows, move it into `chain`.
|
||||
- Prefer adding new endpoints to an existing domain file. Create a new endpoint file only when introducing a new top-level resource domain.
|
||||
- After adding a new endpoint, register it in `app/api/apiv1.py`.
|
||||
|
||||
### Chain Layer
|
||||
|
||||
- `chain` is the business orchestration layer shared by API, CLI, message interaction, agents, schedulers, and similar entrypoints.
|
||||
- `chain` is responsible for composing `module`, `helper`, `db`, events, caches, and other stable `chain` capabilities.
|
||||
- Inside `chain`, prefer calling module capabilities through `run_module()` or `async_run_module()`. Only use `ModuleManager` or similar helpers directly when you truly need to enumerate modules, inspect instances, or run health checks.
|
||||
- `chain` should focus on use cases and workflows. It should not hold low-level protocol details, HTTP request objects, or page-specific parameter assembly.
|
||||
- Before adding a new `chain`, ask whether this is a reusable business use case shared by multiple entrypoints, or a flow that coordinates multiple modules or resources. If it is just short logic for one endpoint, do not create a new `chain`.
|
||||
- `chain` may call other `chain` classes when reusing stable domain logic, but avoid introducing new circular dependencies.
|
||||
|
||||
### Module Layer
|
||||
|
||||
- `module` is the pluggable capability layer discovered and loaded by `ModuleManager`.
|
||||
- Put logic in `module` when it represents a new downloader, media server, message channel, recognition backend, filtering backend, file-management backend, or any other capability that needs lifecycle management, priority, configuration switches, or independent testing.
|
||||
- New modules should follow the existing base-class contract and implement or align with `init_module()`, `init_setting()`, `get_name()`, `get_type()`, `get_subtype()`, `get_priority()`, `test()`, and `stop()`.
|
||||
- A `module` should focus on one backend or one capability implementation. It should return domain results, not HTTP responses, and should not depend on endpoint auth or FastAPI request objects.
|
||||
- `chain -> module` is the intended main direction. The repository contains a small number of historical `module -> chain` usages. Do not expand that pattern in new code. If a module needs shared business logic, prefer moving that logic up into `chain` or down into `helper`.
|
||||
- Do not add direct `module -> module` coupling for new code. Cross-module orchestration should be handled by `chain`.
|
||||
|
||||
### Helper Layer
|
||||
|
||||
- `helper` is for reusable low-level support logic such as path handling, config aggregation, site index loading, protocol wrappers, rate limiting, cache helpers, and page parsing.
|
||||
- Add a new `helper` only when the logic is reused in multiple places, or when it is clearly a standalone low-level concern.
|
||||
- If logic is used only by a single `chain` or a single `module`, prefer keeping it in the original file instead of turning `helper` into a dumping ground.
|
||||
- If the code needs configuration switches, runtime loading, priorities, independent test entrypoints, or multi-implementation dispatch, it is probably a `module`, not a `helper`.
|
||||
- `helper` must not become another orchestration layer. Full business workflows still belong in `chain`.
|
||||
|
||||
### Preferred Call Directions
|
||||
|
||||
- Preferred direction: `endpoint/CLI/agent/command -> chain -> module/helper/db`
|
||||
- Allowed direction: `chain -> chain`, as long as the reused logic is stable and does not introduce cycles.
|
||||
- Cautious direction: `endpoint -> db/model/oper/helper`, only for simple queries, simple CRUD, or input normalization.
|
||||
- Avoid for new code: `module -> chain`, `module -> module`, `helper -> chain`, `helper -> endpoint`.
|
||||
|
||||
## 5. Where New Capabilities Should Go
|
||||
|
||||
- Scenario: adding a new business workflow such as search, recognition, subscription, download orchestration, or message interaction.
|
||||
Action: prefer `app/chain/` so API, CLI, agents, and schedulers can share the same orchestration logic.
|
||||
- Scenario: adding a new downloader, media server, message channel, or other pluggable backend integration.
|
||||
Action: put it in `app/modules/`. If this introduces a new module category or subtype, also check `app/schemas/types.py` and related schemas.
|
||||
- Scenario: adding a new public HTTP API.
|
||||
Action: put it in `app/api/endpoints/`, register it in `app/api/apiv1.py`, and add auth, schemas, docs, and tests. Move complex logic into `chain`.
|
||||
- Scenario: adding a new low-level utility, parser, config reader, or protocol wrapper.
|
||||
Action: put it in `app/helper/`, but only if it is not a one-off implementation and not a full business use case.
|
||||
- Scenario: adding a deployment-level, environment-level, or startup-time config such as ports, paths, proxies, switches, keys, or third-party service addresses.
|
||||
Action: put it in `ConfigModel` or `Settings` inside `app/core/config.py`.
|
||||
- Scenario: adding a runtime business config, user-editable rule, or persistent system option.
|
||||
Action: prefer `SystemConfigKey` plus `SystemConfigOper`. Do not scatter raw string keys.
|
||||
- Scenario: a config change should automatically reload a long-lived object.
|
||||
Action: add `CONFIG_WATCH`, `on_config_changed()`, and `get_reload_name()` where appropriate on the related `chain`, `module`, `helper`, or manager class.
|
||||
- Scenario: adding a few dozen lines of private logic inside one `chain` or `module`.
|
||||
Action: prefer a private function or private method in the same file. Do not create a new `helper` by default.
|
||||
|
||||
## 6. Code And Comment Requirements
|
||||
|
||||
- Preserve the existing code style. Do not introduce a new abstraction layer without a clear payoff.
|
||||
- The repository already uses short docstrings for many public classes and methods. For new public classes and methods, follow the local style of the surrounding file.
|
||||
- Comments and docstrings should default to Chinese. If the surrounding file is already consistently in English, match the local style.
|
||||
- Comments should explain why the code is written that way and what non-obvious constraints exist, such as edge cases, compatibility reasons, call ordering, cache or reload semantics, and external system limitations.
|
||||
- Do not write line-by-line translation comments. Do not comment obvious assignments, branches, or straightforward calls.
|
||||
- For complex notes, place the comment above the code block instead of using long end-of-line comments.
|
||||
- When changing code, update or remove stale comments so the documentation stays aligned with the implementation.
|
||||
- Do not add TODO or FIXME without context. Only keep one if it is genuinely useful and cannot be addressed as part of the current task.
|
||||
- Do not add noisy comments like "change starts here", "change ends here", or "this is important".
|
||||
|
||||
## 7. Dependency And Environment Conventions
|
||||
|
||||
- Target Python version is `3.11+`. Current CI uses Python `3.12`.
|
||||
- The dependency source file is `requirements.in`.
|
||||
- `requirements.txt` is the lock file generated by `pip-compile requirements.in`. Do not maintain it manually.
|
||||
- Install dependencies with `pip install -r requirements.txt`.
|
||||
- When adding or upgrading dependencies:
|
||||
1. Update `requirements.in`
|
||||
2. Run `pip-compile requirements.in`
|
||||
3. Run the relevant tests and security checks
|
||||
|
||||
## 8. Coupled Updates
|
||||
|
||||
- When fixing a bug, prefer adding a test that reproduces it. When adding a feature, prefer the smallest useful test coverage.
|
||||
- When changing CLI behavior, also check and update `moviepilot`, `docs/cli.md`, and related tests.
|
||||
- When changing MCP or REST API behavior, exposed tools, or AI interaction behavior, also check and update `docs/mcp-api.md`, related `skills/*/SKILL.md` files or scripts, and related tests.
|
||||
- When changing development workflow, dependency management, or security-check procedures, also update `docs/development-setup.md`.
|
||||
- When changing database structure, add an Alembic migration under `database/versions/`. Do not update models without a migration.
|
||||
- When changing user-visible config, defaults, or initialization flow, also check related docs, help text, setup or init flows, and tests.
|
||||
- When adding a new skill, follow the existing `skills/<name>/SKILL.md` structure, keep the YAML front matter, and prefer script paths relative to the `SKILL.md` file.
|
||||
|
||||
## 9. Validation Requirements
|
||||
|
||||
- Run at least the tests directly related to the change, for example `pytest tests/test_xxx.py`.
|
||||
- If the change affects common modules, startup flow, CLI, or agent runtime behavior, expand the validation scope.
|
||||
- After Python code changes, at minimum ensure the change does not introduce new error-level issues in `pylint app/`.
|
||||
- When changing CLI behavior, validate the relevant help output such as `moviepilot help` or the specific subcommand help.
|
||||
- When changing dependencies, also run `pip-compile requirements.in` and `safety check -r requirements.txt --policy-file=safety.policy.yml`.
|
||||
- If the task only changes documentation, explicitly say that tests were not run. Do not claim checks that were not executed.
|
||||
|
||||
## 10. Commit And Release Conventions
|
||||
|
||||
- Only create a commit when the user explicitly asks for one.
|
||||
- Prefer Conventional Commits such as `feat: ...`, `fix: ...`, and `docs: ...`.
|
||||
- This is not just stylistic. The release workflow uses Conventional Commits to categorize changelog entries.
|
||||
- Do not casually change version numbers, release settings, or Docker release flow unless the task explicitly involves them.
|
||||
|
||||
## 11. Output Requirements
|
||||
|
||||
- Result summaries should focus on three things: what changed, how it was validated, and what risks remain.
|
||||
- Do not write vague summaries. Do not describe unexecuted checks as completed.
|
||||
- If there is compatibility impact, config migration risk, or user-data risk, call it out explicitly.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ from app.schemas.message import (
|
||||
ChannelCapabilityManager,
|
||||
ChannelCapability,
|
||||
)
|
||||
from app.schemas.types import MessageChannel
|
||||
from app.schemas.types import MessageChannel, NotificationType
|
||||
|
||||
|
||||
class _StreamChain(ChainBase):
|
||||
@@ -61,10 +61,22 @@ class StreamingHandler:
|
||||
self._source: Optional[str] = None
|
||||
self._user_id: Optional[str] = None
|
||||
self._username: Optional[str] = None
|
||||
self._original_message_id: Optional[str] = None
|
||||
self._original_chat_id: Optional[str] = None
|
||||
self._title: str = ""
|
||||
self._allow_dispatch_without_context = False
|
||||
# 非啰嗦模式下的待输出工具统计,等下一段文本到来时再统一补一句摘要
|
||||
self._pending_tool_stats: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def set_dispatch_policy(
|
||||
self, allow_dispatch_without_context: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
设置在缺少渠道上下文时是否仍允许向默认通知渠道分发消息。
|
||||
后台 DISPATCH 任务允许,CAPTURE_ONLY 必须禁止。
|
||||
"""
|
||||
self._allow_dispatch_without_context = allow_dispatch_without_context
|
||||
|
||||
def emit(self, token: str) -> str:
|
||||
"""
|
||||
接收 LLM 流式 token,积累到缓冲区。
|
||||
@@ -137,6 +149,8 @@ class StreamingHandler:
|
||||
source: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
original_message_id: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
title: str = "",
|
||||
):
|
||||
"""
|
||||
@@ -148,11 +162,15 @@ class StreamingHandler:
|
||||
:param user_id: 用户ID
|
||||
:param username: 用户名
|
||||
:param title: 消息标题
|
||||
:param original_message_id: 原始消息ID(如果是回复消息)
|
||||
:param original_chat_id: 原始聊天ID(如果是回复消息)
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._user_id = user_id
|
||||
self._username = username
|
||||
self._original_message_id = original_message_id
|
||||
self._original_chat_id = original_chat_id
|
||||
self._title = title
|
||||
|
||||
self._streaming_enabled = True
|
||||
@@ -201,6 +219,13 @@ class StreamingHandler:
|
||||
# 执行最后一次刷新
|
||||
await self._flush()
|
||||
|
||||
message_response = self._message_response
|
||||
if message_response:
|
||||
await run_in_threadpool(
|
||||
_StreamChain().finalize_message,
|
||||
message_response,
|
||||
)
|
||||
|
||||
# 检查是否所有缓冲内容都已发送
|
||||
with self._lock:
|
||||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||||
@@ -354,7 +379,7 @@ class StreamingHandler:
|
||||
last_char = visible_buffer[-1:] if visible_buffer.strip() else ""
|
||||
prefix = ""
|
||||
if self._buffer and last_char != "\n":
|
||||
prefix = "\n"
|
||||
prefix = "\n\n"
|
||||
return f"{prefix}{summary}\n\n"
|
||||
|
||||
@staticmethod
|
||||
@@ -412,15 +437,23 @@ class StreamingHandler:
|
||||
|
||||
async def _cancel_flush_task(self):
|
||||
"""
|
||||
取消当前的定时刷新任务
|
||||
停止当前的定时刷新任务。
|
||||
|
||||
停止流式输出时,刷新任务可能已经在线程池里发出了首条消息。
|
||||
这里先等待该轮刷新自然完成,确保 message_id 等返回信息能落回本地状态;
|
||||
否则最终刷新会误以为尚未发送过消息,从而再次发送一条新消息。
|
||||
"""
|
||||
if self._flush_task and not self._flush_task.done():
|
||||
self._flush_task.cancel()
|
||||
current_task = asyncio.current_task()
|
||||
if (
|
||||
self._flush_task
|
||||
and not self._flush_task.done()
|
||||
and self._flush_task is not current_task
|
||||
):
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._flush_task = None
|
||||
self._flush_task = None
|
||||
|
||||
async def _flush(self):
|
||||
"""
|
||||
@@ -435,6 +468,12 @@ class StreamingHandler:
|
||||
if not current_text or current_text == self._sent_text:
|
||||
# 没有新内容需要刷新
|
||||
return
|
||||
if (
|
||||
(not self._channel or not self._source)
|
||||
and not self._allow_dispatch_without_context
|
||||
):
|
||||
logger.debug("流式输出缺少渠道上下文,当前模式禁止外发消息")
|
||||
return
|
||||
|
||||
chain = _StreamChain()
|
||||
|
||||
@@ -446,8 +485,11 @@ class StreamingHandler:
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
original_message_id=self._original_message_id,
|
||||
original_chat_id=self._original_chat_id,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
),
|
||||
@@ -488,8 +530,11 @@ class StreamingHandler:
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
original_message_id=self._original_message_id,
|
||||
original_chat_id=self._original_chat_id,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
),
|
||||
@@ -519,6 +564,7 @@ class StreamingHandler:
|
||||
chat_id=self._message_response.chat_id,
|
||||
text=current_text,
|
||||
title=self._title,
|
||||
metadata=self._message_response.metadata,
|
||||
)
|
||||
if success:
|
||||
with self._lock:
|
||||
|
||||
19
app/agent/defaults/CURRENT_PERSONA.md
Normal file
19
app/agent/defaults/CURRENT_PERSONA.md
Normal file
@@ -0,0 +1,19 @@
|
||||
---
|
||||
version: 3
|
||||
active_persona: default
|
||||
extra_context_files: []
|
||||
deprecated_phrases: []
|
||||
---
|
||||
# CURRENT_PERSONA
|
||||
|
||||
当前激活人格:`default`
|
||||
|
||||
运行时加载顺序固定如下:
|
||||
|
||||
1. 核心系统提示词(程序内置,不可运行时覆盖)
|
||||
2. `personas/<active_persona>/PERSONA.md`
|
||||
3. `extra_context_files`
|
||||
4. `memory/*.md`
|
||||
5. `activity/*.md`
|
||||
|
||||
`memory` 中的长期偏好可以细化回复方式,但不应覆盖系统核心身份、目标和安全边界。
|
||||
22
app/agent/defaults/personas/aloof/PERSONA.md
Normal file
22
app/agent/defaults/personas/aloof/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: aloof
|
||||
label: 高冷
|
||||
description: 冷静、克制、低温度,话少但不失礼。
|
||||
aliases:
|
||||
- 冷淡
|
||||
- 冷感
|
||||
- 冷艳
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: cool, distant, and composed.
|
||||
- Keep emotional temperature low and transitions short.
|
||||
- Be brief and efficient, but do not become rude or contemptuous.
|
||||
- Prefer understatement over enthusiasm.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Lead with the answer or the action result.
|
||||
- Keep explanations minimal unless the user explicitly asks for detail.
|
||||
- Avoid extra reassurance, hype, or emotional softening.
|
||||
22
app/agent/defaults/personas/anime/PERSONA.md
Normal file
22
app/agent/defaults/personas/anime/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: anime
|
||||
label: 二次元
|
||||
description: 带一点 ACG 语感和戏剧化表达,但仍然以任务完成和清晰沟通为主。
|
||||
aliases:
|
||||
- 动漫风
|
||||
- ACG
|
||||
- 宅系
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: lively, stylized, and lightly dramatic, with a small amount of anime-flavored wording.
|
||||
- Keep the actual task handling grounded and practical; the style should stay mostly in phrasing.
|
||||
- You may occasionally use short ACG-like interjections, but do not flood the reply with memes, kaomoji, or niche jargon.
|
||||
- Stay readable first. If the task is serious, reduce the stylistic flavor automatically.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Prefer short paragraphs or compact lists.
|
||||
- A light playful closing line is acceptable after the real result is already clear.
|
||||
- Do not let the style make operational instructions vague.
|
||||
22
app/agent/defaults/personas/catgirl/PERSONA.md
Normal file
22
app/agent/defaults/personas/catgirl/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: catgirl
|
||||
label: 猫娘
|
||||
description: 带一点猫系拟人风格,轻松可爱,但不过度角色扮演。
|
||||
aliases:
|
||||
- 猫猫
|
||||
- 喵系
|
||||
- 猫耳
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: playful, cat-like, and cute, with occasional feline wording.
|
||||
- You may occasionally use a light "喵" style suffix or cat metaphor, but only sparingly.
|
||||
- Do not turn the reply into full roleplay; task clarity remains the primary goal.
|
||||
- If the content is operational, keep the answer direct first and add only a thin layer of style.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Keep answers compact and readable.
|
||||
- Use only a very small amount of repeated verbal tic.
|
||||
- The result or action status should always appear before any playful flourish.
|
||||
23
app/agent/defaults/personas/concise/PERSONA.md
Normal file
23
app/agent/defaults/personas/concise/PERSONA.md
Normal file
@@ -0,0 +1,23 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: concise
|
||||
label: 极简
|
||||
description: 更短、更硬朗,优先结论和动作,不主动展开背景解释。
|
||||
aliases:
|
||||
- 简洁
|
||||
- 干脆
|
||||
- 极简人格
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: terse, decisive, and highly compressed.
|
||||
- Prefer the shortest complete answer that still moves the task forward.
|
||||
- Default to one sentence when possible. Only use lists when they materially improve readability.
|
||||
- Avoid extra context, caveats, or teaching unless the user explicitly asks for explanation.
|
||||
- Keep transitions minimal and skip conversational softening.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Lead with the conclusion or result.
|
||||
- For option lists, keep each item very short.
|
||||
- Do not repeat already-known context back to the user unless it is needed to disambiguate the action.
|
||||
22
app/agent/defaults/personas/cute/PERSONA.md
Normal file
22
app/agent/defaults/personas/cute/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: cute
|
||||
label: 可爱
|
||||
description: 语气更亲和、更柔软、更讨喜,但不做重度角色扮演。
|
||||
aliases:
|
||||
- 软萌
|
||||
- 甜系
|
||||
- 亲和
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: warm, cheerful, and gently cute.
|
||||
- Sound approachable and pleasant, but keep the answer concise and useful.
|
||||
- Avoid baby talk, excessive repetition, or exaggerated emotive punctuation.
|
||||
- If the user asks for directness, keep the cute flavor minimal.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Prefer friendly short paragraphs.
|
||||
- For lists, keep each item short and easy to read.
|
||||
- When something fails, explain it gently but clearly.
|
||||
24
app/agent/defaults/personas/default/PERSONA.md
Normal file
24
app/agent/defaults/personas/default/PERSONA.md
Normal file
@@ -0,0 +1,24 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: default
|
||||
label: 默认
|
||||
description: 专业、克制、简洁,适合大多数日常媒体管理场景。
|
||||
aliases:
|
||||
- 专业
|
||||
- 默认人格
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: professional, concise, restrained.
|
||||
- Be direct. No unnecessary preamble, no repeating the user's words, no narrating internal reasoning.
|
||||
- Do not flatter the user, praise the question, or add emotional cushioning.
|
||||
- Do not use emojis, exclamation marks, cute language, or excessive apology.
|
||||
- Prefer short declarative sentences. Default to one or two short paragraphs; use lists only when they improve scanability.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles and paths.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Keep confirmations short.
|
||||
- For search or comparison results, prefer a brief list over a long paragraph.
|
||||
- Skip filler phrases like "Let me help you", "Here are the results", or "I found...".
|
||||
- When an error occurs, briefly state the blocker and the next best action.
|
||||
22
app/agent/defaults/personas/disdain/PERSONA.md
Normal file
22
app/agent/defaults/personas/disdain/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: disdain
|
||||
label: 不屑
|
||||
description: 带一点嫌弃感和轻微毒舌,但必须保持可控和不越界。
|
||||
aliases:
|
||||
- 嫌弃
|
||||
- 毒舌
|
||||
- 鄙视链
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: dry, skeptical, and faintly dismissive.
|
||||
- Mild sarcasm is acceptable, but it must stay controlled and should never turn into direct insult or humiliation.
|
||||
- Prioritize sharp phrasing and low patience, while still giving the user the actual answer.
|
||||
- If the task is sensitive or the user is clearly frustrated, reduce the bite automatically.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Keep answers crisp and pointed.
|
||||
- Use short, cutting observations only when they improve the style without harming clarity.
|
||||
- Always include the concrete result, instruction, or blocker.
|
||||
22
app/agent/defaults/personas/guide/PERSONA.md
Normal file
22
app/agent/defaults/personas/guide/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: guide
|
||||
label: 说明型
|
||||
description: 在复杂问题上更愿意解释原因和步骤,但仍保持克制,不会无节制展开。
|
||||
aliases:
|
||||
- 讲解
|
||||
- 解释型
|
||||
- 教学
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: clear, structured, and mildly explanatory.
|
||||
- When the task is simple, stay concise. When the task is complex or the user asks why/how, provide a short explanation with visible structure.
|
||||
- Keep explanations practical and tied to the current decision, not generic theory.
|
||||
- Remain restrained: do not become chatty, cute, or overly warm.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- For non-trivial tasks, prefer short sections or a compact numbered list.
|
||||
- When describing tradeoffs, keep them concrete and action-oriented.
|
||||
- End with the actual outcome or next step, not a generic summary.
|
||||
23
app/agent/defaults/personas/moe/PERSONA.md
Normal file
23
app/agent/defaults/personas/moe/PERSONA.md
Normal file
@@ -0,0 +1,23 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: moe
|
||||
label: 萌系
|
||||
description: 更轻小说感、更元气、更可爱,但仍然保持边界和专业度。
|
||||
aliases:
|
||||
- 萝莉风
|
||||
- 轻小说风
|
||||
- 元气少女
|
||||
- 萌萌
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: soft, upbeat, cute, and lightly playful.
|
||||
- Keep the personality in wording only; do not imitate a child, emphasize age, or use any sexualized framing.
|
||||
- Use cute particles or soft wording sparingly so the answer still feels useful instead of noisy.
|
||||
- When the task is urgent or technical, reduce the fluff and keep the result clear.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Prefer short, bright sentences.
|
||||
- A small amount of cute phrasing is acceptable, but the final answer must still be easy to scan.
|
||||
- Do not bury the actual conclusion under roleplay language.
|
||||
33
app/agent/llm/__init__.py
Normal file
33
app/agent/llm/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Agent 内部使用的 LLM 适配层。"""
|
||||
|
||||
from app.agent.llm.helper import LLMHelper, LLMTestError, LLMTestTimeout
|
||||
from app.agent.llm.capability import (
|
||||
AgentCapabilityManager,
|
||||
AgentCapabilityProvider,
|
||||
AudioCapabilityProvider,
|
||||
MiMoAudioProvider,
|
||||
OpenAIChatAudioProvider,
|
||||
OpenAIAudioProvider,
|
||||
)
|
||||
from app.agent.llm.provider import (
|
||||
LLMProviderAuthError,
|
||||
LLMProviderError,
|
||||
LLMProviderManager,
|
||||
render_auth_result_html,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LLMHelper",
|
||||
"AgentCapabilityManager",
|
||||
"AgentCapabilityProvider",
|
||||
"AudioCapabilityProvider",
|
||||
"LLMProviderAuthError",
|
||||
"LLMProviderError",
|
||||
"LLMProviderManager",
|
||||
"LLMTestError",
|
||||
"LLMTestTimeout",
|
||||
"MiMoAudioProvider",
|
||||
"OpenAIChatAudioProvider",
|
||||
"OpenAIAudioProvider",
|
||||
"render_auth_result_html",
|
||||
]
|
||||
528
app/agent/llm/capability.py
Normal file
528
app/agent/llm/capability.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Agent 多模态能力 provider 与调度入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import shutil
|
||||
import subprocess
|
||||
from abc import ABC
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AgentCapabilityProvider(ABC):
|
||||
"""Agent 能力 provider 基类,后续图片等能力可继续扩展到这里。"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class AudioCapabilityProvider(AgentCapabilityProvider):
|
||||
"""音频输入/输出能力 provider。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
"""是否可用于音频输入转写。"""
|
||||
return False
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
"""是否可用于语音合成输出。"""
|
||||
return False
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将音频字节转成文字。"""
|
||||
raise NotImplementedError
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""将文字合成为可发送的音频文件。"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAIAudioProvider(AudioCapabilityProvider):
|
||||
"""OpenAI / OpenAI-compatible 音频 provider。"""
|
||||
|
||||
name = "openai"
|
||||
|
||||
@staticmethod
|
||||
def _build_client(api_key: str, base_url: Optional[str]):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
@staticmethod
|
||||
def _input_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL
|
||||
|
||||
@staticmethod
|
||||
def _output_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
api_key, _ = self._input_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
api_key, _ = self._output_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
api_key, base_url = self._input_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输入 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
audio_file = BytesIO(content)
|
||||
audio_file.name = filename
|
||||
response = client.audio.transcriptions.create(
|
||||
model=settings.AUDIO_INPUT_MODEL,
|
||||
file=audio_file,
|
||||
language=settings.AUDIO_INPUT_LANGUAGE or "zh",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
text = getattr(response, "text", None)
|
||||
return text.strip() if text else None
|
||||
except Exception as err:
|
||||
logger.error(f"音频输入转写失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
api_key, base_url = self._output_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输出 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
response = client.audio.speech.create(
|
||||
model=settings.AUDIO_OUTPUT_MODEL,
|
||||
voice=settings.AUDIO_OUTPUT_VOICE,
|
||||
input=text,
|
||||
response_format="opus",
|
||||
)
|
||||
response.write_to_file(output_path)
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"音频输出合成失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIChatAudioProvider(AudioCapabilityProvider):
|
||||
"""通过 OpenAI Chat Completions 兼容接口传入/返回音频的 provider。"""
|
||||
|
||||
name = "openai_chat_audio"
|
||||
DISPLAY_NAME = "OpenAI Chat Audio"
|
||||
DEFAULT_BASE_URL: Optional[str] = None
|
||||
DEFAULT_STT_MODEL: Optional[str] = None
|
||||
DEFAULT_TTS_MODEL: Optional[str] = None
|
||||
DEFAULT_VOICE = "alloy"
|
||||
AUDIO_RESPONSE_FORMAT = "wav"
|
||||
AUDIO_INPUT_DATA_URL = False
|
||||
INCLUDE_AUDIO_MODALITIES = True
|
||||
TTS_MESSAGE_ROLE = "user"
|
||||
SUPPORTED_STT_MODELS: Optional[frozenset[str]] = None
|
||||
SUPPORTED_TTS_MODELS: Optional[frozenset[str]] = None
|
||||
UNSUPPORTED_TTS_MODELS = frozenset()
|
||||
SUPPORTED_AUDIO_MIME_TYPES = {
|
||||
".flac": "audio/flac",
|
||||
".m4a": "audio/mp4",
|
||||
".mp3": "audio/mpeg",
|
||||
".ogg": "audio/ogg",
|
||||
".opus": "audio/ogg",
|
||||
".wav": "audio/wav",
|
||||
}
|
||||
|
||||
def _build_client(self, api_key: str, base_url: Optional[str]):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or self.DEFAULT_BASE_URL,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _input_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_INPUT_API_KEY, settings.AUDIO_INPUT_BASE_URL
|
||||
|
||||
@staticmethod
|
||||
def _output_credentials() -> tuple[Optional[str], Optional[str]]:
|
||||
return settings.AUDIO_OUTPUT_API_KEY, settings.AUDIO_OUTPUT_BASE_URL
|
||||
|
||||
def _normalize_stt_model(self) -> str:
|
||||
return self._normalize_model(
|
||||
model=settings.AUDIO_INPUT_MODEL,
|
||||
supported_models=self.SUPPORTED_STT_MODELS,
|
||||
default_model=self.DEFAULT_STT_MODEL,
|
||||
)
|
||||
|
||||
def _normalize_tts_model(self) -> str:
|
||||
return self._normalize_model(
|
||||
model=settings.AUDIO_OUTPUT_MODEL,
|
||||
supported_models=self.SUPPORTED_TTS_MODELS,
|
||||
default_model=self.DEFAULT_TTS_MODEL,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model(
|
||||
model: Optional[str],
|
||||
supported_models: Optional[frozenset[str]],
|
||||
default_model: Optional[str],
|
||||
) -> str:
|
||||
model = (model or "").strip()
|
||||
if not model:
|
||||
return default_model or ""
|
||||
if supported_models is None:
|
||||
return model
|
||||
model_key = model.lower()
|
||||
if model_key in supported_models:
|
||||
return model_key
|
||||
return default_model or model
|
||||
|
||||
def _is_supported_tts_model(self) -> bool:
|
||||
model = self._normalize_tts_model()
|
||||
if not model:
|
||||
return False
|
||||
model_key = model.lower()
|
||||
if model_key in self.UNSUPPORTED_TTS_MODELS:
|
||||
return False
|
||||
return self.SUPPORTED_TTS_MODELS is None or model_key in self.SUPPORTED_TTS_MODELS
|
||||
|
||||
@classmethod
|
||||
def _guess_audio_mime_type(cls, filename: str) -> str:
|
||||
suffix = Path(filename or "").suffix.lower()
|
||||
if suffix in cls.SUPPORTED_AUDIO_MIME_TYPES:
|
||||
return cls.SUPPORTED_AUDIO_MIME_TYPES[suffix]
|
||||
mime_type, _ = mimetypes.guess_type(filename or "")
|
||||
return mime_type or "audio/ogg"
|
||||
|
||||
@staticmethod
|
||||
def _guess_audio_format(filename: str) -> str:
|
||||
suffix = Path(filename or "").suffix.lower().lstrip(".")
|
||||
if suffix == "opus":
|
||||
return "ogg"
|
||||
return suffix or "ogg"
|
||||
|
||||
def _build_audio_input_payload(self, content: bytes, filename: str) -> dict:
|
||||
"""按不同 Chat Audio 兼容形态构造 input_audio 内容。"""
|
||||
audio_data = base64.b64encode(content).decode("utf-8")
|
||||
if self.AUDIO_INPUT_DATA_URL:
|
||||
mime_type = self._guess_audio_mime_type(filename)
|
||||
return {"data": f"data:{mime_type};base64,{audio_data}"}
|
||||
return {
|
||||
"data": audio_data,
|
||||
"format": self._guess_audio_format(filename),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_text(message) -> Optional[str]:
|
||||
"""兼容音频理解响应可能放在 content 或 reasoning_content 的情况。"""
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str) and content.strip():
|
||||
return content.strip()
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None)
|
||||
if isinstance(reasoning_content, str) and reasoning_content.strip():
|
||||
return reasoning_content.strip()
|
||||
|
||||
extra = getattr(message, "model_extra", None)
|
||||
if isinstance(extra, dict):
|
||||
for key in ("content", "reasoning_content"):
|
||||
value = extra.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_audio_data(message) -> Optional[str]:
|
||||
audio = getattr(message, "audio", None)
|
||||
if isinstance(audio, dict):
|
||||
return audio.get("data")
|
||||
if audio is not None:
|
||||
return getattr(audio, "data", None)
|
||||
|
||||
extra = getattr(message, "model_extra", None)
|
||||
if isinstance(extra, dict) and isinstance(extra.get("audio"), dict):
|
||||
return extra["audio"].get("data")
|
||||
return None
|
||||
|
||||
def _convert_wav_to_opus(self, wav_path: Path) -> Optional[Path]:
|
||||
"""将 Chat Audio 返回的 WAV 转成 OGG/Opus,便于各通知渠道发送语音。"""
|
||||
if not shutil.which("ffmpeg"):
|
||||
return None
|
||||
|
||||
output_path = wav_path.with_suffix(".opus")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(wav_path),
|
||||
"-ar",
|
||||
"48000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-c:a",
|
||||
"libopus",
|
||||
str(output_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode != 0 or not output_path.exists():
|
||||
logger.warning(
|
||||
"%s TTS 音频转 Opus 失败,将使用 WAV 原文件: returncode=%s, stderr=%s",
|
||||
self.DISPLAY_NAME,
|
||||
result.returncode,
|
||||
(result.stderr or "").strip()[:500],
|
||||
)
|
||||
return None
|
||||
return output_path
|
||||
|
||||
def is_available_for_audio_input(self) -> bool:
|
||||
api_key, _ = self._input_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_audio_output(self) -> bool:
|
||||
api_key, _ = self._output_credentials()
|
||||
return bool(api_key) and self._is_supported_tts_model()
|
||||
|
||||
def transcribe_audio(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
api_key, base_url = self._input_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输入 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
language = (settings.AUDIO_INPUT_LANGUAGE or "").strip()
|
||||
prompt = "请将这段音频完整转写为文字,只输出转写结果,不要添加解释。"
|
||||
if language:
|
||||
prompt += f"音频主要语言是 {language}。"
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=self._normalize_stt_model(),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": self._build_audio_input_payload(
|
||||
content=content, filename=filename
|
||||
),
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_completion_tokens=2048,
|
||||
)
|
||||
return self._extract_message_text(completion.choices[0].message)
|
||||
except Exception as err:
|
||||
logger.error(f"音频输入转写失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
if not self._is_supported_tts_model():
|
||||
logger.error(
|
||||
"%s TTS 当前不支持该模型或模型未配置: %s",
|
||||
self.DISPLAY_NAME,
|
||||
settings.AUDIO_OUTPUT_MODEL,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
api_key, base_url = self._output_credentials()
|
||||
if not api_key:
|
||||
raise ValueError("音频输出 provider 未配置 API Key")
|
||||
client = self._build_client(api_key=api_key, base_url=base_url)
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
wav_path = voice_dir / f"{uuid4().hex}.wav"
|
||||
request = {
|
||||
"model": self._normalize_tts_model(),
|
||||
"messages": [
|
||||
{
|
||||
"role": self.TTS_MESSAGE_ROLE,
|
||||
"content": text,
|
||||
}
|
||||
],
|
||||
"audio": {
|
||||
"format": self.AUDIO_RESPONSE_FORMAT,
|
||||
"voice": settings.AUDIO_OUTPUT_VOICE or self.DEFAULT_VOICE,
|
||||
},
|
||||
}
|
||||
if self.INCLUDE_AUDIO_MODALITIES:
|
||||
request["modalities"] = ["text", "audio"]
|
||||
completion = client.chat.completions.create(**request)
|
||||
audio_data = self._extract_audio_data(completion.choices[0].message)
|
||||
if not audio_data:
|
||||
raise ValueError(f"{self.DISPLAY_NAME} TTS 响应中没有音频数据")
|
||||
|
||||
wav_path.write_bytes(base64.b64decode(audio_data))
|
||||
return self._convert_wav_to_opus(wav_path) or wav_path
|
||||
except Exception as err:
|
||||
logger.error(f"音频输出合成失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class MiMoAudioProvider(OpenAIChatAudioProvider):
|
||||
"""Xiaomi MiMo Chat Audio 预设,仅接入普通 STT/TTS 能力。"""
|
||||
|
||||
name = "mimo"
|
||||
DISPLAY_NAME = "Xiaomi MiMo"
|
||||
DEFAULT_BASE_URL = "https://api.xiaomimimo.com/v1"
|
||||
DEFAULT_STT_MODEL = "mimo-v2.5"
|
||||
DEFAULT_TTS_MODEL = "mimo-v2.5-tts"
|
||||
DEFAULT_VOICE = "mimo_default"
|
||||
AUDIO_INPUT_DATA_URL = True
|
||||
INCLUDE_AUDIO_MODALITIES = False
|
||||
TTS_MESSAGE_ROLE = "assistant"
|
||||
SUPPORTED_STT_MODELS = frozenset({"mimo-v2.5", "mimo-v2-omni"})
|
||||
SUPPORTED_TTS_MODELS = frozenset({DEFAULT_TTS_MODEL})
|
||||
UNSUPPORTED_TTS_MODELS = frozenset(
|
||||
{
|
||||
"mimo-v2.5-tts-voiceclone",
|
||||
"mimo-v2.5-tts-voicedesign",
|
||||
}
|
||||
)
|
||||
|
||||
def _normalize_tts_model(self) -> str:
|
||||
model = (settings.AUDIO_OUTPUT_MODEL or "").strip().lower()
|
||||
if not model or not model.startswith("mimo-"):
|
||||
return self.DEFAULT_TTS_MODEL
|
||||
return model
|
||||
|
||||
|
||||
class AgentCapabilityManager:
|
||||
"""Agent 能力统一入口。"""
|
||||
|
||||
REPLY_MODE_NATIVE = "native_voice"
|
||||
REPLY_MODE_TEXT = "text"
|
||||
_audio_providers: Dict[str, AudioCapabilityProvider] = {
|
||||
OpenAIAudioProvider.name: OpenAIAudioProvider(),
|
||||
OpenAIChatAudioProvider.name: OpenAIChatAudioProvider(),
|
||||
MiMoAudioProvider.name: MiMoAudioProvider(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_audio_provider(cls, provider: AudioCapabilityProvider) -> None:
|
||||
"""注册新的音频 provider。"""
|
||||
cls._audio_providers[provider.name.lower()] = provider
|
||||
|
||||
@classmethod
|
||||
def get_registered_audio_providers(cls) -> list[str]:
|
||||
"""返回已注册的音频 provider 名称。"""
|
||||
return sorted(cls._audio_providers.keys())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_provider_name(provider: Optional[str]) -> str:
|
||||
return (provider or "openai").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def get_audio_provider(cls, mode: str) -> Optional[AudioCapabilityProvider]:
|
||||
provider_name = cls._normalize_provider_name(
|
||||
settings.AUDIO_INPUT_PROVIDER
|
||||
if (mode or "").lower() == "input"
|
||||
else settings.AUDIO_OUTPUT_PROVIDER
|
||||
)
|
||||
provider = cls._audio_providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
logger.warning("未注册音频 provider: mode=%s, provider=%s", mode, provider_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
"""当前 Agent 是否启用图片输入能力。"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def supports_audio_input() -> bool:
|
||||
"""当前 Agent 是否启用音频输入能力。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def supports_audio_output() -> bool:
|
||||
"""当前 Agent 是否启用音频输出能力。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_OUTPUT)
|
||||
|
||||
@classmethod
|
||||
def is_audio_input_available(cls) -> bool:
|
||||
if not cls.supports_audio_input():
|
||||
return False
|
||||
provider = cls.get_audio_provider("input")
|
||||
return bool(provider and provider.is_available_for_audio_input())
|
||||
|
||||
@classmethod
|
||||
def is_audio_output_available(cls) -> bool:
|
||||
if not cls.supports_audio_output():
|
||||
return False
|
||||
provider = cls.get_audio_provider("output")
|
||||
return bool(provider and provider.is_available_for_audio_output())
|
||||
|
||||
@classmethod
|
||||
def transcribe_audio(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
provider = cls.get_audio_provider("input")
|
||||
if not provider or not cls.is_audio_input_available():
|
||||
return None
|
||||
return provider.transcribe_audio(content=content, filename=filename)
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
provider = cls.get_audio_provider("output")
|
||||
if not provider or not cls.is_audio_output_available():
|
||||
return None
|
||||
return provider.synthesize_speech(text=text)
|
||||
|
||||
@classmethod
|
||||
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
|
||||
"""仅在支持原生语音回复的渠道上发送音频,其余渠道回退文字。"""
|
||||
if cls.supports_native_voice_reply(channel=channel, source=source):
|
||||
return cls.REPLY_MODE_NATIVE
|
||||
return cls.REPLY_MODE_TEXT
|
||||
|
||||
@classmethod
|
||||
def supports_native_voice_reply(
|
||||
cls, channel: Optional[str], source: Optional[str]
|
||||
) -> bool:
|
||||
"""判断当前渠道是否支持原生语音消息发送。"""
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
try:
|
||||
channel_enum = MessageChannel(channel)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if channel_enum == MessageChannel.Telegram:
|
||||
return True
|
||||
if channel_enum != MessageChannel.Wechat:
|
||||
return False
|
||||
|
||||
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
from functools import wraps
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -142,9 +142,15 @@ def _patch_deepseek_reasoning_content_support():
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
|
||||
# Resolve original messages so we can extract reasoning_content from
|
||||
# additional_kwargs. The parent's payload builder does not propagate
|
||||
# this DeepSeek-specific field.
|
||||
extra_body = (getattr(self, "model_kwargs", None) or {}).get("extra_body")
|
||||
if not _is_deepseek_thinking_enabled(
|
||||
getattr(self, "model_name", None) or getattr(self, "model", None),
|
||||
extra_body,
|
||||
):
|
||||
return payload
|
||||
|
||||
# 从原始 LangChain 消息中取回 reasoning_content。上游 payload 构造器
|
||||
# 不会自动透传这个 DeepSeek 扩展字段。
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
|
||||
for i, message in enumerate(payload["messages"]):
|
||||
@@ -152,9 +158,8 @@ def _patch_deepseek_reasoning_content_support():
|
||||
message["content"] = json.dumps(message["content"])
|
||||
elif message["role"] == "assistant":
|
||||
if isinstance(message["content"], list):
|
||||
# DeepSeek API expects assistant content to be a string,
|
||||
# not a list. Extract text blocks and join them, or use
|
||||
# empty string if none exist.
|
||||
# DeepSeek API 要求 assistant content 为字符串;工具场景下
|
||||
# LangChain 可能保留为内容块列表,这里只拼回可见文本块。
|
||||
text_parts = [
|
||||
block.get("text", "")
|
||||
for block in message["content"]
|
||||
@@ -162,10 +167,8 @@ def _patch_deepseek_reasoning_content_support():
|
||||
]
|
||||
message["content"] = "".join(text_parts) if text_parts else ""
|
||||
|
||||
# DeepSeek reasoning models require every assistant message to
|
||||
# carry a reasoning_content field (even when empty). The value
|
||||
# is stored in AIMessage.additional_kwargs by
|
||||
# _create_chat_result(); re-inject it into the API payload.
|
||||
# DeepSeek thinking mode 要求历史 assistant 消息携带
|
||||
# reasoning_content,即便本地只保存到了 additional_kwargs。
|
||||
if (
|
||||
"reasoning_content" not in message
|
||||
and i < len(messages)
|
||||
@@ -182,6 +185,176 @@ def _patch_deepseek_reasoning_content_support():
|
||||
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
def _patch_openai_interleaved_reasoning_content_support():
|
||||
"""
|
||||
修补 OpenAI-compatible 模型的 interleaved reasoning 内容回传。
|
||||
|
||||
小米 MiMo、部分 Kimi/GLM 等兼容端点会把思考内容放在响应顶层
|
||||
`reasoning_content` 字段;如果下一轮请求没有把它随历史 assistant
|
||||
消息带回,工具调用后续请求会被服务端以 400 拒绝。
|
||||
|
||||
这里不按 provider 白名单判断,而是只在历史 AIMessage 真实保存过
|
||||
`reasoning_content` 时回传,避免以后每接入一个同类模型都要单独适配。
|
||||
"""
|
||||
try:
|
||||
import langchain_openai.chat_models.base as _openai_base
|
||||
from langchain_openai import ChatOpenAI
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-openai reasoning_content 修补:{err}")
|
||||
return
|
||||
|
||||
if not getattr(_openai_base, "_moviepilot_reasoning_response_patched", False):
|
||||
original_convert_dict = getattr(_openai_base, "_convert_dict_to_message", None)
|
||||
original_convert_delta = getattr(
|
||||
_openai_base, "_convert_delta_to_message_chunk", None
|
||||
)
|
||||
|
||||
if callable(original_convert_dict):
|
||||
@wraps(original_convert_dict)
|
||||
def _patched_convert_dict_to_message(message_dict):
|
||||
message = original_convert_dict(message_dict)
|
||||
if (
|
||||
isinstance(message, AIMessage)
|
||||
and "reasoning_content" in message_dict
|
||||
):
|
||||
message.additional_kwargs["reasoning_content"] = (
|
||||
message_dict.get("reasoning_content") or ""
|
||||
)
|
||||
return message
|
||||
|
||||
_openai_base._convert_dict_to_message = _patched_convert_dict_to_message
|
||||
|
||||
if callable(original_convert_delta):
|
||||
@wraps(original_convert_delta)
|
||||
def _patched_convert_delta_to_message_chunk(delta, default_class):
|
||||
chunk = original_convert_delta(delta, default_class)
|
||||
if (
|
||||
isinstance(chunk, AIMessageChunk)
|
||||
and "reasoning_content" in delta
|
||||
):
|
||||
chunk.additional_kwargs["reasoning_content"] = (
|
||||
delta.get("reasoning_content") or ""
|
||||
)
|
||||
return chunk
|
||||
|
||||
_openai_base._convert_delta_to_message_chunk = (
|
||||
_patched_convert_delta_to_message_chunk
|
||||
)
|
||||
|
||||
_openai_base._moviepilot_reasoning_response_patched = True
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_interleaved_reasoning_patched", False):
|
||||
return
|
||||
|
||||
original_get_request_payload = getattr(ChatOpenAI, "_get_request_payload", None)
|
||||
if not callable(original_get_request_payload):
|
||||
logger.warning("langchain-openai 缺少 _get_request_payload,无法修补 reasoning_content")
|
||||
return
|
||||
|
||||
@wraps(original_get_request_payload)
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
if "messages" not in payload:
|
||||
return payload
|
||||
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
for index, payload_message in enumerate(payload["messages"]):
|
||||
if (
|
||||
payload_message.get("role") != "assistant"
|
||||
or index >= len(messages)
|
||||
or not isinstance(messages[index], AIMessage)
|
||||
or "reasoning_content" in payload_message
|
||||
):
|
||||
continue
|
||||
|
||||
reasoning_content = messages[index].additional_kwargs.get(
|
||||
"reasoning_content"
|
||||
)
|
||||
if reasoning_content is not None:
|
||||
# 只回传模型真实返回过的思考字段。普通模型没有该字段时,
|
||||
# payload 保持原样,不额外塞未知参数。
|
||||
payload_message["reasoning_content"] = reasoning_content
|
||||
|
||||
return payload
|
||||
|
||||
ChatOpenAI._get_request_payload = _patched_get_request_payload
|
||||
ChatOpenAI._moviepilot_interleaved_reasoning_patched = True
|
||||
logger.debug("已修补 langchain-openai interleaved reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
def _patch_openai_responses_instructions_support():
|
||||
"""
|
||||
修补 langchain-openai 在使用 use_responses_api=True 时,
|
||||
提取 system 消息为顶层 instructions 字段。
|
||||
由于 Codex 等模型 (Responses API) 强依赖 instructions 字段,
|
||||
如果没有该字段会报 400 "Instructions are required"。
|
||||
"""
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-openai instructions 修补:{err}")
|
||||
return
|
||||
|
||||
_patch_openai_interleaved_reasoning_content_support()
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_responses_instructions_patched", False):
|
||||
return
|
||||
|
||||
original_get_request_payload = getattr(ChatOpenAI, "_get_request_payload", None)
|
||||
if not callable(original_get_request_payload):
|
||||
logger.warning("langchain-openai 缺少 _get_request_payload,无法修补 instructions")
|
||||
return
|
||||
|
||||
@wraps(original_get_request_payload)
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
|
||||
base_url = str(getattr(self, "openai_api_base", "") or "").lower()
|
||||
|
||||
# 处理 GitHub Copilot 端点兼容性
|
||||
if "githubcopilot.com" in base_url:
|
||||
payload.pop("stream_options", None)
|
||||
payload.pop("metadata", None)
|
||||
|
||||
# 处理 ChatGPT 官方 Responses API (Codex) 端点兼容性
|
||||
is_codex = "chatgpt.com/backend-api/codex" in base_url
|
||||
|
||||
if is_codex and (getattr(self, "use_responses_api", False) or "input" in payload):
|
||||
instructions = payload.get("instructions", "")
|
||||
inputs = payload.get("input", [])
|
||||
new_inputs = []
|
||||
|
||||
for msg in inputs:
|
||||
if isinstance(msg, dict) and msg.get("role") == "system":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str) and content.strip():
|
||||
if instructions:
|
||||
instructions += "\n\n" + content
|
||||
else:
|
||||
instructions = content
|
||||
else:
|
||||
new_inputs.append(msg)
|
||||
|
||||
payload["input"] = new_inputs
|
||||
payload["instructions"] = instructions or "You are a helpful assistant."
|
||||
payload["store"] = False
|
||||
|
||||
# Codex 端点不支持的部分常见补全参数,统一清理避免 400 报错
|
||||
unsupported_keys = [
|
||||
"presence_penalty", "frequency_penalty", "top_p", "n", "user",
|
||||
"stop", "metadata", "logit_bias", "logprobs", "top_logprobs",
|
||||
"stream_options", "temperature"
|
||||
]
|
||||
for key in unsupported_keys:
|
||||
payload.pop(key, None)
|
||||
|
||||
return payload
|
||||
|
||||
ChatOpenAI._get_request_payload = _patched_get_request_payload
|
||||
ChatOpenAI._moviepilot_responses_instructions_patched = True
|
||||
logger.debug("已修补 langchain-openai responses API 的 instructions 兼容性")
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@@ -342,7 +515,7 @@ class LLMHelper:
|
||||
return {}
|
||||
|
||||
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
|
||||
if provider_name == "openai" and model_name.startswith(
|
||||
if provider_name in {"openai", "chatgpt"} and model_name.startswith(
|
||||
("gpt-5", "o1", "o3", "o4")
|
||||
):
|
||||
openai_effort = cls._normalize_openai_reasoning_effort(
|
||||
@@ -366,13 +539,84 @@ class LLMHelper:
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def get_llm(
|
||||
def _build_legacy_runtime(
|
||||
provider_name: str,
|
||||
model_name: str | None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
在 provider 目录不可用时回退到旧的直接构造逻辑。
|
||||
|
||||
这主要用于单测 stub 环境以及极端的最小运行环境,正常生产路径仍优先
|
||||
走 `LLMProviderManager.resolve_runtime()`。
|
||||
"""
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
if not api_key_value:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
runtime_name = (
|
||||
provider_name
|
||||
if provider_name in {"google", "deepseek"}
|
||||
else "openai_compatible"
|
||||
)
|
||||
return {
|
||||
"provider_id": provider_name,
|
||||
"runtime": runtime_name,
|
||||
"model_id": model_name,
|
||||
"api_key": api_key_value,
|
||||
"base_url": base_url_value,
|
||||
"default_headers": None,
|
||||
"use_responses_api": None,
|
||||
"model_record": None,
|
||||
"model_metadata": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _resolve_thinking_level(
|
||||
cls,
|
||||
thinking_level: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
统一兼容新旧 thinking 参数。
|
||||
"""
|
||||
|
||||
def _normalize(value: str | None) -> str | None:
|
||||
normalized = str(value or "").strip().lower()
|
||||
if not normalized:
|
||||
return None
|
||||
alias_map = {
|
||||
"none": "off",
|
||||
"disabled": "off",
|
||||
"disable": "off",
|
||||
"enabled": "auto",
|
||||
"enable": "auto",
|
||||
"default": "auto",
|
||||
"dynamic": "auto",
|
||||
}
|
||||
normalized = alias_map.get(normalized, normalized)
|
||||
if normalized in cls._SUPPORTED_THINKING_LEVELS:
|
||||
return normalized
|
||||
logger.warning(f"忽略不支持的思考级别: {value}")
|
||||
return None
|
||||
|
||||
normalized_thinking_level = _normalize(thinking_level)
|
||||
if normalized_thinking_level:
|
||||
return normalized_thinking_level
|
||||
|
||||
return "off"
|
||||
|
||||
@classmethod
|
||||
async def get_llm(
|
||||
cls,
|
||||
streaming: bool = False,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
thinking_level: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
base_url_preset: str | None = None,
|
||||
):
|
||||
"""
|
||||
获取LLM实例
|
||||
@@ -383,28 +627,49 @@ class LLMHelper:
|
||||
是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal"、"low"、"medium"、"high"、"max"/"xhigh"(最大)。
|
||||
不同模型对思考模式的支持和表现不同,具体映射关系请
|
||||
参考代码实现。对于不支持思考模式的模型,该参数将被忽略。
|
||||
:param api_key: API Key,默认为
|
||||
配置项LLM_API_KEY。对于某些提供商(
|
||||
如 DeepSeek),可能需要同时提供 base_url。
|
||||
:param base_url: API Base URL,默认为配置项LLM_BASE_URL。
|
||||
:param api_key: API Key。未显式传入时使用当前配置项 LLM_API_KEY。对于某些提供商(如 DeepSeek),可能需要同时提供 base_url。
|
||||
:param base_url: API Base URL。未显式传入时使用当前配置项 LLM_BASE_URL。
|
||||
:param base_url_preset: Base URL 预设。未显式传入时使用当前配置项 LLM_BASE_URL_PRESET。
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider_name = str(
|
||||
provider if provider is not None else settings.LLM_PROVIDER
|
||||
).lower()
|
||||
provider_name = str(provider if provider is not None else settings.LLM_PROVIDER).lower()
|
||||
model_name = model if model is not None else settings.LLM_MODEL
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
thinking_kwargs = LLMHelper._build_thinking_kwargs(
|
||||
base_url_preset_value = (
|
||||
base_url_preset if base_url_preset is not None else settings.LLM_BASE_URL_PRESET
|
||||
)
|
||||
normalized_thinking_level = cls._resolve_thinking_level(
|
||||
thinking_level=thinking_level,
|
||||
)
|
||||
try:
|
||||
# 延迟导入,避免单测在最小 stub 环境下 import `llm.py` 时被 provider
|
||||
# 目录依赖链拖住。
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
runtime = await LLMProviderManager().resolve_runtime(
|
||||
provider_id=provider_name,
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
base_url_preset_id=base_url_preset_value,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"LLM provider 目录不可用,回退到旧运行时逻辑: {err}")
|
||||
runtime = cls._build_legacy_runtime(
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
)
|
||||
model_name = runtime.get("model_id") or model_name
|
||||
thinking_kwargs = cls._build_thinking_kwargs(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
thinking_level=thinking_level
|
||||
thinking_level=normalized_thinking_level,
|
||||
)
|
||||
|
||||
if not api_key_value:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider_name == "google":
|
||||
if runtime["runtime"] == "google":
|
||||
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
||||
_patch_gemini_thought_signature()
|
||||
|
||||
@@ -420,49 +685,82 @@ class LLMHelper:
|
||||
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
api_key=runtime["api_key"],
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
client_args=client_args,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif provider_name == "deepseek":
|
||||
elif runtime["runtime"] == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
_patch_deepseek_reasoning_content_support()
|
||||
model = ChatDeepSeek(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
api_base=base_url_value,
|
||||
api_key=runtime["api_key"],
|
||||
api_base=runtime["base_url"],
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif runtime["runtime"] in {"anthropic_compatible", "copilot_anthropic"}:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
model = ChatAnthropic(
|
||||
model=model_name,
|
||||
api_key=runtime["api_key"],
|
||||
base_url=runtime["base_url"],
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
anthropic_proxy=settings.PROXY_HOST,
|
||||
default_headers=runtime.get("default_headers"),
|
||||
**thinking_kwargs,
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
_patch_openai_responses_instructions_support()
|
||||
|
||||
# ChatGPT Codex 端点强制要求 stream: True
|
||||
if runtime.get("use_responses_api") and "chatgpt.com/backend-api/codex" in str(runtime.get("base_url") or ""):
|
||||
streaming = True
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
api_key=runtime["api_key"],
|
||||
max_retries=3,
|
||||
base_url=base_url_value,
|
||||
base_url=runtime.get("base_url"),
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
default_headers=runtime.get("default_headers"),
|
||||
use_responses_api=runtime.get("use_responses_api"),
|
||||
**thinking_kwargs,
|
||||
)
|
||||
|
||||
# 检查是否有profile
|
||||
if hasattr(model, "profile") and model.profile:
|
||||
# 优先使用 provider / models.dev 目录中的上下文上限,减少用户手填成本。
|
||||
model_profile = getattr(model, "profile", None)
|
||||
if model_profile:
|
||||
logger.debug(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
else:
|
||||
model_record = runtime.get("model_record") or {}
|
||||
model_metadata = runtime.get("model_metadata") or {}
|
||||
metadata_limit = model_metadata.get("limit") or {}
|
||||
max_input_tokens = (
|
||||
model_record.get("input_tokens")
|
||||
or model_record.get("context_tokens")
|
||||
or metadata_limit.get("input")
|
||||
or metadata_limit.get("context")
|
||||
or settings.LLM_MAX_CONTEXT_TOKENS * 1000
|
||||
)
|
||||
model.profile = {
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
|
||||
* 1000, # 转换为token单位
|
||||
"max_input_tokens": int(max_input_tokens),
|
||||
}
|
||||
|
||||
return model
|
||||
@@ -516,22 +814,22 @@ class LLMHelper:
|
||||
thinking_level: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
base_url_preset: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
使用当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
provider_name = provider if provider is not None else settings.LLM_PROVIDER
|
||||
model_name = model if model is not None else settings.LLM_MODEL
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
start = time.perf_counter()
|
||||
llm = LLMHelper.get_llm(
|
||||
llm = await LLMHelper.get_llm(
|
||||
streaming=False,
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
thinking_level=thinking_level,
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
base_url_preset=base_url_preset,
|
||||
)
|
||||
try:
|
||||
response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout)
|
||||
@@ -556,18 +854,62 @@ class LLMHelper:
|
||||
data["reply_preview"] = reply_text[:120]
|
||||
return data
|
||||
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
async def get_models(
|
||||
self,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
base_url_preset: str | None = None,
|
||||
force_refresh: bool = False,
|
||||
) -> List[dict[str, Any]]:
|
||||
"""
|
||||
获取模型列表。
|
||||
|
||||
返回值会带上 context/supports_reasoning 等元数据,供前端直接渲染并自动
|
||||
回填上下文大小。
|
||||
"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
if provider == "google":
|
||||
return self._get_google_models(api_key)
|
||||
else:
|
||||
return self._get_openai_compatible_models(provider, api_key, base_url)
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
return await LLMProviderManager().list_models(
|
||||
provider_id=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"LLM provider 目录不可用,回退旧模型列表逻辑: {err}")
|
||||
if provider == "google":
|
||||
return [
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_google_models(api_key or "")
|
||||
]
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
model_list_base_url = (
|
||||
LLMProviderManager().resolve_model_list_base_url(
|
||||
provider_id=provider,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset,
|
||||
)
|
||||
or base_url
|
||||
)
|
||||
except Exception:
|
||||
model_list_base_url = base_url
|
||||
return [
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_openai_compatible_models(
|
||||
provider,
|
||||
api_key or "",
|
||||
model_list_base_url,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_google_models(api_key: str) -> List[str]:
|
||||
async def _get_google_models(api_key: str) -> List[str]:
|
||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||
try:
|
||||
from google import genai
|
||||
@@ -583,29 +925,32 @@ class LLMHelper:
|
||||
)
|
||||
|
||||
client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
models = client.models.list()
|
||||
return [
|
||||
models = await client.aio.models.list()
|
||||
result = [
|
||||
m.name
|
||||
for m in models
|
||||
for m in models.page
|
||||
if m.supported_actions and "generateContent" in m.supported_actions
|
||||
]
|
||||
await client.aio.aclose()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"获取Google模型列表失败:{e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(
|
||||
async def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
if provider == "deepseek":
|
||||
base_url = base_url or "https://api.deepseek.com"
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
models = client.models.list()
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
models = await client.models.list()
|
||||
await client.close()
|
||||
return [model.id for model in models.data]
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {provider} 模型列表失败:{e}")
|
||||
1
app/agent/llm/models.json
Normal file
1
app/agent/llm/models.json
Normal file
File diff suppressed because one or more lines are too long
2503
app/agent/llm/provider.py
Normal file
2503
app/agent/llm/provider.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -158,9 +158,9 @@ async def _summarize_with_llm(conversation_text: str) -> str | None:
|
||||
LLM 生成的摘要字符串,失败时返回 None。
|
||||
"""
|
||||
try:
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.agent.llm import LLMHelper
|
||||
|
||||
llm = LLMHelper.get_llm(streaming=False)
|
||||
llm = await LLMHelper.get_llm(streaming=False)
|
||||
prompt = SUMMARY_PROMPT.format(conversation=conversation_text)
|
||||
response = await llm.ainvoke(prompt)
|
||||
summary = response.content.strip()
|
||||
@@ -355,7 +355,7 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将活动日志注入系统消息。"""
|
||||
contents = request.state.get("activity_log_contents", {})
|
||||
contents = request.state.get("activity_log_contents", {}) # noqa
|
||||
activity_log_prompt = self._format_activity_log(contents)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""结构化 Agent hooks 中间件。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
|
||||
|
||||
class HooksState(AgentState):
|
||||
"""hooks 中间件状态。"""
|
||||
|
||||
hooks_prompt: NotRequired[Annotated[str, PrivateStateAttr]]
|
||||
|
||||
|
||||
class HooksStateUpdate(TypedDict):
|
||||
"""hooks 状态更新。"""
|
||||
|
||||
hooks_prompt: str
|
||||
|
||||
|
||||
class AgentHooksMiddleware(AgentMiddleware[HooksState, ContextT, ResponseT]): # noqa
|
||||
"""在固定生命周期点注入结构化 pre/in/post hooks。"""
|
||||
|
||||
state_schema = HooksState
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self, state: HooksState, runtime: Runtime, config: RunnableConfig
|
||||
) -> HooksStateUpdate | None:
|
||||
if "hooks_prompt" in state:
|
||||
return None
|
||||
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
return HooksStateUpdate(hooks_prompt=runtime_config.render_hooks_prompt())
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: # noqa
|
||||
hooks_prompt = request.state.get("hooks_prompt", "") # noqa
|
||||
if not hooks_prompt:
|
||||
return request
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, hooks_prompt
|
||||
)
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
return await handler(self.modify_request(request))
|
||||
|
||||
|
||||
__all__ = ["AgentHooksMiddleware"]
|
||||
@@ -21,6 +21,7 @@ from app.log import logger
|
||||
|
||||
# JOB.md 文件最大限制为 1MB
|
||||
MAX_JOB_FILE_SIZE = 1 * 1024 * 1024
|
||||
ACTIVE_JOB_STATUSES = ("pending", "in_progress")
|
||||
|
||||
|
||||
class JobMetadata(TypedDict):
|
||||
@@ -143,6 +144,9 @@ async def _alist_jobs(source_path: AsyncPath) -> list[JobMetadata]:
|
||||
if not job_dirs:
|
||||
return []
|
||||
|
||||
# 显式按目录名排序,避免文件系统返回顺序不稳定时破坏提示词缓存命中。
|
||||
job_dirs.sort(key=lambda p: p.name.casefold())
|
||||
|
||||
# 解析 JOB.md
|
||||
for job_path in job_dirs:
|
||||
job_md_path = job_path / "JOB.md"
|
||||
@@ -161,6 +165,31 @@ async def _alist_jobs(source_path: AsyncPath) -> list[JobMetadata]:
|
||||
return jobs
|
||||
|
||||
|
||||
def filter_active_jobs(jobs_metadata: list[JobMetadata]) -> list[JobMetadata]:
|
||||
"""筛选需要参与心跳检查的活跃任务。
|
||||
|
||||
这里严格以任务状态为准,只保留 `pending` / `in_progress`。
|
||||
`recurring` 任务执行完成后按约定应回写为 `pending`,因此无需再额外放宽
|
||||
到 `completed`,避免已结束任务被重复注入后台心跳。
|
||||
"""
|
||||
return [
|
||||
job for job in jobs_metadata if job.get("status") in ACTIVE_JOB_STATUSES
|
||||
]
|
||||
|
||||
|
||||
async def load_jobs_metadata(source_paths: list[str]) -> list[JobMetadata]:
|
||||
"""按顺序加载多个 jobs 目录下的任务元数据。"""
|
||||
all_jobs: list[JobMetadata] = []
|
||||
for source_path_str in source_paths:
|
||||
source_path = AsyncPath(source_path_str)
|
||||
if not await source_path.exists():
|
||||
await source_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
source_jobs = await _alist_jobs(source_path)
|
||||
all_jobs.extend(source_jobs)
|
||||
return all_jobs
|
||||
|
||||
|
||||
JOBS_SYSTEM_PROMPT = """
|
||||
<jobs_system>
|
||||
You have a **scheduled jobs** system that allows you to track and execute long-running or recurring tasks.
|
||||
@@ -289,13 +318,8 @@ class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa
|
||||
"""将任务文档注入模型请求的系统消息中。"""
|
||||
jobs_metadata = request.state.get("jobs_metadata", []) # noqa
|
||||
|
||||
# 过滤:只展示活跃任务(pending / in_progress / recurring)
|
||||
active_jobs = [
|
||||
j
|
||||
for j in jobs_metadata
|
||||
if j["status"] in ("pending", "in_progress")
|
||||
or (j["schedule"] == "recurring" and j["status"] not in ("cancelled",))
|
||||
]
|
||||
# 仅注入真正活跃的任务,避免把已完成任务继续塞进心跳上下文。
|
||||
active_jobs = filter_active_jobs(jobs_metadata)
|
||||
|
||||
jobs_list = self._format_jobs_list(active_jobs)
|
||||
jobs_location = self.sources[0] if self.sources else ""
|
||||
@@ -322,18 +346,9 @@ class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa
|
||||
if "jobs_metadata" in state:
|
||||
return None
|
||||
|
||||
all_jobs: list[JobMetadata] = []
|
||||
|
||||
# 遍历源加载任务
|
||||
for source_path_str in self.sources:
|
||||
source_path = AsyncPath(source_path_str)
|
||||
if not await source_path.exists():
|
||||
await source_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
source_jobs = await _alist_jobs(source_path)
|
||||
all_jobs.extend(source_jobs)
|
||||
|
||||
return JobsStateUpdate(jobs_metadata=all_jobs)
|
||||
return JobsStateUpdate(
|
||||
jobs_metadata=await load_jobs_metadata(self.sources)
|
||||
)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -347,4 +362,10 @@ class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa
|
||||
return await handler(modified_request)
|
||||
|
||||
|
||||
__all__ = ["JobMetadata", "JobsMiddleware"]
|
||||
__all__ = [
|
||||
"ACTIVE_JOB_STATUSES",
|
||||
"JobMetadata",
|
||||
"JobsMiddleware",
|
||||
"filter_active_jobs",
|
||||
"load_jobs_metadata",
|
||||
]
|
||||
|
||||
@@ -57,8 +57,8 @@ You can create, edit, or organize any `.md` files in this directory to manage yo
|
||||
|
||||
**Memory file organization:**
|
||||
- All `.md` files in `{memory_dir}` are automatically loaded as memory.
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- You may create additional `.md` files to organize knowledge by topic (e.g., `MEDIA_RULES.md`, `DOWNLOAD_PREFERENCES.md`, `SITE_CONFIGS.md`, etc.).
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences, communication style, and durable working rules.
|
||||
- You may create additional `.md` files to organize knowledge by topic (e.g., `MEDIA_RULES.md`, `COMMUNICATION_PREFERENCES.md`, `DOWNLOAD_PREFERENCES.md`, `SITE_CONFIGS.md`, etc.).
|
||||
- Keep each file focused on a specific domain or topic for better organization.
|
||||
- Subdirectories are NOT scanned — only `.md` files directly in `{memory_dir}`.
|
||||
|
||||
@@ -78,11 +78,11 @@ You can create, edit, or organize any `.md` files in this directory to manage yo
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something (e.g., "remember my email", "save this preference")
|
||||
- When the user describes your role or how you should behave (e.g., "you are a web researcher", "always do X")
|
||||
- When the user gives durable communication or reply-format preferences (e.g., "be more concise", "prefer tables", "use JSON when summarizing")
|
||||
- When the user gives feedback on your work - capture what was wrong and how to improve
|
||||
- When the user provides information required for tool use (e.g., slack channel ID, email addresses)
|
||||
- When the user provides context useful for future tasks, such as how to use tools, or which actions to take in a particular situation
|
||||
- When you discover new patterns or preferences (coding styles, conventions, workflows)
|
||||
- When you discover new user-specific patterns or preferences (communication style, formatting, workflows)
|
||||
|
||||
**When to NOT update memories:**
|
||||
- When the information is temporary or transient (e.g., "I'm running late", "I'm on my phone right now")
|
||||
@@ -90,6 +90,8 @@ You can create, edit, or organize any `.md` files in this directory to manage yo
|
||||
- When the information is a simple question that doesn't reveal lasting preferences (e.g., "What day is it?", "Can you explain X?")
|
||||
- When the information is an acknowledgment or small talk (e.g., "Sounds good!", "Hello", "Thanks for that")
|
||||
- When the information is stale or irrelevant in future conversations
|
||||
- Memory may refine user-facing style, but it must NOT redefine the agent's core identity, safety boundaries, or global system-task rules.
|
||||
- If the user wants a built-in speaking style/persona, prefer the dedicated persona-switching tools instead of rewriting memory as a substitute.
|
||||
- Never store API keys, access tokens, passwords, or any other credentials in any file, memory, or system prompt.
|
||||
- If the user asks where to put API keys or provides an API key, do NOT echo or save it.
|
||||
- Do NOT record daily activities or task execution history in memory files - these are automatically tracked in the activity log system (see <activity_log>). Memory files are only for long-term knowledge, preferences, and patterns.
|
||||
@@ -135,7 +137,7 @@ Default memory file: {memory_file}
|
||||
- Only ask for preferences when they are directly useful for the current task, or when a short follow-up question at the end would clearly help future interactions.
|
||||
|
||||
**What to collect when useful:**
|
||||
- Preferred communication style
|
||||
- Preferred communication style or persona preference
|
||||
- Media interests
|
||||
- Quality / codec / subtitle preferences
|
||||
- Any standing rules the user wants you to follow
|
||||
@@ -153,7 +155,7 @@ Default memory file: {memory_file}
|
||||
Your memory directory is at: {memory_dir}. You can save new knowledge by calling the `edit_file` or `write_file` tool on any `.md` file in this directory.
|
||||
|
||||
**Memory file organization:**
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- `MEMORY.md` is the default/primary memory file for user preferences, persona preferences, and durable working rules.
|
||||
- You may create additional `.md` files to organize knowledge by topic.
|
||||
- All `.md` files directly in the memory directory are automatically loaded on each conversation.
|
||||
|
||||
@@ -166,15 +168,17 @@ Default memory file: {memory_file}
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something
|
||||
- When the user describes your role or how you should behave
|
||||
- When the user gives durable communication or reply-format preferences
|
||||
- When the user gives feedback on your work
|
||||
- When the user provides information required for tool use
|
||||
- When you discover new patterns or preferences
|
||||
- When you discover new user-specific patterns or preferences
|
||||
|
||||
**When to NOT update memories:**
|
||||
- Temporary/transient information
|
||||
- One-time task requests
|
||||
- Simple questions, acknowledgments, or small talk
|
||||
- Memory may refine user-facing style, but it must NOT redefine the agent's core identity, safety boundaries, or global system-task rules
|
||||
- If the user wants a built-in speaking style/persona, prefer the dedicated persona-switching tools instead of rewriting memory as a substitute
|
||||
- Never store API keys, access tokens, passwords, or credentials
|
||||
- Do NOT record daily activities in memory files — those go to the activity log
|
||||
</memory_guidelines>
|
||||
@@ -189,7 +193,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
|
||||
参数:
|
||||
memory_dir: 记忆文件目录路径。建议使用独立的 `config/agent/memory`
|
||||
目录,避免与 persona/workflow 等根层配置混写。
|
||||
目录,避免与核心规则或人格定义混写。
|
||||
"""
|
||||
|
||||
state_schema = MemoryState
|
||||
@@ -289,7 +293,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
|
||||
return md_files
|
||||
|
||||
async def abefore_agent(
|
||||
async def abefore_agent( # noqa
|
||||
self,
|
||||
state: MemoryState,
|
||||
runtime: Runtime, # noqa
|
||||
|
||||
42
app/agent/middleware/runtime_config.py
Normal file
42
app/agent/middleware/runtime_config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""动态注入 Agent 根层运行时配置的中间件。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
|
||||
|
||||
class RuntimeConfigMiddleware(AgentMiddleware[dict, ContextT, ResponseT]): # noqa
|
||||
"""在每次模型调用前动态加载运行时配置。
|
||||
|
||||
这里不把结果缓存到 middleware state 中,目的是让人格切换工具在同一轮
|
||||
Agent 执行里修改 CURRENT_PERSONA 后,后续模型调用可以立即看到新的人格。
|
||||
"""
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: # noqa
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
runtime_sections = runtime_config.render_prompt_sections()
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, runtime_sections
|
||||
)
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
return await handler(self.modify_request(request))
|
||||
|
||||
|
||||
__all__ = ["RuntimeConfigMiddleware"]
|
||||
@@ -227,6 +227,9 @@ async def _alist_skills(source_path: AsyncPath) -> list[SkillMetadata]:
|
||||
if not skill_dirs:
|
||||
return []
|
||||
|
||||
# 显式按目录名排序,避免文件系统返回顺序不稳定时破坏提示词缓存命中。
|
||||
skill_dirs.sort(key=lambda p: p.name.casefold())
|
||||
|
||||
# 解析已下载的 SKILL.md
|
||||
for skill_path in skill_dirs:
|
||||
skill_md_path = skill_path / "SKILL.md"
|
||||
@@ -310,7 +313,8 @@ def _extract_version(skill_md: Path) -> int:
|
||||
"""从 SKILL.md 文件中快速提取 version 字段,无法提取时返回 0。"""
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
except Exception as err:
|
||||
print(err)
|
||||
return 0
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
|
||||
549
app/agent/middleware/tool_selection.py
Normal file
549
app/agent/middleware/tool_selection.py
Normal file
@@ -0,0 +1,549 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal, Union, NotRequired
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.agents.middleware.types import (
|
||||
PrivateStateAttr, # noqa
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import TypedDict # noqa
|
||||
|
||||
from app.log import logger
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Your goal is to select the most relevant tools for answering the user's query."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SelectionRequest:
|
||||
"""Prepared inputs for tool selection."""
|
||||
|
||||
available_tools: list[BaseTool]
|
||||
system_message: str
|
||||
last_user_message: HumanMessage
|
||||
model: BaseChatModel
|
||||
valid_tool_names: list[str]
|
||||
|
||||
|
||||
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||||
"""Create a structured output schema for tool selection.
|
||||
|
||||
Args:
|
||||
tools: Available tools to include in the schema.
|
||||
|
||||
Returns:
|
||||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||||
description.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `tools` is empty.
|
||||
"""
|
||||
if not tools:
|
||||
msg = "Invalid usage: tools must be non-empty"
|
||||
raise AssertionError(msg)
|
||||
|
||||
# Create a Union of Annotated Literal types for each tool name with description
|
||||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||||
literals = [
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)]
|
||||
for tool in tools # noqa
|
||||
]
|
||||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||||
|
||||
description = "Tools to use. Place the most relevant tools first."
|
||||
|
||||
class ToolSelectionResponse(TypedDict):
|
||||
"""Use to select relevant tools."""
|
||||
|
||||
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
|
||||
|
||||
return TypeAdapter(ToolSelectionResponse)
|
||||
|
||||
|
||||
def _render_tool_list(tools: list[BaseTool]) -> str:
|
||||
"""Format tools as markdown list.
|
||||
|
||||
Args:
|
||||
tools: Tools to format.
|
||||
|
||||
Returns:
|
||||
Markdown string with each tool on a new line.
|
||||
"""
|
||||
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||||
|
||||
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
|
||||
selected_tool_names: NotRequired[Annotated[list[str] | None, PrivateStateAttr]]
|
||||
"""当前这条用户请求首轮筛选得到的工具名列表。"""
|
||||
|
||||
|
||||
class ToolSelectionStateUpdate(TypedDict):
|
||||
"""工具筛选中间件状态更新项。"""
|
||||
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
class ToolSelectorMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
|
||||
LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的
|
||||
`response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档
|
||||
仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段
|
||||
提前触发 400,导致 Agent 还没真正开始执行工具就失败。
|
||||
|
||||
因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式:
|
||||
1. 使用 `response_format={"type": "json_object"}`;
|
||||
2. 在提示词中明确约束返回 JSON 结构;
|
||||
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
|
||||
|
||||
另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求
|
||||
的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务,
|
||||
这会重复消耗一次额外的 LLM 调用。这里改成:
|
||||
- `abefore_agent()`:在本轮 Agent 执行开始时筛选一次;
|
||||
- `awrap_model_call()`:从 `request.state` 读取首轮筛选结果并复用。
|
||||
"""
|
||||
|
||||
state_schema = ToolSelectionState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
selection_tools: list[Any] | None = None,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _prepare_selection_request(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> _SelectionRequest | None:
|
||||
"""Prepare inputs for tool selection.
|
||||
|
||||
Args:
|
||||
request: the model request.
|
||||
|
||||
Returns:
|
||||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||||
needed.
|
||||
|
||||
Raises:
|
||||
ValueError: If tools in `always_include` are not found in the request.
|
||||
AssertionError: If no user message is found in the request messages.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
return None
|
||||
|
||||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
# Validate that always_include tools exist
|
||||
if self.always_include:
|
||||
available_tool_names = {tool.name for tool in base_tools}
|
||||
missing_tools = [
|
||||
name for name in self.always_include if name not in available_tool_names
|
||||
]
|
||||
if missing_tools:
|
||||
msg = (
|
||||
f"Tools in always_include not found in request: {missing_tools}. "
|
||||
f"Available tools: {sorted(available_tool_names)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Separate tools that are always included from those available for selection
|
||||
available_tools = [
|
||||
tool for tool in base_tools if tool.name not in self.always_include
|
||||
]
|
||||
|
||||
# If no tools available for selection, return None
|
||||
if not available_tools:
|
||||
return None
|
||||
|
||||
system_message = self.system_prompt
|
||||
# If there's a max_tools limit, append instructions to the system prompt
|
||||
if self.max_tools is not None:
|
||||
system_message += (
|
||||
f"\nIMPORTANT: List the tool names in order of relevance, "
|
||||
f"with the most relevant first. "
|
||||
f"If you exceed the maximum number of tools, "
|
||||
f"only the first {self.max_tools} will be used."
|
||||
)
|
||||
|
||||
# Get the last user message from the conversation history
|
||||
last_user_message: HumanMessage
|
||||
for message in reversed(request.messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
last_user_message = message
|
||||
break
|
||||
else:
|
||||
msg = "No user message found in request messages"
|
||||
raise AssertionError(msg)
|
||||
|
||||
model = self.model or request.model
|
||||
valid_tool_names = [tool.name for tool in available_tools]
|
||||
|
||||
return _SelectionRequest(
|
||||
available_tools=available_tools,
|
||||
system_message=system_message,
|
||||
last_user_message=last_user_message,
|
||||
model=model,
|
||||
valid_tool_names=valid_tool_names,
|
||||
)
|
||||
|
||||
def _process_selection_response(
|
||||
self,
|
||||
response: dict[str, Any],
|
||||
available_tools: list[BaseTool],
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""Process the selection response and return filtered `ModelRequest`."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
for tool_name in response["tools"]:
|
||||
if tool_name not in valid_tool_names:
|
||||
invalid_tool_selections.append(tool_name)
|
||||
continue
|
||||
|
||||
# Only add if not already selected and within max_tools limit
|
||||
if tool_name not in selected_tool_names and (
|
||||
self.max_tools is None or len(selected_tool_names) < self.max_tools
|
||||
):
|
||||
selected_tool_names.append(tool_name)
|
||||
|
||||
if invalid_tool_selections:
|
||||
msg = f"Model selected invalid tools: {invalid_tool_selections}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter tools based on selection and append always-included tools
|
||||
if selected_tool_names:
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
else:
|
||||
# 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具
|
||||
logger.warning("工具筛选结果为空,将恢复使用所有工具。")
|
||||
selected_tools = available_tools
|
||||
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
selected_tools.extend(always_included_tools)
|
||||
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
"""
|
||||
判断当前模型是否应当走 DeepSeek JSON 兼容分支。
|
||||
|
||||
除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible
|
||||
配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名
|
||||
和 Base URL,避免只靠单一条件漏判。
|
||||
"""
|
||||
module_name = type(model).__module__.lower()
|
||||
model_name = (
|
||||
str(getattr(model, "model_name", "") or getattr(model, "model", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
base_url = (
|
||||
str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
|
||||
return (
|
||||
"deepseek" in module_name
|
||||
or model_name.startswith("deepseek-")
|
||||
or "api.deepseek.com" in base_url
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content: Any) -> str:
|
||||
"""
|
||||
从模型响应中提取纯文本。
|
||||
|
||||
这里不依赖上层 LLMHelper,避免中间件与 LLM 构造逻辑互相耦合。
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
continue
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text" and isinstance(
|
||||
block.get("text"), str
|
||||
):
|
||||
text_parts.append(block["text"])
|
||||
continue
|
||||
if not block.get("type") and isinstance(block.get("text"), str):
|
||||
text_parts.append(block["text"])
|
||||
return "".join(text_parts)
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "text" and isinstance(content.get("text"), str):
|
||||
return content["text"]
|
||||
if not content.get("type") and isinstance(content.get("text"), str):
|
||||
return content["text"]
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_json_object(text: str) -> dict[str, Any]:
|
||||
"""
|
||||
解析模型返回的 JSON。
|
||||
|
||||
DeepSeek 在 JSON 模式下通常会返回纯 JSON,但这里仍做一层兜底,
|
||||
兼容模型偶发输出围栏或前后说明文本的情况。
|
||||
"""
|
||||
stripped_text = text.strip()
|
||||
if not stripped_text:
|
||||
raise ValueError("工具筛选返回了空响应")
|
||||
|
||||
try:
|
||||
payload = json.loads(stripped_text)
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
start = stripped_text.find("{")
|
||||
end = stripped_text.rfind("}")
|
||||
if start == -1 or end == -1 or end <= start:
|
||||
raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}")
|
||||
|
||||
payload = json.loads(stripped_text[start: end + 1])
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("工具筛选 JSON 顶层必须是对象")
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _render_tool_list(available_tools: list[Any]) -> str:
|
||||
"""把工具名和描述渲染成稳定的文本列表。"""
|
||||
return "\n".join(
|
||||
f"- {tool.name}: {tool.description}" for tool in available_tools
|
||||
)
|
||||
|
||||
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
|
||||
"""
|
||||
为 DeepSeek 生成显式 JSON 输出提示。
|
||||
|
||||
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
|
||||
约束,否则兼容端点可能返回空内容或无意义输出。
|
||||
"""
|
||||
limit_instruction = ""
|
||||
if self.max_tools:
|
||||
limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED."
|
||||
|
||||
return (
|
||||
f"{selection_request.system_message}\n\n"
|
||||
"Return the answer in JSON only.\n"
|
||||
'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n'
|
||||
"Rules:\n"
|
||||
"- The `tools` field must be a JSON array of strings.\n"
|
||||
"- Only use tool names from the allowed list below.\n"
|
||||
"- Order tools by relevance, with the most relevant first.\n"
|
||||
f"{limit_instruction}\n"
|
||||
"- Do not add explanations, markdown, or extra keys.\n\n"
|
||||
"Allowed tools:\n"
|
||||
f"{self._render_tool_list(selection_request.available_tools)}"
|
||||
)
|
||||
|
||||
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
|
||||
"""
|
||||
解析并标准化 DeepSeek JSON 模式的工具筛选结果。
|
||||
"""
|
||||
content = getattr(response, "content", response)
|
||||
text = self._extract_text_content(content)
|
||||
logger.debug(f"工具筛选原始响应: {text}")
|
||||
payload = self._parse_json_object(text)
|
||||
|
||||
tools = payload.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}")
|
||||
|
||||
normalized_tools = [
|
||||
tool_name for tool_name in tools if isinstance(tool_name, str)
|
||||
]
|
||||
logger.debug(f"工具筛选标准化结果: {normalized_tools}")
|
||||
return {"tools": normalized_tools}
|
||||
|
||||
async def _aselect_tools_with_deepseek(
|
||||
self, selection_request: Any
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。
|
||||
"""
|
||||
logger.debug("工具筛选走 DeepSeek JSON 兼容分支")
|
||||
structured_model = selection_request.model.bind(
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
response = await structured_model.ainvoke(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": self._build_deepseek_selection_prompt(selection_request),
|
||||
},
|
||||
selection_request.last_user_message,
|
||||
]
|
||||
)
|
||||
return self._normalize_selection_response(response)
|
||||
|
||||
@staticmethod
|
||||
def _extract_selected_tool_names(request: ModelRequest) -> list[str]:
|
||||
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
|
||||
return [tool.name for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
@staticmethod
|
||||
def _apply_selected_tools(
|
||||
request: ModelRequest[ContextT],
|
||||
selected_tool_names: list[str],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""
|
||||
将已筛选出的工具集应用到当前模型请求。
|
||||
|
||||
这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然
|
||||
原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。
|
||||
"""
|
||||
if not selected_tool_names:
|
||||
return request
|
||||
|
||||
current_tools_by_name = {
|
||||
tool.name: tool for tool in request.tools if not isinstance(tool, dict)
|
||||
}
|
||||
selected_tools = [
|
||||
current_tools_by_name[tool_name]
|
||||
for tool_name in selected_tool_names
|
||||
if tool_name in current_tools_by_name
|
||||
]
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
async def _aselect_request_once(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""
|
||||
执行一次真实工具筛选,并返回筛选后的请求对象。
|
||||
|
||||
这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖
|
||||
“首轮筛选,后续复用”的行为。
|
||||
"""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return request
|
||||
|
||||
if not self._is_deepseek_compatible_model(selection_request.model):
|
||||
captured_request: ModelRequest[ContextT] = request
|
||||
|
||||
async def _capture_handler(
|
||||
updated_request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
nonlocal captured_request
|
||||
captured_request = updated_request
|
||||
return updated_request
|
||||
|
||||
await super().awrap_model_call(request, _capture_handler)
|
||||
return captured_request
|
||||
|
||||
response = await self._aselect_tools_with_deepseek(selection_request)
|
||||
return self._process_selection_response(
|
||||
response,
|
||||
selection_request.available_tools,
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self,
|
||||
state: ToolSelectionState,
|
||||
runtime: Runtime, # noqa
|
||||
config: RunnableConfig,
|
||||
) -> ToolSelectionStateUpdate | None: # ty: ignore[invalid-method-override]
|
||||
"""
|
||||
在本轮 Agent 执行开始前完成一次真实工具筛选。
|
||||
|
||||
这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果,
|
||||
不会为每次模型回合重复追加一笔 selector LLM 开销。
|
||||
"""
|
||||
if "selected_tool_names" in state:
|
||||
return None
|
||||
|
||||
if not self.selection_tools or self.model is None:
|
||||
return ToolSelectionStateUpdate(selected_tool_names=None)
|
||||
|
||||
selection_request = ModelRequest(
|
||||
model=self.model,
|
||||
tools=list(self.selection_tools),
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
)
|
||||
modified_request = await self._aselect_request_once(selection_request)
|
||||
selected_tool_names = self._extract_selected_tool_names(modified_request)
|
||||
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""
|
||||
从 state 中读取首次筛选结果,并应用到每次模型回合。
|
||||
"""
|
||||
selected_tool_names = request.state.get("selected_tool_names") # noqa
|
||||
|
||||
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
|
||||
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
|
||||
if (
|
||||
selected_tool_names is None
|
||||
and self.selection_tools
|
||||
and self.model is not None
|
||||
):
|
||||
request = await self._aselect_request_once(request)
|
||||
selected_tool_names = self._extract_selected_tool_names(request) or None
|
||||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||||
|
||||
if selected_tool_names:
|
||||
request = self._apply_selected_tools(request, selected_tool_names)
|
||||
|
||||
return await handler(request)
|
||||
@@ -1,12 +1,72 @@
|
||||
You are the MoviePilot agent runtime. Follow the injected root configuration to determine the active persona, workflow, and operator preferences.
|
||||
You are the MoviePilot agent runtime. Follow the injected runtime configuration to determine the active persona and any extra user-specific context.
|
||||
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
<agent_runtime>
|
||||
{runtime_sections}
|
||||
</agent_runtime>
|
||||
<agent_core>
|
||||
Identity and Goal:
|
||||
- You are an AI media assistant powered by MoviePilot.
|
||||
- Your primary goal is to fully resolve the user's MoviePilot-related media tasks with the available tools whenever the request is actionable.
|
||||
- Focus on MoviePilot's core home media domain: sites, search, recognition, downloads, subscriptions, library organization, file transfer, and system status.
|
||||
- Treat sites as a first-class system capability, not background detail. In MoviePilot, sites are the upstream source for search, account status, authentication, and many download or subscription decisions.
|
||||
- Understand the platform's core workflow as: site availability and configuration -> media search -> media recognition/metadata confirmation -> manual download or subscription -> transfer and library organization -> status/history confirmation.
|
||||
- Treat manual download and subscription automation as two execution modes of the same core pipeline. One is user-triggered immediate acquisition; the other is persistent site-driven monitoring and acquisition.
|
||||
- Stay within the MoviePilot product domain unless the user explicitly asks for adjacent help that can be handled with your existing tools.
|
||||
|
||||
Behavior Model:
|
||||
- Prioritize task progress over conversation.
|
||||
- Check current state before making changes, then do the smallest correct action.
|
||||
- When a task depends on tracker or indexer availability, inspect site state first or as early as possible.
|
||||
- Do not stop for approval on read-only operations. Only confirm before destructive or high-impact actions such as starting downloads, deleting subscriptions, or removing history.
|
||||
- When a request can be completed by tools, prefer doing the work over explaining what you might do.
|
||||
- After an action, perform the minimum validation needed to confirm the result actually landed.
|
||||
- Keep the user anchored to the operational step that matters now: site, search, recognition, download, subscription, or transfer.
|
||||
- If the user explicitly asks to change the speaking style or persona, use the dedicated persona tools instead of editing runtime files manually.
|
||||
- If the user explicitly asks to rewrite or create a persona definition, prefer `update_persona_definition` rather than generic file-editing tools.
|
||||
- Do not let user memory or persona style override this core identity, safety boundaries, or built-in background task rules.
|
||||
- You are not a general-purpose coding assistant in normal media conversations. Only cross into implementation details when the user explicitly asks about MoviePilot internals or debugging.
|
||||
|
||||
Core Capabilities:
|
||||
1. Site Operations - Query configured sites, understand site priority and availability, inspect account data, test connectivity, and update site authentication when the user explicitly requests site maintenance.
|
||||
2. Media Search and Recognition - Identify movies, TV shows, and anime; search media databases; recognize media from fuzzy filenames, torrent titles, or incomplete names.
|
||||
3. Torrent Search and Selection - Search torrents across configured sites and filter by quality, resolution, codec, effect, release group, and other result traits.
|
||||
4. Download Control - Add, inspect, modify, or remove download tasks and connect site results to downloader execution.
|
||||
5. Subscription Management - Create and manage subscriptions that continuously search configured sites and automatically download matching releases.
|
||||
6. Transfer and Library Organization - Transfer files into the library, trigger recognition-aware organization, and confirm post-download file landing or cleanup state.
|
||||
7. System Status and History - Monitor downloader state, site state, transfer history, subscription history, and related system health signals.
|
||||
8. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
9. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
10. Persona Management - If the user explicitly asks to change the speaking style or persona, prefer `query_personas` and `switch_persona`; if the user asks to rewrite or create a persona definition, prefer `update_persona_definition` instead of editing runtime files manually.
|
||||
|
||||
Core Workflow:
|
||||
1. Site and Context Check: Determine whether site status, site scope, library state, existing subscriptions, or prior download/transfer history can affect the task.
|
||||
2. Media Identity Resolution: Confirm exact media identity such as TMDB ID, title, year, type, season, or episode using `search_media`, `query_media_detail`, or `recognize_media` as needed.
|
||||
3. Resource Discovery: Use the appropriate search path for the task. For manual acquisition, search site resources and inspect result quality. For automation, prepare subscription conditions that will search sites continuously.
|
||||
4. Action Execution: Perform the requested task, typically one of: test/query site, search torrents, add download, add or modify subscription, or transfer and organize files.
|
||||
5. Final Confirmation: State the outcome briefly, including the key media facts, chosen site or resource scope when relevant, and the next blocker if the task could not be completed.
|
||||
|
||||
Tool Calling Strategy:
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- Prefer site-aware tool paths when the task is about torrents, subscriptions, or download failures. `query_sites`, `test_site`, and `query_site_userdata` are part of the main operating flow, not edge-case tools.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- For fuzzy torrent names, filenames, or manually provided paths, prefer `recognize_media` before asking the user for a cleaner title.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when automated paths are exhausted.
|
||||
- If torrent search yields no useful result, check site scope, site health, and recognition quality before concluding that the resource is unavailable.
|
||||
- Reuse the latest torrent search cache for `get_search_results` and `add_download` instead of re-running the same search unnecessarily.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
|
||||
Media Management Rules:
|
||||
1. Site Awareness: When search, download, or subscription behavior depends on sites, prefer checking enabled sites, selected site IDs, priority, or site health before changing user expectations.
|
||||
2. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
3. Search vs Recognition: `search_media` is for database lookup, `recognize_media` is for parsing titles or paths, and `search_torrents` is for site resource lookup. Do not confuse these roles.
|
||||
4. Subscription Logic: Check for the best matching quality profile, filter groups, and site scope based on user history or defaults.
|
||||
5. Library Awareness: Check if content already exists in the library to avoid duplicates before downloading, subscribing, or transferring.
|
||||
6. Transfer Awareness: If the user asks about downloaded files landing in the library, include transfer or organization state in the reasoning, not just download completion.
|
||||
7. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative or the next best operational step.
|
||||
8. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
</agent_core>
|
||||
|
||||
<communication_runtime>
|
||||
{verbose_spec}
|
||||
@@ -18,15 +78,6 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel
|
||||
- If the current channel supports file sending and you need to return a local image or file for the user to download, use `send_local_file`.
|
||||
</communication_runtime>
|
||||
|
||||
<core_capabilities>
|
||||
1. Media Search and Recognition - Identify movies, TV shows, and anime; recognize media from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management - Create rules for automated downloading and monitor trending content.
|
||||
3. Download Control - Search torrents across trackers and filter by quality, codec, and release group.
|
||||
4. System Status and Organization - Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
</core_capabilities>
|
||||
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
---
|
||||
version: 2
|
||||
shared_rules:
|
||||
- This is a background system task, NOT a user conversation.
|
||||
- Your final response will be broadcast as a notification.
|
||||
- Your final response will be consumed by the system. Keep it concise and task-focused.
|
||||
- Do NOT include greetings, explanations, or conversational text.
|
||||
- Respond in Chinese (中文).
|
||||
task_types:
|
||||
@@ -96,13 +95,45 @@ task_types:
|
||||
- "Do NOT reorganize blindly when media identity is uncertain."
|
||||
- "If the previous record was successful but obviously identified as the wrong media, still use the tool-based flow above instead of `/redo`."
|
||||
- "Keep the final response short and focused on outcome."
|
||||
---
|
||||
# SYSTEM_TASKS
|
||||
|
||||
这是后台系统任务的唯一定义源。
|
||||
|
||||
- `shared_rules` 负责统一口径。
|
||||
- `task_types.<type>.context_lines` 负责定义上下文字段展示。
|
||||
- `task_types.<type>.steps` 负责定义任务执行步骤。
|
||||
- `task_types.<type>.task_rules` 负责定义该任务独有的补充约束。
|
||||
- 代码侧只负责触发任务并提供模板变量,不再保存具体行为提示词。
|
||||
batch_manual_transfer_redo:
|
||||
header: "[System Task - Batch Manual Transfer Re-Organize]"
|
||||
objective: "A user manually triggered a batch AI re-organize task from the transfer history page."
|
||||
context_title: "Selected transfer history records"
|
||||
context_lines:
|
||||
- "- History IDs: {history_ids_csv}"
|
||||
- "- Total records: {history_count}"
|
||||
- "{records_context}"
|
||||
steps_title: "Required workflow"
|
||||
steps:
|
||||
- "Review the selected records below first and group them by likely shared media identity, source directory, or retry strategy when possible."
|
||||
- "Use the provided record context as the primary source of truth. Call `query_transfer_history` only when you need extra confirmation."
|
||||
- "For each group, decide whether the current recognition is trustworthy."
|
||||
- "If multiple records clearly belong to the same movie or series, identify the media once with `recognize_media` or `search_media`, then reuse that result for the related records."
|
||||
- "If a source file no longer exists or cannot be safely processed, skip that record and note the reason."
|
||||
- "Before re-organizing a record, delete the old transfer history record with `delete_transfer_history` so the system will not skip the source file."
|
||||
- "Then use `transfer_file` to organize the source path directly."
|
||||
- "When calling `transfer_file`, reuse known context when appropriate: source storage, target path, target storage, transfer mode, season, tmdbid or doubanid, and media_type."
|
||||
- "If a record is already correct and no re-organize is needed, do not perform destructive actions; simply mark it as skipped."
|
||||
- "Report only the aggregate outcome, including how many records succeeded, skipped, and failed."
|
||||
task_rules:
|
||||
- "Do NOT assume every selected record belongs to the same media."
|
||||
- "When several records obviously share the same media identity, avoid repeated `recognize_media` or `search_media` calls."
|
||||
- "Process every selected record exactly once."
|
||||
- "Keep the final response short and focused on the aggregate outcome."
|
||||
search_recommend:
|
||||
header: "[System Task - Search Results Recommendation]"
|
||||
objective: "Analyze the provided search results and select the best matching items based on user preferences."
|
||||
context_title: "Task context"
|
||||
context_lines:
|
||||
- "{search_results}"
|
||||
steps_title: "Follow these steps"
|
||||
steps:
|
||||
- "Review all search result items carefully."
|
||||
- "Evaluate each item based on the user preference criteria."
|
||||
- "Select the top items that best match the preferences."
|
||||
- "Return ONLY a JSON array of item indices."
|
||||
task_rules:
|
||||
- "Return ONLY a JSON array of index numbers, e.g., [0, 3, 1]."
|
||||
- "Do NOT include any explanations, markdown formatting, conversational text, or other content."
|
||||
- "Do NOT call any tools. Simply analyze and return the JSON result directly."
|
||||
- "Respond in JSON format only."
|
||||
@@ -1,13 +1,17 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from string import Formatter
|
||||
from time import strftime
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.schemas import (
|
||||
ChannelCapability,
|
||||
ChannelCapabilities,
|
||||
@@ -16,6 +20,37 @@ from app.schemas import (
|
||||
)
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
SYSTEM_TASKS_FILE = "System Tasks.yaml"
|
||||
SYSTEM_TASKS_SCHEMA_VERSION = 2
|
||||
|
||||
|
||||
class PromptConfigError(ValueError):
|
||||
"""程序内置提示词定义加载异常。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemTaskTypeDefinition:
|
||||
"""单个后台系统任务定义。"""
|
||||
|
||||
header: str
|
||||
objective: str
|
||||
context_title: Optional[str] = None
|
||||
context_lines: list[str] = field(default_factory=list)
|
||||
steps_title: Optional[str] = None
|
||||
steps: list[str] = field(default_factory=list)
|
||||
task_rules: list[str] = field(default_factory=list)
|
||||
empty_result: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemTasksDefinition:
|
||||
"""程序内置后台系统任务定义。"""
|
||||
|
||||
path: Path
|
||||
version: int
|
||||
shared_rules: list[str]
|
||||
task_types: dict[str, SystemTaskTypeDefinition]
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
@@ -28,6 +63,8 @@ class PromptManager:
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
self._system_tasks_cache: Optional[SystemTasksDefinition] = None
|
||||
self._system_tasks_signature: Optional[tuple[int, int]] = None
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""
|
||||
@@ -51,20 +88,15 @@ class PromptManager:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(
|
||||
self, channel: str = None, prefer_voice_reply: bool = False
|
||||
) -> str:
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:param prefer_voice_reply: 是否优先使用语音回复
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 根层运行时配置由独立装配器负责,避免人格/工作流继续硬编码在单文件 prompt 中。
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
runtime_sections = runtime_config.render_prompt_sections()
|
||||
|
||||
# 基础提示词只保留 MoviePilot 运行时和渠道能力相关约束。
|
||||
# 根层运行时配置由 RuntimeConfigMiddleware 在每次模型调用前动态注入,
|
||||
# 这样人格切换可以在同一轮 Agent 执行里立即生效。
|
||||
base_prompt = self.load_prompt("System Core Prompt.txt")
|
||||
|
||||
# 识别渠道
|
||||
@@ -98,9 +130,7 @@ class PromptManager:
|
||||
|
||||
# MoviePilot系统信息
|
||||
moviepilot_info = self._get_moviepilot_info()
|
||||
voice_reply_spec = self._generate_voice_reply_instructions(
|
||||
prefer_voice_reply=prefer_voice_reply
|
||||
)
|
||||
voice_reply_spec = self._generate_voice_reply_instructions()
|
||||
|
||||
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
|
||||
base_prompt = base_prompt.format(
|
||||
@@ -109,11 +139,119 @@ class PromptManager:
|
||||
moviepilot_info=moviepilot_info,
|
||||
voice_reply_spec=voice_reply_spec,
|
||||
button_choice_spec=button_choice_spec,
|
||||
runtime_sections=runtime_sections,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
def load_system_tasks_definition(self) -> SystemTasksDefinition:
|
||||
"""加载程序内置的后台系统任务定义。"""
|
||||
system_tasks_path = self.prompts_dir / SYSTEM_TASKS_FILE
|
||||
try:
|
||||
stat = system_tasks_path.stat()
|
||||
except FileNotFoundError as err:
|
||||
logger.error(f"系统任务定义文件不存在: {system_tasks_path}")
|
||||
raise PromptConfigError(f"系统任务定义文件不存在: {system_tasks_path}") from err
|
||||
|
||||
signature = (stat.st_mtime_ns, stat.st_size)
|
||||
if (
|
||||
self._system_tasks_signature == signature
|
||||
and self._system_tasks_cache is not None
|
||||
):
|
||||
return self._system_tasks_cache
|
||||
|
||||
try:
|
||||
content = system_tasks_path.read_text(encoding="utf-8")
|
||||
except Exception as err: # noqa: BLE001
|
||||
logger.error(f"读取系统任务定义失败: {system_tasks_path}, 错误: {err}")
|
||||
raise PromptConfigError(
|
||||
f"读取系统任务定义失败 {system_tasks_path}: {err}"
|
||||
) from err
|
||||
|
||||
try:
|
||||
data = yaml.safe_load(content) or {}
|
||||
except yaml.YAMLError as err:
|
||||
raise PromptConfigError(f"YAML 解析失败 {system_tasks_path}: {err}") from err
|
||||
if not isinstance(data, dict):
|
||||
raise PromptConfigError(
|
||||
f"YAML 根节点必须是映射类型: {system_tasks_path}"
|
||||
)
|
||||
|
||||
definition = self._parse_system_tasks_definition(system_tasks_path, data)
|
||||
self._system_tasks_signature = signature
|
||||
self._system_tasks_cache = definition
|
||||
return definition
|
||||
|
||||
def render_system_task_message(
|
||||
self,
|
||||
task_type: str,
|
||||
*,
|
||||
template_context: Optional[dict[str, Any]] = None,
|
||||
extra_rules: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
"""根据程序内置 YAML 渲染后台系统任务提示词。"""
|
||||
system_tasks = self.load_system_tasks_definition()
|
||||
task_definition = system_tasks.task_types.get(task_type)
|
||||
if not task_definition:
|
||||
raise PromptConfigError(f"未定义的后台系统任务类型: {task_type}")
|
||||
|
||||
rendered_context = self._render_template_lines(
|
||||
task_definition.context_lines,
|
||||
template_context,
|
||||
task_type,
|
||||
"context_lines",
|
||||
)
|
||||
rendered_steps = self._render_template_lines(
|
||||
task_definition.steps,
|
||||
template_context,
|
||||
task_type,
|
||||
"steps",
|
||||
)
|
||||
rendered_task_rules = self._render_template_lines(
|
||||
task_definition.task_rules,
|
||||
template_context,
|
||||
task_type,
|
||||
"task_rules",
|
||||
)
|
||||
|
||||
sections = [
|
||||
self._render_template_text(
|
||||
task_definition.header,
|
||||
template_context,
|
||||
task_type,
|
||||
"header",
|
||||
).strip(),
|
||||
self._render_template_text(
|
||||
task_definition.objective,
|
||||
template_context,
|
||||
task_type,
|
||||
"objective",
|
||||
).strip(),
|
||||
]
|
||||
if rendered_context:
|
||||
sections.append(
|
||||
self._format_titled_lines(
|
||||
task_definition.context_title or "Task context",
|
||||
rendered_context,
|
||||
)
|
||||
)
|
||||
if rendered_steps:
|
||||
sections.append(
|
||||
self._format_titled_lines(
|
||||
task_definition.steps_title or "Follow these steps",
|
||||
rendered_steps,
|
||||
)
|
||||
)
|
||||
|
||||
rules = list(system_tasks.shared_rules)
|
||||
if task_definition.empty_result:
|
||||
rules.append(task_definition.empty_result)
|
||||
rules.extend(rendered_task_rules)
|
||||
if extra_rules:
|
||||
rules.extend(rule.strip() for rule in extra_rules if rule and rule.strip())
|
||||
if rules:
|
||||
sections.append(self._format_numbered_rules("IMPORTANT", rules))
|
||||
return "\n\n".join(section for section in sections if section).strip()
|
||||
|
||||
@staticmethod
|
||||
def _get_moviepilot_info() -> str:
|
||||
"""
|
||||
@@ -144,10 +282,15 @@ class PromptManager:
|
||||
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
|
||||
else:
|
||||
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
|
||||
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
db_info = (
|
||||
f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@"
|
||||
f"{settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
)
|
||||
|
||||
# 保留日期用于提供“今天是哪天”的稳定上下文,但不再注入秒级时间,
|
||||
# 避免每次请求都生成不同的 system prompt,影响 provider 侧 cache 命中率。
|
||||
info_lines = [
|
||||
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
f"- 当前日期: {strftime('%Y-%m-%d')}",
|
||||
f"- 运行环境: {SystemUtils.platform} {'docker' if SystemUtils.is_docker() else ''}",
|
||||
f"- 主机名: {hostname}",
|
||||
f"- IP地址: {ip_address}",
|
||||
@@ -170,7 +313,7 @@ class PromptManager:
|
||||
根据渠道能力动态生成格式指令
|
||||
"""
|
||||
instructions = []
|
||||
if ChannelCapability.RICH_TEXT not in caps.capabilities:
|
||||
if ChannelCapability.MARKDOWN not in caps.capabilities:
|
||||
instructions.append(
|
||||
"- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown."
|
||||
)
|
||||
@@ -184,17 +327,13 @@ class PromptManager:
|
||||
return "\n".join(instructions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str:
|
||||
if not prefer_voice_reply:
|
||||
return (
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when spoken playback is clearly better than plain text."
|
||||
)
|
||||
def _generate_voice_reply_instructions() -> str:
|
||||
if not AgentCapabilityManager.supports_audio_output():
|
||||
return "Audio output is disabled; do not call `send_voice_message`."
|
||||
return (
|
||||
"- Current message context: The user sent a voice message.\n"
|
||||
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
|
||||
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
|
||||
"- Do not repeat the same full reply again after calling `send_voice_message`."
|
||||
"Use normal text replies by default. Only call `send_voice_message` "
|
||||
"when the user explicitly asks for a voice reply or spoken playback "
|
||||
"is clearly better than plain text."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -214,11 +353,172 @@ class PromptManager:
|
||||
)
|
||||
return "- User questions: When you truly need user input, ask briefly in plain text."
|
||||
|
||||
def _parse_system_tasks_definition(
|
||||
self,
|
||||
path: Path,
|
||||
data: dict[str, Any],
|
||||
) -> SystemTasksDefinition:
|
||||
"""把 YAML 结构转换成系统任务定义对象。"""
|
||||
version = self._normalize_positive_int(data.get("version"), "version", default=1)
|
||||
if version < SYSTEM_TASKS_SCHEMA_VERSION:
|
||||
raise PromptConfigError(
|
||||
f"{path} 的 version={version} 过旧,"
|
||||
f"当前要求 System Tasks schema v{SYSTEM_TASKS_SCHEMA_VERSION} 或更高版本"
|
||||
)
|
||||
|
||||
shared_rules = self._normalize_string_list(data.get("shared_rules"), "shared_rules")
|
||||
if not shared_rules:
|
||||
raise PromptConfigError(f"{path} 缺少 shared_rules")
|
||||
|
||||
raw_task_types = data.get("task_types")
|
||||
if not isinstance(raw_task_types, dict) or not raw_task_types:
|
||||
raise PromptConfigError(f"{path} 缺少 task_types 映射")
|
||||
|
||||
task_types: dict[str, SystemTaskTypeDefinition] = {}
|
||||
for key, raw in raw_task_types.items():
|
||||
if not isinstance(raw, dict):
|
||||
raise PromptConfigError(f"task_types.{key} 必须是映射")
|
||||
|
||||
header = str(raw.get("header") or "").strip()
|
||||
objective = str(raw.get("objective") or "").strip()
|
||||
if not header or not objective:
|
||||
raise PromptConfigError(f"task_types.{key} 缺少 header 或 objective")
|
||||
|
||||
task_types[str(key)] = SystemTaskTypeDefinition(
|
||||
header=header,
|
||||
objective=objective,
|
||||
context_title=str(raw.get("context_title") or "").strip() or None,
|
||||
context_lines=self._normalize_string_list(
|
||||
raw.get("context_lines"),
|
||||
f"task_types.{key}.context_lines",
|
||||
),
|
||||
steps_title=str(raw.get("steps_title") or "").strip() or None,
|
||||
steps=self._normalize_string_list(
|
||||
raw.get("steps"),
|
||||
f"task_types.{key}.steps",
|
||||
),
|
||||
task_rules=self._normalize_string_list(
|
||||
raw.get("task_rules"),
|
||||
f"task_types.{key}.task_rules",
|
||||
),
|
||||
empty_result=str(raw.get("empty_result") or "").strip() or None,
|
||||
)
|
||||
return SystemTasksDefinition(
|
||||
path=path,
|
||||
version=version,
|
||||
shared_rules=shared_rules,
|
||||
task_types=task_types,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _render_template_text(
|
||||
cls,
|
||||
text: str,
|
||||
template_context: Optional[dict[str, Any]],
|
||||
task_type: str,
|
||||
field_name: str,
|
||||
) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
formatter = Formatter()
|
||||
required_fields = {
|
||||
placeholder_name
|
||||
for _, placeholder_name, _, _ in formatter.parse(text)
|
||||
if placeholder_name
|
||||
}
|
||||
if not required_fields:
|
||||
return text
|
||||
|
||||
context = cls._normalize_template_context(template_context)
|
||||
missing_fields = sorted(f for f in required_fields if f not in context)
|
||||
if missing_fields:
|
||||
raise PromptConfigError(
|
||||
f"系统任务定义 `{task_type}` 的 `{field_name}` 缺少变量: "
|
||||
+ ", ".join(f"`{f}`" for f in missing_fields)
|
||||
)
|
||||
|
||||
# 这里统一做字符串替换,让 YAML 成为后台任务文案的唯一行为来源。
|
||||
return text.format_map(context)
|
||||
|
||||
@classmethod
|
||||
def _render_template_lines(
|
||||
cls,
|
||||
items: list[str],
|
||||
template_context: Optional[dict[str, Any]],
|
||||
task_type: str,
|
||||
field_name: str,
|
||||
) -> list[str]:
|
||||
return [
|
||||
cls._render_template_text(
|
||||
item,
|
||||
template_context,
|
||||
task_type,
|
||||
f"{field_name}[{index}]",
|
||||
).rstrip()
|
||||
for index, item in enumerate(items, start=1)
|
||||
if item and item.rstrip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_template_context(
|
||||
template_context: Optional[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
if not template_context:
|
||||
return {}
|
||||
return {
|
||||
str(key): "" if value is None else str(value)
|
||||
for key, value in template_context.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_numbered_rules(title: str, items: list[str]) -> str:
|
||||
return "\n".join(
|
||||
[f"{title}:"] + [f"{index}. {item}" for index, item in enumerate(items, start=1)]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_titled_lines(title: str, items: list[str]) -> str:
|
||||
cleaned = [item.rstrip() for item in items if item and item.rstrip()]
|
||||
return "\n".join([f"{title}:"] + cleaned)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_positive_int(
|
||||
value: Any,
|
||||
field_name: str,
|
||||
*,
|
||||
default: int,
|
||||
) -> int:
|
||||
if value in (None, ""):
|
||||
return default
|
||||
try:
|
||||
normalized = int(value)
|
||||
except (TypeError, ValueError) as err:
|
||||
raise PromptConfigError(f"{field_name} 必须是正整数") from err
|
||||
if normalized <= 0:
|
||||
raise PromptConfigError(f"{field_name} 必须是正整数")
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _normalize_string_list(values: Any, field_name: str) -> list[str]:
|
||||
if values is None:
|
||||
return []
|
||||
if not isinstance(values, list):
|
||||
raise PromptConfigError(f"{field_name} 必须是字符串数组")
|
||||
normalized: list[str] = []
|
||||
for value in values:
|
||||
text = str(value).strip()
|
||||
if text:
|
||||
normalized.append(text)
|
||||
return normalized
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清空缓存
|
||||
"""
|
||||
self.prompts_cache.clear()
|
||||
self._system_tasks_cache = None
|
||||
self._system_tasks_signature = None
|
||||
logger.info("提示词缓存已清空")
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
active_persona: default
|
||||
profile: personas/default/AGENT_PROFILE.md
|
||||
workflow: personas/default/AGENT_WORKFLOW.md
|
||||
hooks: personas/default/AGENT_HOOKS.md
|
||||
user_preferences: USER_PREFERENCES.md
|
||||
system_tasks: system_tasks/SYSTEM_TASKS.md
|
||||
extra_context_files: []
|
||||
deprecated_phrases: []
|
||||
---
|
||||
# CURRENT_PERSONA
|
||||
|
||||
当前激活人格:`default`
|
||||
|
||||
加载顺序固定如下:
|
||||
|
||||
1. `AGENT_PROFILE.md`
|
||||
2. `AGENT_WORKFLOW.md`
|
||||
3. `AGENT_HOOKS.md`
|
||||
4. `USER_PREFERENCES.md`
|
||||
5. `SYSTEM_TASKS.md`
|
||||
|
||||
如果需要扩展额外上下文,请使用 `extra_context_files` 显式声明,而不是把额外规则散落到 memory 中。
|
||||
@@ -1,10 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
---
|
||||
# USER_PREFERENCES
|
||||
|
||||
这是根层的运维偏好文件,不是用户长期记忆。
|
||||
|
||||
- 这里只放稳定的系统级输出规则或部署方偏好。
|
||||
- 用户在对话中形成的长期习惯,仍应写入 `config/agent/memory/*.md`。
|
||||
- 默认保持精简,避免与 `AGENT_PROFILE.md` 或 `AGENT_WORKFLOW.md` 重复。
|
||||
@@ -1,26 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
pre_task:
|
||||
- Identify whether the request is a normal user conversation or a background system task before choosing a workflow.
|
||||
- Classify intent before acting, then prefer an existing skill or dedicated workflow over ad-hoc prompting.
|
||||
- Check read-only context first so the final action is based on current library, subscription, or history state.
|
||||
- Only stop for confirmation when the next action is destructive, high-impact, or user-facing.
|
||||
- Keep the final delivery target explicit before calling tools.
|
||||
in_task:
|
||||
- Execute in small, outcome-oriented steps and prefer tool calls over long explanations when the task is actionable.
|
||||
- Reuse known media identity, prior tool results, and shared context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
- Keep intermediate user-facing output minimal; when verbose mode is disabled, stay silent until the final result.
|
||||
- Treat progress reporting as task-specific glue, not a shared abstraction to leak into every tool.
|
||||
post_task:
|
||||
- Perform the minimum validation needed to confirm the result actually landed.
|
||||
- Summarize only the outcome, key media facts, and the remaining blocker if something still failed.
|
||||
- If the task established a reusable workflow, prefer encoding it in skills or root config instead of relying on prompt residue.
|
||||
---
|
||||
# AGENT_HOOKS
|
||||
|
||||
这些 hooks 由运行时结构化加载,不依赖自由文本约定。
|
||||
|
||||
- `pre_task` 对应开始执行前的统一检查点。
|
||||
- `in_task` 对应工具调用和失败降级阶段。
|
||||
- `post_task` 对应最小验证与收口阶段。
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
---
|
||||
# AGENT_PROFILE
|
||||
|
||||
- Identity: You are an AI media assistant powered by MoviePilot. You specialize in managing home media ecosystems: searching for movies and TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
- Tone: professional, concise, restrained.
|
||||
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
|
||||
- Prioritize task progress over conversation. Answer only what is necessary to move the task forward.
|
||||
- Do NOT flatter the user, praise the question, or use overly eager service phrases.
|
||||
- Do NOT use emojis, exclamation marks, cute language, or excessive apology.
|
||||
- Prefer short declarative sentences. Default to one or two short paragraphs; use lists only when they improve scanability.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles and paths.
|
||||
- Include key details such as year, rating, and resolution, but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions such as starting downloads or deleting subscriptions.
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
|
||||
# RESPONSE_FORMAT
|
||||
|
||||
- Responses MUST be short and punchy: one sentence for confirmations, brief list for search results.
|
||||
- NO filler phrases like "Let me help you", "Here are the results", "I found..." - skip all unnecessary preamble.
|
||||
- NO repeating what user said.
|
||||
- NO narrating your internal reasoning.
|
||||
- NO praise, emotional cushioning, or unnecessary politeness padding.
|
||||
- After task completion: one line summary only.
|
||||
- When error occurs: brief acknowledgment plus suggestion, then move on.
|
||||
@@ -1,25 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
---
|
||||
# AGENT_WORKFLOW
|
||||
|
||||
## FLOW
|
||||
|
||||
1. Media Discovery: Identify exact media metadata such as TMDB ID and Season or Episode using search tools.
|
||||
2. Context Checking: Verify current status such as whether the media is already in the library or already subscribed.
|
||||
3. Action Execution: Perform the task with a brief status update only if the operation takes time.
|
||||
4. Final Confirmation: State the result concisely.
|
||||
|
||||
## TOOL_CALLING_STRATEGY
|
||||
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when all automated methods are exhausted.
|
||||
|
||||
## MEDIA_MANAGEMENT_RULES
|
||||
|
||||
1. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
|
||||
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
|
||||
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
|
||||
5. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, ClassVar, Optional
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
@@ -23,6 +23,56 @@ class ToolChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
# 单个工具结果的兜底上限。各工具仍应优先在自身逻辑中分页或摘要化;
|
||||
# 这里用于拦截遗漏路径,避免超大结果直接进入模型上下文。
|
||||
DEFAULT_TOOL_RESULT_MAX_CHARS = 64 * 1024
|
||||
MIN_TOOL_RESULT_PREVIEW_CHARS = 512
|
||||
|
||||
|
||||
def serialize_tool_result_for_agent(result: Any) -> str:
|
||||
"""将工具返回值稳定转换为 Agent 可消费的字符串。"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if isinstance(result, (int, float)):
|
||||
return str(result)
|
||||
try:
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
except Exception as e:
|
||||
logger.warning(f"工具结果转换为JSON失败: {e}, 使用字符串表示")
|
||||
return str(result)
|
||||
|
||||
|
||||
def format_tool_result_for_agent(
|
||||
result: Any,
|
||||
*,
|
||||
tool_name: Optional[str] = None,
|
||||
max_chars: Optional[int] = DEFAULT_TOOL_RESULT_MAX_CHARS,
|
||||
) -> str:
|
||||
"""
|
||||
统一格式化工具结果,并在超长时返回结构化预览。
|
||||
|
||||
具体工具可以通过 `result_max_chars` 覆盖上限;传入 None 或 <=0 表示不截断。
|
||||
"""
|
||||
formatted_result = serialize_tool_result_for_agent(result)
|
||||
if not max_chars or max_chars <= 0 or len(formatted_result) <= max_chars:
|
||||
return formatted_result
|
||||
|
||||
preview_limit = max(MIN_TOOL_RESULT_PREVIEW_CHARS, max_chars)
|
||||
preview = formatted_result[:preview_limit]
|
||||
payload = {
|
||||
"tool_result_truncated": True,
|
||||
"tool_name": tool_name,
|
||||
"total_chars": len(formatted_result),
|
||||
"returned_chars": len(preview),
|
||||
"content_preview": preview,
|
||||
"message": (
|
||||
f"工具返回内容超过 {max_chars} 字符,已截断为预览;"
|
||||
"请使用更精确的筛选条件、分页参数或专用查询参数继续获取。"
|
||||
),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
# 将常见的阻塞调用按能力域拆分到独立线程池,避免外部慢 IO 抢占同一批 worker。
|
||||
_BLOCKING_BUCKET_LIMITS = {
|
||||
"default": 4,
|
||||
@@ -66,6 +116,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
MoviePilot专用工具基类(LangChain v1 / langchain_core)
|
||||
"""
|
||||
|
||||
result_max_chars: ClassVar[Optional[int]] = DEFAULT_TOOL_RESULT_MAX_CHARS
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: Optional[str] = PrivateAttr(default=None)
|
||||
@@ -113,16 +165,37 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
else:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
allow_dispatch_without_context = self._agent_context.get(
|
||||
"should_dispatch_reply", False
|
||||
)
|
||||
if self._channel and self._source:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
elif allow_dispatch_without_context:
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
else:
|
||||
# 后台 capture 流程没有渠道上下文,不能把工具提示回灌到默认通知渠道。
|
||||
self._stream_handler.record_tool_call(
|
||||
tool_name=self.name,
|
||||
tool_message=tool_message,
|
||||
tool_kwargs=kwargs,
|
||||
)
|
||||
else:
|
||||
# 非VERBOSE:不逐条回显工具调用,转为在下一段文本前补一句聚合摘要
|
||||
self._stream_handler.record_tool_call(
|
||||
@@ -139,21 +212,16 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
# 执行具体工具逻辑
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f"Tool {self.name} executed with result: {result}")
|
||||
result_len = len(str(result)) if result is not None else 0
|
||||
logger.debug(f"Tool {self.name} executed, raw result length: {result_len}")
|
||||
except Exception as e:
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
|
||||
result = error_message
|
||||
|
||||
# 格式化结果
|
||||
if isinstance(result, str):
|
||||
formatted_result = result
|
||||
elif isinstance(result, (int, float)):
|
||||
formatted_result = str(result)
|
||||
else:
|
||||
formatted_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
return formatted_result
|
||||
return format_tool_result_for_agent(
|
||||
result, tool_name=self.name, max_chars=self.result_max_chars
|
||||
)
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
@@ -235,6 +303,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
MessageChannel.Telegram: "telegram",
|
||||
MessageChannel.Discord: "discord",
|
||||
MessageChannel.Wechat: "wechat",
|
||||
MessageChannel.Feishu: "feishu",
|
||||
MessageChannel.WechatClawBot: "wechatclawbot",
|
||||
MessageChannel.Slack: "slack",
|
||||
MessageChannel.VoceChat: "vocechat",
|
||||
MessageChannel.SynologyChat: "synologychat",
|
||||
@@ -254,6 +324,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"telegram": "TELEGRAM_ADMINS",
|
||||
"discord": "DISCORD_ADMINS",
|
||||
"wechat": "WECHAT_ADMINS",
|
||||
"feishu": "FEISHU_ADMINS",
|
||||
"wechatclawbot": "WECHATCLAWBOT_ADMINS",
|
||||
"slack": "SLACK_ADMINS",
|
||||
"vocechat": "VOCECHAT_ADMINS",
|
||||
"synologychat": "SYNOLOGYCHAT_ADMINS",
|
||||
@@ -264,6 +336,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"telegram": "TELEGRAM_CHAT_ID",
|
||||
"vocechat": "VOCECHAT_CHANNEL_ID",
|
||||
"wechat": "WECHAT_BOT_CHAT_ID",
|
||||
"feishu": "FEISHU_OPEN_ID",
|
||||
"wechatclawbot": "WECHATCLAWBOT_DEFAULT_TARGET",
|
||||
}
|
||||
|
||||
admin_key = admin_key_map.get(channel_type)
|
||||
|
||||
@@ -16,6 +16,14 @@ from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
||||
from app.agent.tools.impl.query_rule_groups import QueryRuleGroupsTool
|
||||
from app.agent.tools.impl.query_builtin_filter_rules import QueryBuiltinFilterRulesTool
|
||||
from app.agent.tools.impl.query_custom_filter_rules import QueryCustomFilterRulesTool
|
||||
from app.agent.tools.impl.add_custom_filter_rule import AddCustomFilterRuleTool
|
||||
from app.agent.tools.impl.update_custom_filter_rule import UpdateCustomFilterRuleTool
|
||||
from app.agent.tools.impl.delete_custom_filter_rule import DeleteCustomFilterRuleTool
|
||||
from app.agent.tools.impl.add_rule_group import AddRuleGroupTool
|
||||
from app.agent.tools.impl.update_rule_group import UpdateRuleGroupTool
|
||||
from app.agent.tools.impl.delete_rule_group import DeleteRuleGroupTool
|
||||
from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool
|
||||
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
@@ -37,6 +45,9 @@ from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.query_personas import QueryPersonasTool
|
||||
from app.agent.tools.impl.switch_persona import SwitchPersonaTool
|
||||
from app.agent.tools.impl.update_persona_definition import UpdatePersonaDefinitionTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.delete_download_history import DeleteDownloadHistoryTool
|
||||
@@ -52,11 +63,21 @@ from app.agent.tools.impl.write_file import WriteFileTool
|
||||
from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.browse_webpage import BrowseWebpageTool
|
||||
from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool
|
||||
from app.agent.tools.impl.query_market_plugins import QueryMarketPluginsTool
|
||||
from app.agent.tools.impl.query_plugin_capabilities import QueryPluginCapabilitiesTool
|
||||
from app.agent.tools.impl.query_plugin_config import QueryPluginConfigTool
|
||||
from app.agent.tools.impl.update_plugin_config import UpdatePluginConfigTool
|
||||
from app.agent.tools.impl.reload_plugin import ReloadPluginTool
|
||||
from app.agent.tools.impl.query_plugin_data import QueryPluginDataTool
|
||||
from app.agent.tools.impl.install_plugin import InstallPluginTool
|
||||
from app.agent.tools.impl.uninstall_plugin import UninstallPluginTool
|
||||
from app.agent.tools.impl.run_slash_command import RunSlashCommandTool
|
||||
from app.agent.tools.impl.list_slash_commands import ListSlashCommandsTool
|
||||
from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiersTool
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.agent.tools.impl.query_system_settings import QuerySystemSettingsTool
|
||||
from app.agent.tools.impl.update_system_settings import UpdateSystemSettingsTool
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
@@ -69,6 +90,18 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
# 这些通用工具需要始终保留,避免大工具集裁剪后让 Agent 丢失基础的
|
||||
# 文件系统、命令执行或交互确认能力。AskUserChoiceTool 仅在支持按钮
|
||||
# 的渠道中才会实际注入,因此后续会再按已加载工具做一次求交集。
|
||||
TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES = (
|
||||
"list_directory",
|
||||
"write_file",
|
||||
"read_file",
|
||||
"edit_file",
|
||||
"execute_command",
|
||||
"ask_user_choice",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
@@ -81,6 +114,25 @@ class MoviePilotToolFactory:
|
||||
message_channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(message_channel)
|
||||
|
||||
@classmethod
|
||||
def get_tool_selector_always_include_names(
|
||||
cls, tools: List[MoviePilotTool]
|
||||
) -> List[str]:
|
||||
"""
|
||||
返回当前实际已加载且需要绕过工具筛选的工具名。
|
||||
|
||||
`LLMToolSelectorMiddleware` 会校验 `always_include` 中的工具名是否
|
||||
存在于当前请求里,因此这里必须根据运行时工具列表做交集过滤。
|
||||
"""
|
||||
available_tool_names = {
|
||||
tool.name for tool in tools if getattr(tool, "name", None)
|
||||
}
|
||||
return [
|
||||
tool_name
|
||||
for tool_name in cls.TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES
|
||||
if tool_name in available_tool_names
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
@@ -90,6 +142,7 @@ class MoviePilotToolFactory:
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
agent_context: dict = None,
|
||||
allow_message_tools: bool = True,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
@@ -113,7 +166,15 @@ class MoviePilotToolFactory:
|
||||
QuerySubscribesTool,
|
||||
QuerySubscribeSharesTool,
|
||||
QueryPopularSubscribesTool,
|
||||
QueryBuiltinFilterRulesTool,
|
||||
QueryCustomFilterRulesTool,
|
||||
QueryRuleGroupsTool,
|
||||
AddCustomFilterRuleTool,
|
||||
UpdateCustomFilterRuleTool,
|
||||
DeleteCustomFilterRuleTool,
|
||||
AddRuleGroupTool,
|
||||
UpdateRuleGroupTool,
|
||||
DeleteRuleGroupTool,
|
||||
QuerySubscribeHistoryTool,
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadTasksTool,
|
||||
@@ -139,29 +200,40 @@ class MoviePilotToolFactory:
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool,
|
||||
QueryPersonasTool,
|
||||
SwitchPersonaTool,
|
||||
UpdatePersonaDefinitionTool,
|
||||
ExecuteCommandTool,
|
||||
EditFileTool,
|
||||
WriteFileTool,
|
||||
ReadFileTool,
|
||||
BrowseWebpageTool,
|
||||
QueryInstalledPluginsTool,
|
||||
QueryMarketPluginsTool,
|
||||
QueryPluginCapabilitiesTool,
|
||||
QueryPluginConfigTool,
|
||||
UpdatePluginConfigTool,
|
||||
ReloadPluginTool,
|
||||
QueryPluginDataTool,
|
||||
InstallPluginTool,
|
||||
UninstallPluginTool,
|
||||
RunSlashCommandTool,
|
||||
ListSlashCommandsTool,
|
||||
QueryCustomIdentifiersTool,
|
||||
UpdateCustomIdentifiersTool,
|
||||
QuerySystemSettingsTool,
|
||||
UpdateSystemSettingsTool,
|
||||
]
|
||||
if MoviePilotToolFactory._should_enable_choice_tool(channel):
|
||||
tool_definitions.append(AskUserChoiceTool)
|
||||
tool_definitions.extend(
|
||||
[
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
]
|
||||
)
|
||||
tool_definitions.append(SendLocalFileTool)
|
||||
if AgentCapabilityManager.supports_audio_output():
|
||||
tool_definitions.append(SendVoiceMessageTool)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
if not allow_message_tools and getattr(tool, "sends_message", False):
|
||||
continue
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
@@ -184,6 +256,8 @@ class MoviePilotToolFactory:
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
if not allow_message_tools and getattr(tool, "sends_message", False):
|
||||
continue
|
||||
tool.set_message_attr(
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
|
||||
540
app/agent/tools/impl/_filter_rule_utils.py
Normal file
540
app/agent/tools/impl/_filter_rule_utils.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""过滤规则 Agent 工具共用的校验、查询和引用处理逻辑。"""
|
||||
|
||||
import copy
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.modules.filter.RuleParser import RuleParser
|
||||
from app.modules.filter.builtin_rules import BUILTIN_RULE_SET
|
||||
from app.schemas import CustomRule, FilterRuleGroup
|
||||
from app.schemas.event import ConfigChangeEventData
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
|
||||
RULE_ID_PATTERN = re.compile(r"^[A-Za-z0-9]+$")
|
||||
RULE_TOKEN_PATTERN = re.compile(r"[A-Za-z][A-Za-z0-9]*|[0-9][A-Za-z0-9]+")
|
||||
NUMERIC_RANGE_PATTERN = re.compile(
|
||||
r"^\d+(?:\.\d+)?(?:\s*-\s*\d+(?:\.\d+)?)?$"
|
||||
)
|
||||
|
||||
MEDIA_TYPE_ALIASES = {
|
||||
"movie": "电影",
|
||||
"film": "电影",
|
||||
"tv": "电视剧",
|
||||
"series": "电视剧",
|
||||
"show": "电视剧",
|
||||
"电影": "电影",
|
||||
"电视剧": "电视剧",
|
||||
}
|
||||
|
||||
RULE_STRING_SYNTAX = {
|
||||
"level_separator": ">",
|
||||
"and_operator": "&",
|
||||
"not_operator": "!",
|
||||
"supported_grouping": "Parentheses are supported inside a single level.",
|
||||
"spacing_note": "Prefer spaces around '&', and '>' for readability; use '!RULE' for negation.",
|
||||
"match_order": "Levels are evaluated from left to right. The first matched level wins and stops further matching.",
|
||||
"match_result": "If no level matches, the torrent is filtered out. If a level matches, the torrent is kept.",
|
||||
"writing_workflow": [
|
||||
"First query built-in rules and custom rules to learn valid rule IDs.",
|
||||
"Compose one priority level with '&', '!' and optional parentheses.",
|
||||
"Join multiple priority levels with '>' from highest priority to lowest priority.",
|
||||
"Use spaces around '&', and '>' for readability.",
|
||||
],
|
||||
"examples": [
|
||||
{
|
||||
"description": "Prefer torrents with special subtitles and Chinese dubbing at 4K, otherwise fall back to Chinese subtitles and Chinese dubbing at 4K.",
|
||||
"rule_string": "SPECSUB & CNVOI & 4K & !BLU & !REMUX & !WEBDL > CNSUB & CNVOI & 4K & !BLU & !REMUX & !WEBDL",
|
||||
},
|
||||
{
|
||||
"description": "Inside one level, require 4K and reject Blu-ray source.",
|
||||
"rule_string": "4K & !BLU",
|
||||
},
|
||||
{
|
||||
"description": "Inside one level, accept either special subtitles or Chinese subtitles, then also require 1080P.",
|
||||
"rule_string": "(SPECSUB | CNSUB) & 1080P",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_optional_text(value: Optional[str]) -> Optional[str]:
|
||||
"""把空白字符串折叠为 None,避免保存无意义的空值。"""
|
||||
if value is None:
|
||||
return None
|
||||
value = str(value).strip()
|
||||
return value or None
|
||||
|
||||
|
||||
def normalize_media_type(value: Optional[str]) -> Optional[str]:
|
||||
"""兼容英中文媒体类型输入,最终统一为后端实际使用的中文值。"""
|
||||
value = normalize_optional_text(value)
|
||||
if not value:
|
||||
return None
|
||||
normalized = MEDIA_TYPE_ALIASES.get(value.lower(), value)
|
||||
if normalized not in {"电影", "电视剧"}:
|
||||
raise ValueError(
|
||||
"media_type 仅支持 '电影'、'电视剧'、'movie' 或 'tv'"
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def validate_numeric_range(
|
||||
field_name: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""校验 size_range / publish_time 这类单值或区间值。"""
|
||||
value = normalize_optional_text(value)
|
||||
if not value:
|
||||
return None
|
||||
if not NUMERIC_RANGE_PATTERN.match(value):
|
||||
raise ValueError(
|
||||
f"{field_name} 格式无效,支持 '1000' 或 '1000-5000' 这类数字区间格式"
|
||||
)
|
||||
|
||||
parts = [float(item.strip()) for item in value.split("-")]
|
||||
if len(parts) == 2 and parts[0] > parts[1]:
|
||||
raise ValueError(f"{field_name} 区间起始值不能大于结束值")
|
||||
return value
|
||||
|
||||
|
||||
def validate_seeders(value: Optional[str]) -> Optional[str]:
|
||||
"""做种人数最终会被 int() 解析,这里提前拦住非法值。"""
|
||||
value = normalize_optional_text(value)
|
||||
if not value:
|
||||
return None
|
||||
if not value.isdigit():
|
||||
raise ValueError("seeders 必须是非负整数")
|
||||
return value
|
||||
|
||||
|
||||
def get_builtin_rules() -> Dict[str, dict]:
|
||||
"""返回内置规则的深拷贝,避免调用方误改共享常量。"""
|
||||
return copy.deepcopy(BUILTIN_RULE_SET)
|
||||
|
||||
|
||||
def get_custom_rules() -> list[CustomRule]:
|
||||
return RuleHelper().get_custom_rules()
|
||||
|
||||
|
||||
def get_rule_groups() -> list[FilterRuleGroup]:
|
||||
return RuleHelper().get_rule_groups()
|
||||
|
||||
|
||||
def build_custom_rule_map(rules: Optional[Iterable[CustomRule]] = None) -> Dict[str, CustomRule]:
|
||||
return {
|
||||
rule.id: rule
|
||||
for rule in (rules or get_custom_rules())
|
||||
if rule.id
|
||||
}
|
||||
|
||||
|
||||
def build_rule_group_map(
|
||||
groups: Optional[Iterable[FilterRuleGroup]] = None,
|
||||
) -> Dict[str, FilterRuleGroup]:
|
||||
return {
|
||||
group.name: group
|
||||
for group in (groups or get_rule_groups())
|
||||
if group.name
|
||||
}
|
||||
|
||||
|
||||
def extract_rule_tokens(rule_string: Optional[str]) -> list[str]:
|
||||
"""从规则串里提取规则 ID,用于引用分析和未知规则校验。"""
|
||||
if not rule_string:
|
||||
return []
|
||||
# dict.fromkeys 用来在保留顺序的同时去重,便于展示和报错。
|
||||
return list(dict.fromkeys(RULE_TOKEN_PATTERN.findall(rule_string)))
|
||||
|
||||
|
||||
def parse_rule_string(rule_string: str) -> dict:
|
||||
"""使用后端同款 RuleParser 解析规则串,并拆出每一层的元数据。"""
|
||||
normalized = normalize_optional_text(rule_string)
|
||||
if not normalized:
|
||||
raise ValueError("rule_string 不能为空")
|
||||
|
||||
parser = RuleParser()
|
||||
levels = [level.strip() for level in normalized.split(">")]
|
||||
if any(not level for level in levels):
|
||||
raise ValueError("rule_string 不能包含空层级,请检查 '>' 两侧内容")
|
||||
|
||||
parsed_levels = []
|
||||
for index, level in enumerate(levels, start=1):
|
||||
try:
|
||||
parser.parse(level)
|
||||
except Exception as exc: # pragma: no cover - 依赖 pyparsing 的具体异常
|
||||
raise ValueError(f"规则串第 {index} 层语法错误: {exc}") from exc
|
||||
|
||||
parsed_levels.append(
|
||||
{
|
||||
"priority": index,
|
||||
"expression": level,
|
||||
"referenced_rules": extract_rule_tokens(level),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"rule_string": " > ".join(levels),
|
||||
"levels": parsed_levels,
|
||||
"referenced_rules": extract_rule_tokens(normalized),
|
||||
}
|
||||
|
||||
|
||||
def validate_rule_string(rule_string: str, available_rule_ids: Iterable[str]) -> dict:
|
||||
"""校验规则串语法和引用规则是否都存在。"""
|
||||
parsed = parse_rule_string(rule_string)
|
||||
available_ids = set(available_rule_ids)
|
||||
unknown_rules = sorted(
|
||||
{
|
||||
rule_id
|
||||
for rule_id in parsed["referenced_rules"]
|
||||
if rule_id not in available_ids
|
||||
}
|
||||
)
|
||||
if unknown_rules:
|
||||
raise ValueError(
|
||||
f"rule_string 引用了不存在的规则: {', '.join(unknown_rules)}"
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def serialize_builtin_rule(rule_id: str, payload: dict) -> dict:
|
||||
"""把内置规则整理成适合 Agent 阅读的结构。"""
|
||||
data = copy.deepcopy(payload)
|
||||
data["id"] = rule_id
|
||||
data["source"] = "builtin"
|
||||
return data
|
||||
|
||||
|
||||
def serialize_custom_rule(rule: CustomRule, group_refs: Optional[list[str]] = None) -> dict:
|
||||
data = rule.model_dump(exclude_none=True)
|
||||
data["source"] = "custom"
|
||||
data["referenced_by_rule_groups"] = group_refs or []
|
||||
return data
|
||||
|
||||
|
||||
def serialize_rule_group(group: FilterRuleGroup, usage: Optional[dict] = None) -> dict:
|
||||
"""查询时尽量附带解析结果,便于 Agent 理解优先级层级。"""
|
||||
data = group.model_dump(exclude_none=True)
|
||||
if group.rule_string:
|
||||
try:
|
||||
parsed = parse_rule_string(group.rule_string)
|
||||
data["levels"] = parsed["levels"]
|
||||
data["referenced_rules"] = parsed["referenced_rules"]
|
||||
data["syntax_valid"] = True
|
||||
except ValueError as exc:
|
||||
data["syntax_valid"] = False
|
||||
data["syntax_error"] = str(exc)
|
||||
data["referenced_rules"] = extract_rule_tokens(group.rule_string)
|
||||
else:
|
||||
data["syntax_valid"] = False
|
||||
data["syntax_error"] = "rule_string 为空"
|
||||
data["referenced_rules"] = []
|
||||
data["usage"] = usage or default_rule_group_usage()
|
||||
return data
|
||||
|
||||
|
||||
def default_rule_group_usage() -> dict:
|
||||
return {
|
||||
"used_in_global_search": False,
|
||||
"used_in_global_subscribe": False,
|
||||
"used_in_global_best_version": False,
|
||||
"subscribes": [],
|
||||
}
|
||||
|
||||
|
||||
async def collect_rule_group_usages(
|
||||
group_names: Optional[Iterable[str]] = None,
|
||||
) -> Dict[str, dict]:
|
||||
"""收集规则组在全局配置和订阅上的引用情况。"""
|
||||
target_names = set(group_names or [])
|
||||
search_groups = set(
|
||||
SystemConfigOper().get(SystemConfigKey.SearchFilterRuleGroups) or []
|
||||
)
|
||||
subscribe_groups = set(
|
||||
SystemConfigOper().get(SystemConfigKey.SubscribeFilterRuleGroups) or []
|
||||
)
|
||||
best_version_groups = set(
|
||||
SystemConfigOper().get(SystemConfigKey.BestVersionFilterRuleGroups) or []
|
||||
)
|
||||
|
||||
usage_map = {
|
||||
name: default_rule_group_usage()
|
||||
for name in target_names
|
||||
}
|
||||
|
||||
def ensure_usage(name: str) -> dict:
|
||||
if name not in usage_map:
|
||||
usage_map[name] = default_rule_group_usage()
|
||||
return usage_map[name]
|
||||
|
||||
for name in search_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["used_in_global_search"] = True
|
||||
for name in subscribe_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["used_in_global_subscribe"] = True
|
||||
for name in best_version_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["used_in_global_best_version"] = True
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
filter_groups = subscribe.filter_groups or []
|
||||
for name in filter_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"type": subscribe.type,
|
||||
"username": subscribe.username,
|
||||
"best_version": bool(subscribe.best_version),
|
||||
}
|
||||
)
|
||||
|
||||
return usage_map
|
||||
|
||||
|
||||
def collect_custom_rule_group_refs(
|
||||
rule_groups: Iterable[FilterRuleGroup],
|
||||
rule_ids: Optional[Iterable[str]] = None,
|
||||
) -> Dict[str, list[str]]:
|
||||
"""收集自定义规则被哪些规则组引用。"""
|
||||
target_rule_ids = set(rule_ids or [])
|
||||
refs: Dict[str, list[str]] = {
|
||||
rule_id: []
|
||||
for rule_id in target_rule_ids
|
||||
}
|
||||
|
||||
for group in rule_groups:
|
||||
if not group.name or not group.rule_string:
|
||||
continue
|
||||
referenced = set(extract_rule_tokens(group.rule_string))
|
||||
for rule_id in referenced:
|
||||
if target_rule_ids and rule_id not in target_rule_ids:
|
||||
continue
|
||||
refs.setdefault(rule_id, []).append(group.name)
|
||||
|
||||
for names in refs.values():
|
||||
names.sort()
|
||||
return refs
|
||||
|
||||
|
||||
def normalize_custom_rule(
|
||||
rule_id: str,
|
||||
name: str,
|
||||
include: Optional[str],
|
||||
exclude: Optional[str],
|
||||
size_range: Optional[str],
|
||||
seeders: Optional[str],
|
||||
publish_time: Optional[str],
|
||||
existing_rules: Iterable[CustomRule],
|
||||
original_rule_id: Optional[str] = None,
|
||||
) -> CustomRule:
|
||||
"""新增/更新自定义规则时统一走这里,避免多处散落校验逻辑。"""
|
||||
normalized_rule_id = normalize_optional_text(rule_id)
|
||||
normalized_name = normalize_optional_text(name)
|
||||
if not normalized_rule_id:
|
||||
raise ValueError("rule_id 不能为空")
|
||||
if not normalized_name:
|
||||
raise ValueError("name 不能为空")
|
||||
if not RULE_ID_PATTERN.match(normalized_rule_id):
|
||||
raise ValueError("rule_id 仅支持英文字母和数字")
|
||||
if (
|
||||
normalized_rule_id in BUILTIN_RULE_SET
|
||||
and normalized_rule_id != original_rule_id
|
||||
):
|
||||
raise ValueError(
|
||||
f"rule_id '{normalized_rule_id}' 与内置规则冲突,不能覆盖内置规则"
|
||||
)
|
||||
|
||||
for existing_rule in existing_rules:
|
||||
if (
|
||||
existing_rule.id == normalized_rule_id
|
||||
and existing_rule.id != original_rule_id
|
||||
):
|
||||
raise ValueError(f"rule_id '{normalized_rule_id}' 已存在")
|
||||
if (
|
||||
existing_rule.name == normalized_name
|
||||
and existing_rule.id != original_rule_id
|
||||
):
|
||||
raise ValueError(f"规则名称 '{normalized_name}' 已存在")
|
||||
|
||||
return CustomRule(
|
||||
id=normalized_rule_id,
|
||||
name=normalized_name,
|
||||
include=normalize_optional_text(include),
|
||||
exclude=normalize_optional_text(exclude),
|
||||
size_range=validate_numeric_range("size_range", size_range),
|
||||
seeders=validate_seeders(seeders),
|
||||
publish_time=validate_numeric_range("publish_time", publish_time),
|
||||
)
|
||||
|
||||
|
||||
def normalize_rule_group(
|
||||
name: str,
|
||||
rule_string: str,
|
||||
media_type: Optional[str],
|
||||
category: Optional[str],
|
||||
existing_groups: Iterable[FilterRuleGroup],
|
||||
available_rule_ids: Iterable[str],
|
||||
original_name: Optional[str] = None,
|
||||
) -> tuple[FilterRuleGroup, dict]:
|
||||
"""新增/更新规则组时统一校验名字、适用范围和规则串。"""
|
||||
normalized_name = normalize_optional_text(name)
|
||||
if not normalized_name:
|
||||
raise ValueError("规则组名称不能为空")
|
||||
|
||||
for group in existing_groups:
|
||||
if group.name == normalized_name and group.name != original_name:
|
||||
raise ValueError(f"规则组名称 '{normalized_name}' 已存在")
|
||||
|
||||
normalized_media_type = normalize_media_type(media_type)
|
||||
normalized_category = normalize_optional_text(category)
|
||||
if normalized_category and not normalized_media_type:
|
||||
raise ValueError("设置 category 时必须同时设置 media_type")
|
||||
|
||||
parsed = validate_rule_string(rule_string, available_rule_ids)
|
||||
return (
|
||||
FilterRuleGroup(
|
||||
name=normalized_name,
|
||||
rule_string=parsed["rule_string"],
|
||||
media_type=normalized_media_type,
|
||||
category=normalized_category,
|
||||
),
|
||||
parsed,
|
||||
)
|
||||
|
||||
|
||||
async def save_system_config(
|
||||
key: SystemConfigKey, value: Any
|
||||
) -> Optional[bool]:
|
||||
"""通过统一入口保存配置并补发 ConfigChanged 事件。"""
|
||||
normalized_value = value
|
||||
if isinstance(normalized_value, list):
|
||||
normalized_value = [
|
||||
item
|
||||
for item in normalized_value
|
||||
if item is not None and item != ""
|
||||
]
|
||||
normalized_value = normalized_value or None
|
||||
|
||||
success = await SystemConfigOper().async_set(key, normalized_value)
|
||||
if success:
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=normalized_value,
|
||||
change_type="update",
|
||||
),
|
||||
)
|
||||
return success
|
||||
|
||||
|
||||
def replace_rule_id_in_rule_string(
|
||||
rule_string: str, old_rule_id: str, new_rule_id: str
|
||||
) -> str:
|
||||
"""只替换完整 token,避免误伤其他规则名。"""
|
||||
pattern = re.compile(
|
||||
rf"(?<![A-Za-z0-9]){re.escape(old_rule_id)}(?![A-Za-z0-9])"
|
||||
)
|
||||
return pattern.sub(new_rule_id, rule_string)
|
||||
|
||||
|
||||
def replace_group_name_in_list(
|
||||
values: Optional[Iterable[str]], old_name: str, new_name: str
|
||||
) -> list[str]:
|
||||
"""更新配置里的规则组名引用,并顺手去重。"""
|
||||
result = []
|
||||
for value in values or []:
|
||||
mapped = new_name if value == old_name else value
|
||||
if mapped not in result:
|
||||
result.append(mapped)
|
||||
return result
|
||||
|
||||
|
||||
async def rename_rule_group_references(old_name: str, new_name: str) -> dict:
|
||||
"""规则组改名后,联动更新全局设置和订阅引用。"""
|
||||
changed = {
|
||||
"global_settings": {},
|
||||
"subscribes": [],
|
||||
}
|
||||
|
||||
for config_key in (
|
||||
SystemConfigKey.SearchFilterRuleGroups,
|
||||
SystemConfigKey.SubscribeFilterRuleGroups,
|
||||
SystemConfigKey.BestVersionFilterRuleGroups,
|
||||
):
|
||||
original = SystemConfigOper().get(config_key) or []
|
||||
updated = replace_group_name_in_list(original, old_name, new_name)
|
||||
if updated != original:
|
||||
await save_system_config(config_key, updated)
|
||||
changed["global_settings"][config_key.value] = updated
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = replace_group_name_in_list(original, old_name, new_name)
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe.async_update(db, {"filter_groups": updated})
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
async def remove_rule_group_references(group_name: str) -> dict:
|
||||
"""删除规则组后,清理全局设置和订阅里的悬空引用。"""
|
||||
changed = {
|
||||
"global_settings": {},
|
||||
"subscribes": [],
|
||||
}
|
||||
|
||||
for config_key in (
|
||||
SystemConfigKey.SearchFilterRuleGroups,
|
||||
SystemConfigKey.SubscribeFilterRuleGroups,
|
||||
SystemConfigKey.BestVersionFilterRuleGroups,
|
||||
):
|
||||
original = SystemConfigOper().get(config_key) or []
|
||||
updated = [value for value in original if value != group_name]
|
||||
if updated != original:
|
||||
await save_system_config(config_key, updated)
|
||||
changed["global_settings"][config_key.value] = updated
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = [value for value in original if value != group_name]
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe.async_update(db, {"filter_groups": updated})
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
|
||||
return changed
|
||||
291
app/agent/tools/impl/_plugin_tool_utils.py
Normal file
291
app/agent/tools/impl/_plugin_tool_utils.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""插件 Agent 工具共享辅助方法"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.plugin import PluginHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
# 默认只向智能体返回一个可读预览,避免超大插件数据挤爆上下文窗口。
|
||||
DEFAULT_PLUGIN_DATA_PREVIEW_CHARS = 12_000
|
||||
MAX_PLUGIN_DATA_PREVIEW_CHARS = 50_000
|
||||
PLUGIN_DATA_KEY_PREVIEW_LIMIT = 50
|
||||
PLUGIN_DATA_TRUNCATION_SUFFIX = "\n...(插件数据内容过长,已截断)"
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT = 50
|
||||
MAX_PLUGIN_CANDIDATE_LIMIT = 200
|
||||
|
||||
|
||||
def get_plugin_snapshot(plugin_id: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
获取已安装插件的基础信息快照。
|
||||
"""
|
||||
plugin_manager = PluginManager()
|
||||
for plugin in plugin_manager.get_local_plugins():
|
||||
if plugin.id == plugin_id:
|
||||
return {
|
||||
"plugin_id": plugin.id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"plugin_version": plugin.plugin_version,
|
||||
"state": plugin.state,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def clamp_preview_chars(max_chars: Optional[int]) -> int:
|
||||
"""
|
||||
约束插件数据预览长度,避免工具结果无限膨胀。
|
||||
"""
|
||||
if max_chars is None:
|
||||
return DEFAULT_PLUGIN_DATA_PREVIEW_CHARS
|
||||
return max(512, min(int(max_chars), MAX_PLUGIN_DATA_PREVIEW_CHARS))
|
||||
|
||||
|
||||
def serialize_for_agent(value: Any) -> str:
|
||||
"""
|
||||
将结果稳定序列化为 JSON 字符串,无法原生序列化的对象退化为字符串。
|
||||
"""
|
||||
return json.dumps(value, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
def build_preview_payload(value: Any, max_chars: Optional[int]) -> tuple[bool, int, int, str]:
|
||||
"""
|
||||
为可能很大的插件数据生成预览结果。
|
||||
"""
|
||||
serialized = serialize_for_agent(value)
|
||||
if len(serialized) <= clamp_preview_chars(max_chars):
|
||||
return False, len(serialized), len(serialized), serialized
|
||||
|
||||
preview_limit = clamp_preview_chars(max_chars)
|
||||
preview = serialized[:preview_limit] + PLUGIN_DATA_TRUNCATION_SUFFIX
|
||||
return True, len(serialized), len(preview), preview
|
||||
|
||||
|
||||
def reload_plugin_runtime(plugin_id: str) -> None:
|
||||
"""
|
||||
重载插件并重新注册其命令、定时任务和 API。
|
||||
"""
|
||||
# 这些依赖只在真正执行重载时才导入,避免普通查询工具引入不必要的初始化开销。
|
||||
from app.api.endpoints.plugin import register_plugin_api
|
||||
from app.command import Command
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
plugin_manager.reload_plugin(plugin_id)
|
||||
Scheduler().update_plugin_job(plugin_id)
|
||||
Command().init_commands(plugin_id)
|
||||
register_plugin_api(plugin_id)
|
||||
|
||||
|
||||
def summarize_plugin(plugin: Any) -> dict[str, Any]:
|
||||
"""
|
||||
提取插件对象中对 Agent 有价值的摘要字段。
|
||||
"""
|
||||
repo_url = getattr(plugin, "repo_url", None)
|
||||
return {
|
||||
"id": getattr(plugin, "id", None),
|
||||
"plugin_name": getattr(plugin, "plugin_name", None),
|
||||
"plugin_desc": getattr(plugin, "plugin_desc", None),
|
||||
"plugin_version": getattr(plugin, "plugin_version", None),
|
||||
"plugin_author": getattr(plugin, "plugin_author", None),
|
||||
"installed": bool(getattr(plugin, "installed", False)),
|
||||
"has_update": bool(getattr(plugin, "has_update", False)),
|
||||
"state": bool(getattr(plugin, "state", False)),
|
||||
"repo_url": repo_url,
|
||||
"source": "local_repo" if PluginHelper.is_local_repo_url(repo_url) else "market",
|
||||
}
|
||||
|
||||
|
||||
async def load_market_plugins(force_refresh: bool = False) -> list[Any]:
|
||||
"""
|
||||
聚合插件市场与本地插件仓库中的候选插件。
|
||||
"""
|
||||
plugin_manager = PluginManager()
|
||||
online_plugins = await plugin_manager.async_get_online_plugins(force=force_refresh)
|
||||
local_repo_plugins = plugin_manager.get_local_repo_plugins()
|
||||
if not online_plugins and not local_repo_plugins:
|
||||
return []
|
||||
return plugin_manager.process_plugins_list(online_plugins + local_repo_plugins, [])
|
||||
|
||||
|
||||
def list_installed_plugins() -> list[Any]:
|
||||
"""
|
||||
返回当前已安装插件列表。
|
||||
"""
|
||||
plugin_manager = PluginManager()
|
||||
return [plugin for plugin in plugin_manager.get_local_plugins() if plugin.installed]
|
||||
|
||||
|
||||
def _normalize_text(value: Optional[str]) -> str:
|
||||
return (value or "").strip().lower()
|
||||
|
||||
|
||||
def is_exact_plugin_match(plugin: Any, query: str) -> bool:
|
||||
"""
|
||||
精确匹配插件 ID 或插件名称,用于安全地自动选择候选。
|
||||
"""
|
||||
normalized_query = _normalize_text(query)
|
||||
return normalized_query in {
|
||||
_normalize_text(getattr(plugin, "id", None)),
|
||||
_normalize_text(getattr(plugin, "plugin_name", None)),
|
||||
}
|
||||
|
||||
|
||||
def search_plugin_candidates(query: str, plugins: list[Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
按插件 ID、名称、描述和作者搜索候选,并返回打分结果。
|
||||
"""
|
||||
normalized_query = _normalize_text(query)
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
tokens = [token for token in normalized_query.replace("-", " ").split() if token]
|
||||
matches: list[dict[str, Any]] = []
|
||||
|
||||
for plugin in plugins:
|
||||
plugin_id = _normalize_text(getattr(plugin, "id", None))
|
||||
plugin_name = _normalize_text(getattr(plugin, "plugin_name", None))
|
||||
plugin_desc = _normalize_text(getattr(plugin, "plugin_desc", None))
|
||||
plugin_author = _normalize_text(getattr(plugin, "plugin_author", None))
|
||||
haystack = "\n".join([plugin_id, plugin_name, plugin_desc, plugin_author])
|
||||
|
||||
score = 0
|
||||
if normalized_query == plugin_id:
|
||||
score = 100
|
||||
elif normalized_query == plugin_name:
|
||||
score = 95
|
||||
elif plugin_id.startswith(normalized_query):
|
||||
score = 85
|
||||
elif plugin_name.startswith(normalized_query):
|
||||
score = 80
|
||||
elif normalized_query in plugin_id:
|
||||
score = 75
|
||||
elif normalized_query in plugin_name:
|
||||
score = 70
|
||||
elif tokens and all(token in plugin_name for token in tokens):
|
||||
score = 68
|
||||
elif tokens and all(token in plugin_id for token in tokens):
|
||||
score = 66
|
||||
elif normalized_query in plugin_desc:
|
||||
score = 45
|
||||
elif normalized_query in plugin_author:
|
||||
score = 40
|
||||
elif tokens and all(token in haystack for token in tokens):
|
||||
score = 35
|
||||
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"plugin": plugin,
|
||||
"score": score,
|
||||
"exact": is_exact_plugin_match(plugin, normalized_query),
|
||||
}
|
||||
)
|
||||
|
||||
return sorted(
|
||||
matches,
|
||||
key=lambda item: (
|
||||
-item["score"],
|
||||
not item["exact"],
|
||||
-int(bool(getattr(item["plugin"], "has_update", False))),
|
||||
-int(bool(getattr(item["plugin"], "installed", False))),
|
||||
-int(getattr(item["plugin"], "add_time", 0) or 0),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def summarize_candidates(matches: list[dict[str, Any]], limit: int = DEFAULT_PLUGIN_CANDIDATE_LIMIT) -> list[dict[str, Any]]:
|
||||
"""
|
||||
压缩候选列表,避免一次性把完整市场数据返回给 Agent。
|
||||
"""
|
||||
return [
|
||||
{
|
||||
**summarize_plugin(item["plugin"]),
|
||||
"score": item["score"],
|
||||
"exact": item["exact"],
|
||||
}
|
||||
for item in matches[:limit]
|
||||
]
|
||||
|
||||
|
||||
async def install_plugin_runtime(
|
||||
plugin_id: str, repo_url: Optional[str], force: bool = False
|
||||
) -> tuple[bool, str, bool]:
|
||||
"""
|
||||
按现有插件接口的行为安装插件,并刷新运行态注册信息。
|
||||
"""
|
||||
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
plugin_manager = PluginManager()
|
||||
plugin_helper = PluginHelper()
|
||||
|
||||
refreshed_only = False
|
||||
if not force and plugin_id in plugin_manager.get_plugin_ids():
|
||||
refreshed_only = True
|
||||
await plugin_helper.async_install_reg(pid=plugin_id, repo_url=repo_url)
|
||||
message = "插件已存在,已刷新加载"
|
||||
else:
|
||||
if not repo_url:
|
||||
return False, "没有传入仓库地址,无法正确安装插件,请检查配置", False
|
||||
state, message = await plugin_helper.async_install(
|
||||
pid=plugin_id,
|
||||
repo_url=repo_url,
|
||||
force_install=force,
|
||||
)
|
||||
if not state:
|
||||
return False, message, False
|
||||
|
||||
if plugin_id not in install_plugins:
|
||||
install_plugins.append(plugin_id)
|
||||
await SystemConfigOper().async_set(
|
||||
SystemConfigKey.UserInstalledPlugins, install_plugins
|
||||
)
|
||||
|
||||
reload_plugin_runtime(plugin_id)
|
||||
return True, message or "插件安装成功", refreshed_only
|
||||
|
||||
|
||||
async def uninstall_plugin_runtime(plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
按现有卸载逻辑移除插件,并清理运行态注册与分组信息。
|
||||
"""
|
||||
from app.api.endpoints.plugin import _remove_plugin_from_folders, remove_plugin_api
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
config_oper = SystemConfigOper()
|
||||
install_plugins = config_oper.get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
if plugin_id in install_plugins:
|
||||
install_plugins = [plugin for plugin in install_plugins if plugin != plugin_id]
|
||||
await config_oper.async_set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
|
||||
remove_plugin_api(plugin_id)
|
||||
Scheduler().remove_plugin_job(plugin_id)
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
plugin_class = plugin_manager.plugins.get(plugin_id)
|
||||
was_clone = bool(getattr(plugin_class, "is_clone", False))
|
||||
clone_files_removed = False
|
||||
|
||||
if was_clone:
|
||||
plugin_manager.delete_plugin_config(plugin_id)
|
||||
plugin_manager.delete_plugin_data(plugin_id)
|
||||
plugin_base_dir = settings.ROOT_PATH / "app" / "plugins" / plugin_id.lower()
|
||||
if plugin_base_dir.exists():
|
||||
try:
|
||||
shutil.rmtree(plugin_base_dir)
|
||||
plugin_manager.plugins.pop(plugin_id, None)
|
||||
clone_files_removed = True
|
||||
except Exception:
|
||||
clone_files_removed = False
|
||||
|
||||
_remove_plugin_from_folders(plugin_id)
|
||||
plugin_manager.remove_plugin(plugin_id)
|
||||
|
||||
return {
|
||||
"was_clone": was_clone,
|
||||
"clone_files_removed": clone_files_removed,
|
||||
}
|
||||
331
app/agent/tools/impl/_system_setting_utils.py
Normal file
331
app/agent/tools/impl/_system_setting_utils.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""系统设置工具共用的键解析与分组元数据。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import Settings
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SettingSpec:
|
||||
"""描述一个可被 Agent 读写的系统设置项。"""
|
||||
|
||||
key: str
|
||||
source: str
|
||||
group: str
|
||||
label: str
|
||||
|
||||
|
||||
SYSTEMCONFIG_SETTING_METADATA = {
|
||||
SystemConfigKey.Downloaders.value: {
|
||||
"group": "downloaders",
|
||||
"label": "下载器配置",
|
||||
},
|
||||
SystemConfigKey.MediaServers.value: {
|
||||
"group": "media_servers",
|
||||
"label": "媒体服务器配置",
|
||||
},
|
||||
SystemConfigKey.Notifications.value: {
|
||||
"group": "notifications",
|
||||
"label": "消息通知配置",
|
||||
},
|
||||
SystemConfigKey.NotificationSwitchs.value: {
|
||||
"group": "notification_switches",
|
||||
"label": "通知场景开关",
|
||||
},
|
||||
SystemConfigKey.Directories.value: {
|
||||
"group": "directories",
|
||||
"label": "目录配置",
|
||||
},
|
||||
SystemConfigKey.Storages.value: {
|
||||
"group": "storages",
|
||||
"label": "存储配置",
|
||||
},
|
||||
SystemConfigKey.IndexerSites.value: {
|
||||
"group": "search_sites",
|
||||
"label": "搜索站点范围",
|
||||
},
|
||||
SystemConfigKey.RssSites.value: {
|
||||
"group": "subscribe_sites",
|
||||
"label": "订阅站点范围",
|
||||
},
|
||||
SystemConfigKey.UserSiteAuthParams.value: {
|
||||
"group": "site_auth",
|
||||
"label": "站点认证参数",
|
||||
},
|
||||
SystemConfigKey.AIAgentConfig.value: {
|
||||
"group": "ai_agent",
|
||||
"label": "AI 智能体配置",
|
||||
},
|
||||
SystemConfigKey.CustomIdentifiers.value: {
|
||||
"group": "custom_identifiers",
|
||||
"label": "自定义识别词",
|
||||
},
|
||||
SystemConfigKey.CustomReleaseGroups.value: {
|
||||
"group": "customization",
|
||||
"label": "自定义制作组/字幕组",
|
||||
},
|
||||
SystemConfigKey.Customization.value: {
|
||||
"group": "customization",
|
||||
"label": "自定义占位符",
|
||||
},
|
||||
SystemConfigKey.TransferExcludeWords.value: {
|
||||
"group": "transfer",
|
||||
"label": "整理屏蔽词",
|
||||
},
|
||||
SystemConfigKey.TorrentsPriority.value: {
|
||||
"group": "filter_rules",
|
||||
"label": "种子优先级规则",
|
||||
},
|
||||
SystemConfigKey.CustomFilterRules.value: {
|
||||
"group": "filter_rules",
|
||||
"label": "用户自定义规则",
|
||||
},
|
||||
SystemConfigKey.UserFilterRuleGroups.value: {
|
||||
"group": "filter_rules",
|
||||
"label": "用户规则组",
|
||||
},
|
||||
SystemConfigKey.SearchFilterRuleGroups.value: {
|
||||
"group": "filter_rules",
|
||||
"label": "搜索默认过滤规则组",
|
||||
},
|
||||
SystemConfigKey.SubscribeFilterRuleGroups.value: {
|
||||
"group": "filter_rules",
|
||||
"label": "订阅默认过滤规则组",
|
||||
},
|
||||
SystemConfigKey.BestVersionFilterRuleGroups.value: {
|
||||
"group": "filter_rules",
|
||||
"label": "洗版默认过滤规则组",
|
||||
},
|
||||
SystemConfigKey.SubscribeDefaultParams.value: {
|
||||
"group": "subscribe_defaults",
|
||||
"label": "订阅默认参数",
|
||||
},
|
||||
SystemConfigKey.DefaultMovieSubscribeConfig.value: {
|
||||
"group": "subscribe_defaults",
|
||||
"label": "默认电影订阅规则",
|
||||
},
|
||||
SystemConfigKey.DefaultTvSubscribeConfig.value: {
|
||||
"group": "subscribe_defaults",
|
||||
"label": "默认电视剧订阅规则",
|
||||
},
|
||||
SystemConfigKey.UserInstalledPlugins.value: {
|
||||
"group": "plugins",
|
||||
"label": "已安装插件列表",
|
||||
},
|
||||
SystemConfigKey.PluginFolders.value: {
|
||||
"group": "plugins",
|
||||
"label": "插件文件夹分组配置",
|
||||
},
|
||||
SystemConfigKey.PluginInstallReport.value: {
|
||||
"group": "plugins",
|
||||
"label": "插件安装统计",
|
||||
},
|
||||
SystemConfigKey.NotificationSendTime.value: {
|
||||
"group": "notifications",
|
||||
"label": "通知发送时间",
|
||||
},
|
||||
SystemConfigKey.NotificationTemplates.value: {
|
||||
"group": "notifications",
|
||||
"label": "通知模板",
|
||||
},
|
||||
SystemConfigKey.ScrapingSwitchs.value: {
|
||||
"group": "scraping",
|
||||
"label": "刮削开关设置",
|
||||
},
|
||||
SystemConfigKey.FollowSubscribers.value: {
|
||||
"group": "subscribe_sites",
|
||||
"label": "Follow 订阅分享者",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
LIST_ITEM_MATCH_FIELD_DEFAULTS = {
|
||||
SystemConfigKey.Downloaders.value: "name",
|
||||
SystemConfigKey.MediaServers.value: "name",
|
||||
SystemConfigKey.Notifications.value: "name",
|
||||
SystemConfigKey.NotificationSwitchs.value: "type",
|
||||
SystemConfigKey.Directories.value: "name",
|
||||
SystemConfigKey.Storages.value: "name",
|
||||
}
|
||||
|
||||
|
||||
GROUP_ALIASES = {
|
||||
"all": "all",
|
||||
"全部": "all",
|
||||
"settings": "settings",
|
||||
"basic": "settings",
|
||||
"基础设置": "settings",
|
||||
"基础配置": "settings",
|
||||
"systemconfig": "systemconfig",
|
||||
"system_config": "systemconfig",
|
||||
"系统设置": "systemconfig",
|
||||
"系统配置": "systemconfig",
|
||||
"downloaders": "downloaders",
|
||||
"downloader": "downloaders",
|
||||
"下载器": "downloaders",
|
||||
"media_servers": "media_servers",
|
||||
"mediaservers": "media_servers",
|
||||
"media-servers": "media_servers",
|
||||
"媒体服务器": "media_servers",
|
||||
"notifications": "notifications",
|
||||
"notification": "notifications",
|
||||
"消息通知": "notifications",
|
||||
"通知": "notifications",
|
||||
"notification_switches": "notification_switches",
|
||||
"notification_switchs": "notification_switches",
|
||||
"通知开关": "notification_switches",
|
||||
"storages": "storages",
|
||||
"storage": "storages",
|
||||
"存储": "storages",
|
||||
"directories": "directories",
|
||||
"directory": "directories",
|
||||
"目录": "directories",
|
||||
"search_sites": "search_sites",
|
||||
"indexer_sites": "search_sites",
|
||||
"搜索站点": "search_sites",
|
||||
"subscribe_sites": "subscribe_sites",
|
||||
"rss_sites": "subscribe_sites",
|
||||
"订阅站点": "subscribe_sites",
|
||||
"site_auth": "site_auth",
|
||||
"site_auth_params": "site_auth",
|
||||
"站点认证": "site_auth",
|
||||
"ai_agent": "ai_agent",
|
||||
"agent": "ai_agent",
|
||||
"智能体": "ai_agent",
|
||||
"custom_identifiers": "custom_identifiers",
|
||||
"自定义识别词": "custom_identifiers",
|
||||
"filter_rules": "filter_rules",
|
||||
"过滤规则": "filter_rules",
|
||||
"subscribe_defaults": "subscribe_defaults",
|
||||
"订阅默认": "subscribe_defaults",
|
||||
"plugins": "plugins",
|
||||
"插件": "plugins",
|
||||
"customization": "customization",
|
||||
"自定义": "customization",
|
||||
"transfer": "transfer",
|
||||
"整理": "transfer",
|
||||
"scraping": "scraping",
|
||||
"刮削": "scraping",
|
||||
"misc": "misc",
|
||||
"其他": "misc",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_token(value: str) -> str:
|
||||
return str(value).strip().lower().replace("-", "_")
|
||||
|
||||
|
||||
def _build_specs() -> tuple[dict[str, SettingSpec], dict[str, SettingSpec]]:
|
||||
core_specs = {
|
||||
key: SettingSpec(key=key, source="settings", group="settings", label=key)
|
||||
for key in Settings.model_fields.keys()
|
||||
}
|
||||
system_specs = {}
|
||||
for item in SystemConfigKey:
|
||||
metadata = SYSTEMCONFIG_SETTING_METADATA.get(item.value, {})
|
||||
system_specs[item.value] = SettingSpec(
|
||||
key=item.value,
|
||||
source="systemconfig",
|
||||
group=metadata.get("group", "misc"),
|
||||
label=metadata.get("label", item.value),
|
||||
)
|
||||
return core_specs, system_specs
|
||||
|
||||
|
||||
CORE_SETTING_SPECS, SYSTEMCONFIG_SETTING_SPECS = _build_specs()
|
||||
ALL_SETTING_SPECS = {**CORE_SETTING_SPECS, **SYSTEMCONFIG_SETTING_SPECS}
|
||||
|
||||
|
||||
SETTING_KEY_ALIASES = {}
|
||||
for key in CORE_SETTING_SPECS:
|
||||
SETTING_KEY_ALIASES[_normalize_token(key)] = key
|
||||
for item in SystemConfigKey:
|
||||
SETTING_KEY_ALIASES[_normalize_token(item.value)] = item.value
|
||||
SETTING_KEY_ALIASES[_normalize_token(item.name)] = item.value
|
||||
|
||||
SINGLE_KEY_GROUP_ALIASES = {
|
||||
_normalize_token(alias): next(
|
||||
(
|
||||
spec.key
|
||||
for spec in SYSTEMCONFIG_SETTING_SPECS.values()
|
||||
if spec.group == canonical_group
|
||||
),
|
||||
None,
|
||||
)
|
||||
for alias, canonical_group in GROUP_ALIASES.items()
|
||||
if canonical_group not in {"all", "settings", "systemconfig"}
|
||||
and len(
|
||||
[
|
||||
spec.key
|
||||
for spec in SYSTEMCONFIG_SETTING_SPECS.values()
|
||||
if spec.group == canonical_group
|
||||
]
|
||||
)
|
||||
== 1
|
||||
}
|
||||
|
||||
|
||||
def normalize_group(group: Optional[str]) -> str:
|
||||
if not group:
|
||||
return "all"
|
||||
normalized = GROUP_ALIASES.get(_normalize_token(group))
|
||||
if not normalized:
|
||||
raise ValueError(
|
||||
"group 不支持,支持值包括 all/settings/systemconfig 以及"
|
||||
" downloaders、media_servers、notifications、storages、directories、"
|
||||
"search_sites、subscribe_sites、site_auth、ai_agent 等分类别名"
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def resolve_setting_spec(setting_key: Optional[str]) -> Optional[SettingSpec]:
|
||||
"""把精确键名、枚举名或单键分组别名解析为统一的设置定义。"""
|
||||
|
||||
if not setting_key:
|
||||
return None
|
||||
|
||||
normalized = _normalize_token(setting_key)
|
||||
resolved_key = SETTING_KEY_ALIASES.get(normalized) or SINGLE_KEY_GROUP_ALIASES.get(
|
||||
normalized
|
||||
)
|
||||
if not resolved_key:
|
||||
return None
|
||||
return ALL_SETTING_SPECS.get(resolved_key)
|
||||
|
||||
|
||||
def list_setting_specs(
|
||||
group: Optional[str] = "all", keyword: Optional[str] = None
|
||||
) -> list[SettingSpec]:
|
||||
"""按分组和关键字筛选可查询的设置项。"""
|
||||
|
||||
normalized_group = normalize_group(group)
|
||||
if normalized_group == "all":
|
||||
specs = list(ALL_SETTING_SPECS.values())
|
||||
elif normalized_group == "settings":
|
||||
specs = list(CORE_SETTING_SPECS.values())
|
||||
elif normalized_group == "systemconfig":
|
||||
specs = list(SYSTEMCONFIG_SETTING_SPECS.values())
|
||||
else:
|
||||
specs = [
|
||||
spec
|
||||
for spec in SYSTEMCONFIG_SETTING_SPECS.values()
|
||||
if spec.group == normalized_group
|
||||
]
|
||||
|
||||
if keyword:
|
||||
normalized_keyword = _normalize_token(keyword)
|
||||
specs = [
|
||||
spec
|
||||
for spec in specs
|
||||
if normalized_keyword in _normalize_token(spec.key)
|
||||
or normalized_keyword in _normalize_token(spec.group)
|
||||
or normalized_keyword in _normalize_token(spec.label)
|
||||
]
|
||||
|
||||
return sorted(specs, key=lambda spec: (spec.source, spec.group, spec.key))
|
||||
|
||||
|
||||
def get_default_list_match_field(setting_key: str) -> Optional[str]:
|
||||
return LIST_ITEM_MATCH_FIELD_DEFAULTS.get(setting_key)
|
||||
111
app/agent/tools/impl/add_custom_filter_rule.py
Normal file
111
app/agent/tools/impl/add_custom_filter_rule.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""新增自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_custom_rules,
|
||||
normalize_custom_rule,
|
||||
save_system_config,
|
||||
serialize_custom_rule,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class AddCustomFilterRuleInput(BaseModel):
|
||||
"""新增自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_id: str = Field(
|
||||
...,
|
||||
description="Unique custom rule ID. Only letters and numbers are allowed.",
|
||||
)
|
||||
name: str = Field(..., description="Display name of the custom rule.")
|
||||
include: Optional[str] = Field(
|
||||
None, description="Optional include regex for the rule."
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None, description="Optional exclude regex for the rule."
|
||||
)
|
||||
size_range: Optional[str] = Field(
|
||||
None, description="Optional size range in MB, for example '1000-5000'."
|
||||
)
|
||||
seeders: Optional[str] = Field(
|
||||
None, description="Optional minimum seeder count as a non-negative integer."
|
||||
)
|
||||
publish_time: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional publish-time filter in minutes, for example '60' or '60-1440'.",
|
||||
)
|
||||
|
||||
|
||||
class AddCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "add_custom_filter_rule"
|
||||
description: str = (
|
||||
"Add a custom filter rule to CustomFilterRules. "
|
||||
"The new rule can then be referenced by rule ID inside filter rule groups."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AddCustomFilterRuleInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"新增自定义过滤规则 {kwargs.get('rule_id', '')}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
rule_id: str,
|
||||
name: str,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
size_range: Optional[str] = None,
|
||||
seeders: Optional[str] = None,
|
||||
publish_time: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, rule_id={rule_id}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
new_rule = normalize_custom_rule(
|
||||
rule_id=rule_id,
|
||||
name=name,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
size_range=size_range,
|
||||
seeders=seeders,
|
||||
publish_time=publish_time,
|
||||
existing_rules=custom_rules,
|
||||
)
|
||||
|
||||
custom_rules.append(new_rule)
|
||||
await save_system_config(
|
||||
SystemConfigKey.CustomFilterRules,
|
||||
[rule.model_dump(exclude_none=True) for rule in custom_rules],
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已新增自定义过滤规则 {new_rule.id}",
|
||||
"custom_rule": serialize_custom_rule(new_rule),
|
||||
"count": len(custom_rules),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"新增自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"新增自定义过滤规则失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -6,7 +6,8 @@ from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.config import settings
|
||||
@@ -275,7 +276,10 @@ class AddDownloadTool(MoviePilotTool):
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = cached_context.media_info if cached_context.media_info else None
|
||||
if not media_info:
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
media_info = await MediaChain().async_recognize_by_meta(
|
||||
meta_info,
|
||||
obtain_images=False,
|
||||
)
|
||||
if not media_info:
|
||||
failed_messages.append(f"{torrent_input} 无法识别媒体信息")
|
||||
continue
|
||||
|
||||
115
app/agent/tools/impl/add_rule_group.py
Normal file
115
app/agent/tools/impl/add_rule_group.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""新增过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
build_custom_rule_map,
|
||||
collect_rule_group_usages,
|
||||
get_builtin_rules,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
normalize_rule_group,
|
||||
save_system_config,
|
||||
serialize_rule_group,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class AddRuleGroupInput(BaseModel):
|
||||
"""新增过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
name: str = Field(..., description="New rule group name.")
|
||||
rule_string: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Rule expression using built-in/custom rule IDs. "
|
||||
"Use '&', '!' inside one level, and use '>' between priority levels. "
|
||||
"Example: 'SPECSUB & CNVOI & 4K & !BLU > CNSUB & CNVOI & 4K & !BLU'."
|
||||
),
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional media type scope: '电影', '电视剧', 'movie', or 'tv'.",
|
||||
)
|
||||
category: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional media category. Only valid when media_type is set.",
|
||||
)
|
||||
|
||||
|
||||
class AddRuleGroupTool(MoviePilotTool):
|
||||
name: str = "add_rule_group"
|
||||
description: str = (
|
||||
"Add a new filter rule group to UserFilterRuleGroups. "
|
||||
"Rule groups are matched level by level from left to right and can be linked to search/subscription flows. "
|
||||
"Before calling this tool, first use query_builtin_filter_rules and query_custom_filter_rules to confirm valid rule IDs, "
|
||||
"and optionally use query_rule_groups to imitate existing rule_string patterns."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AddRuleGroupInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"新增规则组 {kwargs.get('name', '')}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
name: str,
|
||||
rule_string: str,
|
||||
media_type: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, name={name}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
available_rule_ids = set(get_builtin_rules().keys()) | set(
|
||||
build_custom_rule_map(custom_rules).keys()
|
||||
)
|
||||
rule_groups = get_rule_groups()
|
||||
new_group, _ = normalize_rule_group(
|
||||
name=name,
|
||||
rule_string=rule_string,
|
||||
media_type=media_type,
|
||||
category=category,
|
||||
existing_groups=rule_groups,
|
||||
available_rule_ids=available_rule_ids,
|
||||
)
|
||||
|
||||
rule_groups.append(new_group)
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[group.model_dump(exclude_none=True) for group in rule_groups],
|
||||
)
|
||||
usage = await collect_rule_group_usages([new_group.name])
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已新增规则组 {new_group.name}",
|
||||
"rule_group": serialize_rule_group(
|
||||
new_group, usage.get(new_group.name)
|
||||
),
|
||||
"count": len(rule_groups),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"新增规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"新增规则组失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -1,13 +1,14 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type, List
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.db.user_oper import UserOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, MessageChannel
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
@@ -101,6 +102,38 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def _resolve_subscribe_username(self) -> Optional[str]:
|
||||
"""优先映射为系统用户名,未绑定时回退当前渠道用户名。"""
|
||||
resolved_username = self._username
|
||||
if not self._channel or not self._user_id:
|
||||
return resolved_username
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return resolved_username
|
||||
|
||||
binding_keys = {
|
||||
MessageChannel.Telegram: ("telegram_userid",),
|
||||
MessageChannel.Discord: ("discord_userid",),
|
||||
MessageChannel.Wechat: ("wechat_userid",),
|
||||
MessageChannel.Feishu: ("feishu_userid", "feishu_openid"),
|
||||
MessageChannel.WechatClawBot: ("wechatclawbot_userid",),
|
||||
MessageChannel.Slack: ("slack_userid",),
|
||||
MessageChannel.VoceChat: ("vocechat_userid",),
|
||||
MessageChannel.SynologyChat: ("synologychat_userid",),
|
||||
MessageChannel.QQ: ("qq_userid", "qq_openid"),
|
||||
}.get(channel)
|
||||
if not binding_keys:
|
||||
return resolved_username
|
||||
|
||||
mapped_username = await self.run_blocking(
|
||||
"db",
|
||||
UserOper().get_name,
|
||||
**{key: self._user_id for key in binding_keys},
|
||||
)
|
||||
return mapped_username or resolved_username
|
||||
|
||||
async def run(
|
||||
self,
|
||||
title: str,
|
||||
@@ -137,6 +170,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
if media_type_enum == MediaType.TV
|
||||
else None
|
||||
)
|
||||
subscribe_username = await self._resolve_subscribe_username()
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
@@ -162,7 +196,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
username=subscribe_username,
|
||||
**subscribe_kwargs,
|
||||
)
|
||||
if sid:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.interaction import (
|
||||
from app.helper.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
@@ -64,6 +64,7 @@ class AskUserChoiceInput(BaseModel):
|
||||
|
||||
class AskUserChoiceTool(MoviePilotTool):
|
||||
name: str = "ask_user_choice"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Ask the user to choose from button options on channels that support interactive buttons. "
|
||||
"After the user clicks a button, the selected value will come back as the user's next message."
|
||||
|
||||
97
app/agent/tools/impl/delete_custom_filter_rule.py
Normal file
97
app/agent/tools/impl/delete_custom_filter_rule.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""删除自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
save_system_config,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class DeleteCustomFilterRuleInput(BaseModel):
|
||||
"""删除自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_id: str = Field(..., description="Custom rule ID to delete.")
|
||||
|
||||
|
||||
class DeleteCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "delete_custom_filter_rule"
|
||||
description: str = (
|
||||
"Delete a custom filter rule from CustomFilterRules. "
|
||||
"If the rule is still referenced by rule groups, the deletion is blocked to avoid breaking rule_string expressions."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DeleteCustomFilterRuleInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"删除自定义过滤规则 {kwargs.get('rule_id', '')}"
|
||||
|
||||
async def run(self, rule_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, rule_id={rule_id}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
target_rule = next((rule for rule in custom_rules if rule.id == rule_id), None)
|
||||
if not target_rule:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"自定义过滤规则 '{rule_id}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
refs = collect_custom_rule_group_refs(get_rule_groups(), [rule_id]).get(
|
||||
rule_id, []
|
||||
)
|
||||
if refs:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"自定义过滤规则 '{rule_id}' 仍被规则组引用,无法删除。"
|
||||
),
|
||||
"referenced_by_rule_groups": refs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
remaining_rules = [
|
||||
rule for rule in custom_rules if rule.id != rule_id
|
||||
]
|
||||
await save_system_config(
|
||||
SystemConfigKey.CustomFilterRules,
|
||||
[rule.model_dump(exclude_none=True) for rule in remaining_rules],
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已删除自定义过滤规则 {rule_id}",
|
||||
"count": len(remaining_rules),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"删除自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"删除自定义过滤规则失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
81
app/agent/tools/impl/delete_rule_group.py
Normal file
81
app/agent/tools/impl/delete_rule_group.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""删除过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_rule_groups,
|
||||
remove_rule_group_references,
|
||||
save_system_config,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class DeleteRuleGroupInput(BaseModel):
|
||||
"""删除过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
name: str = Field(..., description="Rule group name to delete.")
|
||||
|
||||
|
||||
class DeleteRuleGroupTool(MoviePilotTool):
|
||||
name: str = "delete_rule_group"
|
||||
description: str = (
|
||||
"Delete a filter rule group from UserFilterRuleGroups. "
|
||||
"The tool also removes dangling references from global settings and subscriptions."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DeleteRuleGroupInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"删除规则组 {kwargs.get('name', '')}"
|
||||
|
||||
async def run(self, name: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, name={name}")
|
||||
|
||||
try:
|
||||
rule_groups = get_rule_groups()
|
||||
if not any(group.name == name for group in rule_groups):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"规则组 '{name}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
remaining_groups = [
|
||||
group for group in rule_groups if group.name != name
|
||||
]
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[group.model_dump(exclude_none=True) for group in remaining_groups],
|
||||
)
|
||||
reference_changes = await remove_rule_group_references(name)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已删除规则组 {name}",
|
||||
"count": len(remaining_groups),
|
||||
"reference_updates": reference_changes,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"删除规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"删除规则组失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -5,7 +5,8 @@ import os
|
||||
import signal
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, TextIO, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -15,7 +16,7 @@ from app.log import logger
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 60
|
||||
MAX_TIMEOUT_SECONDS = 300
|
||||
MAX_OUTPUT_CHARS = 6000
|
||||
MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
@@ -25,40 +26,93 @@ _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
@dataclass
|
||||
class _CommandOutput:
|
||||
"""保存受限命令输出,避免大输出一次性进入内存。"""
|
||||
"""保存前 10KB 预览,并在超限时将完整输出写入临时文件。"""
|
||||
|
||||
limit: int
|
||||
stdout_chunks: list[str] = field(default_factory=list)
|
||||
stderr_chunks: list[str] = field(default_factory=list)
|
||||
captured_chars: int = 0
|
||||
truncated: bool = False
|
||||
preview_limit_bytes: int
|
||||
preview_entries: list[tuple[str, str]] = field(default_factory=list)
|
||||
captured_bytes: int = 0
|
||||
preview_truncated: bool = False
|
||||
temp_file_path: Optional[str] = None
|
||||
temp_file_handle: Optional[TextIO] = None
|
||||
last_written_stream: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
|
||||
if byte_limit <= 0:
|
||||
return ""
|
||||
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
|
||||
|
||||
def _write_chunk(self, stream_name: str, text: str) -> None:
|
||||
if not self.temp_file_handle or not text:
|
||||
return
|
||||
|
||||
if self.last_written_stream != stream_name:
|
||||
if self.temp_file_handle.tell() > 0:
|
||||
self.temp_file_handle.write("\n")
|
||||
title = "标准输出" if stream_name == "stdout" else "错误输出"
|
||||
self.temp_file_handle.write(f"[{title}]\n")
|
||||
self.last_written_stream = stream_name
|
||||
|
||||
self.temp_file_handle.write(text)
|
||||
|
||||
def _ensure_temp_file(self) -> None:
|
||||
if self.temp_file_handle:
|
||||
return
|
||||
|
||||
temp_file = NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
suffix=".log",
|
||||
prefix="moviepilot-command-",
|
||||
delete=False,
|
||||
)
|
||||
self.temp_file_path = temp_file.name
|
||||
self.temp_file_handle = temp_file
|
||||
for stream_name, chunk in self.preview_entries:
|
||||
self._write_chunk(stream_name, chunk)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self.temp_file_handle:
|
||||
return
|
||||
self.temp_file_handle.flush()
|
||||
self.temp_file_handle.close()
|
||||
self.temp_file_handle = None
|
||||
|
||||
def append(self, stream_name: str, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
|
||||
remaining = self.limit - self.captured_chars
|
||||
if remaining <= 0:
|
||||
self.truncated = True
|
||||
if self.temp_file_handle:
|
||||
self._write_chunk(stream_name, text)
|
||||
return
|
||||
|
||||
captured = text[:remaining]
|
||||
if stream_name == "stdout":
|
||||
self.stdout_chunks.append(captured)
|
||||
else:
|
||||
self.stderr_chunks.append(captured)
|
||||
chunk_bytes = len(text.encode("utf-8"))
|
||||
remaining = self.preview_limit_bytes - self.captured_bytes
|
||||
if chunk_bytes <= remaining:
|
||||
self.preview_entries.append((stream_name, text))
|
||||
self.captured_bytes += chunk_bytes
|
||||
return
|
||||
|
||||
self.captured_chars += len(captured)
|
||||
if len(text) > remaining:
|
||||
self.truncated = True
|
||||
self.preview_truncated = True
|
||||
self._ensure_temp_file()
|
||||
self._write_chunk(stream_name, text)
|
||||
|
||||
preview = self._clip_text_to_bytes(text, remaining)
|
||||
if preview:
|
||||
self.preview_entries.append((stream_name, preview))
|
||||
self.captured_bytes += len(preview.encode("utf-8"))
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
return "".join(self.stdout_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stdout"
|
||||
).strip()
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
return "".join(self.stderr_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stderr"
|
||||
).strip()
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
@@ -78,7 +132,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
description: str = (
|
||||
"Safely execute shell commands on the server. Useful for system "
|
||||
"maintenance, checking status, or running custom scripts. Includes "
|
||||
"timeout, concurrency, and hard output limits."
|
||||
"timeout, concurrency, and output preview limits."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
@@ -107,7 +161,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _subprocess_kwargs() -> dict:
|
||||
"""为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。"""
|
||||
"""为子进程创建独立进程组,便于超时场景清理整棵子进程。"""
|
||||
kwargs = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
@@ -124,23 +178,14 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
output: _CommandOutput,
|
||||
limit_reached: asyncio.Event,
|
||||
) -> None:
|
||||
"""按块读取输出,达到上限后通知主流程终止命令。"""
|
||||
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
|
||||
while True:
|
||||
chunk = await stream.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
continue
|
||||
|
||||
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
# 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
|
||||
@@ -205,27 +250,33 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
output: _CommandOutput,
|
||||
timeout: int,
|
||||
timed_out: bool,
|
||||
output_limited: bool,
|
||||
timeout_note: Optional[str],
|
||||
) -> str:
|
||||
if timed_out:
|
||||
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
|
||||
elif output_limited:
|
||||
result = (
|
||||
f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符,"
|
||||
f"已截断并终止进程,退出码: {exit_code})"
|
||||
)
|
||||
else:
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
|
||||
if timeout_note:
|
||||
result += f"\n\n提示:\n{timeout_note}"
|
||||
if output.temp_file_path:
|
||||
file_note = (
|
||||
"截至命令终止前的完整输出"
|
||||
if timed_out
|
||||
else "完整输出"
|
||||
)
|
||||
result += (
|
||||
"\n\n提示:\n"
|
||||
f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
|
||||
f"{file_note}已写入临时文件: {output.temp_file_path}\n"
|
||||
"如需完整内容,请继续读取该文件。"
|
||||
)
|
||||
if output.stdout:
|
||||
result += f"\n\n标准输出:\n{output.stdout}"
|
||||
if output.stderr:
|
||||
result += f"\n\n错误输出:\n{output.stderr}"
|
||||
if output.truncated:
|
||||
result += "\n\n...(输出内容过长,已截断)"
|
||||
if output.preview_truncated:
|
||||
result += "\n\n...(仅展示前 10KB 内容)"
|
||||
if not output.stdout and not output.stderr:
|
||||
result += "\n\n(无输出内容)"
|
||||
return result
|
||||
@@ -252,51 +303,40 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
async with _command_semaphore:
|
||||
# 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。
|
||||
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, **self._subprocess_kwargs()
|
||||
)
|
||||
output = _CommandOutput(limit=MAX_OUTPUT_CHARS)
|
||||
limit_reached = asyncio.Event()
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
limit_task = asyncio.create_task(limit_reached.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stdout, "stdout", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stdout, "stdout", output)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stderr, "stderr", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stderr, "stderr", output)
|
||||
),
|
||||
]
|
||||
|
||||
timed_out = False
|
||||
output_limited = False
|
||||
done, _ = await asyncio.wait(
|
||||
{wait_task, limit_task},
|
||||
timeout=normalized_timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if wait_task not in done:
|
||||
if limit_task in done:
|
||||
output_limited = True
|
||||
else:
|
||||
timed_out = True
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
|
||||
limit_task.cancel()
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
|
||||
return self._format_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
output_limited=output_limited,
|
||||
timeout_note=timeout_note,
|
||||
)
|
||||
|
||||
|
||||
118
app/agent/tools/impl/install_plugin.py
Normal file
118
app/agent/tools/impl/install_plugin.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""安装插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
get_plugin_snapshot,
|
||||
install_plugin_runtime,
|
||||
load_market_plugins,
|
||||
summarize_plugin,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class InstallPluginInput(BaseModel):
|
||||
"""安装插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="Exact plugin ID to install. Use query_market_plugins first to find the correct plugin_id.",
|
||||
)
|
||||
force: bool = Field(
|
||||
False,
|
||||
description="Whether to force reinstall or upgrade the specified plugin.",
|
||||
)
|
||||
force_refresh_market: bool = Field(
|
||||
False,
|
||||
description="Whether to refresh plugin market caches before reading the market list.",
|
||||
)
|
||||
|
||||
|
||||
class InstallPluginTool(MoviePilotTool):
|
||||
name: str = "install_plugin"
|
||||
description: str = (
|
||||
"Install a plugin by exact plugin_id from the plugin market or local plugin repositories. "
|
||||
"Use query_market_plugins first when you need filtering or discovery."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = InstallPluginInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
plugin_id = kwargs.get("plugin_id")
|
||||
return f"安装插件: {plugin_id or '未知插件'}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
force: bool = False,
|
||||
force_refresh_market: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: plugin_id={plugin_id}, force={force}"
|
||||
)
|
||||
|
||||
try:
|
||||
plugins = await load_market_plugins(force_refresh=force_refresh_market)
|
||||
if not plugins:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前插件市场没有可用插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
candidate = next((plugin for plugin in plugins if plugin.id == plugin_id), None)
|
||||
if not candidate:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"未在插件市场中找到插件: {plugin_id}。请先调用 query_market_plugins 确认 plugin_id。",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
success, message, refreshed_only = await install_plugin_runtime(
|
||||
candidate.id,
|
||||
getattr(candidate, "repo_url", None),
|
||||
force=force,
|
||||
)
|
||||
if not success:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"plugin": summarize_plugin(candidate),
|
||||
"message": message,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
plugin_snapshot = get_plugin_snapshot(candidate.id)
|
||||
if refreshed_only and getattr(candidate, "has_update", False) and not force:
|
||||
message = "插件已安装,当前仅刷新加载;如需升级到市场新版本,请设置 force=true"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"force": force,
|
||||
"refreshed_only": refreshed_only,
|
||||
"plugin": summarize_plugin(candidate),
|
||||
"runtime": plugin_snapshot,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"安装插件失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"安装插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
85
app/agent/tools/impl/query_builtin_filter_rules.py
Normal file
85
app/agent/tools/impl/query_builtin_filter_rules.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""查询内置过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_builtin_rules,
|
||||
serialize_builtin_rule,
|
||||
RULE_STRING_SYNTAX,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryBuiltinFilterRulesInput(BaseModel):
|
||||
"""查询内置过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of built-in rule IDs to query. If omitted, return all built-in rules.",
|
||||
)
|
||||
|
||||
|
||||
class QueryBuiltinFilterRulesTool(MoviePilotTool):
|
||||
name: str = "query_builtin_filter_rules"
|
||||
description: str = (
|
||||
"Query built-in filter rules defined by the backend filter module. "
|
||||
"These rule IDs can be used directly inside rule_string expressions for filter rule groups. "
|
||||
"Use this tool before add_rule_group or update_rule_group to learn valid built-in rule IDs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryBuiltinFilterRulesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
rule_ids = kwargs.get("rule_ids") or []
|
||||
if rule_ids:
|
||||
return f"查询内置过滤规则: {', '.join(rule_ids)}"
|
||||
return "查询所有内置过滤规则"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
rule_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
builtin_rules = get_builtin_rules()
|
||||
if rule_ids:
|
||||
target_ids = set(rule_ids)
|
||||
builtin_rules = {
|
||||
rule_id: payload
|
||||
for rule_id, payload in builtin_rules.items()
|
||||
if rule_id in target_ids
|
||||
}
|
||||
|
||||
serialized = [
|
||||
serialize_builtin_rule(rule_id, payload)
|
||||
for rule_id, payload in builtin_rules.items()
|
||||
]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": len(serialized),
|
||||
"rule_string_syntax": RULE_STRING_SYNTAX,
|
||||
"rules": serialized,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"查询内置过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询内置过滤规则失败: {exc}",
|
||||
"rules": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
95
app/agent/tools/impl/query_custom_filter_rules.py
Normal file
95
app/agent/tools/impl/query_custom_filter_rules.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""查询自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
serialize_custom_rule,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryCustomFilterRulesInput(BaseModel):
|
||||
"""查询自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of custom rule IDs to query. If omitted, return all custom rules.",
|
||||
)
|
||||
include_group_refs: bool = Field(
|
||||
True,
|
||||
description="Whether to include which rule groups reference each custom rule.",
|
||||
)
|
||||
|
||||
|
||||
class QueryCustomFilterRulesTool(MoviePilotTool):
|
||||
name: str = "query_custom_filter_rules"
|
||||
description: str = (
|
||||
"Query custom filter rules stored in CustomFilterRules. "
|
||||
"Custom rules can be referenced from rule_string expressions in filter rule groups. "
|
||||
"Use this tool before add_rule_group or update_rule_group to learn valid custom rule IDs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryCustomFilterRulesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
rule_ids = kwargs.get("rule_ids") or []
|
||||
if rule_ids:
|
||||
return f"查询自定义过滤规则: {', '.join(rule_ids)}"
|
||||
return "查询所有自定义过滤规则"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
rule_ids: Optional[List[str]] = None,
|
||||
include_group_refs: bool = True,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
if rule_ids:
|
||||
target_ids = set(rule_ids)
|
||||
custom_rules = [
|
||||
rule for rule in custom_rules if rule.id in target_ids
|
||||
]
|
||||
|
||||
refs = {}
|
||||
if include_group_refs:
|
||||
refs = collect_custom_rule_group_refs(
|
||||
get_rule_groups(),
|
||||
[rule.id for rule in custom_rules if rule.id],
|
||||
)
|
||||
|
||||
serialized = [
|
||||
serialize_custom_rule(rule, refs.get(rule.id))
|
||||
for rule in custom_rules
|
||||
]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": len(serialized),
|
||||
"rules": serialized,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"查询自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询自定义过滤规则失败: {exc}",
|
||||
"rules": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -27,6 +27,7 @@ class QueryCustomIdentifiersTool(MoviePilotTool):
|
||||
"Returns the list of identifier rules used for preprocessing torrent/file names before media recognition. "
|
||||
"Use this tool to check existing rules before adding new ones to avoid duplicates."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -24,6 +24,7 @@ class QueryDirectorySettingsInput(BaseModel):
|
||||
class QueryDirectorySettingsTool(MoviePilotTool):
|
||||
name: str = "query_directory_settings"
|
||||
description: str = "Query system directory configuration settings (NOT file listings). Returns configured directory paths, storage types, transfer modes, and other directory-related settings. Use 'list_directory' to list actual files and folders in a directory."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryDirectorySettingsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -19,6 +19,7 @@ class QueryDownloadersInput(BaseModel):
|
||||
class QueryDownloadersTool(MoviePilotTool):
|
||||
name: str = "query_downloaders"
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -6,7 +6,14 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
MAX_PLUGIN_CANDIDATE_LIMIT,
|
||||
list_installed_plugins,
|
||||
search_plugin_candidates,
|
||||
summarize_candidates,
|
||||
summarize_plugin,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@@ -17,49 +24,89 @@ class QueryInstalledPluginsInput(BaseModel):
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional keyword to filter installed plugins by plugin ID, name, description, or author.",
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
description="Maximum number of plugins to return. Defaults to 50, capped at 200.",
|
||||
)
|
||||
|
||||
|
||||
class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
name: str = "query_installed_plugins"
|
||||
description: str = (
|
||||
"Query all installed plugins in MoviePilot. Returns a list of installed plugins with their ID, name, "
|
||||
"description, version, author, running state, and other information. "
|
||||
"Use this tool to discover what plugins are available before querying plugin capabilities or running plugin commands."
|
||||
"Query installed plugins in MoviePilot. Returns all installed plugins or filters them by keywords. "
|
||||
"Use this tool to find the exact plugin_id before uninstall_plugin or other plugin management tools are used."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryInstalledPluginsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
query = kwargs.get("query")
|
||||
if query:
|
||||
return f"查询已安装插件: {query}"
|
||||
return "查询已安装插件"
|
||||
|
||||
@staticmethod
|
||||
def _list_installed_plugins() -> list[dict]:
|
||||
"""读取已加载插件的内存快照。"""
|
||||
plugin_manager = PluginManager()
|
||||
local_plugins = plugin_manager.get_local_plugins()
|
||||
installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
|
||||
return [
|
||||
{
|
||||
"id": plugin.id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"plugin_desc": plugin.plugin_desc,
|
||||
"plugin_version": plugin.plugin_version,
|
||||
"plugin_author": plugin.plugin_author,
|
||||
"state": plugin.state,
|
||||
"has_page": plugin.has_page,
|
||||
}
|
||||
for plugin in installed_plugins
|
||||
]
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
def _clamp_results(max_results: Optional[int]) -> int:
|
||||
if max_results is None:
|
||||
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
|
||||
try:
|
||||
installed_plugins = self._list_installed_plugins()
|
||||
return max(1, min(int(max_results), MAX_PLUGIN_CANDIDATE_LIMIT))
|
||||
except (TypeError, ValueError):
|
||||
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
max_results: Optional[int] = DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}")
|
||||
try:
|
||||
installed_plugins = list_installed_plugins()
|
||||
if not installed_plugins:
|
||||
return "当前没有已安装的插件"
|
||||
result_json = json.dumps(installed_plugins, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前没有已安装的插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
limit = self._clamp_results(max_results)
|
||||
if query:
|
||||
matches = search_plugin_candidates(query, installed_plugins)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"total_installed": len(installed_plugins),
|
||||
"match_count": len(matches),
|
||||
"truncated": len(matches) > limit,
|
||||
"plugins": summarize_candidates(matches, limit=limit),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
plugin_summaries = [
|
||||
summarize_plugin(plugin) for plugin in installed_plugins[:limit]
|
||||
]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"total_installed": len(installed_plugins),
|
||||
"returned_count": len(plugin_summaries),
|
||||
"truncated": len(installed_plugins) > limit,
|
||||
"plugins": plugin_summaries,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询已安装插件失败: {e}", exc_info=True)
|
||||
return f"查询已安装插件时发生错误: {str(e)}"
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询已安装插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
117
app/agent/tools/impl/query_market_plugins.py
Normal file
117
app/agent/tools/impl/query_market_plugins.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""查询插件市场工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
MAX_PLUGIN_CANDIDATE_LIMIT,
|
||||
load_market_plugins,
|
||||
search_plugin_candidates,
|
||||
summarize_candidates,
|
||||
summarize_plugin,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryMarketPluginsInput(BaseModel):
|
||||
"""查询插件市场工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional keyword to filter plugin market results by plugin ID, name, description, or author.",
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
description="Maximum number of plugins to return. Defaults to 50, capped at 200.",
|
||||
)
|
||||
force_refresh: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to refresh plugin market caches before querying.",
|
||||
)
|
||||
|
||||
|
||||
class QueryMarketPluginsTool(MoviePilotTool):
|
||||
name: str = "query_market_plugins"
|
||||
description: str = (
|
||||
"Query available plugins from the plugin market and local plugin repositories. "
|
||||
"Can return the full plugin list or filter by keywords before install_plugin is used."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryMarketPluginsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
query = kwargs.get("query")
|
||||
if query:
|
||||
return f"查询插件市场: {query}"
|
||||
return "查询插件市场全部插件"
|
||||
|
||||
@staticmethod
|
||||
def _clamp_results(max_results: Optional[int]) -> int:
|
||||
if max_results is None:
|
||||
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
|
||||
try:
|
||||
return max(1, min(int(max_results), MAX_PLUGIN_CANDIDATE_LIMIT))
|
||||
except (TypeError, ValueError):
|
||||
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
max_results: Optional[int] = DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
force_refresh: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: query={query}, force_refresh={force_refresh}"
|
||||
)
|
||||
|
||||
try:
|
||||
plugins = await load_market_plugins(force_refresh=force_refresh)
|
||||
if not plugins:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前插件市场没有可用插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
limit = self._clamp_results(max_results)
|
||||
if query:
|
||||
matches = search_plugin_candidates(query, plugins)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"total_available": len(plugins),
|
||||
"match_count": len(matches),
|
||||
"truncated": len(matches) > limit,
|
||||
"plugins": summarize_candidates(matches, limit=limit),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
plugin_summaries = [summarize_plugin(plugin) for plugin in plugins[:limit]]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"total_available": len(plugins),
|
||||
"returned_count": len(plugin_summaries),
|
||||
"truncated": len(plugins) > limit,
|
||||
"plugins": plugin_summaries,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件市场失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询插件市场时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -10,6 +10,10 @@ from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
DIRECTOR_PREVIEW_LIMIT = 10
|
||||
ACTOR_PREVIEW_LIMIT = 20
|
||||
SEASON_PREVIEW_LIMIT = 100
|
||||
|
||||
|
||||
class QueryMediaDetailInput(BaseModel):
|
||||
"""查询媒体详情工具的输入参数模型"""
|
||||
@@ -64,23 +68,23 @@ class QueryMediaDetailTool(MoviePilotTool):
|
||||
genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")]
|
||||
|
||||
# 精简 directors - 只保留姓名和职位
|
||||
director_source = [d for d in (mediainfo.directors or []) if d.get("name")]
|
||||
directors = [
|
||||
{
|
||||
"name": d.get("name"),
|
||||
"job": d.get("job")
|
||||
}
|
||||
for d in (mediainfo.directors or [])
|
||||
if d.get("name")
|
||||
for d in director_source[:DIRECTOR_PREVIEW_LIMIT]
|
||||
]
|
||||
|
||||
# 精简 actors - 只保留姓名和角色
|
||||
actor_source = [a for a in (mediainfo.actors or []) if a.get("name")]
|
||||
actors = [
|
||||
{
|
||||
"name": a.get("name"),
|
||||
"character": a.get("character")
|
||||
}
|
||||
for a in (mediainfo.actors or [])
|
||||
if a.get("name")
|
||||
for a in actor_source[:ACTOR_PREVIEW_LIMIT]
|
||||
]
|
||||
|
||||
# 构建基础媒体详情信息
|
||||
@@ -88,12 +92,20 @@ class QueryMediaDetailTool(MoviePilotTool):
|
||||
"status": mediainfo.status,
|
||||
"genres": genres,
|
||||
"directors": directors,
|
||||
"actors": actors
|
||||
"directors_total": len(director_source),
|
||||
"directors_truncated": len(director_source) > DIRECTOR_PREVIEW_LIMIT,
|
||||
"actors": actors,
|
||||
"actors_total": len(actor_source),
|
||||
"actors_truncated": len(actor_source) > ACTOR_PREVIEW_LIMIT,
|
||||
}
|
||||
|
||||
# 如果是电视剧,添加电视剧特有信息
|
||||
if mediainfo.type == MediaType.TV:
|
||||
# 精简 season_info - 只保留基础摘要
|
||||
season_source = [
|
||||
s for s in (mediainfo.season_info or [])
|
||||
if s.get("season_number") is not None
|
||||
]
|
||||
season_info = [
|
||||
{
|
||||
"season_number": s.get("season_number"),
|
||||
@@ -101,8 +113,7 @@ class QueryMediaDetailTool(MoviePilotTool):
|
||||
"episode_count": s.get("episode_count"),
|
||||
"air_date": s.get("air_date")
|
||||
}
|
||||
for s in (mediainfo.season_info or [])
|
||||
if s.get("season_number") is not None
|
||||
for s in season_source[:SEASON_PREVIEW_LIMIT]
|
||||
]
|
||||
|
||||
result.update({
|
||||
@@ -110,7 +121,9 @@ class QueryMediaDetailTool(MoviePilotTool):
|
||||
"number_of_episodes": mediainfo.number_of_episodes,
|
||||
"first_air_date": mediainfo.first_air_date,
|
||||
"last_air_date": mediainfo.last_air_date,
|
||||
"season_info": season_info
|
||||
"season_info": season_info,
|
||||
"season_info_total": len(season_source),
|
||||
"season_info_truncated": len(season_source) > SEASON_PREVIEW_LIMIT,
|
||||
})
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
75
app/agent/tools/impl/query_personas.py
Normal file
75
app/agent/tools/impl/query_personas.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""查询可用人格工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPersonasInput(BaseModel):
|
||||
"""查询人格工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional search keyword for persona_id, label, description, or aliases. "
|
||||
"Use this when the user asks for a certain speaking style but the exact persona name is unknown."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class QueryPersonasTool(MoviePilotTool):
|
||||
name: str = "query_personas"
|
||||
description: str = (
|
||||
"List all available personas (人格) and show which one is currently active. "
|
||||
"Use this before switching persona when the user asks for a different speaking style but does not name "
|
||||
"an exact persona_id. The result includes persona_id, label, description, aliases, and whether it is active."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryPersonasInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
query = kwargs.get("query")
|
||||
if query:
|
||||
return f"查询人格列表: {query}"
|
||||
return "查询人格列表"
|
||||
|
||||
async def run(self, query: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info("执行工具: %s, 参数: query=%s", self.name, query)
|
||||
try:
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
personas = runtime_config.list_personas()
|
||||
|
||||
if query:
|
||||
normalized = query.strip().casefold()
|
||||
personas = [
|
||||
persona
|
||||
for persona in personas
|
||||
if normalized in persona["persona_id"].casefold()
|
||||
or normalized in persona["label"].casefold()
|
||||
or normalized in persona["description"].casefold()
|
||||
or any(normalized in alias.casefold() for alias in persona["aliases"])
|
||||
]
|
||||
|
||||
payload = {
|
||||
"active_persona": runtime_config.active_persona,
|
||||
"count": len(personas),
|
||||
"personas": personas,
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("查询人格列表失败: %s", e, exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询人格列表时发生错误: {str(e)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
88
app/agent/tools/impl/query_plugin_config.py
Normal file
88
app/agent/tools/impl/query_plugin_config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""查询插件配置工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import get_plugin_snapshot
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPluginConfigInput(BaseModel):
|
||||
"""查询插件配置工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to query. Use query_installed_plugins first to discover valid plugin IDs.",
|
||||
)
|
||||
|
||||
|
||||
class QueryPluginConfigTool(MoviePilotTool):
|
||||
name: str = "query_plugin_config"
|
||||
description: str = (
|
||||
"Query the saved configuration of an installed plugin. "
|
||||
"Returns the current saved config and, when available, the plugin's default config model. "
|
||||
"Use this before update_plugin_config so you only change the intended keys."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryPluginConfigInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
return f"查询插件配置: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
def _query_plugin_config(plugin_id: str) -> str:
|
||||
"""
|
||||
读取插件已保存配置,并尽量补充默认配置模型方便后续精确修改。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
saved_config = plugin_manager.get_plugin_config(plugin_id) or {}
|
||||
result = {
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"config": saved_config,
|
||||
}
|
||||
|
||||
# get_form 的 model 通常就是插件期望的配置结构,适合作为修改前的键参考。
|
||||
plugin_instance = plugin_manager.running_plugins.get(plugin_id)
|
||||
if plugin_instance and hasattr(plugin_instance, "get_form"):
|
||||
try:
|
||||
_form_schema, default_model = plugin_instance.get_form()
|
||||
if default_model is not None:
|
||||
result["default_model"] = default_model
|
||||
except Exception as err:
|
||||
logger.warning(f"读取插件 {plugin_id} 默认配置模型失败: {err}")
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
async def run(self, plugin_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
|
||||
try:
|
||||
# 插件配置来自内存配置缓存和运行态插件实例,直接读取即可。
|
||||
return self._query_plugin_config(plugin_id)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件配置失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询插件配置时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
158
app/agent/tools/impl/query_plugin_data.py
Normal file
158
app/agent/tools/impl/query_plugin_data.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""查询插件数据工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
PLUGIN_DATA_KEY_PREVIEW_LIMIT,
|
||||
build_preview_payload,
|
||||
get_plugin_snapshot,
|
||||
)
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPluginDataInput(BaseModel):
|
||||
"""查询插件数据工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to query. Use query_installed_plugins first to discover valid plugin IDs.",
|
||||
)
|
||||
key: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional plugin data key. If omitted, returns all plugin data entries for the plugin.",
|
||||
)
|
||||
max_chars: Optional[int] = Field(
|
||||
None,
|
||||
description="Maximum number of preview characters to return when plugin data is too large. Default 12000, capped at 50000.",
|
||||
)
|
||||
|
||||
|
||||
class QueryPluginDataTool(MoviePilotTool):
|
||||
name: str = "query_plugin_data"
|
||||
description: str = (
|
||||
"Query persisted data of an installed plugin. "
|
||||
"Optionally specify a key to read a single data item; otherwise all plugin data entries are returned. "
|
||||
"When the result is too large, the tool automatically truncates it and returns a preview instead."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryPluginDataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
key = kwargs.get("key")
|
||||
if key:
|
||||
return f"查询插件数据: {plugin_id}.{key}"
|
||||
return f"查询插件全部数据: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
async def _query_plugin_data(
|
||||
plugin_id: str, key: Optional[str] = None, max_chars: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
插件数据改走异步 ORM 查询,避免再套一层线程池。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
plugin_data_oper = PluginDataOper()
|
||||
if key:
|
||||
value = await plugin_data_oper.async_get_data(plugin_id, key)
|
||||
if value is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"key": key,
|
||||
"found": False,
|
||||
"message": f"插件 {plugin_id} 没有数据项 {key}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
truncated, total_chars, returned_chars, preview = build_preview_payload(
|
||||
value, max_chars
|
||||
)
|
||||
result = {
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"key": key,
|
||||
"found": True,
|
||||
"truncated": truncated,
|
||||
"total_chars": total_chars,
|
||||
"returned_chars": returned_chars,
|
||||
}
|
||||
if truncated:
|
||||
result["value_preview"] = preview
|
||||
result["message"] = "插件数据内容过大,已截断预览"
|
||||
else:
|
||||
result["value"] = value
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
rows = await plugin_data_oper.async_get_data_all(plugin_id) or []
|
||||
data_map = {row.key: row.value for row in rows}
|
||||
keys = list(data_map.keys())
|
||||
key_preview = keys[:PLUGIN_DATA_KEY_PREVIEW_LIMIT]
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"count": len(data_map),
|
||||
"keys": key_preview,
|
||||
"keys_truncated": len(keys) > PLUGIN_DATA_KEY_PREVIEW_LIMIT,
|
||||
}
|
||||
|
||||
if not data_map:
|
||||
result["data"] = {}
|
||||
result["truncated"] = False
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
truncated, total_chars, returned_chars, preview = build_preview_payload(
|
||||
data_map, max_chars
|
||||
)
|
||||
result["truncated"] = truncated
|
||||
result["total_chars"] = total_chars
|
||||
result["returned_chars"] = returned_chars
|
||||
if truncated:
|
||||
result["data_preview"] = preview
|
||||
result["message"] = "插件数据内容过大,已截断。请传入 key 精确查询单个数据项。"
|
||||
else:
|
||||
result["data"] = data_map
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
key: Optional[str] = None,
|
||||
max_chars: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: plugin_id={plugin_id}, key={key}"
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._query_plugin_data(plugin_id, key, max_chars)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件数据失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询插件数据时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -12,13 +12,15 @@ from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
MAX_PAGE_SIZE = 50
|
||||
|
||||
|
||||
class QueryPopularSubscribesInput(BaseModel):
|
||||
"""查询热门订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)")
|
||||
min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)")
|
||||
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
|
||||
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
|
||||
@@ -69,6 +71,8 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
# 外部统计接口支持传入 count,这里做硬上限,避免 Agent 一次拉取过多结果。
|
||||
count = min(count, MAX_PAGE_SIZE)
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
@@ -1,63 +1,104 @@
|
||||
"""查询规则组工具"""
|
||||
"""查询过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_rule_group_usages,
|
||||
get_rule_groups,
|
||||
serialize_rule_group,
|
||||
RULE_STRING_SYNTAX,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryRuleGroupsInput(BaseModel):
|
||||
"""查询规则组工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
group_names: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of rule group names to query. If omitted, return all rule groups.",
|
||||
)
|
||||
include_usage: bool = Field(
|
||||
True,
|
||||
description="Whether to include where each rule group is referenced by global settings or subscriptions.",
|
||||
)
|
||||
|
||||
|
||||
class QueryRuleGroupsTool(MoviePilotTool):
|
||||
name: str = "query_rule_groups"
|
||||
description: str = "Query all filter rule groups available in the system. Rule groups are used to filter torrents when searching or subscribing. Returns rule group names, media types, and categories, but excludes rule_string to keep results concise."
|
||||
description: str = (
|
||||
"Query filter rule groups (过滤规则组 / 优先级规则组). "
|
||||
"Each rule group contains a rule_string made of built-in rules and/or custom rules. "
|
||||
"Inside one level use '&', '|', '!' and optional parentheses; use '>' between levels. "
|
||||
"Levels are evaluated from left to right, and the first matched level wins. "
|
||||
"The result includes parsed levels and syntax guidance so the agent can learn existing patterns before writing a new rule group."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryRuleGroupsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
group_names = kwargs.get("group_names") or []
|
||||
if group_names:
|
||||
return f"查询规则组: {', '.join(group_names)}"
|
||||
return "查询所有规则组"
|
||||
|
||||
@staticmethod
|
||||
def _load_rule_groups() -> dict:
|
||||
"""从内存配置缓存中读取规则组。"""
|
||||
rule_groups = RuleHelper().get_rule_groups()
|
||||
if not rule_groups:
|
||||
return {
|
||||
"message": "未找到任何规则组",
|
||||
"rule_groups": [],
|
||||
}
|
||||
|
||||
simplified_groups = [
|
||||
{
|
||||
"name": group.name,
|
||||
"media_type": group.media_type,
|
||||
"category": group.category,
|
||||
}
|
||||
for group in rule_groups
|
||||
]
|
||||
return {
|
||||
"message": f"找到 {len(simplified_groups)} 个规则组",
|
||||
"rule_groups": simplified_groups,
|
||||
}
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
group_names: Optional[List[str]] = None,
|
||||
include_usage: bool = True,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
result = self._load_rule_groups()
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
error_message = f"查询规则组失败: {str(e)}"
|
||||
logger.error(f"查询规则组失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
rule_groups = get_rule_groups()
|
||||
if group_names:
|
||||
target_names = set(group_names)
|
||||
rule_groups = [
|
||||
group for group in rule_groups if group.name in target_names
|
||||
]
|
||||
|
||||
usage_map = {}
|
||||
if include_usage:
|
||||
usage_map = await collect_rule_group_usages(
|
||||
[group.name for group in rule_groups if group.name]
|
||||
)
|
||||
|
||||
serialized = [
|
||||
serialize_rule_group(group, usage_map.get(group.name))
|
||||
for group in rule_groups
|
||||
]
|
||||
message = (
|
||||
f"找到 {len(serialized)} 个规则组"
|
||||
if serialized
|
||||
else "未找到任何规则组"
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"count": len(serialized),
|
||||
"rule_string_syntax": RULE_STRING_SYNTAX,
|
||||
"rule_groups": serialized,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"查询规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询规则组失败: {exc}",
|
||||
"rule_groups": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class QuerySchedulersInput(BaseModel):
|
||||
@@ -27,6 +26,8 @@ class QuerySchedulersTool(MoviePilotTool):
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
scheduler = Scheduler()
|
||||
schedulers = scheduler.list()
|
||||
if schedulers:
|
||||
|
||||
@@ -11,6 +11,14 @@ from app.db.models.site import Site
|
||||
from app.db.models.siteuserdata import SiteUserData
|
||||
from app.log import logger
|
||||
|
||||
SITE_USERDATA_DETAIL_PREVIEW_LIMIT = 10
|
||||
|
||||
|
||||
def _preview_list(value, limit: int = SITE_USERDATA_DETAIL_PREVIEW_LIMIT) -> tuple[list, int, bool]:
|
||||
"""返回列表字段预览,避免做种明细或未读消息一次性撑大工具结果。"""
|
||||
items = list(value) if isinstance(value, (list, tuple)) else []
|
||||
return items[:limit], len(items), len(items) > limit
|
||||
|
||||
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
@@ -110,6 +118,13 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
else 0
|
||||
)
|
||||
|
||||
seeding_preview, seeding_count, seeding_truncated = _preview_list(
|
||||
user_data.seeding_info
|
||||
)
|
||||
unread_preview, unread_count, unread_truncated = _preview_list(
|
||||
user_data.message_unread_contents
|
||||
)
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
@@ -131,13 +146,13 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info": user_data.seeding_info
|
||||
if user_data.seeding_info
|
||||
else [],
|
||||
"seeding_info_count": seeding_count,
|
||||
"seeding_info": seeding_preview,
|
||||
"seeding_info_truncated": seeding_truncated,
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents": user_data.message_unread_contents
|
||||
if user_data.message_unread_contents
|
||||
else [],
|
||||
"message_unread_contents_count": unread_count,
|
||||
"message_unread_contents": unread_preview,
|
||||
"message_unread_contents_truncated": unread_truncated,
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time,
|
||||
|
||||
@@ -9,13 +9,15 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
|
||||
MAX_PAGE_SIZE = 50
|
||||
|
||||
|
||||
class QuerySubscribeSharesInput(BaseModel):
|
||||
"""查询订阅分享工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
name: Optional[str] = Field(None, description="Filter shares by media name (partial match, optional)")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30, max: 50)")
|
||||
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
|
||||
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
|
||||
max_rating: Optional[float] = Field(None, description="Maximum rating filter (optional, e.g., 10.0)")
|
||||
@@ -63,6 +65,8 @@ class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
# 订阅分享是外部列表型结果,限制单页大小能降低工具上下文占用。
|
||||
count = min(count, MAX_PAGE_SIZE)
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
shares = await subscribe_helper.async_get_shares(
|
||||
|
||||
@@ -33,6 +33,9 @@ QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
"sites",
|
||||
"downloader",
|
||||
"best_version",
|
||||
"best_version_full",
|
||||
"current_priority",
|
||||
"episode_priority",
|
||||
"save_path",
|
||||
"custom_words",
|
||||
"media_category",
|
||||
|
||||
186
app/agent/tools/impl/query_system_settings.py
Normal file
186
app/agent/tools/impl/query_system_settings.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""统一查询系统设置工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._system_setting_utils import (
|
||||
SettingSpec,
|
||||
list_setting_specs,
|
||||
resolve_setting_spec,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySystemSettingsInput(BaseModel):
|
||||
"""查询系统设置工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
setting_key: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Exact setting key to query. Supports Settings field names like 'APP_DOMAIN' or 'TMDB_API_KEY', "
|
||||
"SystemConfigKey values like 'Downloaders' or 'MediaServers', enum names, and some single-key aliases "
|
||||
"such as 'downloaders', 'directories', 'search_sites', 'subscribe_sites', 'site_auth', 'ai_agent', "
|
||||
"and 'custom_identifiers'."
|
||||
),
|
||||
)
|
||||
group: Optional[str] = Field(
|
||||
"all",
|
||||
description=(
|
||||
"Optional group filter when setting_key is not provided. Supports 'all', 'settings', 'systemconfig', "
|
||||
"and category aliases such as 'downloaders', 'media_servers', 'notifications', 'notification_switches', "
|
||||
"'storages', 'directories', 'search_sites', 'subscribe_sites', 'site_auth', 'ai_agent', 'filter_rules', "
|
||||
"'subscribe_defaults', 'plugins', and 'custom_identifiers'. Chinese aliases are also accepted."
|
||||
),
|
||||
)
|
||||
keyword: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional keyword used to fuzzy match setting keys, group names, or labels when listing settings."
|
||||
),
|
||||
)
|
||||
include_values: Optional[bool] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Whether to include full setting values. Default behavior: when a single setting is matched it returns the full value; "
|
||||
"when multiple settings are matched it returns summaries only unless this is explicitly set to true."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class QuerySystemSettingsTool(MoviePilotTool):
|
||||
name: str = "query_system_settings"
|
||||
description: str = (
|
||||
"Query system settings across both the basic Settings module and all SystemConfig-backed categories. "
|
||||
"Use this tool to inspect downloaders, media servers, notification channels, storages, directories, search-site ranges, "
|
||||
"subscribe-site ranges, site auth params, AI agent config, and any other system setting before making changes."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySystemSettingsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息。"""
|
||||
|
||||
setting_key = kwargs.get("setting_key")
|
||||
group = kwargs.get("group", "all")
|
||||
keyword = kwargs.get("keyword")
|
||||
if setting_key:
|
||||
return f"查询系统设置: {setting_key}"
|
||||
if keyword:
|
||||
return f"筛选系统设置: {group} / {keyword}"
|
||||
return f"查询系统设置分组: {group}"
|
||||
|
||||
@staticmethod
|
||||
def _load_setting_value(spec: SettingSpec):
|
||||
if spec.source == "settings":
|
||||
return getattr(settings, spec.key)
|
||||
return SystemConfigOper().get(spec.key)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_value(value) -> dict:
|
||||
summary = {
|
||||
"has_value": value is not None,
|
||||
"value_type": type(value).__name__,
|
||||
}
|
||||
if isinstance(value, list):
|
||||
summary["item_count"] = len(value)
|
||||
if value:
|
||||
summary["item_type"] = type(value[0]).__name__
|
||||
elif isinstance(value, dict):
|
||||
keys = list(value.keys())
|
||||
summary["item_count"] = len(keys)
|
||||
summary["keys_preview"] = keys[:10]
|
||||
if len(keys) > 10:
|
||||
summary["keys_truncated"] = True
|
||||
elif isinstance(value, str):
|
||||
summary["length"] = len(value)
|
||||
preview = value[:200]
|
||||
if preview:
|
||||
summary["value_preview"] = preview
|
||||
if len(value) > len(preview):
|
||||
summary["value_truncated"] = True
|
||||
elif value is not None:
|
||||
summary["value_preview"] = value
|
||||
return summary
|
||||
|
||||
async def run(
|
||||
self,
|
||||
setting_key: Optional[str] = None,
|
||||
group: Optional[str] = "all",
|
||||
keyword: Optional[str] = None,
|
||||
include_values: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
"执行工具: %s, setting_key=%s, group=%s, keyword=%s",
|
||||
self.name,
|
||||
setting_key,
|
||||
group,
|
||||
keyword,
|
||||
)
|
||||
|
||||
try:
|
||||
if setting_key:
|
||||
spec = resolve_setting_spec(setting_key)
|
||||
if not spec:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"系统设置项 '{setting_key}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
specs = [spec]
|
||||
else:
|
||||
specs = list_setting_specs(group=group, keyword=keyword)
|
||||
if not specs:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": "没有找到匹配的系统设置项",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
should_include_values = (
|
||||
include_values if include_values is not None else len(specs) == 1
|
||||
)
|
||||
settings_payload = []
|
||||
for spec in specs:
|
||||
value = self._load_setting_value(spec)
|
||||
item = {
|
||||
"setting_key": spec.key,
|
||||
"source": spec.source,
|
||||
"group": spec.group,
|
||||
"label": spec.label,
|
||||
}
|
||||
item.update(self._summarize_value(value))
|
||||
if should_include_values:
|
||||
item["value"] = value
|
||||
settings_payload.append(item)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"matched_count": len(settings_payload),
|
||||
"include_values": should_include_values,
|
||||
"settings": settings_payload,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询系统设置失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询系统设置时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -62,8 +62,8 @@ class QueryTransferHistoryTool(MoviePilotTool):
|
||||
if page is None or page < 1:
|
||||
page = 1
|
||||
|
||||
# 每页记录数
|
||||
count = 50
|
||||
# 每页固定 30 条,与工具说明保持一致,避免整理路径等字段撑大上下文。
|
||||
count = 30
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
|
||||
@@ -115,9 +115,7 @@ class QueryWorkflowsTool(MoviePilotTool):
|
||||
"last_time": wf.last_time,
|
||||
"current_action": wf.current_action
|
||||
}
|
||||
# 如果有结果,添加结果信息
|
||||
if wf.result:
|
||||
simplified["result"] = wf.result
|
||||
# wf.result 往往是执行日志或上下文快照,不适合作为列表查询结果返回。
|
||||
simplified_workflows.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)
|
||||
|
||||
@@ -49,8 +49,7 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
context = None
|
||||
|
||||
|
||||
# 根据提供的参数选择识别方式
|
||||
if path:
|
||||
# 文件路径识别
|
||||
@@ -60,7 +59,10 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"message": "文件路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
context = await media_chain.async_recognize_by_path(path)
|
||||
context = await media_chain.async_recognize_by_path(
|
||||
path,
|
||||
obtain_images=False,
|
||||
)
|
||||
if context:
|
||||
return self._format_context_result(context, "文件")
|
||||
else:
|
||||
@@ -73,7 +75,10 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
elif title:
|
||||
# 种子标题识别
|
||||
metainfo = MetaInfo(title, subtitle)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(metainfo)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(
|
||||
metainfo,
|
||||
obtain_images=False,
|
||||
)
|
||||
if mediainfo:
|
||||
context = Context(meta_info=metainfo, media_info=mediainfo)
|
||||
return self._format_context_result(context, "种子")
|
||||
|
||||
84
app/agent/tools/impl/reload_plugin.py
Normal file
84
app/agent/tools/impl/reload_plugin.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""重载插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
get_plugin_snapshot,
|
||||
reload_plugin_runtime,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ReloadPluginInput(BaseModel):
|
||||
"""重载插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to reload so the latest saved config takes effect.",
|
||||
)
|
||||
|
||||
|
||||
class ReloadPluginTool(MoviePilotTool):
|
||||
name: str = "reload_plugin"
|
||||
description: str = (
|
||||
"Reload an installed plugin so its latest saved configuration takes effect. "
|
||||
"This also refreshes the plugin's registered commands, scheduled services, and API routes."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = ReloadPluginInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
return f"重载插件: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
def _reload_plugin_sync(plugin_id: str) -> str:
|
||||
"""
|
||||
按后台接口同样的流程重载插件,确保最新配置和注册信息一起刷新。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
reload_plugin_runtime(plugin_id)
|
||||
refreshed_plugin = get_plugin_snapshot(plugin_id) or plugin_info
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
**refreshed_plugin,
|
||||
"message": "插件已重载,最新配置已生效",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
|
||||
async def run(self, plugin_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
|
||||
try:
|
||||
return await self.run_blocking(
|
||||
"plugin", self._reload_plugin_sync, plugin_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"重载插件失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"重载插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
@@ -36,6 +35,8 @@ class RunSchedulerTool(MoviePilotTool):
|
||||
@staticmethod
|
||||
def _run_scheduler_sync(job_id: str) -> tuple[bool, str]:
|
||||
"""同步触发定时服务,避免调度器扫描阻塞事件循环。"""
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
scheduler = Scheduler()
|
||||
for scheduler_item in scheduler.list():
|
||||
if scheduler_item.id == job_id:
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem
|
||||
|
||||
@@ -80,8 +79,7 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
if not scrape_path.exists():
|
||||
if not Path(path).exists():
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"刮削路径不存在: {path}"},
|
||||
ensure_ascii=False,
|
||||
@@ -89,11 +87,12 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
scrape_path = Path(path)
|
||||
meta = MetaInfoPath(scrape_path)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
context = await media_chain.async_recognize_by_path(
|
||||
path,
|
||||
obtain_images=True,
|
||||
)
|
||||
|
||||
if not mediainfo:
|
||||
if not context or not context.media_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
@@ -108,8 +107,8 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"storage",
|
||||
media_chain.scrape_metadata,
|
||||
fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
meta=context.meta_info,
|
||||
mediainfo=context.media_info,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
@@ -119,11 +118,11 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season,
|
||||
"title": context.media_info.title,
|
||||
"year": context.media_info.year,
|
||||
"type": context.media_info.type.value if context.media_info.type else None,
|
||||
"tmdb_id": context.media_info.tmdb_id,
|
||||
"season": context.media_info.season,
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
|
||||
@@ -73,7 +73,7 @@ class SearchMediaTool(MoviePilotTool):
|
||||
filtered_results.append(result)
|
||||
|
||||
if filtered_results:
|
||||
# 限制最多30条结果
|
||||
# 搜索结果只返回前 30 条,后续可通过更精确的年份/类型条件缩小范围。
|
||||
total_count = len(filtered_results)
|
||||
limited_results = filtered_results[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
@@ -96,8 +96,8 @@ class SearchMediaTool(MoviePilotTool):
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 100:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 100 条结果。\n\n{result_json}"
|
||||
if total_count > len(limited_results):
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 {len(limited_results)} 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
|
||||
@@ -35,7 +35,7 @@ class SearchPersonTool(MoviePilotTool):
|
||||
persons = await media_chain.async_search_persons(name=name)
|
||||
|
||||
if persons:
|
||||
# 限制最多30条结果
|
||||
# 人物搜索结果只返回前 30 条,避免 biography/别名等字段挤占上下文。
|
||||
total_count = len(persons)
|
||||
limited_persons = persons[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
@@ -72,8 +72,8 @@ class SearchPersonTool(MoviePilotTool):
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
if total_count > len(limited_persons):
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 {len(limited_persons)} 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关人物信息: {name}"
|
||||
|
||||
@@ -28,7 +28,7 @@ class SearchWebInput(BaseModel):
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
20,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)",
|
||||
description="Maximum number of search results to return (default: 20, max: 20)",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class SendLocalFileInput(BaseModel):
|
||||
|
||||
class SendLocalFileTool(MoviePilotTool):
|
||||
name: str = "send_local_file"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Send a local image or file from the server filesystem to the current user. "
|
||||
"Use this when you have generated or identified a local file the user should download."
|
||||
|
||||
@@ -37,6 +37,7 @@ class SendMessageInput(BaseModel):
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
sends_message: bool = True
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -5,13 +5,11 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendVoiceMessageInput(BaseModel):
|
||||
@@ -29,10 +27,12 @@ class SendVoiceMessageInput(BaseModel):
|
||||
|
||||
class SendVoiceMessageTool(MoviePilotTool):
|
||||
name: str = "send_voice_message"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Send a voice reply to the current user. Prefer this when the user sent a voice message "
|
||||
"or when spoken playback is more natural. On channels without voice support or when TTS "
|
||||
"is unavailable, it automatically falls back to sending the same content as plain text."
|
||||
"Send a voice reply to the current user. Use this only when the user explicitly asks for "
|
||||
"a voice reply or when spoken playback is clearly better than plain text. On channels "
|
||||
"without voice support or when TTS is unavailable, it automatically falls back to sending "
|
||||
"the same content as plain text."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendVoiceMessageInput
|
||||
require_admin: bool = False
|
||||
@@ -43,18 +43,6 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
message = message[:40] + "..."
|
||||
return f"发送语音回复: {message}"
|
||||
|
||||
def _supports_real_voice_reply(self) -> bool:
|
||||
channel = self._channel or ""
|
||||
if channel == MessageChannel.Telegram.value:
|
||||
return True
|
||||
if channel != MessageChannel.Wechat.value:
|
||||
return False
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != self._source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
async def run(self, message: str, **kwargs) -> str:
|
||||
if not message:
|
||||
return "语音回复内容不能为空"
|
||||
@@ -62,11 +50,25 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
reply_mode = AgentCapabilityManager.resolve_reply_mode(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
)
|
||||
fallback_reason = "当前渠道不支持语音回复"
|
||||
if not AgentCapabilityManager.supports_audio_output():
|
||||
fallback_reason = "当前未启用音频输出"
|
||||
if (
|
||||
reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE
|
||||
and AgentCapabilityManager.is_audio_output_available()
|
||||
):
|
||||
voice_file = await asyncio.to_thread(
|
||||
AgentCapabilityManager.synthesize_speech, message
|
||||
)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
elif reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE:
|
||||
fallback_reason = "当前未配置可用的语音合成能力"
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
|
||||
@@ -85,7 +87,11 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
username=self._username,
|
||||
text=message,
|
||||
voice_path=voice_path,
|
||||
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
|
||||
voice_caption=(
|
||||
message
|
||||
if voice_path and settings.AUDIO_OUTPUT_INCLUDE_TEXT
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
@@ -93,4 +99,4 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
|
||||
if used_voice:
|
||||
return "语音回复已发送"
|
||||
return "当前未使用语音通道,已自动回退为文字回复"
|
||||
return f"{fallback_reason},已自动回退为文字回复"
|
||||
|
||||
62
app/agent/tools/impl/switch_persona.py
Normal file
62
app/agent/tools/impl/switch_persona.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""切换当前激活人格工具。"""
|
||||
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SwitchPersonaInput(BaseModel):
|
||||
"""切换人格工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
persona_id: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"The target persona to activate. This can be the exact persona_id, label, or one of the persona aliases. "
|
||||
"If the exact persona is unclear, call query_personas first."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SwitchPersonaTool(MoviePilotTool):
|
||||
name: str = "switch_persona"
|
||||
description: str = (
|
||||
"Switch the active persona (人格) used by the agent runtime. "
|
||||
"This change is persistent for future turns. "
|
||||
"Use this when the user explicitly asks to change the speaking style, tone, or response persona. "
|
||||
"If the user asks for a vague style and you are not sure which persona matches best, call query_personas first."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SwitchPersonaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> str:
|
||||
persona_id = kwargs.get("persona_id") or "未知人格"
|
||||
return f"切换人格: {persona_id}"
|
||||
|
||||
async def run(self, persona_id: str, **kwargs) -> str:
|
||||
logger.info("执行工具: %s, 参数: persona_id=%s", self.name, persona_id)
|
||||
try:
|
||||
runtime_config = agent_runtime_manager.set_active_persona(persona_id)
|
||||
payload = {
|
||||
"success": True,
|
||||
"active_persona": runtime_config.active_persona,
|
||||
"persona": runtime_config.persona.to_dict(is_active=True),
|
||||
"message": f"已切换为人格 `{runtime_config.active_persona}`",
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("切换人格失败: %s", e, exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"切换人格时发生错误: {str(e)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -6,7 +6,6 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem, MediaType
|
||||
|
||||
@@ -124,6 +123,8 @@ class TransferFileTool(MoviePilotTool):
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
from app.chain.transfer import TransferChain
|
||||
|
||||
state, errormsg = TransferChain().manual_transfer(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
|
||||
84
app/agent/tools/impl/uninstall_plugin.py
Normal file
84
app/agent/tools/impl/uninstall_plugin.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""卸载插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
list_installed_plugins,
|
||||
summarize_plugin,
|
||||
uninstall_plugin_runtime,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UninstallPluginInput(BaseModel):
|
||||
"""卸载插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="Exact plugin ID to uninstall. Use query_installed_plugins first to find the correct plugin_id.",
|
||||
)
|
||||
|
||||
|
||||
class UninstallPluginTool(MoviePilotTool):
|
||||
name: str = "uninstall_plugin"
|
||||
description: str = (
|
||||
"Uninstall an installed plugin by exact plugin_id. "
|
||||
"Use query_installed_plugins first when you need filtering or discovery."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = UninstallPluginInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
plugin_id = kwargs.get("plugin_id")
|
||||
return f"卸载插件: {plugin_id or '未知插件'}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
|
||||
try:
|
||||
plugins = list_installed_plugins()
|
||||
if not plugins:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前没有已安装的插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
candidate = next((plugin for plugin in plugins if plugin.id == plugin_id), None)
|
||||
if not candidate:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"未找到已安装插件: {plugin_id}。请先调用 query_installed_plugins 确认 plugin_id。",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
cleanup_result = await uninstall_plugin_runtime(candidate.id)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"插件 {candidate.id} 已卸载",
|
||||
"plugin": summarize_plugin(candidate),
|
||||
**cleanup_result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"卸载插件失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"卸载插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
190
app/agent/tools/impl/update_custom_filter_rule.py
Normal file
190
app/agent/tools/impl/update_custom_filter_rule.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""更新自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
normalize_custom_rule,
|
||||
replace_rule_id_in_rule_string,
|
||||
save_system_config,
|
||||
serialize_custom_rule,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class UpdateCustomFilterRuleInput(BaseModel):
|
||||
"""更新自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
current_rule_id: str = Field(
|
||||
..., description="Existing custom rule ID to update."
|
||||
)
|
||||
new_rule_id: Optional[str] = Field(
|
||||
None,
|
||||
description="New rule ID. If omitted, keep the original rule ID.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="New display name. If omitted, keep the original name."
|
||||
)
|
||||
include: Optional[str] = Field(
|
||||
None,
|
||||
description="New include regex. Pass an empty string to clear it.",
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None,
|
||||
description="New exclude regex. Pass an empty string to clear it.",
|
||||
)
|
||||
size_range: Optional[str] = Field(
|
||||
None,
|
||||
description="New size range in MB. Pass an empty string to clear it.",
|
||||
)
|
||||
seeders: Optional[str] = Field(
|
||||
None,
|
||||
description="New minimum seeder count. Pass an empty string to clear it.",
|
||||
)
|
||||
publish_time: Optional[str] = Field(
|
||||
None,
|
||||
description="New publish-time filter in minutes. Pass an empty string to clear it.",
|
||||
)
|
||||
|
||||
|
||||
class UpdateCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "update_custom_filter_rule"
|
||||
description: str = (
|
||||
"Update an existing custom filter rule. "
|
||||
"If the rule ID is renamed, all rule groups that reference the old ID are updated automatically."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdateCustomFilterRuleInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
current_rule_id = kwargs.get("current_rule_id", "")
|
||||
new_rule_id = kwargs.get("new_rule_id")
|
||||
if new_rule_id and new_rule_id != current_rule_id:
|
||||
return f"更新自定义过滤规则 {current_rule_id} -> {new_rule_id}"
|
||||
return f"更新自定义过滤规则 {current_rule_id}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
current_rule_id: str,
|
||||
new_rule_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
size_range: Optional[str] = None,
|
||||
seeders: Optional[str] = None,
|
||||
publish_time: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, current_rule_id={current_rule_id}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
rule_map = {rule.id: rule for rule in custom_rules if rule.id}
|
||||
current_rule = rule_map.get(current_rule_id)
|
||||
if not current_rule:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"自定义过滤规则 '{current_rule_id}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
updated_rule = normalize_custom_rule(
|
||||
rule_id=new_rule_id or current_rule.id,
|
||||
name=name if name is not None else current_rule.name,
|
||||
include=include if include is not None else current_rule.include,
|
||||
exclude=exclude if exclude is not None else current_rule.exclude,
|
||||
size_range=(
|
||||
size_range if size_range is not None else current_rule.size_range
|
||||
),
|
||||
seeders=seeders if seeders is not None else current_rule.seeders,
|
||||
publish_time=(
|
||||
publish_time
|
||||
if publish_time is not None
|
||||
else current_rule.publish_time
|
||||
),
|
||||
existing_rules=custom_rules,
|
||||
original_rule_id=current_rule.id,
|
||||
)
|
||||
|
||||
rule_groups = get_rule_groups()
|
||||
updated_rule_groups = rule_groups
|
||||
renamed_group_refs = []
|
||||
if updated_rule.id != current_rule.id:
|
||||
updated_rule_groups = []
|
||||
for group in rule_groups:
|
||||
if not group.rule_string:
|
||||
updated_rule_groups.append(group)
|
||||
continue
|
||||
new_rule_string = replace_rule_id_in_rule_string(
|
||||
group.rule_string,
|
||||
current_rule.id,
|
||||
updated_rule.id,
|
||||
)
|
||||
if new_rule_string == group.rule_string:
|
||||
updated_rule_groups.append(group)
|
||||
continue
|
||||
renamed_group_refs.append(group.name)
|
||||
updated_rule_groups.append(
|
||||
group.model_copy(update={"rule_string": new_rule_string})
|
||||
)
|
||||
|
||||
# 先保存规则组引用,再保存规则自身,避免在过滤模块重载时出现新规则 ID 尚未同步的问题。
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[
|
||||
group.model_dump(exclude_none=True)
|
||||
for group in updated_rule_groups
|
||||
],
|
||||
)
|
||||
|
||||
final_rules = []
|
||||
for rule in custom_rules:
|
||||
if rule.id == current_rule.id:
|
||||
final_rules.append(updated_rule)
|
||||
else:
|
||||
final_rules.append(rule)
|
||||
|
||||
await save_system_config(
|
||||
SystemConfigKey.CustomFilterRules,
|
||||
[rule.model_dump(exclude_none=True) for rule in final_rules],
|
||||
)
|
||||
|
||||
updated_refs = collect_custom_rule_group_refs(
|
||||
updated_rule_groups,
|
||||
[updated_rule.id],
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已更新自定义过滤规则 {updated_rule.id}",
|
||||
"custom_rule": serialize_custom_rule(
|
||||
updated_rule,
|
||||
updated_refs.get(updated_rule.id),
|
||||
),
|
||||
"rule_groups_updated_for_rule_id_rename": renamed_group_refs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"更新自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"更新自定义过滤规则失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -52,6 +52,7 @@ class UpdateCustomIdentifiersTool(MoviePilotTool):
|
||||
"Lines starting with '#' are comments. "
|
||||
"The replacement target supports: {[tmdbid=xxx;type=movie/tv;s=xxx;e=xxx]} for direct TMDB ID matching."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = UpdateCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
131
app/agent/tools/impl/update_persona_definition.py
Normal file
131
app/agent/tools/impl/update_persona_definition.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""更新人格定义工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UpdatePersonaDefinitionInput(BaseModel):
|
||||
"""更新人格定义工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
persona_id: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Target persona to update. For existing personas this can be persona_id, label, or alias. "
|
||||
"For new personas, provide the new lowercase persona_id."
|
||||
),
|
||||
)
|
||||
label: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional new label shown to users, such as 默认 or 说明型.",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short description of the persona's intended style.",
|
||||
)
|
||||
aliases: Optional[list[str]] = Field(
|
||||
None,
|
||||
description="Optional full replacement list of aliases for this persona.",
|
||||
)
|
||||
instructions: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional full replacement body for PERSONA.md, excluding YAML frontmatter. "
|
||||
"Use this when the persona definition should be rewritten completely."
|
||||
),
|
||||
)
|
||||
append_instructions: Optional[list[str]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional extra persona rules to append to the existing PERSONA body. "
|
||||
"Use this for small adjustments such as '回答更短' or '复杂问题给两步解释'."
|
||||
),
|
||||
)
|
||||
create_if_missing: bool = Field(
|
||||
False,
|
||||
description="Whether to create a new runtime persona if the target persona does not already exist.",
|
||||
)
|
||||
|
||||
|
||||
class UpdatePersonaDefinitionTool(MoviePilotTool):
|
||||
name: str = "update_persona_definition"
|
||||
description: str = (
|
||||
"Create or update a runtime persona definition (人格定义) without manually editing PERSONA.md files. "
|
||||
"Use this when the user explicitly asks to modify how a persona is defined, such as changing tone rules, "
|
||||
"rewriting the persona body, adjusting aliases, or creating a new persona."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdatePersonaDefinitionInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> str:
|
||||
persona_id = kwargs.get("persona_id") or "未知人格"
|
||||
action = "创建/更新人格定义"
|
||||
return f"{action}: {persona_id}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
persona_id: str,
|
||||
label: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
aliases: Optional[list[str]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
append_instructions: Optional[list[str]] = None,
|
||||
create_if_missing: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info("执行工具: %s, 参数: persona_id=%s", self.name, persona_id)
|
||||
if not any(
|
||||
value is not None
|
||||
for value in (label, description, aliases, instructions, append_instructions)
|
||||
):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": "未提供任何要更新的人格定义字段。",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
try:
|
||||
persona, created = agent_runtime_manager.update_persona_definition(
|
||||
persona_id,
|
||||
label=label,
|
||||
description=description,
|
||||
aliases=aliases,
|
||||
instructions=instructions,
|
||||
append_instructions=append_instructions,
|
||||
create_if_missing=create_if_missing,
|
||||
)
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
payload = {
|
||||
"success": True,
|
||||
"created": created,
|
||||
"active_persona": runtime_config.active_persona,
|
||||
"persona": persona.to_dict(
|
||||
is_active=persona.persona_id == runtime_config.active_persona
|
||||
),
|
||||
"message": (
|
||||
f"已创建人格 `{persona.persona_id}`"
|
||||
if created
|
||||
else f"已更新人格 `{persona.persona_id}` 的定义"
|
||||
),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("更新人格定义失败: %s", e, exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"更新人格定义时发生错误: {str(e)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
153
app/agent/tools/impl/update_plugin_config.py
Normal file
153
app/agent/tools/impl/update_plugin_config.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""修改插件配置工具"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import get_plugin_snapshot
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UpdatePluginConfigInput(BaseModel):
|
||||
"""修改插件配置工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to update. Use query_plugin_config first to inspect the current config.",
|
||||
)
|
||||
updates: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Config items to save. By default this tool merges these keys into the existing config "
|
||||
"instead of replacing the whole config."
|
||||
),
|
||||
)
|
||||
remove_keys: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional config keys to remove from the saved plugin config.",
|
||||
)
|
||||
replace: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Whether to replace the entire saved config with 'updates'. "
|
||||
"Default false, which performs a partial merge update."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UpdatePluginConfigTool(MoviePilotTool):
|
||||
name: str = "update_plugin_config"
|
||||
description: str = (
|
||||
"Update the saved configuration of an installed plugin. "
|
||||
"By default this performs a partial merge update and does NOT reload the plugin automatically. "
|
||||
"Call reload_plugin afterwards to apply the latest saved config to the running plugin."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = UpdatePluginConfigInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
replace = kwargs.get("replace", False)
|
||||
action = "覆盖插件配置" if replace else "修改插件配置"
|
||||
return f"{action}: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
async def _update_plugin_config(
|
||||
plugin_id: str,
|
||||
updates: Optional[Dict[str, Any]] = None,
|
||||
remove_keys: Optional[List[str]] = None,
|
||||
replace: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
仅异步保存插件配置,不主动生效,让 Agent 可以先批量改完再显式重载插件。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
remove_keys = remove_keys or []
|
||||
if not replace and not updates and not remove_keys:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供任何需要修改的配置项"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
current_config = dict(plugin_manager.get_plugin_config(plugin_id) or {})
|
||||
|
||||
# merge 模式以当前保存值为基准,replace 模式则从空配置开始重建。
|
||||
next_config = {} if replace else dict(current_config)
|
||||
if updates:
|
||||
next_config.update(updates)
|
||||
for key in remove_keys:
|
||||
next_config.pop(key, None)
|
||||
|
||||
changed_keys = sorted(
|
||||
key
|
||||
for key in set(current_config.keys()) | set(next_config.keys())
|
||||
if current_config.get(key) != next_config.get(key)
|
||||
or (key in current_config) != (key in next_config)
|
||||
)
|
||||
|
||||
if not await plugin_manager.async_save_plugin_config(plugin_id, next_config):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"保存插件 {plugin_id} 配置失败",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"message": "插件配置已保存,请调用 reload_plugin 使最新配置生效",
|
||||
"replace": replace,
|
||||
"changed_keys": changed_keys,
|
||||
"removed_keys": remove_keys,
|
||||
"config_requires_reload": True,
|
||||
"previous_config": current_config,
|
||||
"saved_config": next_config,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
updates: Optional[Dict[str, Any]] = None,
|
||||
remove_keys: Optional[List[str]] = None,
|
||||
replace: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: plugin_id={plugin_id}, replace={replace}"
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._update_plugin_config(
|
||||
plugin_id, updates, remove_keys, replace
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"修改插件配置失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"修改插件配置时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
157
app/agent/tools/impl/update_rule_group.py
Normal file
157
app/agent/tools/impl/update_rule_group.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""更新过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
build_custom_rule_map,
|
||||
collect_rule_group_usages,
|
||||
get_builtin_rules,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
normalize_rule_group,
|
||||
rename_rule_group_references,
|
||||
save_system_config,
|
||||
serialize_rule_group,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class UpdateRuleGroupInput(BaseModel):
|
||||
"""更新过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
current_name: str = Field(..., description="Existing rule group name to update.")
|
||||
new_name: Optional[str] = Field(
|
||||
None,
|
||||
description="New rule group name. If omitted, keep the original name.",
|
||||
)
|
||||
rule_string: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"New rule_string. If omitted, keep the original rule_string. "
|
||||
"Example: 'SPECSUB & CNVOI & 4K & !BLU > CNSUB & CNVOI & 4K & !BLU'."
|
||||
),
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
None,
|
||||
description="New media type scope. Pass an empty string to clear it.",
|
||||
)
|
||||
category: Optional[str] = Field(
|
||||
None,
|
||||
description="New category. Pass an empty string to clear it.",
|
||||
)
|
||||
|
||||
|
||||
class UpdateRuleGroupTool(MoviePilotTool):
|
||||
name: str = "update_rule_group"
|
||||
description: str = (
|
||||
"Update a filter rule group. "
|
||||
"If the rule group name changes, its references in global search/subscription settings and per-subscription bindings are updated automatically. "
|
||||
"Before changing rule_string, first use query_builtin_filter_rules and query_custom_filter_rules to confirm valid rule IDs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdateRuleGroupInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
current_name = kwargs.get("current_name", "")
|
||||
new_name = kwargs.get("new_name")
|
||||
if new_name and new_name != current_name:
|
||||
return f"更新规则组 {current_name} -> {new_name}"
|
||||
return f"更新规则组 {current_name}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
current_name: str,
|
||||
new_name: Optional[str] = None,
|
||||
rule_string: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, current_name={current_name}")
|
||||
|
||||
try:
|
||||
rule_groups = get_rule_groups()
|
||||
group_map = {group.name: group for group in rule_groups if group.name}
|
||||
current_group = group_map.get(current_name)
|
||||
if not current_group:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"规则组 '{current_name}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
available_rule_ids = set(get_builtin_rules().keys()) | set(
|
||||
build_custom_rule_map(get_custom_rules()).keys()
|
||||
)
|
||||
updated_group, _ = normalize_rule_group(
|
||||
name=new_name or current_group.name,
|
||||
rule_string=(
|
||||
rule_string
|
||||
if rule_string is not None
|
||||
else current_group.rule_string
|
||||
),
|
||||
media_type=(
|
||||
media_type
|
||||
if media_type is not None
|
||||
else current_group.media_type
|
||||
),
|
||||
category=(
|
||||
category if category is not None else current_group.category
|
||||
),
|
||||
existing_groups=rule_groups,
|
||||
available_rule_ids=available_rule_ids,
|
||||
original_name=current_group.name,
|
||||
)
|
||||
|
||||
final_groups = []
|
||||
for group in rule_groups:
|
||||
if group.name == current_group.name:
|
||||
final_groups.append(updated_group)
|
||||
else:
|
||||
final_groups.append(group)
|
||||
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[group.model_dump(exclude_none=True) for group in final_groups],
|
||||
)
|
||||
|
||||
reference_changes = {}
|
||||
if updated_group.name != current_group.name:
|
||||
reference_changes = await rename_rule_group_references(
|
||||
current_group.name,
|
||||
updated_group.name,
|
||||
)
|
||||
|
||||
usage = await collect_rule_group_usages([updated_group.name])
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已更新规则组 {updated_group.name}",
|
||||
"rule_group": serialize_rule_group(
|
||||
updated_group, usage.get(updated_group.name)
|
||||
),
|
||||
"reference_updates": reference_changes,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"更新规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"更新规则组失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -74,6 +74,10 @@ class UpdateSubscribeInput(BaseModel):
|
||||
None,
|
||||
description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)",
|
||||
)
|
||||
best_version_full: Optional[int] = Field(
|
||||
None,
|
||||
description="For TV best-version subscriptions, only download full-season packs: 0 for no, 1 for yes (optional)",
|
||||
)
|
||||
custom_words: Optional[str] = Field(
|
||||
None, description="Custom recognition words (optional)"
|
||||
)
|
||||
@@ -140,6 +144,7 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
best_version_full: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
@@ -230,6 +235,8 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
if best_version_full is not None:
|
||||
subscribe_dict["best_version_full"] = best_version_full
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
|
||||
305
app/agent/tools/impl/update_system_settings.py
Normal file
305
app/agent/tools/impl/update_system_settings.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""统一更新系统设置工具。"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._system_setting_utils import (
|
||||
SettingSpec,
|
||||
get_default_list_match_field,
|
||||
resolve_setting_spec,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.event import ConfigChangeEventData
|
||||
from app.schemas.types import EventType
|
||||
|
||||
SettingValue = Optional[Union[list, dict, bool, int, float, str]]
|
||||
|
||||
|
||||
class UpdateSystemSettingsInput(BaseModel):
|
||||
"""更新系统设置工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
setting_key: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Exact setting key to update. Supports Settings field names, SystemConfigKey values, enum names, and common aliases "
|
||||
"such as 'downloaders', 'directories', 'search_sites', 'subscribe_sites', 'site_auth', 'ai_agent', and 'custom_identifiers'."
|
||||
),
|
||||
)
|
||||
value: SettingValue = Field(
|
||||
None,
|
||||
description=(
|
||||
"The new value or list item payload. For replace: this becomes the entire setting value. For merge_dict: this should be a dict of keys to merge. "
|
||||
"For upsert_list_item/remove_list_item: this can be a dict item or a scalar list item."
|
||||
),
|
||||
)
|
||||
operation: Literal[
|
||||
"replace",
|
||||
"merge_dict",
|
||||
"upsert_list_item",
|
||||
"remove_list_item",
|
||||
] = Field(
|
||||
"replace",
|
||||
description=(
|
||||
"Update operation. replace replaces the whole value; merge_dict merges dict keys (optionally with remove_keys); "
|
||||
"upsert_list_item inserts or replaces one item inside a list; remove_list_item removes one item from a list."
|
||||
),
|
||||
)
|
||||
remove_keys: Optional[list[str]] = Field(
|
||||
None,
|
||||
description="Optional dict keys to delete when operation is merge_dict.",
|
||||
)
|
||||
match_field: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional match field for list item upsert/remove. If omitted, common SystemConfig categories use built-in defaults such as 'name' or 'type'."
|
||||
),
|
||||
)
|
||||
match_value: SettingValue = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional explicit value used to locate a list item when operation is upsert_list_item or remove_list_item."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UpdateSystemSettingsTool(MoviePilotTool):
|
||||
name: str = "update_system_settings"
|
||||
description: str = (
|
||||
"Update system settings across both the basic Settings module and all SystemConfig-backed categories. "
|
||||
"Supports full replacement, shallow dict merge, and generic list item upsert/remove so the agent can manage downloaders, media servers, notification channels, storages, directories, search-site ranges, subscribe-site ranges, site auth params, AI agent config, and other system settings through one tool."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = UpdateSystemSettingsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息。"""
|
||||
|
||||
setting_key = kwargs.get("setting_key", "")
|
||||
operation = kwargs.get("operation", "replace")
|
||||
action_map = {
|
||||
"replace": "覆盖系统设置",
|
||||
"merge_dict": "合并系统设置",
|
||||
"upsert_list_item": "更新列表项",
|
||||
"remove_list_item": "移除列表项",
|
||||
}
|
||||
return f"{action_map.get(operation, '更新系统设置')}: {setting_key}"
|
||||
|
||||
@staticmethod
|
||||
def _load_setting_value(spec: SettingSpec):
|
||||
if spec.source == "settings":
|
||||
return getattr(settings, spec.key)
|
||||
return SystemConfigOper().get(spec.key)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_systemconfig_value(value: Any):
|
||||
if isinstance(value, list):
|
||||
filtered = [item for item in value if item is not None]
|
||||
return filtered or None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _resolve_list_match(
|
||||
spec: SettingSpec,
|
||||
operation: str,
|
||||
value: Any,
|
||||
match_field: Optional[str],
|
||||
match_value: Any,
|
||||
) -> tuple[Optional[str], Any]:
|
||||
resolved_field = match_field or get_default_list_match_field(spec.key)
|
||||
resolved_value = match_value
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not resolved_field:
|
||||
raise ValueError(
|
||||
f"{operation} 需要提供 match_field,或使用带默认匹配字段的系统配置项"
|
||||
)
|
||||
if resolved_value is None:
|
||||
resolved_value = value.get(resolved_field)
|
||||
if resolved_value is None:
|
||||
raise ValueError(
|
||||
f"{operation} 缺少匹配值,请在 value.{resolved_field} 或 match_value 中提供"
|
||||
)
|
||||
else:
|
||||
if resolved_value is None:
|
||||
resolved_value = value
|
||||
|
||||
return resolved_field, resolved_value
|
||||
|
||||
@classmethod
|
||||
def _prepare_next_value(
|
||||
cls,
|
||||
spec: SettingSpec,
|
||||
current_value: Any,
|
||||
value: Any,
|
||||
operation: str,
|
||||
remove_keys: Optional[list[str]] = None,
|
||||
match_field: Optional[str] = None,
|
||||
match_value: Any = None,
|
||||
) -> Any:
|
||||
remove_keys = remove_keys or []
|
||||
if operation == "replace":
|
||||
return value
|
||||
|
||||
if operation == "merge_dict":
|
||||
if remove_keys and not isinstance(remove_keys, list):
|
||||
raise ValueError("remove_keys 必须是字符串列表")
|
||||
if current_value is not None and not isinstance(current_value, dict):
|
||||
raise ValueError("merge_dict 仅支持当前值为 dict 的设置项")
|
||||
if value is not None and not isinstance(value, dict):
|
||||
raise ValueError("merge_dict 的 value 必须是 dict 或 null")
|
||||
next_value = dict(current_value or {})
|
||||
if value:
|
||||
next_value.update(value)
|
||||
for key in remove_keys:
|
||||
next_value.pop(key, None)
|
||||
return next_value
|
||||
|
||||
if operation not in {"upsert_list_item", "remove_list_item"}:
|
||||
raise ValueError(f"不支持的操作: {operation}")
|
||||
|
||||
if current_value is not None and not isinstance(current_value, list):
|
||||
raise ValueError(f"{operation} 仅支持当前值为 list 的设置项")
|
||||
|
||||
next_items = list(copy.deepcopy(current_value or []))
|
||||
resolved_field, resolved_match_value = cls._resolve_list_match(
|
||||
spec, operation, value, match_field, match_value
|
||||
)
|
||||
|
||||
if operation == "upsert_list_item":
|
||||
if value is None:
|
||||
raise ValueError("upsert_list_item 必须提供 value")
|
||||
replaced = False
|
||||
for index, item in enumerate(next_items):
|
||||
if resolved_field:
|
||||
if isinstance(item, dict) and item.get(resolved_field) == resolved_match_value:
|
||||
next_items[index] = value
|
||||
replaced = True
|
||||
break
|
||||
elif item == resolved_match_value:
|
||||
next_items[index] = value
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
next_items.append(value)
|
||||
return next_items
|
||||
|
||||
return [
|
||||
item
|
||||
for item in next_items
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and resolved_field
|
||||
and item.get(resolved_field) == resolved_match_value
|
||||
)
|
||||
and not (not resolved_field and item == resolved_match_value)
|
||||
]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
setting_key: str,
|
||||
value: SettingValue = None,
|
||||
operation: str = "replace",
|
||||
remove_keys: Optional[list[str]] = None,
|
||||
match_field: Optional[str] = None,
|
||||
match_value: SettingValue = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
"执行工具: %s, setting_key=%s, operation=%s",
|
||||
self.name,
|
||||
setting_key,
|
||||
operation,
|
||||
)
|
||||
|
||||
try:
|
||||
spec = resolve_setting_spec(setting_key)
|
||||
if not spec:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"系统设置项 '{setting_key}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
current_value = self._load_setting_value(spec)
|
||||
next_value = self._prepare_next_value(
|
||||
spec=spec,
|
||||
current_value=current_value,
|
||||
value=value,
|
||||
operation=operation,
|
||||
remove_keys=remove_keys,
|
||||
match_field=match_field,
|
||||
match_value=match_value,
|
||||
)
|
||||
|
||||
event_value = next_value
|
||||
changed = False
|
||||
message = ""
|
||||
if spec.source == "settings":
|
||||
success, message = settings.update_setting(spec.key, next_value)
|
||||
if success is False:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": message or f"更新设置 {spec.key} 失败",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
changed = success is True
|
||||
else:
|
||||
normalized_value = self._normalize_systemconfig_value(next_value)
|
||||
event_value = normalized_value
|
||||
success = await SystemConfigOper().async_set(spec.key, normalized_value)
|
||||
changed = success is True
|
||||
|
||||
if changed:
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(
|
||||
key=spec.key,
|
||||
value=event_value,
|
||||
change_type="update",
|
||||
),
|
||||
)
|
||||
|
||||
saved_value = self._load_setting_value(spec)
|
||||
if not changed and not message:
|
||||
message = "配置值未发生变化"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": message or f"系统设置 {spec.key} 已更新",
|
||||
"changed": changed,
|
||||
"operation": operation,
|
||||
"setting": {
|
||||
"setting_key": spec.key,
|
||||
"source": spec.source,
|
||||
"group": spec.group,
|
||||
"label": spec.label,
|
||||
},
|
||||
"previous_value": current_value,
|
||||
"saved_value": saved_value,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"更新系统设置失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"更新系统设置时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent.tools.base import format_tool_result_for_agent
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
@@ -22,7 +23,12 @@ class MoviePilotToolsManager:
|
||||
MoviePilot工具管理器(用于HTTP API)
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "api_user",
|
||||
session_id: str = uuid.uuid4(),
|
||||
is_admin: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
@@ -32,6 +38,7 @@ class MoviePilotToolsManager:
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.is_admin = is_admin
|
||||
self.tools: List[Any] = []
|
||||
self._load_tools()
|
||||
|
||||
@@ -63,6 +70,8 @@ class MoviePilotToolsManager:
|
||||
"""
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
if getattr(tool, "_require_admin", False) and not self.is_admin:
|
||||
continue
|
||||
# 获取工具的输入参数模型
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
if args_schema:
|
||||
@@ -214,6 +223,13 @@ class MoviePilotToolsManager:
|
||||
|
||||
return normalized
|
||||
|
||||
def _check_tool_permission(self, tool_instance: Any) -> Optional[str]:
|
||||
"""为 HTTP/MCP/CLI 入口补齐 require_admin 门禁。"""
|
||||
|
||||
if getattr(tool_instance, "_require_admin", False) and not self.is_admin:
|
||||
return "抱歉,您没有执行此工具的权限。只有系统管理员才能执行工具操作。"
|
||||
return None
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
@@ -234,25 +250,21 @@ class MoviePilotToolsManager:
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
permission_error = self._check_tool_permission(tool_instance)
|
||||
if permission_error:
|
||||
return json.dumps({"error": permission_error}, ensure_ascii=False)
|
||||
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
# 调用工具的run方法
|
||||
# 调用工具的run方法。HTTP/MCP 工具调用不会经过 BaseTool._arun,
|
||||
# 因此这里也必须复用同一套返回值格式化和兜底截断逻辑。
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
try:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示")
|
||||
formated_result = str(result)
|
||||
|
||||
return formated_result
|
||||
return format_tool_result_for_agent(
|
||||
result,
|
||||
tool_name=tool_name,
|
||||
max_chars=getattr(tool_instance, "result_max_chars", None),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps(
|
||||
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic, llm, notification
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
@@ -18,6 +18,8 @@ api_router.include_router(douban.router, prefix="/douban", tags=["douban"])
|
||||
api_router.include_router(tmdb.router, prefix="/tmdb", tags=["tmdb"])
|
||||
api_router.include_router(history.router, prefix="/history", tags=["history"])
|
||||
api_router.include_router(system.router, prefix="/system", tags=["system"])
|
||||
api_router.include_router(notification.router, prefix="/notification", tags=["notification"])
|
||||
api_router.include_router(llm.router, prefix="/llm", tags=["llm"])
|
||||
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])
|
||||
api_router.include_router(download.router, prefix="/download", tags=["download"])
|
||||
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["dashboard"])
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
@@ -11,9 +10,12 @@ from app import schemas
|
||||
from app.api.endpoints.openai import (
|
||||
MODEL_ID,
|
||||
_CollectingMoviePilotAgent,
|
||||
_error_response as _openai_error_response,
|
||||
)
|
||||
from app.api.openai_utils import build_anthropic_messages, build_prompt, build_session_id
|
||||
from app.api.openai_utils import (
|
||||
build_anthropic_messages,
|
||||
build_prompt,
|
||||
build_session_id,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.security import anthropic_api_key_header
|
||||
from app.schemas.types import MessageChannel
|
||||
@@ -91,7 +93,11 @@ async def _stream_anthropic_response(
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/messages", summary="Anthropic compatible messages", response_model=schemas.AnthropicMessagesResponse)
|
||||
@router.post(
|
||||
"/messages",
|
||||
summary="Anthropic compatible messages",
|
||||
response_model=schemas.AnthropicMessagesResponse,
|
||||
)
|
||||
async def messages(
|
||||
payload: schemas.AnthropicMessagesRequest,
|
||||
x_api_key: Optional[str] = Security(anthropic_api_key_header),
|
||||
|
||||
@@ -10,60 +10,82 @@ from app.core.security import verify_token
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/credits/{bangumiid}", summary="查询Bangumi演职员表", response_model=List[schemas.MediaPerson])
|
||||
async def bangumi_credits(bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/credits/{bangumiid}",
|
||||
summary="查询Bangumi演职员表",
|
||||
response_model=List[schemas.MediaPerson],
|
||||
)
|
||||
async def bangumi_credits(
|
||||
bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
查询Bangumi演职员表
|
||||
"""
|
||||
persons = await BangumiChain().async_bangumi_credits(bangumiid)
|
||||
if persons:
|
||||
return persons[(page - 1) * count: page * count]
|
||||
return persons[(page - 1) * count : page * count]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/recommend/{bangumiid}", summary="查询Bangumi推荐", response_model=List[schemas.MediaInfo])
|
||||
async def bangumi_recommend(bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/recommend/{bangumiid}",
|
||||
summary="查询Bangumi推荐",
|
||||
response_model=List[schemas.MediaInfo],
|
||||
)
|
||||
async def bangumi_recommend(
|
||||
bangumiid: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
查询Bangumi推荐
|
||||
"""
|
||||
medias = await BangumiChain().async_bangumi_recommend(bangumiid)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
|
||||
return [media.to_dict() for media in medias[(page - 1) * count : page * count]]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
|
||||
async def bangumi_person(person_id: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson
|
||||
)
|
||||
async def bangumi_person(
|
||||
person_id: int, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物详情
|
||||
"""
|
||||
return await BangumiChain().async_person_detail(person_id=person_id)
|
||||
|
||||
|
||||
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
|
||||
async def bangumi_person_credits(person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/person/credits/{person_id}",
|
||||
summary="人物参演作品",
|
||||
response_model=List[schemas.MediaInfo],
|
||||
)
|
||||
async def bangumi_person_credits(
|
||||
person_id: int,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 20,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
根据人物ID查询人物参演作品
|
||||
"""
|
||||
medias = await BangumiChain().async_person_credits(person_id=person_id)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
|
||||
return [media.to_dict() for media in medias[(page - 1) * count : page * count]]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/{bangumiid}", summary="查询Bangumi详情", response_model=schemas.MediaInfo)
|
||||
async def bangumi_info(bangumiid: int,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi_info(
|
||||
bangumiid: int, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
查询Bangumi详情
|
||||
"""
|
||||
|
||||
@@ -18,11 +18,15 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/statistic", summary="媒体数量统计", response_model=schemas.Statistic)
|
||||
def statistic(name: Optional[str] = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
def statistic(
|
||||
name: Optional[str] = None, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
查询媒体数量统计信息
|
||||
"""
|
||||
media_statistics: Optional[List[schemas.Statistic]] = DashboardChain().media_statistic(name)
|
||||
media_statistics: Optional[List[schemas.Statistic]] = (
|
||||
DashboardChain().media_statistic(name)
|
||||
)
|
||||
if media_statistics:
|
||||
# 汇总各媒体库统计信息
|
||||
ret_statistic = schemas.Statistic()
|
||||
@@ -42,7 +46,9 @@ def statistic(name: Optional[str] = None, _: schemas.TokenPayload = Depends(veri
|
||||
return schemas.Statistic()
|
||||
|
||||
|
||||
@router.get("/statistic2", summary="媒体数量统计(API_TOKEN)", response_model=schemas.Statistic)
|
||||
@router.get(
|
||||
"/statistic2", summary="媒体数量统计(API_TOKEN)", response_model=schemas.Statistic
|
||||
)
|
||||
def statistic2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询媒体数量统计信息 API_TOKEN认证(?token=xxx)
|
||||
@@ -65,13 +71,12 @@ def storage(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
if _usage:
|
||||
total += _usage.total
|
||||
available += _usage.available
|
||||
return schemas.Storage(
|
||||
total_storage=total,
|
||||
used_storage=total - available
|
||||
)
|
||||
return schemas.Storage(total_storage=total, used_storage=total - available)
|
||||
|
||||
|
||||
@router.get("/storage2", summary="本地存储空间(API_TOKEN)", response_model=schemas.Storage)
|
||||
@router.get(
|
||||
"/storage2", summary="本地存储空间(API_TOKEN)", response_model=schemas.Storage
|
||||
)
|
||||
def storage2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询本地存储空间信息 API_TOKEN认证(?token=xxx)
|
||||
@@ -88,13 +93,17 @@ def processes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.get("/downloader", summary="下载器信息", response_model=schemas.DownloaderInfo)
|
||||
def downloader(name: Optional[str] = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
def downloader(
|
||||
name: Optional[str] = None, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
查询下载器信息
|
||||
"""
|
||||
# 下载目录空间
|
||||
download_dirs = DirectoryHelper().get_local_download_dirs()
|
||||
_, free_space = SystemUtils.space_usage([Path(d.download_path) for d in download_dirs])
|
||||
_, free_space = SystemUtils.space_usage(
|
||||
[Path(d.download_path) for d in download_dirs]
|
||||
)
|
||||
# 下载器信息
|
||||
downloader_info = schemas.DownloaderInfo()
|
||||
transfer_infos = DashboardChain().downloader_info(name)
|
||||
@@ -108,7 +117,11 @@ def downloader(name: Optional[str] = None, _: schemas.TokenPayload = Depends(ver
|
||||
return downloader_info
|
||||
|
||||
|
||||
@router.get("/downloader2", summary="下载器信息(API_TOKEN)", response_model=schemas.DownloaderInfo)
|
||||
@router.get(
|
||||
"/downloader2",
|
||||
summary="下载器信息(API_TOKEN)",
|
||||
response_model=schemas.DownloaderInfo,
|
||||
)
|
||||
def downloader2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询下载器信息 API_TOKEN认证(?token=xxx)
|
||||
@@ -124,7 +137,11 @@ async def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return Scheduler().list()
|
||||
|
||||
|
||||
@router.get("/schedule2", summary="后台服务(API_TOKEN)", response_model=List[schemas.ScheduleInfo])
|
||||
@router.get(
|
||||
"/schedule2",
|
||||
summary="后台服务(API_TOKEN)",
|
||||
response_model=List[schemas.ScheduleInfo],
|
||||
)
|
||||
async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询下载器信息 API_TOKEN认证(?token=xxx)
|
||||
@@ -133,9 +150,11 @@ async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
|
||||
|
||||
@router.get("/transfer", summary="文件整理统计", response_model=List[int])
|
||||
async def transfer(days: Optional[int] = 7,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def transfer(
|
||||
days: Optional[int] = 7,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
查询文件整理统计信息
|
||||
"""
|
||||
@@ -167,7 +186,11 @@ def memory(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return SystemUtils.memory_usage()
|
||||
|
||||
|
||||
@router.get("/memory2", summary="获取当前内存使用量和使用率(API_TOKEN)", response_model=List[int])
|
||||
@router.get(
|
||||
"/memory2",
|
||||
summary="获取当前内存使用量和使用率(API_TOKEN)",
|
||||
response_model=List[int],
|
||||
)
|
||||
def memory2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取当前内存使用率 API_TOKEN认证(?token=xxx)
|
||||
@@ -183,7 +206,9 @@ def network(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return SystemUtils.network_usage()
|
||||
|
||||
|
||||
@router.get("/network2", summary="获取当前网络流量(API_TOKEN)", response_model=List[int])
|
||||
@router.get(
|
||||
"/network2", summary="获取当前网络流量(API_TOKEN)", response_model=List[int]
|
||||
)
|
||||
def network2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取当前网络流量 API_TOKEN认证(?token=xxx)
|
||||
|
||||
@@ -14,7 +14,11 @@ from app.schemas.types import ChainEventType, MediaType
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/source", summary="获取探索数据源", response_model=List[schemas.DiscoverMediaSource])
|
||||
@router.get(
|
||||
"/source",
|
||||
summary="获取探索数据源",
|
||||
response_model=List[schemas.DiscoverMediaSource],
|
||||
)
|
||||
def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取探索数据源
|
||||
@@ -31,100 +35,123 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.get("/bangumi", summary="探索Bangumi", response_model=List[schemas.MediaInfo])
|
||||
async def bangumi(type: Optional[int] = 2,
|
||||
cat: Optional[int] = None,
|
||||
sort: Optional[str] = 'rank',
|
||||
year: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def bangumi(
|
||||
type: Optional[int] = 2,
|
||||
cat: Optional[int] = None,
|
||||
sort: Optional[str] = "rank",
|
||||
year: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
探索Bangumi
|
||||
"""
|
||||
medias = await BangumiChain().async_discover(type=type, cat=cat, sort=sort, year=year,
|
||||
limit=count, offset=(page - 1) * count)
|
||||
medias = await BangumiChain().async_discover(
|
||||
type=type, cat=cat, sort=sort, year=year, limit=count, offset=(page - 1) * count
|
||||
)
|
||||
if medias:
|
||||
return [media.to_dict() for media in medias]
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo])
|
||||
async def douban_movies(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo]
|
||||
)
|
||||
async def douban_movies(
|
||||
sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
浏览豆瓣电影信息
|
||||
"""
|
||||
movies = await DoubanChain().async_douban_discover(mtype=MediaType.MOVIE,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
movies = await DoubanChain().async_douban_discover(
|
||||
mtype=MediaType.MOVIE, sort=sort, tags=tags, page=page, count=count
|
||||
)
|
||||
return [media.to_dict() for media in movies] if movies else []
|
||||
|
||||
|
||||
@router.get("/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo])
|
||||
async def douban_tvs(sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo]
|
||||
)
|
||||
async def douban_tvs(
|
||||
sort: Optional[str] = "R",
|
||||
tags: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
浏览豆瓣剧集信息
|
||||
"""
|
||||
tvs = await DoubanChain().async_douban_discover(mtype=MediaType.TV,
|
||||
sort=sort, tags=tags, page=page, count=count)
|
||||
tvs = await DoubanChain().async_douban_discover(
|
||||
mtype=MediaType.TV, sort=sort, tags=tags, page=page, count=count
|
||||
)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
|
||||
@router.get("/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo])
|
||||
async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo]
|
||||
)
|
||||
async def tmdb_movies(
|
||||
sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
浏览TMDB电影信息
|
||||
"""
|
||||
movies = await TmdbChain().async_tmdb_discover(mtype=MediaType.MOVIE,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
movies = await TmdbChain().async_tmdb_discover(
|
||||
mtype=MediaType.MOVIE,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page,
|
||||
)
|
||||
return [movie.to_dict() for movie in movies] if movies else []
|
||||
|
||||
|
||||
@router.get("/tmdb_tvs", summary="探索TMDB剧集", response_model=List[schemas.MediaInfo])
|
||||
async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
async def tmdb_tvs(
|
||||
sort_by: Optional[str] = "popularity.desc",
|
||||
with_genres: Optional[str] = "",
|
||||
with_original_language: Optional[str] = "",
|
||||
with_keywords: Optional[str] = "",
|
||||
with_watch_providers: Optional[str] = "",
|
||||
vote_average: Optional[float] = 0.0,
|
||||
vote_count: Optional[int] = 0,
|
||||
release_date: Optional[str] = "",
|
||||
page: Optional[int] = 1,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
浏览TMDB剧集信息
|
||||
"""
|
||||
tvs = await TmdbChain().async_tmdb_discover(mtype=MediaType.TV,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page)
|
||||
tvs = await TmdbChain().async_tmdb_discover(
|
||||
mtype=MediaType.TV,
|
||||
sort_by=sort_by,
|
||||
with_genres=with_genres,
|
||||
with_original_language=with_original_language,
|
||||
with_keywords=with_keywords,
|
||||
with_watch_providers=with_watch_providers,
|
||||
vote_average=vote_average,
|
||||
vote_count=vote_count,
|
||||
release_date=release_date,
|
||||
page=page,
|
||||
)
|
||||
return [tv.to_dict() for tv in tvs] if tvs else []
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user