mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-08 19:32:40 +08:00
Compare commits
573 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f08a7b9eb3 | ||
|
|
a6fa764e2a | ||
|
|
01676668f1 | ||
|
|
8e5e4f460d | ||
|
|
f907b8a84d | ||
|
|
a3a4285f90 | ||
|
|
0979163b79 | ||
|
|
248a25eaee | ||
|
|
f95b1fa68a | ||
|
|
d2b5d69051 | ||
|
|
3ca419b735 | ||
|
|
50e275a2f9 | ||
|
|
aeccf78957 | ||
|
|
cb3cef70e5 | ||
|
|
b9bd303bf8 | ||
|
|
57d4786a7f | ||
|
|
df031455b2 | ||
|
|
30059eff4f | ||
|
|
bc289b48c8 | ||
|
|
067d8b99b8 | ||
|
|
00a6a9c42d | ||
|
|
070425d446 | ||
|
|
7405883444 | ||
|
|
66959937ed | ||
|
|
e431efbcba | ||
|
|
ba00baa5a0 | ||
|
|
0fb5d4a164 | ||
|
|
1ac717b67f | ||
|
|
273cbd447e | ||
|
|
cee41567a2 | ||
|
|
1aae5eb1a6 | ||
|
|
28a4c81aff | ||
|
|
5e077cd64d | ||
|
|
e3f957a59b | ||
|
|
55c62a3ab5 | ||
|
|
22e7eef1bd | ||
|
|
d6524907f3 | ||
|
|
357db334cd | ||
|
|
f8bed3909b | ||
|
|
182bbdde91 | ||
|
|
2c70f990c2 | ||
|
|
0b01a6aa91 | ||
|
|
e557dffbc6 | ||
|
|
7f33b0b1b8 | ||
|
|
41ddf77a5b | ||
|
|
8c657ce41d | ||
|
|
3ff3b9ed4a | ||
|
|
ef43419ecd | ||
|
|
2ca375c214 | ||
|
|
cbd45c1d0f | ||
|
|
2592ea3464 | ||
|
|
73ac97cd96 | ||
|
|
e014663e97 | ||
|
|
58592e961f | ||
|
|
9a99b9ce82 | ||
|
|
8c6dca1751 | ||
|
|
cf488d5f5f | ||
|
|
515584d34c | ||
|
|
fb2becc7f2 | ||
|
|
0f8ceb0fac | ||
|
|
a70bf18770 | ||
|
|
2de83c44ab | ||
|
|
7b99f09810 | ||
|
|
6b4ba8bfad | ||
|
|
0c6cfc5020 | ||
|
|
abd9733e7f | ||
|
|
98c3ae5e76 | ||
|
|
bb5a657469 | ||
|
|
7797532350 | ||
|
|
c3a5106adc | ||
|
|
c5fd935dd0 | ||
|
|
ec375a19ae | ||
|
|
51e940617c | ||
|
|
58ec8bd437 | ||
|
|
a096395086 | ||
|
|
4bd08bd915 | ||
|
|
2c849cfa7a | ||
|
|
501d530d1d | ||
|
|
91fc4327f4 | ||
|
|
8d56c67079 | ||
|
|
e52d43458e | ||
|
|
9b125bf9b0 | ||
|
|
0716c65269 | ||
|
|
ba3ce4f1b5 | ||
|
|
07f72b0cdc | ||
|
|
bda19df87f | ||
|
|
5d82fae2b0 | ||
|
|
0813b87221 | ||
|
|
961ecfc720 | ||
|
|
81f30ef25a | ||
|
|
140b0d3df2 | ||
|
|
b3d69d7de4 | ||
|
|
8e65564fb8 | ||
|
|
06ce9bd4de | ||
|
|
274fc2d74f | ||
|
|
2f1a448afe | ||
|
|
99cab7c337 | ||
|
|
81f7548579 | ||
|
|
6ebd50bebc | ||
|
|
378ba51f4d | ||
|
|
63a890e85d | ||
|
|
bf4f9921e2 | ||
|
|
167ae65695 | ||
|
|
2affa7c9b8 | ||
|
|
785540e178 | ||
|
|
bcad4c0bc6 | ||
|
|
5af217fbf5 | ||
|
|
128aa2ef23 | ||
|
|
fce1186dd1 | ||
|
|
9a7b11f804 | ||
|
|
b068a06fa8 | ||
|
|
931a42e981 | ||
|
|
e0a20a6697 | ||
|
|
1ef4374899 | ||
|
|
3b7212740b | ||
|
|
4b80b8dc1f | ||
|
|
b7f24827e6 | ||
|
|
1c08a22881 | ||
|
|
8bd848519d | ||
|
|
e19f2aa76d | ||
|
|
4a99e2896f | ||
|
|
de3c83b0aa | ||
|
|
36bdb831be | ||
|
|
1809690915 | ||
|
|
e51b679380 | ||
|
|
10c26de7cb | ||
|
|
ca5ec8af0f | ||
|
|
d1d7b8ce55 | ||
|
|
77f8983307 | ||
|
|
ba415acd37 | ||
|
|
bcf13099ac | ||
|
|
eb2b34d71c | ||
|
|
d0b665f773 | ||
|
|
a1674b1ae5 | ||
|
|
af83681f6a | ||
|
|
bebacf7b20 | ||
|
|
6dc1fcbc3e | ||
|
|
b599ef4509 | ||
|
|
526b6a1119 | ||
|
|
88173db4ce | ||
|
|
e139b1ab22 | ||
|
|
6c1e0058c1 | ||
|
|
c96633eb83 | ||
|
|
91eb35a77b | ||
|
|
d749d59cad | ||
|
|
80396b4d30 | ||
|
|
64b93a009c | ||
|
|
2b32250504 | ||
|
|
9b5f863832 | ||
|
|
fd422d7446 | ||
|
|
5162b2748e | ||
|
|
56c684ec06 | ||
|
|
7e93b33407 | ||
|
|
7662235802 | ||
|
|
e41f9facc7 | ||
|
|
785b8ede11 | ||
|
|
78b198ad70 | ||
|
|
c2c0515991 | ||
|
|
b97fefdb8d | ||
|
|
840da6dd85 | ||
|
|
972d916126 | ||
|
|
e3ed065f5f | ||
|
|
760ebe6113 | ||
|
|
a329d3ad89 | ||
|
|
01f8561582 | ||
|
|
883ea5c996 | ||
|
|
99cf13ed9b | ||
|
|
91c7ef6801 | ||
|
|
84ef5705e7 | ||
|
|
cf2a0cf8c2 | ||
|
|
48c25c40e4 | ||
|
|
996d8ab954 | ||
|
|
fac2546a92 | ||
|
|
728ea6172a | ||
|
|
f59d225029 | ||
|
|
0b178a715f | ||
|
|
e06e5328c2 | ||
|
|
1c14cd0979 | ||
|
|
f9141f5ba2 | ||
|
|
48da5c976c | ||
|
|
fa38c81c08 | ||
|
|
8d5fe5270f | ||
|
|
0dc0d66549 | ||
|
|
f589fcc2d0 | ||
|
|
edd44a0993 | ||
|
|
2aae496742 | ||
|
|
6f72046f86 | ||
|
|
d4a9b446a6 | ||
|
|
95f571e9b9 | ||
|
|
e8aeae5c07 | ||
|
|
ddf6dc0343 | ||
|
|
36d55a9db7 | ||
|
|
7d41379ad5 | ||
|
|
63e928da96 | ||
|
|
5c983b64bc | ||
|
|
b2d36c0e68 | ||
|
|
6123a1620e | ||
|
|
5ae7c10a00 | ||
|
|
b5a6794381 | ||
|
|
6b575f836a | ||
|
|
c83589cac6 | ||
|
|
d64492bda5 | ||
|
|
33d6c75924 | ||
|
|
89f01bad42 | ||
|
|
767496f81b | ||
|
|
147a477365 | ||
|
|
13171f636f | ||
|
|
fea3f0d3e0 | ||
|
|
a3a254c2ea | ||
|
|
bd9d5f7fc0 | ||
|
|
726738ee9e | ||
|
|
725244bb2f | ||
|
|
d2ac2b8990 | ||
|
|
116569223c | ||
|
|
05442a019f | ||
|
|
db67080bf8 | ||
|
|
21fabf7436 | ||
|
|
a8c6516b31 | ||
|
|
f5ca48a56e | ||
|
|
65ceff9824 | ||
|
|
ed73cfdcc7 | ||
|
|
9cb79a7827 | ||
|
|
984f29005a | ||
|
|
805c3719af | ||
|
|
ea646149c0 | ||
|
|
eae1f8ee4d | ||
|
|
8d1de245a6 | ||
|
|
b8ef5d1efc | ||
|
|
e1098b34e8 | ||
|
|
8296f8d2da | ||
|
|
867c83383d | ||
|
|
1354119d6d | ||
|
|
53af7f81bb | ||
|
|
48b1ac28de | ||
|
|
6e329b17a9 | ||
|
|
6a492198a8 | ||
|
|
8bf9b6e7cb | ||
|
|
42e23ef564 | ||
|
|
c6806ee648 | ||
|
|
076fae696c | ||
|
|
ed294d3ea4 | ||
|
|
043be409d0 | ||
|
|
a5e7483870 | ||
|
|
365335be46 | ||
|
|
62543dd171 | ||
|
|
e2eef8ff21 | ||
|
|
3acf937d56 | ||
|
|
d572e523ba | ||
|
|
82113abe88 | ||
|
|
b7d121c58f | ||
|
|
6d5a85b144 | ||
|
|
78121917c6 | ||
|
|
a0913f0e32 | ||
|
|
e96e284715 | ||
|
|
c572a1b607 | ||
|
|
1845311f98 | ||
|
|
4f806db8b7 | ||
|
|
22858cc1e9 | ||
|
|
a0329a3eb0 | ||
|
|
b3e92088ee | ||
|
|
46db1c20f1 | ||
|
|
9d182e53b2 | ||
|
|
1205fc7fdb | ||
|
|
ff2826a448 | ||
|
|
ee750115ec | ||
|
|
0e13d22c97 | ||
|
|
8e7d040ac4 | ||
|
|
6755202958 | ||
|
|
8b7374a687 | ||
|
|
c17cca2365 | ||
|
|
8016a9539a | ||
|
|
e885fb15a0 | ||
|
|
c7f098771b | ||
|
|
fcd0908032 | ||
|
|
7ff1285084 | ||
|
|
b45b603b97 | ||
|
|
247208b8a9 | ||
|
|
182c46037b | ||
|
|
438d3210bc | ||
|
|
d523c7c916 | ||
|
|
09a19e94d5 | ||
|
|
3971c145df | ||
|
|
055117d83d | ||
|
|
c6baf43986 | ||
|
|
4ff16af3a7 | ||
|
|
17a1bd352b | ||
|
|
7421ca09cc | ||
|
|
9797e696e5 | ||
|
|
c36d6d8b2d | ||
|
|
3873786b99 | ||
|
|
76fdba7f09 | ||
|
|
72799e9638 | ||
|
|
2e77d03fe9 | ||
|
|
0c58eae5e7 | ||
|
|
b609567c38 | ||
|
|
7ecfa44fa0 | ||
|
|
a685b1dc3b | ||
|
|
63ce49a17c | ||
|
|
820fbe4076 | ||
|
|
efa05b7775 | ||
|
|
003781e903 | ||
|
|
ee71bafc96 | ||
|
|
bdd5f1231e | ||
|
|
6fee532c96 | ||
|
|
78aaad7b59 | ||
|
|
b128b0ede2 | ||
|
|
737d2f3bc6 | ||
|
|
179be53a65 | ||
|
|
1867f5e7c2 | ||
|
|
6662d24565 | ||
|
|
5880566a99 | ||
|
|
5d05b32711 | ||
|
|
fa2b720e92 | ||
|
|
d381238f83 | ||
|
|
751d627ead | ||
|
|
3e66a8de9b | ||
|
|
266052b12b | ||
|
|
803f4328f4 | ||
|
|
8e95568e11 | ||
|
|
ab09ee4819 | ||
|
|
41f94a172f | ||
|
|
566e597994 | ||
|
|
765fb9c05f | ||
|
|
b6720a19f7 | ||
|
|
3b130651c4 | ||
|
|
3f6c35dabe | ||
|
|
db2a952bca | ||
|
|
0ea9770bc3 | ||
|
|
0b20956c90 | ||
|
|
9f73b47d54 | ||
|
|
ce9c99af71 | ||
|
|
784024fb5d | ||
|
|
1145b32299 | ||
|
|
ab71df0011 | ||
|
|
fb137252a9 | ||
|
|
f57a680306 | ||
|
|
8bb3eaa320 | ||
|
|
9489730a44 | ||
|
|
d4795bb897 | ||
|
|
63775872c7 | ||
|
|
beff508a1f | ||
|
|
deaae8a2c6 | ||
|
|
46a27bd50c | ||
|
|
24f2993433 | ||
|
|
c80bfbfac5 | ||
|
|
06abfc45c7 | ||
|
|
440a773081 | ||
|
|
0797bcb38b | ||
|
|
d463b5bf0d | ||
|
|
0733c8edcc | ||
|
|
86c7c05cb1 | ||
|
|
18ff7ce753 | ||
|
|
8f2ed1004d | ||
|
|
14961323c3 | ||
|
|
f8c682b183 | ||
|
|
dd92708f60 | ||
|
|
4d9eeccefa | ||
|
|
cd7b251031 | ||
|
|
db614180b9 | ||
|
|
b6e527e5f4 | ||
|
|
77c0f8f39e | ||
|
|
58816d73c8 | ||
|
|
3b194d282e | ||
|
|
397f66433d | ||
|
|
04a4ed1d0e | ||
|
|
625850d4e7 | ||
|
|
6c572baca5 | ||
|
|
ee0406a13f | ||
|
|
608a049ba3 | ||
|
|
4d9b5198e2 | ||
|
|
24b6c970aa | ||
|
|
239c47f469 | ||
|
|
f0fc64c517 | ||
|
|
8481fd38ce | ||
|
|
5f425129d5 | ||
|
|
92955b1315 | ||
|
|
a3872d5bb5 | ||
|
|
a123ff2c04 | ||
|
|
188de34306 | ||
|
|
3d43750e9b | ||
|
|
fea228c68d | ||
|
|
a71a28e563 | ||
|
|
3b5d4982b5 | ||
|
|
b201e9ab8c | ||
|
|
d30b9282fd | ||
|
|
4f304a70b7 | ||
|
|
59a54d4f04 | ||
|
|
1e94d794ed | ||
|
|
5bd210406b | ||
|
|
e00514d36d | ||
|
|
f013bf1931 | ||
|
|
107cbbad1d | ||
|
|
481f1f9d30 | ||
|
|
704364061c | ||
|
|
c1bd2d6cf1 | ||
|
|
a018e1228c | ||
|
|
d962d9c7f6 | ||
|
|
4ea28cbca5 | ||
|
|
1b48b8b4cc | ||
|
|
73df197e33 | ||
|
|
bdc66e55ca | ||
|
|
926343ee86 | ||
|
|
8e6021c5e7 | ||
|
|
ac2b6c76ce | ||
|
|
9e966d0a7f | ||
|
|
6c10defaa1 | ||
|
|
b6a76f6f7c | ||
|
|
84e5b77a5c | ||
|
|
89b0ea0bf1 | ||
|
|
48aeb98bf1 | ||
|
|
8a5d864812 | ||
|
|
ae79e645a6 | ||
|
|
0947deb372 | ||
|
|
69c92911a2 | ||
|
|
b16bb37b75 | ||
|
|
9c9ec8adf2 | ||
|
|
eb0e67fc42 | ||
|
|
9cc50bddab | ||
|
|
d3ba0fa487 | ||
|
|
39f6505a80 | ||
|
|
36a6802439 | ||
|
|
d7e2633a92 | ||
|
|
88049e741e | ||
|
|
ff7fb14087 | ||
|
|
816c64bd48 | ||
|
|
d2756e6f2d | ||
|
|
147e12acbb | ||
|
|
4098018ee9 | ||
|
|
133e7578b9 | ||
|
|
74a2bdbf09 | ||
|
|
f22bc68af4 | ||
|
|
26cc6da650 | ||
|
|
d21f1f1b87 | ||
|
|
7cdaafffe1 | ||
|
|
0265dca197 | ||
|
|
9d68366043 | ||
|
|
c8c671d915 | ||
|
|
142daa9d15 | ||
|
|
2552219991 | ||
|
|
a038b698d7 | ||
|
|
a3b222574e | ||
|
|
e0cd467293 | ||
|
|
9c056030d2 | ||
|
|
19efa9d4cc | ||
|
|
90633a6495 | ||
|
|
edc432fbd8 | ||
|
|
1b7bdbf516 | ||
|
|
8c1be70c85 | ||
|
|
b8e0c0db9e | ||
|
|
7b7fb6cc82 | ||
|
|
62512ba215 | ||
|
|
e1beb64c01 | ||
|
|
c81f26ddad | ||
|
|
340114c2a1 | ||
|
|
cd7767b331 | ||
|
|
25289dad8a | ||
|
|
47c6917129 | ||
|
|
6379cda148 | ||
|
|
91a124ab8f | ||
|
|
2357a7135e | ||
|
|
da0b3b3de9 | ||
|
|
6664fb1716 | ||
|
|
1206f24fa9 | ||
|
|
ffb5823e84 | ||
|
|
d45a7fb262 | ||
|
|
918d192c0f | ||
|
|
f7cd6eac50 | ||
|
|
88f4428ff0 | ||
|
|
069ea22ba2 | ||
|
|
8fac8c5307 | ||
|
|
2285befebb | ||
|
|
1cd0648e4e | ||
|
|
0b7ba285c6 | ||
|
|
30446c4526 | ||
|
|
9b843c9ed2 | ||
|
|
2ce1c3bef8 | ||
|
|
e463094dc7 | ||
|
|
71a9fe10f4 | ||
|
|
ba146e13ef | ||
|
|
c060d7e3e0 | ||
|
|
ba96678822 | ||
|
|
4f6354f383 | ||
|
|
2766e80346 | ||
|
|
7cc3777a60 | ||
|
|
cb1dd9f17d | ||
|
|
31f342fe4f | ||
|
|
e90359eb08 | ||
|
|
58b0768a30 | ||
|
|
3b04506893 | ||
|
|
354165aa0a | ||
|
|
343109836f | ||
|
|
fcadac2adb | ||
|
|
5e7dcdfe97 | ||
|
|
2ec9a57391 | ||
|
|
973c545723 | ||
|
|
fd62eecfef | ||
|
|
b5ca7058c2 | ||
|
|
57a48f099f | ||
|
|
4699f511bf | ||
|
|
cd8f7e72e0 | ||
|
|
78803fa284 | ||
|
|
2e8d75df16 | ||
|
|
7e3bbfd960 | ||
|
|
1734d53b3c | ||
|
|
f37540f4e5 | ||
|
|
addb9d836a | ||
|
|
4184d8c7ac | ||
|
|
724c15a68c | ||
|
|
499bdf9b48 | ||
|
|
41cd1ccda1 | ||
|
|
b9521cb3a9 | ||
|
|
1f40663b90 | ||
|
|
5261ed7c4c | ||
|
|
aa8768b18a | ||
|
|
aad07433f4 | ||
|
|
4a7630079b | ||
|
|
44a6ee1994 | ||
|
|
56bd6e69ed | ||
|
|
d1e04588d0 | ||
|
|
21cdaef6d5 | ||
|
|
a1723d18fb | ||
|
|
9e065138e9 | ||
|
|
1c73c92bfd | ||
|
|
bcd560d74e | ||
|
|
02339562ed | ||
|
|
e5804378c2 | ||
|
|
da1c8a162d | ||
|
|
d457a23a1f | ||
|
|
b6154e58b8 | ||
|
|
5f18776c61 | ||
|
|
68b0b9ec7a | ||
|
|
0f5036972e | ||
|
|
0b199b8421 | ||
|
|
a59730f6eb | ||
|
|
c6c84fe65b | ||
|
|
03c757bba6 | ||
|
|
bfeb8d238a | ||
|
|
daf0c08c4b | ||
|
|
d12c1b9ac4 | ||
|
|
bc242f4fd4 | ||
|
|
a240c1bca9 | ||
|
|
219aa6c574 | ||
|
|
abca1b481a | ||
|
|
db72fd2ef5 | ||
|
|
31cca58943 | ||
|
|
c06a4b759c | ||
|
|
f05a23a490 | ||
|
|
1e0f2ffde0 | ||
|
|
06df42ee3d | ||
|
|
65ee1638f7 | ||
|
|
87eefe7673 | ||
|
|
5c124d3988 | ||
|
|
8c69ce624f | ||
|
|
bb73acdde5 | ||
|
|
993bc3775b | ||
|
|
3d2ff28bcd | ||
|
|
9b78deb802 | ||
|
|
dadc525d0b | ||
|
|
22b2140c94 | ||
|
|
f07496a4a0 | ||
|
|
1b2938cbc8 | ||
|
|
d4d2f58830 | ||
|
|
b3113e13ec | ||
|
|
055c8e26f0 | ||
|
|
2a7a7239d7 | ||
|
|
2fa40dac3f | ||
|
|
6b4fbd7dc2 | ||
|
|
5b0bb19717 | ||
|
|
843dfc430a | ||
|
|
69cb07c527 | ||
|
|
89e8a64734 | ||
|
|
5eb2dec32d | ||
|
|
3723cf8ac2 |
@@ -1,3 +1,84 @@
|
||||
# Ignore git
|
||||
# Git
|
||||
.github
|
||||
.git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
docs/
|
||||
README.md
|
||||
LICENSE
|
||||
|
||||
# Development files
|
||||
.pylintrc
|
||||
*.pyc
|
||||
__pycache__/
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
*.so
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.hypothesis/
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# Test files
|
||||
tests/
|
||||
test_*
|
||||
*_test.py
|
||||
|
||||
# Build artifacts
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
|
||||
# Other
|
||||
app.ico
|
||||
frozen.spec
|
||||
1
.github/workflows/beta.yml
vendored
1
.github/workflows/beta.yml
vendored
@@ -52,6 +52,7 @@ jobs:
|
||||
file: docker/Dockerfile
|
||||
platforms: |
|
||||
linux/amd64
|
||||
linux/arm64/v8
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -30,6 +30,8 @@
|
||||
|
||||
API文档:https://api.movie-pilot.org
|
||||
|
||||
MCP工具API文档:详见 [docs/mcp-api.md](docs/mcp-api.md)
|
||||
|
||||
本地运行需要 `Python 3.12`、`Node JS v20.12.1`
|
||||
|
||||
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
|
||||
@@ -40,10 +42,11 @@ git clone https://github.com/jxxghp/MoviePilot
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Resources
|
||||
```
|
||||
- 安装后端依赖,设置`app`为源代码根目录,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
```shell
|
||||
cd MoviePilot
|
||||
pip install -r requirements.txt
|
||||
python3 main.py
|
||||
python3 -m app.main
|
||||
```
|
||||
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
|
||||
```shell
|
||||
|
||||
371
app/agent/__init__.py
Normal file
371
app/agent/__init__.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.channel = channel # 消息渠道
|
||||
self.source = source # 消息来源
|
||||
self.username = username # 用户名
|
||||
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
# LLM模型
|
||||
self.llm = self._initialize_llm()
|
||||
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
# Agent执行器
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
else:
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler,
|
||||
memory_mananger=self.memory_manager
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_message(ToolMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
logger.info("LangChain提示词模板初始化成功")
|
||||
return prompt_template
|
||||
except Exception as e:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt=self.prompt
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=self.tools,
|
||||
verbose=settings.LLM_VERBOSE,
|
||||
max_iterations=settings.LLM_MAX_ITERATIONS,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
early_stopping_method="force"
|
||||
)
|
||||
return RunnableWithMessageHistory(
|
||||
executor,
|
||||
self.get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
try:
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
if agent_message:
|
||||
# 发送回复
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
else:
|
||||
agent_message = "很抱歉,智能体出错了,未能生成回复内容。"
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
return agent_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户(通过原渠道)
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
input_context,
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
logger.info(f"LLM调用消耗: \n{cb}")
|
||||
|
||||
if cb.total_tokens > 0:
|
||||
result["token_usage"] = {
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"total_tokens": cb.total_tokens
|
||||
}
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return {
|
||||
"output": "任务已取消",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": f"执行过程中发生错误: {str(e)}",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
userid=self.user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message
|
||||
)
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id # 确保user_id是最新的
|
||||
# 更新渠道信息
|
||||
if channel:
|
||||
agent.channel = channel
|
||||
if source:
|
||||
agent.source = source
|
||||
if username:
|
||||
agent.username = username
|
||||
|
||||
# 处理消息
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
33
app/agent/callback/__init__.py
Normal file
33
app/agent/callback/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import threading
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
|
||||
288
app/agent/memory/__init__.py
Normal file
288
app/agent/memory/__init__.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
def get_memory_key(session_id: str, user_id: str):
|
||||
"""计算内存Key"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def get_redis_key(session_id: str, user_id: str):
|
||||
"""计算Redis Key"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
return memory.messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# 每小时清理一次
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
current_time = datetime.now()
|
||||
expired_sessions = []
|
||||
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
for cache_key in expired_sessions:
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
70
app/agent/prompt/Agent Prompt.txt
Normal file
70
app/agent/prompt/Agent Prompt.txt
Normal file
@@ -0,0 +1,70 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
|
||||
## Working Principles
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
|
||||
## Common Operation Workflows
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
118
app/agent/prompt/__init__.py
Normal file
118
app/agent/prompt/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
self.prompts_dir = Path(__file__).parent
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
return base_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
0
app/agent/tools/__init__.py
Normal file
0
app/agent/tools/__init__.py
Normal file
133
app/agent/tools/base.py
Normal file
133
app/agent/tools/base.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: str = PrivateAttr(default=None)
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送和记忆工具调用前的信息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
# 发送消息
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
|
||||
# 记忆工具调用
|
||||
await self._memory_manager.add_memory(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": self.__class__.__name__,
|
||||
"tool_name": self.__class__.__name__,
|
||||
"parameters": kwargs
|
||||
}
|
||||
)
|
||||
|
||||
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
|
||||
# 记忆工具调用结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
await self._memory_manager.add_memory(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
获取工具执行时的友好提示消息
|
||||
|
||||
子类可以重写此方法,根据实际参数生成个性化的提示消息。
|
||||
如果返回 None 或空字符串,将回退使用 explanation 参数。
|
||||
|
||||
Args:
|
||||
**kwargs: 工具的所有参数(包括 explanation)
|
||||
|
||||
Returns:
|
||||
str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
|
||||
"""设置记忆客理器"""
|
||||
self._memory_manager = memory_manager
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
)
|
||||
)
|
||||
142
app/agent/tools/factory.py
Normal file
142
app/agent/tools/factory.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool
|
||||
from app.agent.tools.impl.search_subscribe import SearchSubscribeTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_download_tasks import QueryDownloadTasksTool
|
||||
from app.agent.tools.impl.query_library_exists import QueryLibraryExistsTool
|
||||
from app.agent.tools.impl.query_library_latest import QueryLibraryLatestTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.update_site import UpdateSiteTool
|
||||
from app.agent.tools.impl.query_site_userdata import QuerySiteUserdataTool
|
||||
from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
||||
from app.agent.tools.impl.query_rule_groups import QueryRuleGroupsTool
|
||||
from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool
|
||||
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.search_person import SearchPersonTool
|
||||
from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool
|
||||
from app.agent.tools.impl.recognize_media import RecognizeMediaTool
|
||||
from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
|
||||
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
|
||||
from app.agent.tools.impl.query_media_detail import QueryMediaDetailTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool
|
||||
from app.agent.tools.impl.transfer_file import TransferFileTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
SearchPersonTool,
|
||||
SearchPersonCreditsTool,
|
||||
RecognizeMediaTool,
|
||||
ScrapeMetadataTool,
|
||||
QueryEpisodeScheduleTool,
|
||||
QueryMediaDetailTool,
|
||||
AddSubscribeTool,
|
||||
UpdateSubscribeTool,
|
||||
SearchSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
SearchWebTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QuerySubscribeSharesTool,
|
||||
QueryPopularSubscribesTool,
|
||||
QueryRuleGroupsTool,
|
||||
QuerySubscribeHistoryTool,
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadTasksTool,
|
||||
DeleteDownloadTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
UpdateSiteTool,
|
||||
QuerySiteUserdataTool,
|
||||
TestSiteTool,
|
||||
UpdateSiteCookieTool,
|
||||
GetRecommendationsTool,
|
||||
QueryLibraryExistsTool,
|
||||
QueryLibraryLatestTool,
|
||||
QueryDirectorySettingsTool,
|
||||
ListDirectoryTool,
|
||||
QueryTransferHistoryTool,
|
||||
TransferFileTool,
|
||||
SendMessageTool,
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
plugin_tools_count = 0
|
||||
plugin_tools_info = PluginManager().get_plugin_agent_tools()
|
||||
for plugin_info in plugin_tools_info:
|
||||
plugin_id = plugin_info.get("plugin_id")
|
||||
plugin_name = plugin_info.get("plugin_name")
|
||||
tool_classes = plugin_info.get("tools", [])
|
||||
for ToolClass in tool_classes:
|
||||
try:
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过")
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}")
|
||||
|
||||
builtin_tools_count = len(tool_definitions)
|
||||
if plugin_tools_count > 0:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)")
|
||||
else:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
|
||||
return tools
|
||||
0
app/agent/tools/impl/__init__.py
Normal file
0
app/agent/tools/impl/__init__.py
Normal file
106
app/agent/tools/impl/add_download.py
Normal file
106
app/agent/tools/impl/add_download.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""添加下载工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.schemas import TorrentInfo
|
||||
|
||||
|
||||
class AddDownloadInput(BaseModel):
|
||||
"""添加下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
|
||||
torrent_title: str = Field(...,
|
||||
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
|
||||
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
|
||||
torrent_description: Optional[str] = Field(None,
|
||||
description="Brief description of the torrent content (optional)")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
description="Directory path where the downloaded files should be saved. Using `<storage>:<path>` for remote storage. e.g. rclone:/MP, smb:/server/share/Movies. (optional, uses default path if not specified)")
|
||||
labels: Optional[str] = Field(None,
|
||||
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
|
||||
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据下载参数生成友好的提示消息"""
|
||||
torrent_title = kwargs.get("torrent_title", "")
|
||||
site_name = kwargs.get("site_name", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
|
||||
message = f"正在添加下载任务: {torrent_title}"
|
||||
if site_name:
|
||||
message += f" (来源: {site_name})"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
|
||||
try:
|
||||
if not torrent_title or not torrent_url:
|
||||
return "错误:必须提供种子标题和下载链接"
|
||||
|
||||
# 使用DownloadChain添加下载
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 根据站点名称查询站点cookie
|
||||
if not site_name:
|
||||
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
return f"错误:未找到站点信息:{site_name}"
|
||||
|
||||
# 创建下载上下文
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=torrent_url,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
return "错误:无法识别媒体信息,无法添加下载任务"
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
|
||||
did = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=labels
|
||||
)
|
||||
if did:
|
||||
return f"成功添加下载任务:{torrent_title}"
|
||||
else:
|
||||
return "添加下载任务失败"
|
||||
except Exception as e:
|
||||
logger.error(f"添加下载任务失败: {e}", exc_info=True)
|
||||
return f"添加下载任务时发生错误: {str(e)}"
|
||||
135
app/agent/tools/impl/add_subscribe.py
Normal file
135
app/agent/tools/impl/add_subscribe.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
"""添加订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: str = Field(..., description="Release year of the media (required for accurate identification)")
|
||||
media_type: str = Field(...,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
start_episode: Optional[int] = Field(None,
|
||||
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)")
|
||||
total_episode: Optional[int] = Field(None,
|
||||
description="Total number of episodes for TV shows (optional, will be auto-detected from TMDB if not specified)")
|
||||
quality: Optional[str] = Field(None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria. Supports advanced filtering options like quality, resolution, and effect filters using regular expressions."
|
||||
args_schema: Type[BaseModel] = AddSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据订阅参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year", "")
|
||||
media_type = kwargs.get("media_type", "")
|
||||
season = kwargs.get("season")
|
||||
|
||||
message = f"正在添加订阅: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
|
||||
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
# 转换 tmdb_id 为整数
|
||||
tmdbid_int = None
|
||||
if tmdb_id:
|
||||
try:
|
||||
tmdbid_int = int(tmdb_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
if start_episode is not None:
|
||||
subscribe_kwargs['start_episode'] = start_episode
|
||||
if total_episode is not None:
|
||||
subscribe_kwargs['total_episode'] = total_episode
|
||||
if quality:
|
||||
subscribe_kwargs['quality'] = quality
|
||||
if resolution:
|
||||
subscribe_kwargs['resolution'] = resolution
|
||||
if effect:
|
||||
subscribe_kwargs['effect'] = effect
|
||||
if filter_groups:
|
||||
subscribe_kwargs['filter_groups'] = filter_groups
|
||||
if sites:
|
||||
subscribe_kwargs['sites'] = sites
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
**subscribe_kwargs
|
||||
)
|
||||
if sid:
|
||||
result_msg = f"成功添加订阅:{title} ({year})"
|
||||
if subscribe_kwargs:
|
||||
params = []
|
||||
if start_episode is not None:
|
||||
params.append(f"开始集数: {start_episode}")
|
||||
if total_episode is not None:
|
||||
params.append(f"总集数: {total_episode}")
|
||||
if quality:
|
||||
params.append(f"质量过滤: {quality}")
|
||||
if resolution:
|
||||
params.append(f"分辨率过滤: {resolution}")
|
||||
if effect:
|
||||
params.append(f"特效过滤: {effect}")
|
||||
if filter_groups:
|
||||
params.append(f"规则组: {', '.join(filter_groups)}")
|
||||
if sites:
|
||||
params.append(f"站点: {', '.join(map(str, sites))}")
|
||||
if params:
|
||||
result_msg += f"\n配置参数: {', '.join(params)}"
|
||||
return result_msg
|
||||
else:
|
||||
return f"添加订阅失败:{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"添加订阅失败: {e}", exc_info=True)
|
||||
return f"添加订阅时发生错误: {str(e)}"
|
||||
76
app/agent/tools/impl/delete_download.py
Normal file
76
app/agent/tools/impl/delete_download.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""删除下载任务工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
task_identifier: str = Field(..., description="Task identifier: can be task hash (unique identifier) or task title/name")
|
||||
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
|
||||
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
|
||||
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
description: str = "Delete a download task from the downloader. Can delete by task hash (unique identifier) or task title/name. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
task_identifier = kwargs.get("task_identifier", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
message = f"正在删除下载任务: {task_identifier}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
message += " (包含文件)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, task_identifier: str, downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: task_identifier={task_identifier}, downloader={downloader}, delete_files={delete_files}")
|
||||
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 如果task_identifier看起来像hash(通常是40个字符的十六进制字符串)
|
||||
task_hash = None
|
||||
if len(task_identifier) == 40 and all(c in '0123456789abcdefABCDEF' for c in task_identifier):
|
||||
# 直接使用hash
|
||||
task_hash = task_identifier
|
||||
else:
|
||||
# 通过标题查找任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
for dl in downloads:
|
||||
# 检查标题或名称是否匹配
|
||||
if (task_identifier.lower() in (dl.title or "").lower()) or \
|
||||
(task_identifier.lower() in (dl.name or "").lower()):
|
||||
task_hash = dl.hash
|
||||
break
|
||||
|
||||
if not task_hash:
|
||||
return f"未找到匹配的下载任务:{task_identifier},请使用 query_downloads 工具查询可用的下载任务"
|
||||
|
||||
# 删除下载任务
|
||||
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
|
||||
result = download_chain.remove_torrents(hashs=[task_hash], downloader=downloader, delete_file=delete_files)
|
||||
|
||||
if result:
|
||||
files_info = "(包含文件)" if delete_files else "(不包含文件)"
|
||||
return f"成功删除下载任务:{task_identifier} {files_info}"
|
||||
else:
|
||||
return f"删除下载任务失败:{task_identifier},请检查任务是否存在或下载器是否可用"
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载任务失败: {e}", exc_info=True)
|
||||
return f"删除下载任务时发生错误: {str(e)}"
|
||||
|
||||
63
app/agent/tools/impl/delete_subscribe.py
Normal file
63
app/agent/tools/impl/delete_subscribe.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""删除订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
class DeleteSubscribeInput(BaseModel):
|
||||
"""删除订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)")
|
||||
|
||||
|
||||
class DeleteSubscribeTool(MoviePilotTool):
|
||||
name: str = "delete_subscribe"
|
||||
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
|
||||
args_schema: Type[BaseModel] = DeleteSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
return f"正在删除订阅 (ID: {subscribe_id})"
|
||||
|
||||
async def run(self, subscribe_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
try:
|
||||
subscribe_oper = SubscribeOper()
|
||||
# 获取订阅信息
|
||||
subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return f"订阅 ID {subscribe_id} 不存在"
|
||||
|
||||
# 在删除之前获取订阅信息(用于事件)
|
||||
subscribe_info = subscribe.to_dict()
|
||||
|
||||
# 删除订阅
|
||||
subscribe_oper.delete(subscribe_id)
|
||||
|
||||
# 发送事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
|
||||
# 统计订阅
|
||||
SubscribeHelper().sub_done_async({
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
})
|
||||
|
||||
return f"成功删除订阅:{subscribe.name} ({subscribe.year})"
|
||||
except Exception as e:
|
||||
logger.error(f"删除订阅失败: {e}", exc_info=True)
|
||||
return f"删除订阅时发生错误: {str(e)}"
|
||||
|
||||
170
app/agent/tools/impl/get_recommendations.py
Normal file
170
app/agent/tools/impl/get_recommendations.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""获取推荐工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
source: Optional[str] = Field("tmdb_trending",
|
||||
description="Recommendation source: "
|
||||
"'tmdb_trending' for TMDB trending content, "
|
||||
"'tmdb_movies' for TMDB popular movies, "
|
||||
"'tmdb_tvs' for TMDB popular TV shows, "
|
||||
"'douban_hot' for Douban popular content, "
|
||||
"'douban_movie_hot' for Douban hot movies, "
|
||||
"'douban_tv_hot' for Douban hot TV shows, "
|
||||
"'douban_movie_showing' for Douban movies currently showing, "
|
||||
"'douban_movies' for Douban latest movies, "
|
||||
"'douban_tvs' for Douban latest TV shows, "
|
||||
"'douban_movie_top250' for Douban movie TOP250, "
|
||||
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
|
||||
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
|
||||
"'douban_tv_animation' for Douban popular animation, "
|
||||
"'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
limit: Optional[int] = Field(20,
|
||||
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
|
||||
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据推荐参数生成友好的提示消息"""
|
||||
source = kwargs.get("source", "tmdb_trending")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
limit = kwargs.get("limit", 20)
|
||||
|
||||
source_map = {
|
||||
"tmdb_trending": "TMDB流行趋势",
|
||||
"tmdb_movies": "TMDB热门电影",
|
||||
"tmdb_tvs": "TMDB热门电视剧",
|
||||
"douban_hot": "豆瓣热门",
|
||||
"douban_movie_hot": "豆瓣热门电影",
|
||||
"douban_tv_hot": "豆瓣热门电视剧",
|
||||
"douban_movie_showing": "豆瓣正在热映",
|
||||
"douban_movies": "豆瓣最新电影",
|
||||
"douban_tvs": "豆瓣最新电视剧",
|
||||
"douban_movie_top250": "豆瓣电影TOP250",
|
||||
"douban_tv_weekly_chinese": "豆瓣国产剧集榜",
|
||||
"douban_tv_weekly_global": "豆瓣全球剧集榜",
|
||||
"douban_tv_animation": "豆瓣热门动漫",
|
||||
"bangumi_calendar": "番组计划"
|
||||
}
|
||||
source_desc = source_map.get(source, source)
|
||||
|
||||
message = f"正在获取推荐: {source_desc}"
|
||||
if media_type != "all":
|
||||
message += f" [{media_type}]"
|
||||
message += f" (限制: {limit}条)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
try:
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
# async_tmdb_trending 只接受 page 参数,返回固定数量的结果
|
||||
# 如果需要限制数量,需要在返回后截取
|
||||
results = await recommend_chain.async_tmdb_trending(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
elif source == "tmdb_movies":
|
||||
# async_tmdb_movies 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_movies(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
elif source == "tmdb_tvs":
|
||||
# async_tmdb_tvs 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_tvs(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
elif media_type == "tv":
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
else: # all
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit))
|
||||
elif source == "douban_movie_hot":
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
elif source == "douban_tv_hot":
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
elif source == "douban_movie_showing":
|
||||
results = await recommend_chain.async_douban_movie_showing(page=1, count=limit)
|
||||
elif source == "douban_movies":
|
||||
results = await recommend_chain.async_douban_movies(page=1, count=limit)
|
||||
elif source == "douban_tvs":
|
||||
results = await recommend_chain.async_douban_tvs(page=1, count=limit)
|
||||
elif source == "douban_movie_top250":
|
||||
results = await recommend_chain.async_douban_movie_top250(page=1, count=limit)
|
||||
elif source == "douban_tv_weekly_chinese":
|
||||
results = await recommend_chain.async_douban_tv_weekly_chinese(page=1, count=limit)
|
||||
elif source == "douban_tv_weekly_global":
|
||||
results = await recommend_chain.async_douban_tv_weekly_global(page=1, count=limit)
|
||||
elif source == "douban_tv_animation":
|
||||
results = await recommend_chain.async_douban_tv_animation(page=1, count=limit)
|
||||
elif source == "bangumi_calendar":
|
||||
results = await recommend_chain.async_bangumi_calendar(page=1, count=limit)
|
||||
else:
|
||||
# 不支持的推荐来源
|
||||
supported_sources = [
|
||||
"tmdb_trending", "tmdb_movies", "tmdb_tvs",
|
||||
"douban_hot", "douban_movie_hot", "douban_tv_hot",
|
||||
"douban_movie_showing", "douban_movies", "douban_tvs",
|
||||
"douban_movie_top250", "douban_tv_weekly_chinese",
|
||||
"douban_tv_weekly_global", "douban_tv_animation",
|
||||
"bangumi_calendar"
|
||||
]
|
||||
return f"不支持的推荐来源: {source}。支持的来源包括: {', '.join(supported_sources)}"
|
||||
|
||||
if results:
|
||||
# 限制最多20条结果
|
||||
total_count = len(results)
|
||||
limited_results = results[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
# r 应该是字典格式(to_dict的结果),但为了安全起见进行检查
|
||||
if not isinstance(r, dict):
|
||||
logger.warning(f"推荐结果格式异常,跳过: {type(r)}")
|
||||
continue
|
||||
|
||||
simplified = {
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
"year": r.get("year"),
|
||||
"type": r.get("type"),
|
||||
"season": r.get("season"),
|
||||
"tmdb_id": r.get("tmdb_id"),
|
||||
"imdb_id": r.get("imdb_id"),
|
||||
"douban_id": r.get("douban_id"),
|
||||
"vote_average": r.get("vote_average"),
|
||||
"poster_path": r.get("poster_path"),
|
||||
"detail_link": r.get("detail_link")
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到推荐内容。"
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐失败: {e}", exc_info=True)
|
||||
return f"获取推荐时发生错误: {str(e)}"
|
||||
130
app/agent/tools/impl/list_directory.py
Normal file
130
app/agent/tools/impl/list_directory.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""查询文件系统目录内容工具"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.storage import StorageChain
|
||||
from app.log import logger
|
||||
from app.schemas.file import FileItem
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class ListDirectoryInput(BaseModel):
|
||||
"""查询文件系统目录内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(..., description="Directory path to list contents (e.g., '/home/user/downloads' or 'C:/Downloads')")
|
||||
storage: Optional[str] = Field("local", description="Storage type (default: 'local' for local file system, can be 'smb', 'alist', etc.)")
|
||||
sort_by: Optional[str] = Field("name", description="Sort order: 'name' for alphabetical sorting, 'time' for modification time sorting (default: 'name')")
|
||||
|
||||
|
||||
class ListDirectoryTool(MoviePilotTool):
|
||||
name: str = "list_directory"
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directories' to query directory configuration settings."
|
||||
args_schema: Type[BaseModel] = ListDirectoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据目录参数生成友好的提示消息"""
|
||||
path = kwargs.get("path", "")
|
||||
storage = kwargs.get("storage", "local")
|
||||
|
||||
message = f"正在查询目录: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
sort_by: Optional[str] = "name", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, sort_by={sort_by}")
|
||||
|
||||
try:
|
||||
# 规范化路径
|
||||
if not path:
|
||||
return "错误:路径不能为空"
|
||||
|
||||
# 确保路径格式正确
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not path.startswith("/") and not (len(path) > 1 and path[1] == ":"):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
path = str(Path(path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=path,
|
||||
type="dir"
|
||||
)
|
||||
|
||||
# 查询目录内容
|
||||
storage_chain = StorageChain()
|
||||
file_list = storage_chain.list_files(fileitem, recursion=False)
|
||||
|
||||
if file_list is None:
|
||||
return f"无法访问目录:{path},请检查路径是否正确或存储是否可用"
|
||||
|
||||
if not file_list:
|
||||
return f"目录 {path} 为空"
|
||||
|
||||
# 排序
|
||||
if sort_by == "time":
|
||||
file_list.sort(key=lambda x: x.modify_time or 0, reverse=True)
|
||||
else:
|
||||
# 默认按名称排序(目录优先,然后按名称)
|
||||
file_list.sort(key=lambda x: (
|
||||
0 if x.type == "dir" else 1,
|
||||
StringUtils.natural_sort_key(x.name or "")
|
||||
))
|
||||
|
||||
# 限制返回数量
|
||||
total_count = len(file_list)
|
||||
limited_list = file_list[:20]
|
||||
|
||||
# 转换为字典格式
|
||||
simplified_items = []
|
||||
for item in limited_list:
|
||||
# 格式化文件大小
|
||||
size_str = None
|
||||
if item.size:
|
||||
size_str = StringUtils.str_filesize(item.size)
|
||||
|
||||
# 格式化修改时间
|
||||
modify_time_str = None
|
||||
if item.modify_time:
|
||||
try:
|
||||
modify_time_str = datetime.fromtimestamp(item.modify_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
modify_time_str = str(item.modify_time)
|
||||
|
||||
simplified = {
|
||||
"name": item.name,
|
||||
"type": item.type,
|
||||
"path": item.path,
|
||||
"size": size_str,
|
||||
"modify_time": modify_time_str
|
||||
}
|
||||
# 如果是文件,添加扩展名
|
||||
if item.type == "file" and item.extension:
|
||||
simplified["extension"] = item.extension
|
||||
simplified_items.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_items, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 20 个项目。\n\n{result_json}"
|
||||
else:
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询目录内容失败: {e}", exc_info=True)
|
||||
return f"查询目录内容时发生错误: {str(e)}"
|
||||
|
||||
134
app/agent/tools/impl/query_directory_settings.py
Normal file
134
app/agent/tools/impl/query_directory_settings.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""查询系统目录设置工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryDirectorySettingsInput(BaseModel):
|
||||
"""查询系统目录设置工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
directory_type: Optional[str] = Field("all",
|
||||
description="Filter directories by type: 'download' for download directories, 'library' for media library directories, 'all' for all directories")
|
||||
storage_type: Optional[str] = Field("all",
|
||||
description="Filter directories by storage type: 'local' for local storage, 'remote' for remote storage, 'all' for all storage types")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter directories by name (partial match, optional)")
|
||||
|
||||
|
||||
class QueryDirectorySettingsTool(MoviePilotTool):
|
||||
name: str = "query_directory_settings"
|
||||
description: str = "Query system directory configuration settings (NOT file listings). Returns configured directory paths, storage types, transfer modes, and other directory-related settings. Use 'list_directory' to list actual files and folders in a directory."
|
||||
args_schema: Type[BaseModel] = QueryDirectorySettingsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
directory_type = kwargs.get("directory_type", "all")
|
||||
storage_type = kwargs.get("storage_type", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询目录配置"]
|
||||
|
||||
if directory_type != "all":
|
||||
type_map = {"download": "下载目录", "library": "媒体库目录"}
|
||||
parts.append(f"类型: {type_map.get(directory_type, directory_type)}")
|
||||
|
||||
if storage_type != "all":
|
||||
storage_map = {"local": "本地存储", "remote": "远程存储"}
|
||||
parts.append(f"存储: {storage_map.get(storage_type, storage_type)}")
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, directory_type: Optional[str] = "all",
|
||||
storage_type: Optional[str] = "all",
|
||||
name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: directory_type={directory_type}, storage_type={storage_type}, name={name}")
|
||||
|
||||
try:
|
||||
directory_helper = DirectoryHelper()
|
||||
|
||||
# 根据目录类型获取目录列表
|
||||
if directory_type == "download":
|
||||
dirs = directory_helper.get_download_dirs()
|
||||
elif directory_type == "library":
|
||||
dirs = directory_helper.get_library_dirs()
|
||||
else:
|
||||
dirs = directory_helper.get_dirs()
|
||||
|
||||
# 按存储类型过滤
|
||||
filtered_dirs = []
|
||||
for d in dirs:
|
||||
# 按存储类型过滤
|
||||
if storage_type == "local":
|
||||
# 对于下载目录,检查 storage;对于媒体库目录,检查 library_storage
|
||||
if directory_type == "download" and d.storage != "local":
|
||||
continue
|
||||
elif directory_type == "library" and d.library_storage != "local":
|
||||
continue
|
||||
elif directory_type == "all":
|
||||
# 检查是否有本地存储配置
|
||||
if d.download_path and d.storage != "local":
|
||||
continue
|
||||
if d.library_path and d.library_storage != "local":
|
||||
continue
|
||||
elif storage_type == "remote":
|
||||
# 对于下载目录,检查 storage;对于媒体库目录,检查 library_storage
|
||||
if directory_type == "download" and d.storage == "local":
|
||||
continue
|
||||
elif directory_type == "library" and d.library_storage == "local":
|
||||
continue
|
||||
elif directory_type == "all":
|
||||
# 检查是否有远程存储配置
|
||||
if d.download_path and d.storage == "local":
|
||||
continue
|
||||
if d.library_path and d.library_storage == "local":
|
||||
continue
|
||||
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and d.name and name.lower() not in d.name.lower():
|
||||
continue
|
||||
|
||||
filtered_dirs.append(d)
|
||||
|
||||
if filtered_dirs:
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_dirs = []
|
||||
for d in filtered_dirs:
|
||||
simplified = {
|
||||
"name": d.name,
|
||||
"priority": d.priority,
|
||||
"storage": d.storage,
|
||||
"download_path": d.download_path,
|
||||
"library_path": d.library_path,
|
||||
"library_storage": d.library_storage,
|
||||
"media_type": d.media_type,
|
||||
"media_category": d.media_category,
|
||||
"monitor_type": d.monitor_type,
|
||||
"monitor_mode": d.monitor_mode,
|
||||
"transfer_type": d.transfer_type,
|
||||
"overwrite_mode": d.overwrite_mode,
|
||||
"renaming": d.renaming,
|
||||
"scraping": d.scraping,
|
||||
"notify": d.notify,
|
||||
"download_type_folder": d.download_type_folder,
|
||||
"download_category_folder": d.download_category_folder,
|
||||
"library_type_folder": d.library_type_folder,
|
||||
"library_category_folder": d.library_category_folder
|
||||
}
|
||||
simplified_dirs.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_dirs, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return "未找到相关目录配置"
|
||||
except Exception as e:
|
||||
logger.error(f"查询系统目录设置失败: {e}", exc_info=True)
|
||||
return f"查询系统目录设置时发生错误: {str(e)}"
|
||||
|
||||
232
app/agent/tools/impl/query_download_tasks.py
Normal file
232
app/agent/tools/impl/query_download_tasks.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus
|
||||
|
||||
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
"""查询下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
hash: Optional[str] = Field(None, description="Query specific download task by hash (optional, if provided will search for this specific task regardless of status)")
|
||||
title: Optional[str] = Field(None, description="Query download tasks by title/name (optional, supports partial match, searches all tasks if provided)")
|
||||
|
||||
|
||||
class QueryDownloadTasksTool(MoviePilotTool):
|
||||
name: str = "query_download_tasks"
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
@staticmethod
|
||||
def _get_all_torrents(download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
|
||||
"""
|
||||
查询所有状态的任务(包括下载中和已完成的任务)
|
||||
"""
|
||||
all_torrents = []
|
||||
# 查询正在下载的任务
|
||||
downloading_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.DOWNLOADING
|
||||
) or []
|
||||
all_torrents.extend(downloading_torrents)
|
||||
|
||||
# 查询已完成的任务(可转移状态)
|
||||
transfer_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.TRANSFER
|
||||
) or []
|
||||
all_torrents.extend(transfer_torrents)
|
||||
|
||||
return all_torrents
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
downloader = kwargs.get("downloader")
|
||||
status = kwargs.get("status", "all")
|
||||
hash_value = kwargs.get("hash")
|
||||
title = kwargs.get("title")
|
||||
|
||||
parts = ["正在查询下载任务"]
|
||||
|
||||
if downloader:
|
||||
parts.append(f"下载器: {downloader}")
|
||||
|
||||
if status != "all":
|
||||
status_map = {"downloading": "下载中", "completed": "已完成", "paused": "已暂停"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
if hash_value:
|
||||
parts.append(f"Hash: {hash_value[:8]}...")
|
||||
elif title:
|
||||
parts.append(f"标题: {title}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all",
|
||||
hash: Optional[str] = None,
|
||||
title: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}, hash={hash}, title={title}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 如果提供了hash,直接查询该hash的任务(不限制状态)
|
||||
if hash:
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash]) or []
|
||||
if not torrents:
|
||||
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
|
||||
# 转换为DownloadingTorrent格式
|
||||
downloads = []
|
||||
for torrent in torrents:
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
downloads.append(torrent)
|
||||
filtered_downloads = downloads
|
||||
elif title:
|
||||
# 如果提供了title,查询所有任务并搜索匹配的标题
|
||||
# 查询所有状态的任务
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
title_lower = title.lower()
|
||||
for torrent in all_torrents:
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
|
||||
# 检查标题或名称是否匹配(包括下载历史中的标题)
|
||||
matched = False
|
||||
# 检查torrent的title和name字段
|
||||
if (title_lower in (torrent.title or "").lower()) or \
|
||||
(title_lower in (torrent.name or "").lower()):
|
||||
matched = True
|
||||
# 检查下载历史中的标题
|
||||
if history and history.title:
|
||||
if title_lower in history.title.lower():
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
filtered_downloads.append(torrent)
|
||||
if not filtered_downloads:
|
||||
return f"未找到标题包含 '{title}' 的下载任务"
|
||||
else:
|
||||
# 根据status决定查询方式
|
||||
if status == "downloading":
|
||||
# 如果status为下载中,使用downloading方法
|
||||
downloads = download_chain.downloading(name=downloader) or []
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
continue
|
||||
filtered_downloads.append(dl)
|
||||
else:
|
||||
# 其他状态(completed、paused、all),使用list_torrents查询所有任务
|
||||
# 查询所有状态的任务
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
if downloader and torrent.downloader != downloader:
|
||||
continue
|
||||
# 根据status过滤
|
||||
if status == "completed":
|
||||
# 已完成的任务(state为seeding或completed)
|
||||
if torrent.state not in ["seeding", "completed"]:
|
||||
continue
|
||||
elif status == "paused":
|
||||
# 已暂停的任务
|
||||
if torrent.state != "paused":
|
||||
continue
|
||||
# status == "all" 时不过滤
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
filtered_downloads.append(torrent)
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
limited_downloads = filtered_downloads[:20]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_downloads = []
|
||||
for d in limited_downloads:
|
||||
simplified = {
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
|
||||
# 如果查询的是特定hash或title,添加明确的状态信息
|
||||
if hash:
|
||||
return f"找到hash为 {hash} 的下载任务:\n\n{result_json}"
|
||||
elif title:
|
||||
return f"找到 {total_count} 个标题包含 '{title}' 的下载任务:\n\n{result_json}"
|
||||
|
||||
return result_json
|
||||
return "未找到相关下载任务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载失败: {e}", exc_info=True)
|
||||
return f"查询下载时发生错误: {str(e)}"
|
||||
38
app/agent/tools/impl/query_downloaders.py
Normal file
38
app/agent/tools/impl/query_downloaders.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""查询下载器工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class QueryDownloadersInput(BaseModel):
|
||||
"""查询下载器工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryDownloadersTool(MoviePilotTool):
|
||||
name: str = "query_downloaders"
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询下载器配置"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
downloaders_config = system_config_oper.get(SystemConfigKey.Downloaders)
|
||||
if downloaders_config:
|
||||
return json.dumps(downloaders_config, ensure_ascii=False, indent=2)
|
||||
return "未配置下载器。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载器失败: {e}")
|
||||
return f"查询下载器时发生错误: {str(e)}"
|
||||
116
app/agent/tools/impl/query_episode_schedule.py
Normal file
116
app/agent/tools/impl/query_episode_schedule.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""查询剧集上映时间工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class QueryEpisodeScheduleInput(BaseModel):
|
||||
"""查询剧集上映时间工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the TV series")
|
||||
season: int = Field(..., description="Season number to query")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
|
||||
|
||||
class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
name: str = "query_episode_schedule"
|
||||
description: str = "Query TV series episode air dates and schedule. Returns detailed information for each episode including air date, episode number, title, overview, and other metadata. Filters out episodes without air dates."
|
||||
args_schema: Type[BaseModel] = QueryEpisodeScheduleInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
season = kwargs.get("season")
|
||||
episode_group = kwargs.get("episode_group")
|
||||
|
||||
message = f"正在查询剧集上映时间: TMDB ID {tmdb_id} 第{season}季"
|
||||
if episode_group:
|
||||
message += f" (剧集组: {episode_group})"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, tmdb_id: int, season: int, episode_group: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, season={season}, episode_group={episode_group}")
|
||||
|
||||
try:
|
||||
# 获取媒体信息(用于获取标题和海报)
|
||||
media_chain = MediaChain()
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=MediaType.TV)
|
||||
if not mediainfo:
|
||||
return f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
|
||||
# 获取集列表
|
||||
tmdb_chain = TmdbChain()
|
||||
episodes = await tmdb_chain.async_tmdb_episodes(
|
||||
tmdbid=tmdb_id,
|
||||
season=season,
|
||||
episode_group=episode_group
|
||||
)
|
||||
|
||||
if not episodes:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 第{season}季的集信息"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 过滤掉没有上映日期的集,并构建每集的详细信息
|
||||
episode_list = []
|
||||
for episode in episodes:
|
||||
air_date = episode.air_date
|
||||
|
||||
# 过滤掉没有上映日期的数据
|
||||
if not air_date:
|
||||
continue
|
||||
|
||||
episode_info = {
|
||||
"episode_number": episode.episode_number,
|
||||
"name": episode.name,
|
||||
"air_date": air_date,
|
||||
"runtime": episode.runtime,
|
||||
"vote_average": episode.vote_average,
|
||||
"still_path": episode.still_path,
|
||||
"episode_type": episode.episode_type,
|
||||
"season_number": episode.season_number
|
||||
}
|
||||
episode_list.append(episode_info)
|
||||
|
||||
if not episode_list:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 第{season}季的播出时间信息(所有集都没有播出日期)"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 按播出日期排序
|
||||
episode_list.sort(key=lambda x: (x["air_date"] or "", x["episode_number"] or 0))
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"season": season,
|
||||
"episode_group": episode_group,
|
||||
"series_title": mediainfo.title if mediainfo else None,
|
||||
"series_poster": mediainfo.poster_path if mediainfo else None,
|
||||
"total_episodes": len(episodes),
|
||||
"episodes_with_air_date": len(episode_list),
|
||||
"episodes": episode_list
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询剧集上映时间失败: {str(e)}"
|
||||
logger.error(f"查询剧集上映时间失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"tmdb_id": tmdb_id,
|
||||
"season": season
|
||||
}, ensure_ascii=False)
|
||||
97
app/agent/tools/impl/query_library_exists.py
Normal file
97
app/agent/tools/impl/query_library_exists.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class QueryLibraryExistsInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
|
||||
|
||||
class QueryLibraryExistsTool(MoviePilotTool):
|
||||
name: str = "query_library_exists"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
args_schema: Type[BaseModel] = QueryLibraryExistsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
title = kwargs.get("title")
|
||||
year = kwargs.get("year")
|
||||
|
||||
parts = ["正在查询媒体库"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if year:
|
||||
parts.append(f"年份: {year}")
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
if not title:
|
||||
return "请提供媒体标题进行查询"
|
||||
|
||||
# 创建 MediaInfo 对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
|
||||
# 转换媒体类型
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
# media_type == "all" 时不设置类型,让媒体服务器自动判断
|
||||
|
||||
# 调用媒体服务器接口实时查询
|
||||
media_chain = MediaServerChain()
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
if not existsinfo:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 如果找到了,获取详细信息
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
result_items.append(item_dict)
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有详细信息,返回基本信息
|
||||
result_dict = {
|
||||
"type": existsinfo.type.value if existsinfo.type else None,
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
}
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
86
app/agent/tools/impl/query_library_latest.py
Normal file
86
app/agent/tools/impl/query_library_latest.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""查询媒体服务器最近入库影片工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryLibraryLatestInput(BaseModel):
|
||||
"""查询媒体服务器最近入库影片工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
server: Optional[str] = Field(None, description="Media server name (optional, if not specified queries all enabled media servers)")
|
||||
count: Optional[int] = Field(20, description="Number of items to return (default: 20)")
|
||||
|
||||
|
||||
class QueryLibraryLatestTool(MoviePilotTool):
|
||||
name: str = "query_library_latest"
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata."
|
||||
args_schema: Type[BaseModel] = QueryLibraryLatestInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
server = kwargs.get("server")
|
||||
count = kwargs.get("count", 20)
|
||||
|
||||
parts = ["正在查询媒体服务器最近入库影片"]
|
||||
|
||||
if server:
|
||||
parts.append(f"服务器: {server}")
|
||||
else:
|
||||
parts.append("所有服务器")
|
||||
|
||||
parts.append(f"数量: {count}条")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(self, server: Optional[str] = None, count: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: server={server}, count={count}")
|
||||
try:
|
||||
media_chain = MediaServerChain()
|
||||
results = []
|
||||
|
||||
# 如果没有指定服务器,获取所有启用的媒体服务器
|
||||
if not server:
|
||||
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
|
||||
enabled_servers = [ms.name for ms in mediaservers if ms.enabled]
|
||||
|
||||
if not enabled_servers:
|
||||
return "未找到启用的媒体服务器"
|
||||
|
||||
# 遍历所有启用的服务器
|
||||
for server_name in enabled_servers:
|
||||
latest_items = media_chain.latest(server=server_name, count=count, username=self._username)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server_name
|
||||
results.append(item_dict)
|
||||
else:
|
||||
# 查询指定服务器
|
||||
latest_items = media_chain.latest(server=server, count=count, username=self._username)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server
|
||||
results.append(item_dict)
|
||||
|
||||
if not results:
|
||||
server_info = f"服务器 {server}" if server else "所有服务器"
|
||||
return f"未找到 {server_info} 的最近入库影片"
|
||||
|
||||
# 限制返回数量,避免结果过多
|
||||
if len(results) > count:
|
||||
results = results[:count]
|
||||
|
||||
return json.dumps(results, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体服务器最近入库影片失败: {e}", exc_info=True)
|
||||
return f"查询媒体服务器最近入库影片时发生错误: {str(e)}"
|
||||
|
||||
120
app/agent/tools/impl/query_media_detail.py
Normal file
120
app/agent/tools/impl/query_media_detail.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""查询媒体详情工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class QueryMediaDetailInput(BaseModel):
|
||||
"""查询媒体详情工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the media (movie or TV series)")
|
||||
media_type: str = Field(..., description="Media type: 'movie' or 'tv'")
|
||||
|
||||
|
||||
class QueryMediaDetailTool(MoviePilotTool):
|
||||
name: str = "query_media_detail"
|
||||
description: str = "Query detailed media information from TMDB by ID and media_type. IMPORTANT: Convert search results type: '电影'→'movie', '电视剧'→'tv'. Returns core metadata including title, year, overview, status, genres, directors, actors, and season count for TV series."
|
||||
args_schema: Type[BaseModel] = QueryMediaDetailInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
|
||||
|
||||
async def run(self, tmdb_id: int, media_type: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, media_type={media_type}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
|
||||
mtype = None
|
||||
if media_type:
|
||||
if media_type.lower() == 'movie':
|
||||
mtype = MediaType.MOVIE
|
||||
elif media_type.lower() == 'tv':
|
||||
mtype = MediaType.TV
|
||||
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=mtype)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 精简 genres - 只保留名称
|
||||
genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")]
|
||||
|
||||
# 精简 directors - 只保留姓名和职位
|
||||
directors = [
|
||||
{
|
||||
"name": d.get("name"),
|
||||
"job": d.get("job")
|
||||
}
|
||||
for d in (mediainfo.directors or [])
|
||||
if d.get("name")
|
||||
]
|
||||
|
||||
# 精简 actors - 只保留姓名和角色
|
||||
actors = [
|
||||
{
|
||||
"name": a.get("name"),
|
||||
"character": a.get("character")
|
||||
}
|
||||
for a in (mediainfo.actors or [])
|
||||
if a.get("name")
|
||||
]
|
||||
|
||||
# 构建基础媒体详情信息
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"overview": mediainfo.overview,
|
||||
"status": mediainfo.status,
|
||||
"genres": genres,
|
||||
"directors": directors,
|
||||
"actors": actors
|
||||
}
|
||||
|
||||
# 如果是电视剧,添加电视剧特有信息
|
||||
if mediainfo.type == MediaType.TV:
|
||||
# 精简 season_info - 只保留基础摘要
|
||||
season_info = [
|
||||
{
|
||||
"season_number": s.get("season_number"),
|
||||
"name": s.get("name"),
|
||||
"episode_count": s.get("episode_count"),
|
||||
"air_date": s.get("air_date")
|
||||
}
|
||||
for s in (mediainfo.season_info or [])
|
||||
if s.get("season_number") is not None
|
||||
]
|
||||
|
||||
result.update({
|
||||
"number_of_seasons": mediainfo.number_of_seasons,
|
||||
"number_of_episodes": mediainfo.number_of_episodes,
|
||||
"first_air_date": mediainfo.first_air_date,
|
||||
"last_air_date": mediainfo.last_air_date,
|
||||
"season_info": season_info
|
||||
})
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询媒体详情失败: {str(e)}"
|
||||
logger.error(f"查询媒体详情失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"tmdb_id": tmdb_id
|
||||
}, ensure_ascii=False)
|
||||
152
app/agent/tools/impl/query_popular_subscribes.py
Normal file
152
app/agent/tools/impl/query_popular_subscribes.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""查询热门订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
import cn2an
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.context import MediaInfo
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class QueryPopularSubscribesInput(BaseModel):
|
||||
"""查询热门订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
stype: str = Field(..., description="Media type: '电影' for films, '电视剧' for television series")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)")
|
||||
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
|
||||
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
|
||||
max_rating: Optional[float] = Field(None, description="Maximum rating filter (optional, e.g., 10.0)")
|
||||
sort_type: Optional[str] = Field(None, description="Sort type (optional, e.g., 'count', 'rating')")
|
||||
|
||||
|
||||
class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
name: str = "query_popular_subscribes"
|
||||
description: str = "Query popular subscriptions based on user shared data. Shows media with the most subscribers, supports filtering by genre, rating, minimum subscribers, and pagination."
|
||||
args_schema: Type[BaseModel] = QueryPopularSubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
stype = kwargs.get("stype", "")
|
||||
page = kwargs.get("page", 1)
|
||||
min_sub = kwargs.get("min_sub")
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = [f"正在查询热门订阅 [{stype}]"]
|
||||
|
||||
if min_sub:
|
||||
parts.append(f"最少订阅: {min_sub}")
|
||||
if min_rating:
|
||||
parts.append(f"最低评分: {min_rating}")
|
||||
if max_rating:
|
||||
parts.append(f"最高评分: {max_rating}")
|
||||
if page > 1:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, stype: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
min_sub: Optional[int] = None,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: stype={stype}, page={page}, count={count}, min_sub={min_sub}, "
|
||||
f"genre_id={genre_id}, min_rating={min_rating}, max_rating={max_rating}, sort_type={sort_type}")
|
||||
|
||||
try:
|
||||
if page is None or page < 1:
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
subscribes = await subscribe_helper.async_get_statistic(
|
||||
stype=stype,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
|
||||
if not subscribes:
|
||||
return "未找到热门订阅数据(可能订阅统计功能未启用)"
|
||||
|
||||
# 转换为MediaInfo格式并过滤
|
||||
ret_medias = []
|
||||
for sub in subscribes:
|
||||
# 订阅人数
|
||||
subscriber_count = sub.get("count", 0)
|
||||
# 如果设置了最小订阅人数,进行过滤
|
||||
if min_sub and subscriber_count < min_sub:
|
||||
continue
|
||||
|
||||
media = MediaInfo()
|
||||
media.type = MediaType(sub.get("type"))
|
||||
media.tmdb_id = sub.get("tmdbid")
|
||||
# 处理标题
|
||||
title = sub.get("name")
|
||||
season = sub.get("season")
|
||||
if season and int(season) > 1 and media.tmdb_id:
|
||||
# 小写数据转大写
|
||||
season_str = cn2an.an2cn(season, "low")
|
||||
title = f"{title} 第{season_str}季"
|
||||
media.title = title
|
||||
media.year = sub.get("year")
|
||||
media.douban_id = sub.get("doubanid")
|
||||
media.bangumi_id = sub.get("bangumiid")
|
||||
media.tvdb_id = sub.get("tvdbid")
|
||||
media.imdb_id = sub.get("imdbid")
|
||||
media.season = sub.get("season")
|
||||
media.vote_average = sub.get("vote")
|
||||
media.poster_path = sub.get("poster")
|
||||
media.backdrop_path = sub.get("backdrop")
|
||||
media.popularity = subscriber_count
|
||||
ret_medias.append(media)
|
||||
|
||||
if not ret_medias:
|
||||
return "未找到符合条件的热门订阅"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_medias = []
|
||||
for media in ret_medias:
|
||||
media_dict = media.to_dict()
|
||||
simplified = {
|
||||
"type": media_dict.get("type"),
|
||||
"title": media_dict.get("title"),
|
||||
"year": media_dict.get("year"),
|
||||
"tmdb_id": media_dict.get("tmdb_id"),
|
||||
"douban_id": media_dict.get("douban_id"),
|
||||
"bangumi_id": media_dict.get("bangumi_id"),
|
||||
"tvdb_id": media_dict.get("tvdb_id"),
|
||||
"imdb_id": media_dict.get("imdb_id"),
|
||||
"season": media_dict.get("season"),
|
||||
"vote_average": media_dict.get("vote_average"),
|
||||
"poster_path": media_dict.get("poster_path"),
|
||||
"backdrop_path": media_dict.get("backdrop_path"),
|
||||
"popularity": media_dict.get("popularity"), # 订阅人数
|
||||
"subscriber_count": media_dict.get("popularity") # 明确标注为订阅人数
|
||||
}
|
||||
simplified_medias.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_medias, ensure_ascii=False, indent=2)
|
||||
|
||||
pagination_info = f"第 {page} 页,每页 {count} 条,共 {len(simplified_medias)} 条结果"
|
||||
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询热门订阅失败: {e}", exc_info=True)
|
||||
return f"查询热门订阅时发生错误: {str(e)}"
|
||||
|
||||
65
app/agent/tools/impl/query_rule_groups.py
Normal file
65
app/agent/tools/impl/query_rule_groups.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""查询规则组工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryRuleGroupsInput(BaseModel):
|
||||
"""查询规则组工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryRuleGroupsTool(MoviePilotTool):
|
||||
name: str = "query_rule_groups"
|
||||
description: str = "Query all filter rule groups available in the system. Rule groups are used to filter torrents when searching or subscribing. Returns rule group names, media types, and categories, but excludes rule_string to keep results concise."
|
||||
args_schema: Type[BaseModel] = QueryRuleGroupsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
return "正在查询所有规则组"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
rule_helper = RuleHelper()
|
||||
rule_groups = rule_helper.get_rule_groups()
|
||||
|
||||
if not rule_groups:
|
||||
return json.dumps({
|
||||
"message": "未找到任何规则组",
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
# 精简字段,过滤掉 rule_string 避免结果过大
|
||||
simplified_groups = []
|
||||
for group in rule_groups:
|
||||
simplified = {
|
||||
"name": group.name,
|
||||
"media_type": group.media_type,
|
||||
"category": group.category
|
||||
}
|
||||
simplified_groups.append(simplified)
|
||||
|
||||
result = {
|
||||
"message": f"找到 {len(simplified_groups)} 个规则组",
|
||||
"rule_groups": simplified_groups
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询规则组失败: {str(e)}"
|
||||
logger.error(f"查询规则组失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
|
||||
55
app/agent/tools/impl/query_schedulers.py
Normal file
55
app/agent/tools/impl/query_schedulers.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""查询定时服务工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class QuerySchedulersInput(BaseModel):
|
||||
"""查询定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QuerySchedulersTool(MoviePilotTool):
|
||||
name: str = "query_schedulers"
|
||||
description: str = "Query scheduled tasks and list all available scheduler jobs. Shows job status, next run time, and provider information."
|
||||
args_schema: Type[BaseModel] = QuerySchedulersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询定时服务"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
scheduler = Scheduler()
|
||||
schedulers = scheduler.list()
|
||||
if schedulers:
|
||||
# 转换为字典列表以便JSON序列化
|
||||
schedulers_list = []
|
||||
for s in schedulers:
|
||||
schedulers_list.append({
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"provider": s.provider,
|
||||
"status": s.status,
|
||||
"next_run": s.next_run
|
||||
})
|
||||
result_json = json.dumps(schedulers_list, ensure_ascii=False, indent=2)
|
||||
# 限制最多30条结果
|
||||
total_count = len(schedulers_list)
|
||||
if total_count > 30:
|
||||
limited_schedulers = schedulers_list[:30]
|
||||
limited_json = json.dumps(limited_schedulers, ensure_ascii=False, indent=2)
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{limited_json}"
|
||||
return result_json
|
||||
return "未找到定时服务"
|
||||
except Exception as e:
|
||||
logger.error(f"查询定时服务失败: {e}", exc_info=True)
|
||||
return f"查询定时服务时发生错误: {str(e)}"
|
||||
|
||||
136
app/agent/tools/impl/query_site_userdata.py
Normal file
136
app/agent/tools/impl/query_site_userdata.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""查询站点用户数据工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.site import Site
|
||||
from app.db.models.siteuserdata import SiteUserData
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to query user data for")
|
||||
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
|
||||
|
||||
|
||||
class QuerySiteUserdataTool(MoviePilotTool):
|
||||
name: str = "query_site_userdata"
|
||||
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
|
||||
args_schema: Type[BaseModel] = QuerySiteUserdataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
workdate = kwargs.get("workdate")
|
||||
|
||||
message = f"正在查询站点 #{site_id} 的用户数据"
|
||||
if workdate:
|
||||
message += f" (日期: {workdate})"
|
||||
else:
|
||||
message += " (最新数据)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 获取站点用户数据
|
||||
user_data_list = await SiteUserData.async_get_by_domain(
|
||||
db,
|
||||
domain=site.domain,
|
||||
workdate=workdate
|
||||
)
|
||||
|
||||
if not user_data_list:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 格式化用户数据
|
||||
result = {
|
||||
"success": True,
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
"data_count": len(user_data_list),
|
||||
"user_data": []
|
||||
}
|
||||
|
||||
for user_data in user_data_list:
|
||||
# 格式化上传/下载量(转换为可读格式)
|
||||
upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0
|
||||
download_gb = user_data.download / (1024 ** 3) if user_data.download else 0
|
||||
seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0
|
||||
leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
"username": user_data.username,
|
||||
"userid": user_data.userid,
|
||||
"user_level": user_data.user_level,
|
||||
"join_at": user_data.join_at,
|
||||
"bonus": user_data.bonus,
|
||||
"upload": user_data.upload,
|
||||
"upload_gb": round(upload_gb, 2),
|
||||
"download": user_data.download,
|
||||
"download_gb": round(download_gb, 2),
|
||||
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||
"leeching": int(user_data.leeching) if user_data.leeching else 0,
|
||||
"seeding_size": user_data.seeding_size,
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info": user_data.seeding_info if user_data.seeding_info else [],
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [],
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time
|
||||
}
|
||||
result["user_data"].append(user_data_dict)
|
||||
|
||||
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||
if len(result["user_data"]) > 1:
|
||||
result["user_data"].sort(
|
||||
key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")),
|
||||
reverse=True
|
||||
)
|
||||
result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
result["user_data"] = [result["user_data"][0]]
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询站点用户数据失败: {str(e)}"
|
||||
logger.error(f"查询站点用户数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
82
app/agent/tools/impl/query_sites.py
Normal file
82
app/agent/tools/impl/query_sites.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""查询站点工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySitesInput(BaseModel):
|
||||
"""查询站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter sites by name (partial match, optional)")
|
||||
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询站点"]
|
||||
|
||||
if status != "all":
|
||||
status_map = {"active": "已启用", "inactive": "已禁用"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
# 获取所有站点(按优先级排序)
|
||||
sites = await site_oper.async_list()
|
||||
filtered_sites = []
|
||||
for site in sites:
|
||||
# 按状态过滤
|
||||
if status == "active" and not site.is_active:
|
||||
continue
|
||||
if status == "inactive" and site.is_active:
|
||||
continue
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and name.lower() not in (site.name or "").lower():
|
||||
continue
|
||||
filtered_sites.append(site)
|
||||
if filtered_sites:
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_sites = []
|
||||
for s in filtered_sites:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"domain": s.domain,
|
||||
"url": s.url,
|
||||
"pri": s.pri,
|
||||
"is_active": s.is_active,
|
||||
"downloader": s.downloader,
|
||||
"proxy": s.proxy,
|
||||
"timeout": s.timeout
|
||||
}
|
||||
simplified_sites.append(simplified)
|
||||
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return "未找到相关站点"
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点失败: {e}", exc_info=True)
|
||||
return f"查询站点时发生错误: {str(e)}"
|
||||
|
||||
113
app/agent/tools/impl/query_subscribe_history.py
Normal file
113
app/agent/tools/impl/query_subscribe_history.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""查询订阅历史工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribeHistoryInput(BaseModel):
|
||||
"""查询订阅历史工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all", description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types (default: 'all')")
|
||||
name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)")
|
||||
|
||||
|
||||
class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_history"
|
||||
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Returns up to 30 records."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
parts = ["正在查询订阅历史"]
|
||||
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 根据类型查询
|
||||
if media_type == "all":
|
||||
# 查询所有类型,需要分别查询电影和电视剧
|
||||
movie_history = await SubscribeHistory.async_list_by_type(db, mtype="movie", page=1, count=100)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(db, mtype="tv", page=1, count=100)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
# 按日期排序
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
# 查询指定类型
|
||||
all_history = await SubscribeHistory.async_list_by_type(db, mtype=media_type, page=1, count=100)
|
||||
|
||||
# 按名称过滤
|
||||
filtered_history = []
|
||||
if name:
|
||||
name_lower = name.lower()
|
||||
for record in all_history:
|
||||
if record.name and name_lower in record.name.lower():
|
||||
filtered_history.append(record)
|
||||
else:
|
||||
filtered_history = all_history
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 限制最多30条
|
||||
total_count = len(filtered_history)
|
||||
limited_history = filtered_history[:30]
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in limited_history:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"season": record.season,
|
||||
"tmdbid": record.tmdbid,
|
||||
"doubanid": record.doubanid,
|
||||
"bangumiid": record.bangumiid,
|
||||
"poster": record.poster,
|
||||
"vote": record.vote,
|
||||
"total_episode": record.total_episode,
|
||||
"date": record.date,
|
||||
"username": record.username
|
||||
}
|
||||
# 添加过滤规则信息(如果有)
|
||||
if record.filter:
|
||||
simplified["filter"] = record.filter
|
||||
if record.quality:
|
||||
simplified["quality"] = record.quality
|
||||
if record.resolution:
|
||||
simplified["resolution"] = record.resolution
|
||||
simplified_records.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
|
||||
return f"查询订阅历史时发生错误: {str(e)}"
|
||||
|
||||
113
app/agent/tools/impl/query_subscribe_shares.py
Normal file
113
app/agent/tools/impl/query_subscribe_shares.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""查询订阅分享工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribeSharesInput(BaseModel):
|
||||
"""查询订阅分享工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
name: Optional[str] = Field(None, description="Filter shares by media name (partial match, optional)")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
genre_id: Optional[int] = Field(None, description="Filter by genre ID (optional)")
|
||||
min_rating: Optional[float] = Field(None, description="Minimum rating filter (optional, e.g., 7.5)")
|
||||
max_rating: Optional[float] = Field(None, description="Maximum rating filter (optional, e.g., 10.0)")
|
||||
sort_type: Optional[str] = Field(None, description="Sort type (optional, e.g., 'count', 'rating')")
|
||||
|
||||
|
||||
class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_shares"
|
||||
description: str = "Query shared subscriptions from other users. Shows popular subscriptions shared by the community with filtering and pagination support."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeSharesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
name = kwargs.get("name")
|
||||
page = kwargs.get("page", 1)
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = ["正在查询订阅分享"]
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
if min_rating:
|
||||
parts.append(f"最低评分: {min_rating}")
|
||||
if max_rating:
|
||||
parts.append(f"最高评分: {max_rating}")
|
||||
if page > 1:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: name={name}, page={page}, count={count}, genre_id={genre_id}, "
|
||||
f"min_rating={min_rating}, max_rating={max_rating}, sort_type={sort_type}")
|
||||
|
||||
try:
|
||||
if page is None or page < 1:
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
shares = await subscribe_helper.async_get_shares(
|
||||
name=name,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
|
||||
if not shares:
|
||||
return "未找到订阅分享数据(可能订阅分享功能未启用)"
|
||||
|
||||
# 简化字段,只保留关键信息
|
||||
simplified_shares = []
|
||||
for share in shares:
|
||||
simplified = {
|
||||
"id": share.get("id"),
|
||||
"name": share.get("name"),
|
||||
"year": share.get("year"),
|
||||
"type": share.get("type"),
|
||||
"season": share.get("season"),
|
||||
"tmdbid": share.get("tmdbid"),
|
||||
"doubanid": share.get("doubanid"),
|
||||
"bangumiid": share.get("bangumiid"),
|
||||
"poster": share.get("poster"),
|
||||
"vote": share.get("vote"),
|
||||
"share_title": share.get("share_title"),
|
||||
"share_comment": share.get("share_comment"),
|
||||
"share_user": share.get("share_user"),
|
||||
"fork_count": share.get("fork_count", 0)
|
||||
}
|
||||
# 截断过长的描述
|
||||
if simplified.get("description") and len(simplified["description"]) > 200:
|
||||
simplified["description"] = simplified["description"][:200] + "..."
|
||||
simplified_shares.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_shares, ensure_ascii=False, indent=2)
|
||||
|
||||
pagination_info = f"第 {page} 页,每页 {count} 条,共 {len(simplified_shares)} 条结果"
|
||||
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅分享失败: {e}", exc_info=True)
|
||||
return f"查询订阅分享时发生错误: {str(e)}"
|
||||
|
||||
90
app/agent/tools/impl/query_subscribes.py
Normal file
90
app/agent/tools/impl/query_subscribes.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""查询订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types")
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
|
||||
parts = ["正在查询订阅"]
|
||||
|
||||
# 根据状态过滤条件生成提示
|
||||
if status != "all":
|
||||
status_map = {"R": "已启用", "P": "已禁用"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
# 根据媒体类型过滤条件生成提示
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
try:
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
filtered_subscribes = []
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.state != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != media_type:
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"year": s.year,
|
||||
"type": s.type,
|
||||
"season": s.season,
|
||||
"tmdbid": s.tmdbid,
|
||||
"doubanid": s.doubanid,
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
"last_update": s.last_update,
|
||||
"username": s.username
|
||||
}
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
return "未找到相关订阅"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅失败: {e}", exc_info=True)
|
||||
return f"查询订阅时发生错误: {str(e)}"
|
||||
133
app/agent/tools/impl/query_transfer_history.py
Normal file
133
app/agent/tools/impl/query_transfer_history.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""查询整理历史记录工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
import jieba
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryTransferHistoryInput(BaseModel):
|
||||
"""查询整理历史记录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: Optional[str] = Field(None, description="Search by title (optional, supports partial match)")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter by status: 'success' for successful transfers, 'failed' for failed transfers, 'all' for all records (default: 'all')")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1, each page contains 30 records)")
|
||||
|
||||
|
||||
class QueryTransferHistoryTool(MoviePilotTool):
|
||||
name: str = "query_transfer_history"
|
||||
description: str = "Query file transfer history records. Shows transfer status, source and destination paths, media information, and transfer details. Supports filtering by title and status."
|
||||
args_schema: Type[BaseModel] = QueryTransferHistoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
title = kwargs.get("title")
|
||||
status = kwargs.get("status", "all")
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询整理历史"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if status != "all":
|
||||
status_map = {"success": "成功", "failed": "失败"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
if page > 1:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, title: Optional[str] = None,
|
||||
status: Optional[str] = "all",
|
||||
page: Optional[int] = 1, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, status={status}, page={page}")
|
||||
|
||||
try:
|
||||
# 处理状态参数
|
||||
status_bool = None
|
||||
if status == "success":
|
||||
status_bool = True
|
||||
elif status == "failed":
|
||||
status_bool = False
|
||||
|
||||
# 处理页码参数
|
||||
if page is None or page < 1:
|
||||
page = 1
|
||||
|
||||
# 每页记录数
|
||||
count = 50
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 处理标题搜索
|
||||
if title:
|
||||
# 使用 jieba 分词处理标题
|
||||
words = jieba.cut(title, HMM=False)
|
||||
title_search = "%".join(words)
|
||||
# 查询记录
|
||||
result = await TransferHistory.async_list_by_title(
|
||||
db, title=title_search, page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await TransferHistory.async_count_by_title(
|
||||
db, title=title_search, status=status_bool
|
||||
)
|
||||
else:
|
||||
# 查询所有记录
|
||||
result = await TransferHistory.async_list_by_page(
|
||||
db, page=page, count=count, status=status_bool
|
||||
)
|
||||
total = await TransferHistory.async_count(db, status=status_bool)
|
||||
|
||||
if not result:
|
||||
return "未找到相关整理历史记录"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in result:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"title": record.title,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"category": record.category,
|
||||
"seasons": record.seasons,
|
||||
"episodes": record.episodes,
|
||||
"src": record.src,
|
||||
"dest": record.dest,
|
||||
"mode": record.mode,
|
||||
"status": "成功" if record.status else "失败",
|
||||
"date": record.date,
|
||||
"downloader": record.downloader,
|
||||
"download_hash": record.download_hash
|
||||
}
|
||||
# 如果失败,添加错误信息
|
||||
if not record.status and record.errmsg:
|
||||
simplified["errmsg"] = record.errmsg
|
||||
# 添加媒体ID信息(如果有)
|
||||
if record.tmdbid:
|
||||
simplified["tmdbid"] = record.tmdbid
|
||||
if record.imdbid:
|
||||
simplified["imdbid"] = record.imdbid
|
||||
if record.doubanid:
|
||||
simplified["doubanid"] = record.doubanid
|
||||
simplified_records.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
|
||||
# 计算总页数
|
||||
total_pages = (total + count - 1) // count if total > 0 else 1
|
||||
|
||||
# 构建分页信息
|
||||
pagination_info = f"第 {page}/{total_pages} 页,共 {total} 条记录(每页 {count} 条)"
|
||||
|
||||
return f"{pagination_info}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询整理历史记录失败: {e}", exc_info=True)
|
||||
return f"查询整理历史记录时发生错误: {str(e)}"
|
||||
128
app/agent/tools/impl/query_workflows.py
Normal file
128
app/agent/tools/impl/query_workflows.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""查询工作流工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryWorkflowsInput(BaseModel):
|
||||
"""查询工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
state: Optional[str] = Field("all", description="Filter workflows by state: 'W' for waiting, 'R' for running, 'P' for paused, 'S' for success, 'F' for failed, 'all' for all workflows (default: 'all')")
|
||||
name: Optional[str] = Field(None, description="Filter workflows by name (partial match, optional)")
|
||||
trigger_type: Optional[str] = Field("all", description="Filter workflows by trigger type: 'timer' for scheduled, 'event' for event-triggered, 'manual' for manual, 'all' for all types (default: 'all')")
|
||||
|
||||
|
||||
class QueryWorkflowsTool(MoviePilotTool):
|
||||
name: str = "query_workflows"
|
||||
description: str = "Query workflow list and status. Shows workflow name, description, trigger type, state, execution count, and other workflow details. Supports filtering by state, name, and trigger type."
|
||||
args_schema: Type[BaseModel] = QueryWorkflowsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
state = kwargs.get("state", "all")
|
||||
name = kwargs.get("name")
|
||||
trigger_type = kwargs.get("trigger_type", "all")
|
||||
|
||||
parts = ["正在查询工作流"]
|
||||
|
||||
if state != "all":
|
||||
state_map = {"W": "等待", "R": "运行中", "P": "暂停", "S": "成功", "F": "失败"}
|
||||
parts.append(f"状态: {state_map.get(state, state)}")
|
||||
|
||||
if trigger_type != "all":
|
||||
trigger_map = {"timer": "定时触发", "event": "事件触发", "manual": "手动触发"}
|
||||
parts.append(f"触发类型: {trigger_map.get(trigger_type, trigger_type)}")
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, state: Optional[str] = "all",
|
||||
name: Optional[str] = None,
|
||||
trigger_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: state={state}, name={name}, trigger_type={trigger_type}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
workflows = await workflow_oper.async_list()
|
||||
|
||||
# 过滤工作流
|
||||
filtered_workflows = []
|
||||
for wf in workflows:
|
||||
# 按状态过滤
|
||||
if state != "all" and wf.state != state:
|
||||
continue
|
||||
|
||||
# 按触发类型过滤
|
||||
if trigger_type != "all":
|
||||
if trigger_type == "timer" and wf.trigger_type not in ["timer", None]:
|
||||
continue
|
||||
elif trigger_type == "event" and wf.trigger_type != "event":
|
||||
continue
|
||||
elif trigger_type == "manual" and wf.trigger_type != "manual":
|
||||
continue
|
||||
|
||||
# 按名称过滤(部分匹配)
|
||||
if name and wf.name and name.lower() not in wf.name.lower():
|
||||
continue
|
||||
|
||||
filtered_workflows.append(wf)
|
||||
|
||||
if not filtered_workflows:
|
||||
return "未找到相关工作流"
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_workflows = []
|
||||
for wf in filtered_workflows:
|
||||
# 状态说明
|
||||
state_map = {
|
||||
"W": "等待",
|
||||
"R": "运行中",
|
||||
"P": "暂停",
|
||||
"S": "成功",
|
||||
"F": "失败"
|
||||
}
|
||||
state_desc = state_map.get(wf.state, wf.state)
|
||||
|
||||
# 触发类型说明
|
||||
trigger_type_map = {
|
||||
"timer": "定时触发",
|
||||
"event": "事件触发",
|
||||
"manual": "手动触发"
|
||||
}
|
||||
trigger_type_desc = trigger_type_map.get(wf.trigger_type, wf.trigger_type or "定时触发")
|
||||
|
||||
simplified = {
|
||||
"id": wf.id,
|
||||
"name": wf.name,
|
||||
"description": wf.description,
|
||||
"trigger_type": trigger_type_desc,
|
||||
"state": state_desc,
|
||||
"run_count": wf.run_count,
|
||||
"timer": wf.timer,
|
||||
"event_type": wf.event_type,
|
||||
"add_time": wf.add_time,
|
||||
"last_time": wf.last_time,
|
||||
"current_action": wf.current_action
|
||||
}
|
||||
# 如果有结果,添加结果信息
|
||||
if wf.result:
|
||||
simplified["result"] = wf.result
|
||||
simplified_workflows.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_workflows, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询工作流失败: {e}", exc_info=True)
|
||||
return f"查询工作流时发生错误: {str(e)}"
|
||||
|
||||
162
app/agent/tools/impl/recognize_media.py
Normal file
162
app/agent/tools/impl/recognize_media.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""识别媒体信息工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RecognizeMediaInput(BaseModel):
|
||||
"""识别媒体信息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: Optional[str] = Field(None, description="The title of the torrent/media to recognize (required for torrent recognition)")
|
||||
subtitle: Optional[str] = Field(None, description="The subtitle or description of the torrent (optional, helps improve recognition accuracy)")
|
||||
path: Optional[str] = Field(None, description="The file path to recognize (required for file recognition, mutually exclusive with title)")
|
||||
|
||||
|
||||
class RecognizeMediaTool(MoviePilotTool):
|
||||
name: str = "recognize_media"
|
||||
description: str = "Extract/identify media information from torrent titles or file paths (NOT database search). Supports two modes: 1) Extract from torrent title and optional subtitle, 2) Extract from file path. Returns detailed media information. Use 'search_media' to search TMDB database, or 'scrape_metadata' to generate metadata files for existing files."
|
||||
args_schema: Type[BaseModel] = RecognizeMediaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据识别参数生成友好的提示消息"""
|
||||
title = kwargs.get("title")
|
||||
subtitle = kwargs.get("subtitle")
|
||||
path = kwargs.get("path")
|
||||
|
||||
if path:
|
||||
message = f"正在识别文件媒体信息: {path}"
|
||||
elif title:
|
||||
message = f"正在识别种子媒体信息: {title}"
|
||||
if subtitle:
|
||||
message += f" ({subtitle})"
|
||||
else:
|
||||
message = "正在识别媒体信息"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: Optional[str] = None, subtitle: Optional[str] = None,
|
||||
path: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, subtitle={subtitle}, path={path}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
context = None
|
||||
|
||||
# 根据提供的参数选择识别方式
|
||||
if path:
|
||||
# 文件路径识别
|
||||
if not path:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "文件路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
context = await media_chain.async_recognize_by_path(path)
|
||||
if context:
|
||||
return self._format_context_result(context, "文件")
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无法识别文件媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif title:
|
||||
# 种子标题识别
|
||||
metainfo = MetaInfo(title, subtitle)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(metainfo)
|
||||
if mediainfo:
|
||||
context = Context(meta_info=metainfo, media_info=mediainfo)
|
||||
return self._format_context_result(context, "种子")
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无法识别种子媒体信息: {title}",
|
||||
"title": title,
|
||||
"subtitle": subtitle
|
||||
}, ensure_ascii=False)
|
||||
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "必须提供 title(标题)或 path(文件路径)参数之一"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"识别媒体信息失败: {str(e)}"
|
||||
logger.error(f"识别媒体信息失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message
|
||||
}, ensure_ascii=False)
|
||||
|
||||
def _format_context_result(self, context: Context, source_type: str) -> str:
|
||||
"""格式化识别结果为JSON字符串"""
|
||||
if not context:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "识别结果为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
context_dict = context.to_dict()
|
||||
media_info = context_dict.get("media_info")
|
||||
meta_info = context_dict.get("meta_info")
|
||||
|
||||
# 构建简化的结果
|
||||
result = {
|
||||
"success": True,
|
||||
"source_type": source_type,
|
||||
"media_info": None,
|
||||
"meta_info": None
|
||||
}
|
||||
|
||||
# 处理媒体信息
|
||||
if media_info:
|
||||
result["media_info"] = {
|
||||
"title": media_info.get("title"),
|
||||
"en_title": media_info.get("en_title"),
|
||||
"year": media_info.get("year"),
|
||||
"type": media_info.get("type"),
|
||||
"season": media_info.get("season"),
|
||||
"tmdb_id": media_info.get("tmdb_id"),
|
||||
"imdb_id": media_info.get("imdb_id"),
|
||||
"douban_id": media_info.get("douban_id"),
|
||||
"bangumi_id": media_info.get("bangumi_id"),
|
||||
"overview": media_info.get("overview"),
|
||||
"vote_average": media_info.get("vote_average"),
|
||||
"poster_path": media_info.get("poster_path"),
|
||||
"backdrop_path": media_info.get("backdrop_path"),
|
||||
"detail_link": media_info.get("detail_link"),
|
||||
"title_year": media_info.get("title_year"),
|
||||
"source": media_info.get("source")
|
||||
}
|
||||
|
||||
# 处理元数据信息
|
||||
if meta_info:
|
||||
result["meta_info"] = {
|
||||
"name": meta_info.get("name"),
|
||||
"title": meta_info.get("title"),
|
||||
"year": meta_info.get("year"),
|
||||
"type": meta_info.get("type"),
|
||||
"begin_season": meta_info.get("begin_season"),
|
||||
"end_season": meta_info.get("end_season"),
|
||||
"begin_episode": meta_info.get("begin_episode"),
|
||||
"end_episode": meta_info.get("end_episode"),
|
||||
"total_episode": meta_info.get("total_episode"),
|
||||
"part": meta_info.get("part"),
|
||||
"season_episode": meta_info.get("season_episode"),
|
||||
"episode_list": meta_info.get("episode_list"),
|
||||
"tmdbid": meta_info.get("tmdbid"),
|
||||
"doubanid": meta_info.get("doubanid")
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
53
app/agent/tools/impl/run_scheduler.py
Normal file
53
app/agent/tools/impl/run_scheduler.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""运行定时服务工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
"""运行定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)")
|
||||
|
||||
|
||||
class RunSchedulerTool(MoviePilotTool):
|
||||
name: str = "run_scheduler"
|
||||
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
|
||||
args_schema: Type[BaseModel] = RunSchedulerInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据运行参数生成友好的提示消息"""
|
||||
job_id = kwargs.get("job_id", "")
|
||||
return f"正在运行定时服务 (ID: {job_id})"
|
||||
|
||||
async def run(self, job_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: job_id={job_id}")
|
||||
|
||||
try:
|
||||
scheduler = Scheduler()
|
||||
# 检查定时服务是否存在
|
||||
schedulers = scheduler.list()
|
||||
job_exists = False
|
||||
job_name = None
|
||||
for s in schedulers:
|
||||
if s.id == job_id:
|
||||
job_exists = True
|
||||
job_name = s.name
|
||||
break
|
||||
|
||||
if not job_exists:
|
||||
return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务"
|
||||
|
||||
# 运行定时服务
|
||||
scheduler.start(job_id)
|
||||
|
||||
return f"成功触发定时服务:{job_name} (ID: {job_id})"
|
||||
except Exception as e:
|
||||
logger.error(f"运行定时服务失败: {e}", exc_info=True)
|
||||
return f"运行定时服务时发生错误: {str(e)}"
|
||||
|
||||
72
app/agent/tools/impl/run_workflow.py
Normal file
72
app/agent/tools/impl/run_workflow.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""执行工作流工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.workflow import WorkflowChain
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
workflow_identifier: str = Field(..., description="Workflow identifier: can be workflow ID (integer as string) or workflow name")
|
||||
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
|
||||
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
description: str = "Execute a specific workflow manually. Can run workflow by ID or name. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据工作流参数生成友好的提示消息"""
|
||||
workflow_identifier = kwargs.get("workflow_identifier", "")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
message = f"正在执行工作流: {workflow_identifier}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
message += " (从头开始)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, workflow_identifier: str,
|
||||
from_begin: Optional[bool] = True, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_identifier={workflow_identifier}, from_begin={from_begin}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
|
||||
# 尝试解析为工作流ID
|
||||
workflow = None
|
||||
if workflow_identifier.isdigit():
|
||||
# 如果是数字,尝试作为工作流ID查询
|
||||
workflow = await workflow_oper.async_get(int(workflow_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称查询
|
||||
if not workflow:
|
||||
workflow = await workflow_oper.async_get_by_name(workflow_identifier)
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_identifier},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
# 执行工作流
|
||||
workflow_chain = WorkflowChain()
|
||||
state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin)
|
||||
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
return f"工作流执行成功:{workflow.name} (ID: {workflow.id})"
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流失败: {e}", exc_info=True)
|
||||
return f"执行工作流时发生错误: {str(e)}"
|
||||
|
||||
119
app/agent/tools/impl/scrape_metadata.py
Normal file
119
app/agent/tools/impl/scrape_metadata.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""刮削媒体元数据工具"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.config import global_vars
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem
|
||||
|
||||
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False,
|
||||
description="Whether to overwrite existing metadata files (default: False)")
|
||||
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
name: str = "scrape_metadata"
|
||||
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
|
||||
args_schema: Type[BaseModel] = ScrapeMetadataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据刮削参数生成友好的提示消息"""
|
||||
path = kwargs.get("path", "")
|
||||
storage = kwargs.get("storage", "local")
|
||||
overwrite = kwargs.get("overwrite", False)
|
||||
|
||||
message = f"正在刮削媒体元数据: {path}"
|
||||
if storage != "local":
|
||||
message += f" [存储: {storage}]"
|
||||
if overwrite:
|
||||
message += " [覆盖模式]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
|
||||
|
||||
try:
|
||||
# 验证路径
|
||||
if not path:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "刮削路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 创建 FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage,
|
||||
path=path,
|
||||
type="file" if Path(path).suffix else "dir"
|
||||
)
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
if not scrape_path.exists():
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削路径不存在: {path}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
scrape_path = Path(path)
|
||||
meta = MetaInfoPath(scrape_path)
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 在线程池中执行同步的刮削操作
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: media_chain.scrape_metadata(
|
||||
fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
overwrite=overwrite
|
||||
)
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season
|
||||
}
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"刮削媒体元数据失败: {str(e)}"
|
||||
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
104
app/agent/tools/impl/search_media.py
Normal file
104
app/agent/tools/impl/search_media.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""搜索媒体工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchMediaInput(BaseModel):
|
||||
"""搜索媒体工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows and anime (optional, only applicable for series)")
|
||||
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
description: str = "Search TMDB database for media resources (movies, TV shows, anime, etc.) by title, year, type, and other criteria. Returns detailed media information from TMDB. Use 'recognize_media' to extract info from torrent titles/file paths, or 'scrape_metadata' to generate metadata files."
|
||||
args_schema: Type[BaseModel] = SearchMediaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year")
|
||||
media_type = kwargs.get("media_type")
|
||||
season = kwargs.get("season")
|
||||
|
||||
message = f"正在搜索媒体: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 使用 MediaChain.search 方法
|
||||
meta, results = await media_chain.async_search(title=title)
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
if year and result.year != year:
|
||||
continue
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if season and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
if filtered_results:
|
||||
# 限制最多30条结果
|
||||
total_count = len(filtered_results)
|
||||
limited_results = filtered_results[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
simplified = {
|
||||
"title": r.title,
|
||||
"en_title": r.en_title,
|
||||
"year": r.year,
|
||||
"type": r.type.value if r.type else None,
|
||||
"season": r.season,
|
||||
"tmdb_id": r.tmdb_id,
|
||||
"imdb_id": r.imdb_id,
|
||||
"douban_id": r.douban_id,
|
||||
"overview": r.overview[:200] + "..." if r.overview and len(r.overview) > 200 else r.overview,
|
||||
"vote_average": r.vote_average,
|
||||
"poster_path": r.poster_path,
|
||||
"detail_link": r.detail_link
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
else:
|
||||
return f"未找到相关媒体资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索媒体失败: {str(e)}"
|
||||
logger.error(f"搜索媒体失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
83
app/agent/tools/impl/search_person.py
Normal file
83
app/agent/tools/impl/search_person.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""搜索人物工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchPersonInput(BaseModel):
|
||||
"""搜索人物工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
name: str = Field(..., description="The name of the person to search for (e.g., 'Tom Hanks', '周杰伦')")
|
||||
|
||||
|
||||
class SearchPersonTool(MoviePilotTool):
|
||||
name: str = "search_person"
|
||||
description: str = "Search for person information including actors, directors, etc. Supports searching by name. Returns detailed person information from TMDB, Douban, or Bangumi database."
|
||||
args_schema: Type[BaseModel] = SearchPersonInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
name = kwargs.get("name", "")
|
||||
return f"正在搜索人物: {name}"
|
||||
|
||||
async def run(self, name: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: name={name}")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
# 使用 MediaChain.async_search_persons 方法搜索人物
|
||||
persons = await media_chain.async_search_persons(name=name)
|
||||
|
||||
if persons:
|
||||
# 限制最多30条结果
|
||||
total_count = len(persons)
|
||||
limited_persons = persons[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for person in limited_persons:
|
||||
simplified = {
|
||||
"name": person.name,
|
||||
"id": person.id,
|
||||
"source": person.source,
|
||||
"profile_path": person.profile_path,
|
||||
"original_name": person.original_name,
|
||||
"known_for_department": person.known_for_department,
|
||||
"popularity": person.popularity,
|
||||
"biography": person.biography[:200] + "..." if person.biography and len(person.biography) > 200 else person.biography,
|
||||
"birthday": person.birthday,
|
||||
"deathday": person.deathday,
|
||||
"place_of_birth": person.place_of_birth,
|
||||
"gender": person.gender,
|
||||
"imdb_id": person.imdb_id,
|
||||
"also_known_as": person.also_known_as[:5] if person.also_known_as else [], # 限制别名数量
|
||||
}
|
||||
# 添加豆瓣特有字段
|
||||
if person.source == "douban":
|
||||
simplified["url"] = person.url
|
||||
simplified["avatar"] = person.avatar
|
||||
simplified["latin_name"] = person.latin_name
|
||||
simplified["roles"] = person.roles[:5] if person.roles else [] # 限制角色数量
|
||||
# 添加Bangumi特有字段
|
||||
if person.source == "bangumi":
|
||||
simplified["career"] = person.career
|
||||
simplified["relation"] = person.relation
|
||||
|
||||
simplified_results.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关人物信息: {name}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索人物失败: {str(e)}"
|
||||
logger.error(f"搜索人物失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
85
app/agent/tools/impl/search_person_credits.py
Normal file
85
app/agent/tools/impl/search_person_credits.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""搜索演员参演作品工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchPersonCreditsInput(BaseModel):
|
||||
"""搜索演员参演作品工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
person_id: int = Field(..., description="The ID of the person/actor to search for credits (e.g., 31 for Tom Hanks in TMDB)")
|
||||
source: str = Field(..., description="The data source: 'tmdb' for TheMovieDB, 'douban' for Douban, 'bangumi' for Bangumi")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
|
||||
|
||||
class SearchPersonCreditsTool(MoviePilotTool):
|
||||
name: str = "search_person_credits"
|
||||
description: str = "Search for films and TV shows that a person/actor has appeared in (filmography). Supports searching by person ID from TMDB, Douban, or Bangumi database. Returns a list of media works the person has participated in."
|
||||
args_schema: Type[BaseModel] = SearchPersonCreditsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
person_id = kwargs.get("person_id", "")
|
||||
source = kwargs.get("source", "")
|
||||
return f"正在搜索人物参演作品: {source} ID {person_id}"
|
||||
|
||||
async def run(self, person_id: int, source: str, page: Optional[int] = 1, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: person_id={person_id}, source={source}, page={page}")
|
||||
|
||||
try:
|
||||
# 根据source选择相应的chain
|
||||
if source.lower() == "tmdb":
|
||||
tmdb_chain = TmdbChain()
|
||||
medias = await tmdb_chain.async_person_credits(person_id=person_id, page=page)
|
||||
elif source.lower() == "douban":
|
||||
douban_chain = DoubanChain()
|
||||
medias = await douban_chain.async_person_credits(person_id=person_id, page=page)
|
||||
elif source.lower() == "bangumi":
|
||||
bangumi_chain = BangumiChain()
|
||||
medias = await bangumi_chain.async_person_credits(person_id=person_id)
|
||||
else:
|
||||
return f"不支持的数据源: {source}。支持的数据源: tmdb, douban, bangumi"
|
||||
|
||||
if medias:
|
||||
# 限制最多30条结果
|
||||
total_count = len(medias)
|
||||
limited_medias = medias[:30]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for media in limited_medias:
|
||||
simplified = {
|
||||
"title": media.title,
|
||||
"en_title": media.en_title,
|
||||
"year": media.year,
|
||||
"type": media.type.value if media.type else None,
|
||||
"season": media.season,
|
||||
"tmdb_id": media.tmdb_id,
|
||||
"imdb_id": media.imdb_id,
|
||||
"douban_id": media.douban_id,
|
||||
"overview": media.overview[:200] + "..." if media.overview and len(media.overview) > 200 else media.overview,
|
||||
"vote_average": media.vote_average,
|
||||
"poster_path": media.poster_path,
|
||||
"backdrop_path": media.backdrop_path,
|
||||
"detail_link": media.detail_link
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到人物 ID {person_id} ({source}) 的参演作品"
|
||||
except Exception as e:
|
||||
error_message = f"搜索演员参演作品失败: {str(e)}"
|
||||
logger.error(f"搜索演员参演作品失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
127
app/agent/tools/impl/search_subscribe.py
Normal file
127
app/agent/tools/impl/search_subscribe.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""搜索订阅缺失剧集工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import global_vars
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchSubscribeInput(BaseModel):
|
||||
"""搜索订阅缺失剧集工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes")
|
||||
manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply for this search (optional, use query_rule_groups tool to get available rule groups. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
|
||||
|
||||
class SearchSubscribeTool(MoviePilotTool):
|
||||
name: str = "search_subscribe"
|
||||
description: str = "Search for missing episodes/resources for a specific subscription. This tool will search torrent sites for the missing episodes of the subscription and automatically download matching resources. Use this when a user wants to search for missing episodes of a specific subscription."
|
||||
args_schema: Type[BaseModel] = SearchSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
manual = kwargs.get("manual", False)
|
||||
|
||||
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
|
||||
if manual:
|
||||
message += "(手动搜索)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
|
||||
filter_groups: Optional[List[str]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
|
||||
|
||||
try:
|
||||
# 先验证订阅是否存在
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribe = subscribe_oper.get(subscribe_id)
|
||||
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 获取订阅信息用于返回
|
||||
subscribe_info = {
|
||||
"id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"year": subscribe.year,
|
||||
"type": subscribe.type,
|
||||
"season": subscribe.season,
|
||||
"state": subscribe.state,
|
||||
"total_episode": subscribe.total_episode,
|
||||
"lack_episode": subscribe.lack_episode,
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
}
|
||||
|
||||
# 检查订阅状态
|
||||
if subscribe.state == "S":
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 已暂停,无法搜索",
|
||||
"subscribe": subscribe_info
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 如果提供了 filter_groups 参数,先更新订阅的规则组
|
||||
if filter_groups is not None:
|
||||
subscribe_oper.update(subscribe_id, {"filter_groups": filter_groups})
|
||||
logger.info(f"更新订阅 #{subscribe_id} 的规则组为: {filter_groups}")
|
||||
|
||||
# 调用 SubscribeChain 的 search 方法
|
||||
# search 方法是同步的,需要在异步环境中运行
|
||||
subscribe_chain = SubscribeChain()
|
||||
|
||||
# 在线程池中执行同步的搜索操作
|
||||
# 当 sid 有值时,state 参数会被忽略,直接处理该订阅
|
||||
await global_vars.loop.run_in_executor(
|
||||
None,
|
||||
lambda: subscribe_chain.search(
|
||||
sid=subscribe_id,
|
||||
state='R', # 默认状态,当 sid 有值时此参数会被忽略
|
||||
manual=manual
|
||||
)
|
||||
)
|
||||
|
||||
# 重新获取订阅信息以获取更新后的状态
|
||||
updated_subscribe = subscribe_oper.get(subscribe_id)
|
||||
if updated_subscribe:
|
||||
subscribe_info.update({
|
||||
"state": updated_subscribe.state,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"last_update": updated_subscribe.last_update,
|
||||
"filter_groups": updated_subscribe.filter_groups
|
||||
})
|
||||
|
||||
# 如果提供了规则组,会在返回信息中显示
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 搜索完成",
|
||||
"subscribe": subscribe_info
|
||||
}
|
||||
|
||||
if filter_groups is not None:
|
||||
result["message"] += f"(已应用规则组: {', '.join(filter_groups)})"
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索订阅缺失剧集失败: {str(e)}"
|
||||
logger.error(f"搜索订阅缺失剧集失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
142
app/agent/tools/impl/search_torrents.py
Normal file
142
app/agent/tools/impl/search_torrents.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""搜索种子工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
"""搜索种子工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(...,
|
||||
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
|
||||
filter_pattern: Optional[str] = Field(None,
|
||||
description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)")
|
||||
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
|
||||
args_schema: Type[BaseModel] = SearchTorrentsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year")
|
||||
media_type = kwargs.get("media_type")
|
||||
season = kwargs.get("season")
|
||||
filter_pattern = kwargs.get("filter_pattern")
|
||||
|
||||
message = f"正在搜索种子: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
if filter_pattern:
|
||||
message += f" 过滤: {filter_pattern}"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}")
|
||||
|
||||
try:
|
||||
search_chain = SearchChain()
|
||||
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
|
||||
filtered_torrents = []
|
||||
# 编译正则表达式(如果提供)
|
||||
regex_pattern = None
|
||||
if filter_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(filter_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
for torrent in torrents:
|
||||
# torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性
|
||||
if year and torrent.meta_info and torrent.meta_info.year != year:
|
||||
continue
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
if not regex_pattern.search(torrent.torrent_info.title):
|
||||
continue
|
||||
filtered_torrents.append(torrent)
|
||||
|
||||
if filtered_torrents:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_torrents)
|
||||
limited_torrents = filtered_torrents[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_torrents = []
|
||||
for t in limited_torrents:
|
||||
simplified = {}
|
||||
# 精简 torrent_info
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
"enclosure": t.torrent_info.enclosure,
|
||||
"page_url": t.torrent_info.page_url,
|
||||
"volume_factor": t.torrent_info.volume_factor,
|
||||
"pubdate": t.torrent_info.pubdate
|
||||
}
|
||||
# 精简 media_info
|
||||
if t.media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": t.media_info.title,
|
||||
"en_title": t.media_info.en_title,
|
||||
"year": t.media_info.year,
|
||||
"type": t.media_info.type.value if t.media_info.type else None,
|
||||
"season": t.media_info.season,
|
||||
"tmdb_id": t.media_info.tmdb_id
|
||||
}
|
||||
# 精简 meta_info
|
||||
if t.meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": t.meta_info.name,
|
||||
"cn_name": t.meta_info.cn_name,
|
||||
"en_name": t.meta_info.en_name,
|
||||
"year": t.meta_info.year,
|
||||
"type": t.meta_info.type.value if t.meta_info.type else None,
|
||||
"begin_season": t.meta_info.begin_season
|
||||
}
|
||||
simplified_torrents.append(simplified)
|
||||
result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关种子资源: {title}"
|
||||
except Exception as e:
|
||||
error_message = f"搜索种子时发生错误: {str(e)}"
|
||||
logger.error(f"搜索种子失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
193
app/agent/tools/impl/search_web.py
Normal file
193
app/agent/tools/impl/search_web.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""搜索网络内容工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.http import AsyncRequestUtils
|
||||
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
query: str = Field(..., description="The search query string to search for on the web")
|
||||
max_results: Optional[int] = Field(5, description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
name: str = "search_web"
|
||||
description: str = "Search the web for information when you need to find current information, facts, or references that you're uncertain about. Returns search results with titles, snippets, and URLs. Use this tool to get up-to-date information from the internet."
|
||||
args_schema: Type[BaseModel] = SearchWebInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 5)
|
||||
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
max_results: 最大返回结果数(默认5,最大10)
|
||||
|
||||
Returns:
|
||||
格式化的搜索结果JSON字符串
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 5), 10)
|
||||
|
||||
# 使用DuckDuckGo API进行搜索
|
||||
search_results = await self._search_duckduckgo_api(query, max_results)
|
||||
|
||||
if not search_results:
|
||||
return f"未找到与 '{query}' 相关的搜索结果"
|
||||
|
||||
# 裁剪结果以避免占用过多上下文
|
||||
formatted_results = self._format_and_truncate_results(search_results, max_results)
|
||||
|
||||
result_json = json.dumps(formatted_results, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索网络内容失败: {str(e)}"
|
||||
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def _search_duckduckgo_api(query: str, max_results: int) -> list:
|
||||
"""
|
||||
使用DuckDuckGo API进行搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
try:
|
||||
# DuckDuckGo Instant Answer API
|
||||
api_url = "https://api.duckduckgo.com/"
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": "1",
|
||||
"skip_disambig": "1"
|
||||
}
|
||||
|
||||
# 使用代理(如果配置了)
|
||||
http_utils = AsyncRequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
data = await http_utils.get_json(api_url, params=params)
|
||||
|
||||
results = []
|
||||
|
||||
if data:
|
||||
# 处理AbstractText(摘要)
|
||||
if data.get("AbstractText"):
|
||||
results.append({
|
||||
"title": data.get("Heading", query),
|
||||
"snippet": data.get("AbstractText", ""),
|
||||
"url": data.get("AbstractURL", ""),
|
||||
"source": "DuckDuckGo Abstract"
|
||||
})
|
||||
|
||||
# 处理RelatedTopics(相关主题)
|
||||
related_topics = data.get("RelatedTopics", [])
|
||||
for topic in related_topics[:max_results - len(results)]:
|
||||
if isinstance(topic, dict):
|
||||
text = topic.get("Text", "")
|
||||
first_url = topic.get("FirstURL", "")
|
||||
if text and first_url:
|
||||
# 提取标题(通常在" - "之前)
|
||||
title = text.split(" - ")[0] if " - " in text else text[:100]
|
||||
snippet = text
|
||||
|
||||
results.append({
|
||||
"title": title.strip(),
|
||||
"snippet": snippet,
|
||||
"url": first_url,
|
||||
"source": "DuckDuckGo Related"
|
||||
})
|
||||
|
||||
# 处理Results(搜索结果)
|
||||
api_results = data.get("Results", [])
|
||||
for result in api_results[:max_results - len(results)]:
|
||||
if isinstance(result, dict):
|
||||
title = result.get("Text", "")
|
||||
url = result.get("FirstURL", "")
|
||||
if title and url:
|
||||
results.append({
|
||||
"title": title,
|
||||
"snippet": result.get("Text", ""),
|
||||
"url": url,
|
||||
"source": "DuckDuckGo Results"
|
||||
})
|
||||
|
||||
return results[:max_results]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo API搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: list, max_results: int) -> dict:
|
||||
"""
|
||||
格式化并裁剪搜索结果以避免占用过多上下文
|
||||
|
||||
Args:
|
||||
results: 原始搜索结果列表
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
格式化后的结果字典
|
||||
"""
|
||||
formatted = {
|
||||
"total_results": len(results),
|
||||
"results": []
|
||||
}
|
||||
|
||||
# 限制结果数量
|
||||
limited_results = results[:max_results]
|
||||
|
||||
for idx, result in enumerate(limited_results, 1):
|
||||
title = result.get("title", "")[:200] # 限制标题长度
|
||||
snippet = result.get("snippet", "")
|
||||
url = result.get("url", "")
|
||||
source = result.get("source", "Unknown")
|
||||
|
||||
# 裁剪摘要,避免过长
|
||||
max_snippet_length = 300 # 每个摘要最多300字符
|
||||
if len(snippet) > max_snippet_length:
|
||||
snippet = snippet[:max_snippet_length] + "..."
|
||||
|
||||
# 清理文本,移除多余的空白字符
|
||||
snippet = re.sub(r'\s+', ' ', snippet).strip()
|
||||
|
||||
formatted["results"].append({
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source
|
||||
})
|
||||
|
||||
# 添加提示信息
|
||||
if len(results) > max_results:
|
||||
formatted["note"] = f"注意:共找到 {len(results)} 条结果,为节省上下文空间,仅显示前 {max_results} 条结果。"
|
||||
|
||||
return formatted
|
||||
45
app/agent/tools/impl/send_message.py
Normal file
45
app/agent/tools/impl/send_message.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""发送消息工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
message_type = kwargs.get("message_type", "info")
|
||||
|
||||
type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"}
|
||||
type_desc = type_map.get(message_type, message_type)
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
return f"正在发送{type_desc}消息: {message}"
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return f"发送消息时发生错误: {str(e)}"
|
||||
72
app/agent/tools/impl/test_site.py
Normal file
72
app/agent/tools/impl/test_site.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""测试站点连通性工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class TestSiteInput(BaseModel):
|
||||
"""测试站点连通性工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
|
||||
|
||||
class TestSiteTool(MoviePilotTool):
|
||||
name: str = "test_site"
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
args_schema: Type[BaseModel] = TestSiteInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据测试参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
return f"正在测试站点连通性: {site_identifier}"
|
||||
|
||||
async def run(self, site_identifier: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
# 测试站点连通性
|
||||
status, message = site_chain.test(site.domain)
|
||||
|
||||
if status:
|
||||
return f"站点连通性测试成功:{site.name} ({site.domain})\n{message}"
|
||||
else:
|
||||
return f"站点连通性测试失败:{site.name} ({site.domain})\n{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"测试站点连通性失败: {e}", exc_info=True)
|
||||
return f"测试站点连通性时发生错误: {str(e)}"
|
||||
|
||||
134
app/agent/tools/impl/transfer_file.py
Normal file
134
app/agent/tools/impl/transfer_file.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""整理文件或目录工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem, MediaType
|
||||
|
||||
|
||||
class TransferFileInput(BaseModel):
|
||||
"""整理文件或目录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
|
||||
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
|
||||
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
|
||||
media_type: Optional[str] = Field(None, description="Media type: '电影' for films, '电视剧' for television series (optional, will be auto-detected if not specified)")
|
||||
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
|
||||
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)")
|
||||
background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)")
|
||||
|
||||
|
||||
class TransferFileTool(MoviePilotTool):
|
||||
name: str = "transfer_file"
|
||||
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
|
||||
args_schema: Type[BaseModel] = TransferFileInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据整理参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
media_type = kwargs.get("media_type")
|
||||
transfer_type = kwargs.get("transfer_type")
|
||||
background = kwargs.get("background", False)
|
||||
|
||||
message = f"正在整理文件: {file_path}"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if transfer_type:
|
||||
transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"}
|
||||
message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}"
|
||||
if background:
|
||||
message += " [后台运行]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, file_path: str, storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, "
|
||||
f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, "
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}")
|
||||
|
||||
try:
|
||||
if not file_path:
|
||||
return "错误:必须提供文件或目录路径"
|
||||
|
||||
# 规范化路径
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
file_path = str(Path(file_path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not file_path.startswith("/"):
|
||||
file_path = "/" + file_path
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=file_path,
|
||||
type="dir" if file_path.endswith("/") else "file"
|
||||
)
|
||||
|
||||
# 处理目标路径
|
||||
target_path_obj = None
|
||||
if target_path:
|
||||
target_path_obj = Path(target_path)
|
||||
|
||||
# 处理媒体类型
|
||||
mtype = None
|
||||
if media_type:
|
||||
try:
|
||||
mtype = MediaType(media_type)
|
||||
except ValueError:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
# 调用整理方法
|
||||
transfer_chain = TransferChain()
|
||||
state, errormsg = transfer_chain.manual_transfer(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
target_path=target_path_obj,
|
||||
tmdbid=tmdbid,
|
||||
doubanid=doubanid,
|
||||
mtype=mtype,
|
||||
season=season,
|
||||
transfer_type=transfer_type,
|
||||
background=background
|
||||
)
|
||||
|
||||
if not state:
|
||||
# 处理错误信息
|
||||
if isinstance(errormsg, list):
|
||||
error_text = f"整理完成,{len(errormsg)} 个文件转移失败"
|
||||
if errormsg:
|
||||
error_text += f":\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误
|
||||
if len(errormsg) > 5:
|
||||
error_text += f"\n... 还有 {len(errormsg) - 5} 个错误"
|
||||
else:
|
||||
error_text = str(errormsg)
|
||||
return f"整理失败:{error_text}"
|
||||
else:
|
||||
if background:
|
||||
return f"整理任务已提交到后台运行:{file_path}"
|
||||
else:
|
||||
return f"整理成功:{file_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"整理文件失败: {e}", exc_info=True)
|
||||
return f"整理文件时发生错误: {str(e)}"
|
||||
|
||||
203
app/agent/tools/impl/update_site.py
Normal file
203
app/agent/tools/impl/update_site.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""更新站点工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.site import Site
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class UpdateSiteInput(BaseModel):
|
||||
"""更新站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to update")
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||||
token: Optional[str] = Field(None, description="API token (optional)")
|
||||
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||||
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||||
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||||
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||||
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||||
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||||
|
||||
|
||||
class UpdateSiteTool(MoviePilotTool):
|
||||
name: str = "update_site"
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
fields_updated = []
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("url"):
|
||||
fields_updated.append("URL")
|
||||
if kwargs.get("pri") is not None:
|
||||
fields_updated.append("优先级")
|
||||
if kwargs.get("cookie"):
|
||||
fields_updated.append("Cookie")
|
||||
if kwargs.get("ua"):
|
||||
fields_updated.append("User-Agent")
|
||||
if kwargs.get("proxy") is not None:
|
||||
fields_updated.append("代理设置")
|
||||
if kwargs.get("is_active") is not None:
|
||||
fields_updated.append("启用状态")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新站点 #{site_id}"
|
||||
|
||||
async def run(self, site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
if ua is not None:
|
||||
site_dict["ua"] = ua
|
||||
if apikey is not None:
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
if filter is not None:
|
||||
site_dict["filter"] = filter
|
||||
if note is not None:
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
if limit_count is not None:
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 更新站点
|
||||
await site.async_update(db, site_dict)
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await Site.async_get(db, site_id)
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": updated_site.domain if updated_site else site.domain
|
||||
})
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys())
|
||||
}
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
"name": updated_site.name,
|
||||
"domain": updated_site.domain,
|
||||
"url": updated_site.url,
|
||||
"pri": updated_site.pri,
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新站点失败: {str(e)}"
|
||||
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
88
app/agent/tools/impl/update_site_cookie.py
Normal file
88
app/agent/tools/impl/update_site_cookie.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""更新站点Cookie和UA工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class UpdateSiteCookieInput(BaseModel):
|
||||
"""更新站点Cookie和UA工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
username: str = Field(..., description="Site login username")
|
||||
password: str = Field(..., description="Site login password")
|
||||
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
|
||||
|
||||
|
||||
class UpdateSiteCookieTool(MoviePilotTool):
|
||||
name: str = "update_site_cookie"
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
args_schema: Type[BaseModel] = UpdateSiteCookieInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
if two_step_code:
|
||||
message += " [需要两步验证]"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_identifier: str, username: str, password: str,
|
||||
two_step_code: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
# 更新站点Cookie和UA
|
||||
status, message = site_chain.update_cookie(
|
||||
site_info=site,
|
||||
username=username,
|
||||
password=password,
|
||||
two_step_code=two_step_code
|
||||
)
|
||||
|
||||
if status:
|
||||
return f"站点【{site.name}】Cookie和UA更新成功\n{message}"
|
||||
else:
|
||||
return f"站点【{site.name}】Cookie和UA更新失败\n错误原因:{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True)
|
||||
return f"更新站点Cookie和UA时发生错误: {str(e)}"
|
||||
|
||||
239
app/agent/tools/impl/update_subscribe.py
Normal file
239
app/agent/tools/impl/update_subscribe.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""更新订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
|
||||
|
||||
class UpdateSubscribeInput(BaseModel):
|
||||
"""更新订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update")
|
||||
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
|
||||
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
|
||||
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
|
||||
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
|
||||
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
|
||||
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
|
||||
|
||||
class UpdateSubscribeTool(MoviePilotTool):
|
||||
name: str = "update_subscribe"
|
||||
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
|
||||
args_schema: Type[BaseModel] = UpdateSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
fields_updated = []
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("total_episode") is not None:
|
||||
fields_updated.append("总集数")
|
||||
if kwargs.get("lack_episode") is not None:
|
||||
fields_updated.append("缺失集数")
|
||||
if kwargs.get("quality"):
|
||||
fields_updated.append("质量过滤")
|
||||
if kwargs.get("resolution"):
|
||||
fields_updated.append("分辨率过滤")
|
||||
if kwargs.get("state"):
|
||||
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
|
||||
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
|
||||
if kwargs.get("sites"):
|
||||
fields_updated.append("站点")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新订阅 #{subscribe_id}"
|
||||
|
||||
async def run(self, subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取订阅
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
if year is not None:
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
if resolution is not None:
|
||||
subscribe_dict["resolution"] = resolution
|
||||
if effect is not None:
|
||||
subscribe_dict["effect"] = effect
|
||||
if include is not None:
|
||||
subscribe_dict["include"] = include
|
||||
if exclude is not None:
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
|
||||
}, ensure_ascii=False)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
if downloader is not None:
|
||||
subscribe_dict["downloader"] = downloader
|
||||
if save_path is not None:
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
if media_category is not None:
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys())
|
||||
}
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
"name": updated_subscribe.name,
|
||||
"year": updated_subscribe.year,
|
||||
"type": updated_subscribe.type,
|
||||
"season": updated_subscribe.season,
|
||||
"state": updated_subscribe.state,
|
||||
"total_episode": updated_subscribe.total_episode,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新订阅失败: {str(e)}"
|
||||
logger.error(f"更新订阅失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
272
app/agent/tools/manager.py
Normal file
272
app/agent/tools/manager.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""MoviePilot工具管理器
|
||||
用于HTTP API调用工具
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent import ConversationMemoryManager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.input_schema = input_schema
|
||||
|
||||
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
"""加载所有MoviePilot工具"""
|
||||
try:
|
||||
# 创建工具实例
|
||||
self.tools = MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None,
|
||||
memory_mananger=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {e}", exc_info=True)
|
||||
self.tools = []
|
||||
|
||||
def list_tools(self) -> List[ToolDefinition]:
|
||||
"""
|
||||
列出所有可用的工具
|
||||
|
||||
Returns:
|
||||
工具定义列表
|
||||
"""
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
# 获取工具的输入参数模型
|
||||
args_schema = getattr(tool, 'args_schema', None)
|
||||
if args_schema:
|
||||
# 将Pydantic模型转换为JSON Schema
|
||||
input_schema = self._convert_to_json_schema(args_schema)
|
||||
else:
|
||||
# 如果没有args_schema,使用基本信息
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
|
||||
tools_list.append(ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema
|
||||
))
|
||||
|
||||
return tools_list
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Any]:
|
||||
"""
|
||||
获取指定工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具实例,如果未找到返回None
|
||||
"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据工具的参数schema规范化参数类型
|
||||
|
||||
Args:
|
||||
tool_instance: 工具实例
|
||||
arguments: 原始参数
|
||||
|
||||
Returns:
|
||||
规范化后的参数
|
||||
"""
|
||||
# 获取工具的参数schema
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
properties = schema.get("properties", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
if key not in properties:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
field_info = properties[key]
|
||||
field_type = field_info.get("type")
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
# 从 anyOf 中提取实际类型
|
||||
for type_option in any_of:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
|
||||
normalized[key] = None
|
||||
elif field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
|
||||
normalized[key] = None
|
||||
elif field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
normalized[key] = value.lower() in ("true", "1", "yes", "on")
|
||||
elif isinstance(value, (int, float)):
|
||||
normalized[key] = value != 0
|
||||
else:
|
||||
normalized[key] = True
|
||||
else:
|
||||
normalized[key] = value
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果(字符串)
|
||||
"""
|
||||
tool_instance = self.get_tool(tool_name)
|
||||
|
||||
if not tool_instance:
|
||||
error_msg = json.dumps({
|
||||
"error": f"工具 '{tool_name}' 未找到"
|
||||
}, ensure_ascii=False)
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
try:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示")
|
||||
formated_result = str(result)
|
||||
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
return error_msg
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将Pydantic模型转换为JSON Schema
|
||||
|
||||
Args:
|
||||
args_schema: Pydantic模型类
|
||||
|
||||
Returns:
|
||||
JSON Schema字典
|
||||
"""
|
||||
# 获取Pydantic模型的字段信息
|
||||
schema = args_schema.model_json_schema()
|
||||
|
||||
# 构建JSON Schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
if "properties" in schema:
|
||||
for field_name, field_info in schema["properties"].items():
|
||||
# 转换字段类型
|
||||
field_type = field_info.get("type", "string")
|
||||
field_description = field_info.get("description", "")
|
||||
|
||||
# 处理可选字段
|
||||
if field_name not in schema.get("required", []):
|
||||
# 可选字段
|
||||
default_value = field_info.get("default")
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
}
|
||||
if default_value is not None:
|
||||
properties[field_name]["default"] = default_value
|
||||
else:
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
}
|
||||
required.append(field_name)
|
||||
|
||||
# 处理枚举类型
|
||||
if "enum" in field_info:
|
||||
properties[field_name]["enum"] = field_info["enum"]
|
||||
|
||||
# 处理数组类型
|
||||
if field_type == "array" and "items" in field_info:
|
||||
properties[field_name]["items"] = field_info["items"]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
@@ -2,11 +2,12 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, monitoring
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
api_router.include_router(user.router, prefix="/user", tags=["user"])
|
||||
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
|
||||
api_router.include_router(site.router, prefix="/site", tags=["site"])
|
||||
api_router.include_router(message.router, prefix="/message", tags=["message"])
|
||||
api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"])
|
||||
@@ -28,4 +29,4 @@ api_router.include_router(discover.router, prefix="/discover", tags=["discover"]
|
||||
api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"])
|
||||
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
|
||||
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
|
||||
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
|
||||
|
||||
@@ -123,7 +123,7 @@ async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
查询下载器信息 API_TOKEN认证(?token=xxx)
|
||||
"""
|
||||
return schedule()
|
||||
return await schedule()
|
||||
|
||||
|
||||
@router.get("/transfer", summary="文件整理统计", response_model=List[int])
|
||||
@@ -137,7 +137,7 @@ async def transfer(days: Optional[int] = 7,
|
||||
return [stat[1] for stat in transfer_stat]
|
||||
|
||||
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=int)
|
||||
@router.get("/cpu", summary="获取当前CPU使用率", response_model=float)
|
||||
def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率
|
||||
@@ -145,7 +145,7 @@ def cpu(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return SystemUtils.cpu_usage()
|
||||
|
||||
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=int)
|
||||
@router.get("/cpu2", summary="获取当前CPU使用率(API_TOKEN)", response_model=float)
|
||||
def cpu2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
|
||||
"""
|
||||
获取当前CPU使用率 API_TOKEN认证(?token=xxx)
|
||||
|
||||
@@ -6,12 +6,13 @@ from app import schemas
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import MediaInfo, Context, TorrentInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.schemas.types import ChainEventType, SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -40,10 +41,10 @@ def download(
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 手动下载始终使用选择的下载器
|
||||
torrentinfo.site_downloader = downloader
|
||||
# 上下文
|
||||
@@ -64,7 +65,10 @@ def download(
|
||||
@router.post("/add", summary="添加下载(不含媒体信息)", response_model=schemas.Response)
|
||||
def add(
|
||||
torrent_in: schemas.TorrentInfo,
|
||||
tmdbid: Annotated[int | None, Body()] = None,
|
||||
doubanid: Annotated[str | None, Body()] = None,
|
||||
downloader: Annotated[str | None, Body()] = None,
|
||||
# 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
save_path: Annotated[str | None, Body()] = None,
|
||||
current_user: User = Depends(get_current_active_user)) -> Any:
|
||||
"""
|
||||
@@ -73,18 +77,23 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo)
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
mediainfo = MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.dict())
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
# 上下文
|
||||
context = Context(
|
||||
meta_info=metainfo,
|
||||
media_info=mediainfo,
|
||||
torrent_info=torrentinfo
|
||||
)
|
||||
|
||||
did = DownloadChain().download_single(context=context, username=current_user.name,
|
||||
downloader=downloader, save_path=save_path, source="Manual")
|
||||
if not did:
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.schemas.types import EventType, MediaType
|
||||
from app.schemas.types import EventType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -70,7 +70,7 @@ async def transfer_history(title: Optional[str] = None,
|
||||
|
||||
return schemas.Response(success=True,
|
||||
data={
|
||||
"list": result,
|
||||
"list": [item.to_dict() for item in result],
|
||||
"total": total,
|
||||
})
|
||||
|
||||
@@ -90,7 +90,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
# 册除媒体库文件
|
||||
if deletedest and history.dest_fileitem:
|
||||
dest_fileitem = schemas.FileItem(**history.dest_fileitem)
|
||||
StorageChain().delete_media_file(fileitem=dest_fileitem, mtype=MediaType(history.type))
|
||||
StorageChain().delete_media_file(dest_fileitem)
|
||||
|
||||
# 删除源文件
|
||||
if deletesrc and history.src_fileitem:
|
||||
|
||||
@@ -8,8 +8,10 @@ from app import schemas
|
||||
from app.chain.user import UserChain
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.wallpaper import WallpaperHelper
|
||||
from app.helper.image import WallpaperHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -27,9 +29,19 @@ def login_access_token(
|
||||
mfa_code=otp_password)
|
||||
|
||||
if not success:
|
||||
# 如果是需要MFA验证,返回特殊标识
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
|
||||
# 用户等级
|
||||
level = SitesHelper().auth_level
|
||||
# 是否显示配置向导
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user_or_message.id,
|
||||
@@ -45,6 +57,7 @@ def login_access_token(
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions=user_or_message.permissions or {},
|
||||
wizard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
369
app/api/endpoints/mcp.py
Normal file
369
app/api/endpoints/mcp.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""工具API端点
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
# 导入版本号
|
||||
try:
|
||||
from version import APP_VERSION
|
||||
except ImportError:
|
||||
APP_VERSION = "unknown"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# MCP 协议版本
|
||||
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
|
||||
def get_tools_manager() -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
return MoviePilotToolsManager()
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 成功响应"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": result
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 错误响应"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message
|
||||
}
|
||||
}
|
||||
if data is not None:
|
||||
error["error"]["data"] = data
|
||||
return error
|
||||
|
||||
|
||||
# ==================== MCP JSON-RPC 端点 ====================
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
MCP 标准 JSON-RPC 2.0 端点
|
||||
|
||||
处理所有 MCP 协议消息(初始化、工具列表、工具调用等)
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
logger.error(f"解析请求体失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(None, -32700, "Parse error", str(e))
|
||||
)
|
||||
|
||||
# 验证 JSON-RPC 格式
|
||||
if not isinstance(body, dict) or body.get("jsonrpc") != "2.0":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(body.get("id"), -32600, "Invalid Request")
|
||||
)
|
||||
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
# 如果有 id,则为请求;没有 id 则为通知
|
||||
is_notification = request_id is None
|
||||
|
||||
try:
|
||||
# 处理初始化请求
|
||||
if method == "initialize":
|
||||
result = await handle_initialize(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理已初始化通知
|
||||
elif method == "notifications/initialized":
|
||||
if is_notification:
|
||||
return Response(status_code=204)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "initialized must be a notification"}
|
||||
)
|
||||
|
||||
# 处理工具列表请求
|
||||
if method == "tools/list":
|
||||
result = await handle_tools_list()
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理工具调用请求
|
||||
elif method == "tools/call":
|
||||
result = await handle_tools_call(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理 ping 请求
|
||||
elif method == "ping":
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, {}))
|
||||
|
||||
# 未知方法
|
||||
else:
|
||||
return JSONResponse(
|
||||
content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}")
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"MCP 请求参数错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MCP 请求失败: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_jsonrpc_error(request_id, -32603, "Internal error", str(e))
|
||||
)
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理初始化请求"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}")
|
||||
|
||||
# 版本协商:选择客户端和服务器都支持的版本
|
||||
negotiated_version = MCP_PROTOCOL_VERSION
|
||||
if protocol_version in MCP_PROTOCOL_VERSIONS:
|
||||
# 客户端版本在支持列表中,使用客户端版本
|
||||
negotiated_version = protocol_version
|
||||
logger.info(f"使用客户端协议版本: {negotiated_version}")
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": False # 暂不支持工具列表变更通知
|
||||
},
|
||||
"logging": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "MoviePilot",
|
||||
"version": APP_VERSION,
|
||||
"description": "MoviePilot MCP Server - 电影自动化管理工具",
|
||||
},
|
||||
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。"
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""处理工具列表请求"""
|
||||
manager = get_tools_manager()
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
for tool in tools:
|
||||
mcp_tool = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
mcp_tools.append(mcp_tool)
|
||||
|
||||
return {
|
||||
"tools": mcp_tools
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理工具调用请求"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
manager = get_tools_manager()
|
||||
|
||||
try:
|
||||
result_text = await manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result_text
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"工具调用失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误: {str(e)}"
|
||||
}
|
||||
],
|
||||
"isError": True
|
||||
}
|
||||
|
||||
|
||||
@router.delete("", summary="终止 MCP 会话", response_model=None)
|
||||
async def delete_mcp_session(
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
终止 MCP 会话(无状态模式下仅返回成功)
|
||||
"""
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
async def list_tools(
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取所有可用的工具列表
|
||||
|
||||
返回每个工具的名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具定义
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
for tool in tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
tools_list.append(tool_dict)
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tools/call", summary="调用工具", response_model=schemas.ToolCallResponse)
|
||||
async def call_tool(
|
||||
request: schemas.ToolCallRequest,
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Any:
|
||||
"""
|
||||
调用指定的工具
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 使用当前用户ID创建管理器实例
|
||||
manager = get_tools_manager()
|
||||
|
||||
# 调用工具
|
||||
result_text = await manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
success=True,
|
||||
result=result_text
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {request.tool_name} 失败: {e}", exc_info=True)
|
||||
return schemas.ToolCallResponse(
|
||||
success=False,
|
||||
error=f"调用工具失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools/{tool_name}", summary="获取工具详情", response_model=Dict[str, Any])
|
||||
async def get_tool_info(
|
||||
tool_name: str,
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定工具的详细信息
|
||||
|
||||
Returns:
|
||||
工具的详细信息,包括名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"工具 '{tool_name}' 未找到")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具信息失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具信息失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tools/{tool_name}/schema", summary="获取工具参数Schema", response_model=Dict[str, Any])
|
||||
async def get_tool_schema(
|
||||
tool_name: str,
|
||||
_: Annotated[str, Depends(verify_apikey)]
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定工具的参数Schema(JSON Schema格式)
|
||||
|
||||
Returns:
|
||||
工具的JSON Schema定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return tool.input_schema
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"工具 '{tool_name}' 未找到")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具Schema失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取工具Schema失败: {str(e)}")
|
||||
@@ -85,25 +85,26 @@ async def search(title: str,
|
||||
return obj.get("source")
|
||||
return obj.source
|
||||
|
||||
result = []
|
||||
media_chain = MediaChain()
|
||||
if type == "media":
|
||||
_, medias = await media_chain.async_search(title=title)
|
||||
if medias:
|
||||
result = [media.to_dict() for media in medias]
|
||||
result = [media.to_dict() for media in medias] if medias else []
|
||||
elif type == "collection":
|
||||
result = await media_chain.async_search_collections(name=title)
|
||||
else:
|
||||
result = await media_chain.async_search_persons(name=title)
|
||||
if result:
|
||||
# 按设置的顺序对结果进行排序
|
||||
setting_order = settings.SEARCH_SOURCE.split(',') or []
|
||||
sort_order = {}
|
||||
for index, source in enumerate(setting_order):
|
||||
sort_order[source] = index
|
||||
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
|
||||
return result[(page - 1) * count:page * count]
|
||||
return []
|
||||
collections = await media_chain.async_search_collections(name=title)
|
||||
result = [collection.to_dict() for collection in collections] if collections else []
|
||||
else: # person
|
||||
persons = await media_chain.async_search_persons(name=title)
|
||||
result = [person.model_dump() for person in persons] if persons else []
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# 排序和分页
|
||||
setting_order = settings.SEARCH_SOURCE.split(',') if settings.SEARCH_SOURCE else []
|
||||
sort_order = {source: index for index, source in enumerate(setting_order)}
|
||||
|
||||
sorted_result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
|
||||
return sorted_result[(page - 1) * count:page * count]
|
||||
|
||||
|
||||
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)
|
||||
|
||||
@@ -79,10 +79,10 @@ def exists(media_in: schemas.MediaInfo,
|
||||
"""
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
if not existsinfo:
|
||||
return []
|
||||
return {}
|
||||
if media_in.season:
|
||||
return {
|
||||
media_in.season: existsinfo.seasons.get(media_in.season) or []
|
||||
@@ -108,7 +108,7 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
meta.year = media_in.year
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.dict())
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if mediainfo.type == MediaType.MOVIE:
|
||||
|
||||
@@ -132,7 +132,7 @@ async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload
|
||||
"""
|
||||
客户端webpush通知订阅
|
||||
"""
|
||||
subinfo = subscription.dict()
|
||||
subinfo = subscription.model_dump()
|
||||
if subinfo not in global_vars.get_subscriptions():
|
||||
global_vars.push_subscription(subinfo)
|
||||
logger.debug(f"通知订阅成功: {subinfo}")
|
||||
@@ -148,7 +148,7 @@ def send_notification(payload: schemas.SubscriptionMessage, _: schemas.TokenPayl
|
||||
try:
|
||||
webpush(
|
||||
subscription_info=sub,
|
||||
data=json.dumps(payload.dict()),
|
||||
data=json.dumps(payload.model_dump()),
|
||||
vapid_private_key=settings.VAPID.get("privateKey"),
|
||||
vapid_claims={
|
||||
"sub": settings.VAPID.get("subject")
|
||||
|
||||
463
app/api/endpoints/mfa.py
Normal file
463
app/api/endpoints/mfa.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""
|
||||
MFA (Multi-Factor Authentication) API 端点
|
||||
包含 OTP 和 PassKey 相关功能
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db import get_async_db
|
||||
from app.db.models.passkey import PassKey
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_user_async
|
||||
from app.helper.passkey import PassKeyHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
"""OTP验证请求"""
|
||||
uri: str
|
||||
otpPassword: str
|
||||
|
||||
class OtpDisableRequest(schemas.BaseModel):
|
||||
"""OTP禁用请求"""
|
||||
password: str
|
||||
|
||||
class PassKeyDeleteRequest(schemas.BaseModel):
|
||||
"""PassKey删除请求"""
|
||||
passkey_id: int
|
||||
password: str
|
||||
|
||||
# ==================== 通用 MFA 接口 ====================
|
||||
|
||||
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
|
||||
async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
检查指定用户是否启用了任何双重验证方式(OTP 或 PassKey)
|
||||
"""
|
||||
user: User = await User.async_get_by_name(db, username)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
|
||||
# 检查是否启用了OTP
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id))
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
|
||||
|
||||
# ==================== OTP 相关接口 ====================
|
||||
|
||||
@router.post('/otp/generate', summary='生成 OTP 验证 URI', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""生成 OTP 密钥及对应的 URI"""
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
|
||||
async def otp_verify(
|
||||
data: OtpVerifyRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
|
||||
if not OtpUtils.is_legal(data.uri, data.otpPassword):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
data: OtpDisableRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,不允许关闭 OTP
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id))
|
||||
if has_passkey:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
# ==================== PassKey 相关接口 ====================
|
||||
|
||||
class PassKeyRegistrationStart(schemas.BaseModel):
|
||||
"""PassKey注册开始请求"""
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyRegistrationFinish(schemas.BaseModel):
|
||||
"""PassKey注册完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyAuthenticationStart(schemas.BaseModel):
|
||||
"""PassKey认证开始请求"""
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class PassKeyAuthenticationFinish(schemas.BaseModel):
|
||||
"""PassKey认证完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
|
||||
|
||||
@router.post("/passkey/register/start", summary="开始注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_start(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""开始注册 PassKey - 生成注册选项"""
|
||||
try:
|
||||
# 安全检查:必须先启用 OTP
|
||||
if not current_user.is_otp:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥"
|
||||
)
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
] if existing_passkeys else None
|
||||
|
||||
# 生成注册选项
|
||||
options_json, challenge = PassKeyHelper.generate_registration_options(
|
||||
user_id=current_user.id,
|
||||
username=current_user.name,
|
||||
display_name=current_user.settings.get('nickname') if current_user.settings else None,
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey注册选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成注册选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/register/finish", summary="完成注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_finish(
|
||||
passkey_req: PassKeyRegistrationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""完成注册 PassKey - 验证并保存凭证"""
|
||||
try:
|
||||
# 验证注册响应
|
||||
credential_id, public_key, sign_count, aaguid = PassKeyHelper.verify_registration_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge
|
||||
)
|
||||
|
||||
# 提取transports
|
||||
transports = None
|
||||
if 'response' in passkey_req.credential and 'transports' in passkey_req.credential['response']:
|
||||
transports = ','.join(passkey_req.credential['response']['transports'])
|
||||
|
||||
# 保存到数据库
|
||||
passkey = PassKey(
|
||||
user_id=current_user.id,
|
||||
credential_id=credential_id,
|
||||
public_key=public_key,
|
||||
sign_count=sign_count,
|
||||
name=passkey_req.name or "通行密钥",
|
||||
aaguid=aaguid,
|
||||
transports=transports
|
||||
)
|
||||
passkey.create()
|
||||
|
||||
logger.info(f"用户 {current_user.name} 成功注册PassKey: {passkey_req.name}")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥注册成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"注册PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"注册失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/start", summary="开始 PassKey 认证", response_model=schemas.Response)
|
||||
def passkey_authenticate_start(
|
||||
passkey_req: PassKeyAuthenticationStart = Body(...)
|
||||
) -> Any:
|
||||
"""开始 PassKey 认证 - 生成认证选项"""
|
||||
try:
|
||||
existing_credentials = None
|
||||
|
||||
# 如果指定了用户名,只允许该用户的PassKey
|
||||
if passkey_req.username:
|
||||
user = User.get_by_name(db=None, name=passkey_req.username)
|
||||
if not user:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="用户不存在"
|
||||
)
|
||||
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id)
|
||||
if not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="该用户未注册通行密钥"
|
||||
)
|
||||
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
]
|
||||
|
||||
# 生成认证选项
|
||||
options_json, challenge = PassKeyHelper.generate_authentication_options(
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey认证选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成认证选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/finish", summary="完成 PassKey 认证", response_model=schemas.Token)
|
||||
def passkey_authenticate_finish(
|
||||
passkey_req: PassKeyAuthenticationFinish
|
||||
) -> Any:
|
||||
"""完成 PassKey 认证 - 验证凭证并返回 token"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise HTTPException(status_code=400, detail="无效的凭证")
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey:
|
||||
raise HTTPException(status_code=401, detail="通行密钥不存在或已失效")
|
||||
|
||||
# 获取用户
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="用户不存在或已禁用")
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail="通行密钥验证失败")
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
# 生成token
|
||||
level = SitesHelper().auth_level
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
level=level
|
||||
),
|
||||
token_type="bearer",
|
||||
super_user=user.is_superuser,
|
||||
user_id=user.id,
|
||||
user_name=user.name,
|
||||
avatar=user.avatar,
|
||||
level=level,
|
||||
permissions=user.permissions or {},
|
||||
wizard=show_wizard
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey认证失败: {e}")
|
||||
raise HTTPException(status_code=401, detail=f"认证失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
|
||||
def passkey_list(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""获取当前用户的所有 PassKey"""
|
||||
try:
|
||||
passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
|
||||
key_list = [
|
||||
{
|
||||
'id': pk.id,
|
||||
'name': pk.name,
|
||||
'created_at': pk.created_at.isoformat() if pk.created_at else None,
|
||||
'last_used_at': pk.last_used_at.isoformat() if pk.last_used_at else None,
|
||||
'aaguid': pk.aaguid,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=key_list
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取PassKey列表失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"获取列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
|
||||
async def passkey_delete(
|
||||
data: PassKeyDeleteRequest,
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""删除指定的 PassKey"""
|
||||
try:
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
|
||||
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥已删除"
|
||||
)
|
||||
else:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或无权删除"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"删除PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"删除失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/verify", summary="PassKey 二次验证", response_model=schemas.Response)
|
||||
def passkey_verify_mfa(
|
||||
passkey_req: PassKeyAuthenticationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""使用 PassKey 进行二次验证(MFA)"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="无效的凭证"
|
||||
)
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey or passkey.user_id != current_user.id:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或不属于当前用户"
|
||||
)
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if not success:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥验证失败"
|
||||
)
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="二次验证成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey二次验证失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"验证失败: {str(e)}"
|
||||
)
|
||||
@@ -1,409 +0,0 @@
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app import schemas
|
||||
from app.core.security import verify_apitoken
|
||||
from app.monitoring import monitor, get_metrics_response
|
||||
from app.schemas.monitoring import (
|
||||
PerformanceSnapshot,
|
||||
EndpointStats,
|
||||
ErrorRequest,
|
||||
MonitoringOverview
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/overview", summary="获取监控概览", response_model=schemas.MonitoringOverview)
|
||||
def get_overview(_: str = Depends(verify_apitoken)) -> Any:
|
||||
"""
|
||||
获取完整的监控概览信息
|
||||
"""
|
||||
# 获取性能快照
|
||||
performance = monitor.get_performance_snapshot()
|
||||
|
||||
# 获取最活跃端点
|
||||
top_endpoints = monitor.get_top_endpoints(limit=10)
|
||||
|
||||
# 获取最近错误
|
||||
recent_errors = monitor.get_recent_errors(limit=20)
|
||||
|
||||
# 检查告警
|
||||
alerts = monitor.check_alerts()
|
||||
|
||||
return MonitoringOverview(
|
||||
performance=PerformanceSnapshot(
|
||||
timestamp=performance.timestamp,
|
||||
cpu_usage=performance.cpu_usage,
|
||||
memory_usage=performance.memory_usage,
|
||||
active_requests=performance.active_requests,
|
||||
request_rate=performance.request_rate,
|
||||
avg_response_time=performance.avg_response_time,
|
||||
error_rate=performance.error_rate,
|
||||
slow_requests=performance.slow_requests
|
||||
),
|
||||
top_endpoints=[EndpointStats(**endpoint) for endpoint in top_endpoints],
|
||||
recent_errors=[ErrorRequest(**error) for error in recent_errors],
|
||||
alerts=alerts
|
||||
)
|
||||
|
||||
|
||||
@router.get("/performance", summary="获取性能快照", response_model=schemas.PerformanceSnapshot)
|
||||
def get_performance(_: str = Depends(verify_apitoken)) -> Any:
|
||||
"""
|
||||
获取当前性能快照
|
||||
"""
|
||||
snapshot = monitor.get_performance_snapshot()
|
||||
return PerformanceSnapshot(
|
||||
timestamp=snapshot.timestamp,
|
||||
cpu_usage=snapshot.cpu_usage,
|
||||
memory_usage=snapshot.memory_usage,
|
||||
active_requests=snapshot.active_requests,
|
||||
request_rate=snapshot.request_rate,
|
||||
avg_response_time=snapshot.avg_response_time,
|
||||
error_rate=snapshot.error_rate,
|
||||
slow_requests=snapshot.slow_requests
|
||||
)
|
||||
|
||||
|
||||
@router.get("/endpoints", summary="获取端点统计", response_model=List[schemas.EndpointStats])
|
||||
def get_endpoints(
|
||||
limit: int = Query(10, ge=1, le=50, description="返回的端点数量"),
|
||||
_: str = Depends(verify_apitoken)
|
||||
) -> Any:
|
||||
"""
|
||||
获取最活跃的API端点统计
|
||||
"""
|
||||
endpoints = monitor.get_top_endpoints(limit=limit)
|
||||
return [EndpointStats(**endpoint) for endpoint in endpoints]
|
||||
|
||||
|
||||
@router.get("/errors", summary="获取错误请求", response_model=List[schemas.ErrorRequest])
|
||||
def get_errors(
|
||||
limit: int = Query(20, ge=1, le=100, description="返回的错误数量"),
|
||||
_: str = Depends(verify_apitoken)
|
||||
) -> Any:
|
||||
"""
|
||||
获取最近的错误请求记录
|
||||
"""
|
||||
errors = monitor.get_recent_errors(limit=limit)
|
||||
return [ErrorRequest(**error) for error in errors]
|
||||
|
||||
|
||||
@router.get("/alerts", summary="获取告警信息", response_model=List[str])
|
||||
def get_alerts(_: str = Depends(verify_apitoken)) -> Any:
|
||||
"""
|
||||
获取当前告警信息
|
||||
"""
|
||||
return monitor.check_alerts()
|
||||
|
||||
|
||||
@router.get("/metrics", summary="Prometheus指标")
|
||||
def get_prometheus_metrics(_: str = Depends(verify_apitoken)) -> Any:
|
||||
"""
|
||||
获取Prometheus格式的监控指标
|
||||
"""
|
||||
return get_metrics_response()
|
||||
|
||||
|
||||
@router.get("/dashboard", summary="监控仪表板", response_class=HTMLResponse)
|
||||
def get_dashboard(_: str = Depends(verify_apitoken)) -> Any:
|
||||
"""
|
||||
获取实时监控仪表板HTML页面
|
||||
"""
|
||||
return HTMLResponse(content="""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>MoviePilot 性能监控仪表板</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
color: #333;
|
||||
}
|
||||
.metrics-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
|
||||
gap: 20px;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.metric-card {
|
||||
background: white;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
text-align: center;
|
||||
}
|
||||
.metric-value {
|
||||
font-size: 2em;
|
||||
font-weight: bold;
|
||||
color: #2196F3;
|
||||
}
|
||||
.metric-label {
|
||||
color: #666;
|
||||
margin-top: 5px;
|
||||
}
|
||||
.chart-container {
|
||||
background: white;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.alerts {
|
||||
background: #fff3cd;
|
||||
border: 1px solid #ffeaa7;
|
||||
border-radius: 5px;
|
||||
padding: 15px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.alert-item {
|
||||
color: #856404;
|
||||
margin: 5px 0;
|
||||
}
|
||||
.refresh-btn {
|
||||
background: #2196F3;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 10px 20px;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.refresh-btn:hover {
|
||||
background: #1976D2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>🎬 MoviePilot 性能监控仪表板</h1>
|
||||
<button class="refresh-btn" onclick="refreshData()">刷新数据</button>
|
||||
</div>
|
||||
|
||||
<div id="alerts" class="alerts" style="display: none;">
|
||||
<h3>⚠️ 告警信息</h3>
|
||||
<div id="alerts-list"></div>
|
||||
</div>
|
||||
|
||||
<div class="metrics-grid">
|
||||
<div class="metric-card">
|
||||
<div class="metric-value" id="cpu-usage">--</div>
|
||||
<div class="metric-label">CPU使用率 (%)</div>
|
||||
</div>
|
||||
<div class="metric-card">
|
||||
<div class="metric-value" id="memory-usage">--</div>
|
||||
<div class="metric-label">内存使用率 (%)</div>
|
||||
</div>
|
||||
<div class="metric-card">
|
||||
<div class="metric-value" id="active-requests">--</div>
|
||||
<div class="metric-label">活跃请求数</div>
|
||||
</div>
|
||||
<div class="metric-card">
|
||||
<div class="metric-value" id="request-rate">--</div>
|
||||
<div class="metric-label">请求率 (req/min)</div>
|
||||
</div>
|
||||
<div class="metric-card">
|
||||
<div class="metric-value" id="avg-response-time">--</div>
|
||||
<div class="metric-label">平均响应时间 (s)</div>
|
||||
</div>
|
||||
<div class="metric-card">
|
||||
<div class="metric-value" id="error-rate">--</div>
|
||||
<div class="metric-label">错误率 (%)</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="chart-container">
|
||||
<h3>📊 性能趋势</h3>
|
||||
<canvas id="performanceChart" width="400" height="200"></canvas>
|
||||
</div>
|
||||
|
||||
<div class="chart-container">
|
||||
<h3>🔥 最活跃端点</h3>
|
||||
<canvas id="endpointsChart" width="400" height="200"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let performanceChart, endpointsChart;
|
||||
let performanceData = {
|
||||
labels: [],
|
||||
cpu: [],
|
||||
memory: [],
|
||||
requests: []
|
||||
};
|
||||
|
||||
// 初始化图表
|
||||
function initCharts() {
|
||||
const ctx1 = document.getElementById('performanceChart').getContext('2d');
|
||||
performanceChart = new Chart(ctx1, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: performanceData.labels,
|
||||
datasets: [{
|
||||
label: 'CPU使用率 (%)',
|
||||
data: performanceData.cpu,
|
||||
borderColor: '#2196F3',
|
||||
backgroundColor: 'rgba(33, 150, 243, 0.1)',
|
||||
tension: 0.4
|
||||
}, {
|
||||
label: '内存使用率 (%)',
|
||||
data: performanceData.memory,
|
||||
borderColor: '#4CAF50',
|
||||
backgroundColor: 'rgba(76, 175, 80, 0.1)',
|
||||
tension: 0.4
|
||||
}, {
|
||||
label: '活跃请求数',
|
||||
data: performanceData.requests,
|
||||
borderColor: '#FF9800',
|
||||
backgroundColor: 'rgba(255, 152, 0, 0.1)',
|
||||
tension: 0.4
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
scales: {
|
||||
y: {
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
const ctx2 = document.getElementById('endpointsChart').getContext('2d');
|
||||
endpointsChart = new Chart(ctx2, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: [],
|
||||
datasets: [{
|
||||
label: '请求数',
|
||||
data: [],
|
||||
backgroundColor: 'rgba(33, 150, 243, 0.8)'
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
scales: {
|
||||
y: {
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 更新性能数据
|
||||
function updatePerformanceData(data) {
|
||||
const now = new Date().toLocaleTimeString();
|
||||
|
||||
performanceData.labels.push(now);
|
||||
performanceData.cpu.push(data.performance.cpu_usage);
|
||||
performanceData.memory.push(data.performance.memory_usage);
|
||||
performanceData.requests.push(data.performance.active_requests);
|
||||
|
||||
// 保持最近20个数据点
|
||||
if (performanceData.labels.length > 20) {
|
||||
performanceData.labels.shift();
|
||||
performanceData.cpu.shift();
|
||||
performanceData.memory.shift();
|
||||
performanceData.requests.shift();
|
||||
}
|
||||
|
||||
// 更新图表
|
||||
performanceChart.data.labels = performanceData.labels;
|
||||
performanceChart.data.datasets[0].data = performanceData.cpu;
|
||||
performanceChart.data.datasets[1].data = performanceData.memory;
|
||||
performanceChart.data.datasets[2].data = performanceData.requests;
|
||||
performanceChart.update();
|
||||
|
||||
// 更新端点图表
|
||||
const endpointLabels = data.top_endpoints.map(e => e.endpoint.substring(0, 20));
|
||||
const endpointData = data.top_endpoints.map(e => e.count);
|
||||
|
||||
endpointsChart.data.labels = endpointLabels;
|
||||
endpointsChart.data.datasets[0].data = endpointData;
|
||||
endpointsChart.update();
|
||||
}
|
||||
|
||||
// 更新指标显示
|
||||
function updateMetrics(data) {
|
||||
document.getElementById('cpu-usage').textContent = data.performance.cpu_usage.toFixed(1);
|
||||
document.getElementById('memory-usage').textContent = data.performance.memory_usage.toFixed(1);
|
||||
document.getElementById('active-requests').textContent = data.performance.active_requests;
|
||||
document.getElementById('request-rate').textContent = data.performance.request_rate.toFixed(0);
|
||||
document.getElementById('avg-response-time').textContent = data.performance.avg_response_time.toFixed(3);
|
||||
document.getElementById('error-rate').textContent = (data.performance.error_rate * 100).toFixed(2);
|
||||
}
|
||||
|
||||
// 更新告警
|
||||
function updateAlerts(alerts) {
|
||||
const alertsDiv = document.getElementById('alerts');
|
||||
const alertsList = document.getElementById('alerts-list');
|
||||
|
||||
if (alerts.length > 0) {
|
||||
alertsDiv.style.display = 'block';
|
||||
alertsList.innerHTML = alerts.map(alert =>
|
||||
`<div class="alert-item">⚠️ ${alert}</div>`
|
||||
).join('');
|
||||
} else {
|
||||
alertsDiv.style.display = 'none';
|
||||
}
|
||||
}
|
||||
|
||||
// 获取URL中的token参数
|
||||
function getTokenFromUrl() {
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
return urlParams.get('token');
|
||||
}
|
||||
|
||||
// 刷新数据
|
||||
async function refreshData() {
|
||||
try {
|
||||
const token = getTokenFromUrl();
|
||||
if (!token) {
|
||||
console.error('未找到token参数');
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(`/api/v1/monitoring/overview?token=${token}`);
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
updateMetrics(data);
|
||||
updatePerformanceData(data);
|
||||
updateAlerts(data.alerts);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取监控数据失败:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// 页面加载完成后初始化
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
initCharts();
|
||||
refreshData();
|
||||
|
||||
// 每5秒自动刷新
|
||||
setInterval(refreshData, 5000);
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
@@ -463,6 +463,36 @@ async def update_folder_plugins(folder_name: str, plugin_ids: List[str],
|
||||
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新")
|
||||
|
||||
|
||||
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
|
||||
def clone_plugin(plugin_id: str,
|
||||
clone_data: dict,
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
创建插件分身
|
||||
"""
|
||||
try:
|
||||
success, message = PluginManager().clone_plugin(
|
||||
plugin_id=plugin_id,
|
||||
suffix=clone_data.get("suffix", ""),
|
||||
name=clone_data.get("name", ""),
|
||||
description=clone_data.get("description", ""),
|
||||
version=clone_data.get("version", ""),
|
||||
icon=clone_data.get("icon", "")
|
||||
)
|
||||
|
||||
if success:
|
||||
# 注册插件服务
|
||||
reload_plugin(message)
|
||||
# 将分身插件添加到原插件所在的文件夹中
|
||||
_add_clone_to_plugin_folder(plugin_id, message)
|
||||
return schemas.Response(success=True, message="插件分身创建成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message=message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件分身失败:{str(e)}")
|
||||
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", summary="获取插件配置")
|
||||
async def plugin_config(plugin_id: str,
|
||||
_: User = Depends(get_current_active_superuser_async)) -> dict:
|
||||
@@ -528,36 +558,6 @@ def uninstall_plugin(plugin_id: str,
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
|
||||
def clone_plugin(plugin_id: str,
|
||||
clone_data: dict,
|
||||
_: User = Depends(get_current_active_superuser)) -> Any:
|
||||
"""
|
||||
创建插件分身
|
||||
"""
|
||||
try:
|
||||
success, message = PluginManager().clone_plugin(
|
||||
plugin_id=plugin_id,
|
||||
suffix=clone_data.get("suffix", ""),
|
||||
name=clone_data.get("name", ""),
|
||||
description=clone_data.get("description", ""),
|
||||
version=clone_data.get("version", ""),
|
||||
icon=clone_data.get("icon", "")
|
||||
)
|
||||
|
||||
if success:
|
||||
# 注册插件服务
|
||||
reload_plugin(message)
|
||||
# 将分身插件添加到原插件所在的文件夹中
|
||||
_add_clone_to_plugin_folder(plugin_id, message)
|
||||
return schemas.Response(success=True, message="插件分身创建成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message=message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件分身失败:{str(e)}")
|
||||
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
|
||||
|
||||
|
||||
def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
|
||||
"""
|
||||
将分身插件添加到原插件所在的文件夹中
|
||||
|
||||
@@ -20,7 +20,7 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询搜索结果
|
||||
"""
|
||||
torrents = await SearchChain().async_last_search_results()
|
||||
torrents = await SearchChain().async_last_search_results() or []
|
||||
return [torrent.to_dict() for torrent in torrents]
|
||||
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ async def add_site(
|
||||
site_in.name = site_info.get("name")
|
||||
site_in.id = None
|
||||
site_in.public = 1 if site_info.get("public") else 0
|
||||
site = Site(**site_in.dict())
|
||||
site = Site(**site_in.model_dump())
|
||||
site.create(db)
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
@@ -92,7 +92,7 @@ async def update_site(
|
||||
# 校正地址格式
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
|
||||
site_in.url = f"{_scheme}://{_netloc}/"
|
||||
await site.async_update(db, site_in.dict())
|
||||
await site.async_update(db, site_in.model_dump())
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": site_in.domain
|
||||
@@ -219,10 +219,10 @@ async def read_userdata(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
|
||||
if not user_data:
|
||||
user_datas = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
|
||||
if not user_datas:
|
||||
return schemas.Response(success=False, data=[])
|
||||
return schemas.Response(success=True, data=user_data)
|
||||
return schemas.Response(success=True, data=[data.to_dict() for data in user_datas])
|
||||
|
||||
|
||||
@router.get("/test/{site_id}", summary="连接测试", response_model=schemas.Response)
|
||||
@@ -399,7 +399,7 @@ def auth_site(
|
||||
if not auth_info or not auth_info.site or not auth_info.params:
|
||||
return schemas.Response(success=False, message="请输入认证站点和认证参数")
|
||||
status, msg = SitesHelper().check_user(auth_info.site, auth_info.params)
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.dict())
|
||||
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.model_dump())
|
||||
# 认证成功后,重新初始化插件
|
||||
PluginManager().init_config()
|
||||
Scheduler().init_plugin_jobs()
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.db.models import User
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.schemas.types import ProgressKey
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -80,7 +81,7 @@ def list_files(fileitem: schemas.FileItem,
|
||||
file_list = StorageChain().list_files(fileitem)
|
||||
if file_list:
|
||||
if sort == "name":
|
||||
file_list.sort(key=lambda x: x.name or "")
|
||||
file_list.sort(key=lambda x: StringUtils.natural_sort_key(x.name or ""))
|
||||
else:
|
||||
file_list.sort(key=lambda x: x.modify_time or datetime.min, reverse=True)
|
||||
return file_list
|
||||
@@ -171,15 +172,14 @@ def rename(fileitem: schemas.FileItem,
|
||||
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
|
||||
if sub_files:
|
||||
# 开始进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.BatchRename)
|
||||
progress = ProgressHelper(ProgressKey.BatchRename)
|
||||
progress.start()
|
||||
total = len(sub_files)
|
||||
handled = 0
|
||||
for sub_file in sub_files:
|
||||
handled += 1
|
||||
progress.update(value=handled / total * 100,
|
||||
text=f"正在处理 {sub_file.name} ...",
|
||||
key=ProgressKey.BatchRename)
|
||||
text=f"正在处理 {sub_file.name} ...")
|
||||
if sub_file.type == "dir":
|
||||
continue
|
||||
if not sub_file.extension:
|
||||
@@ -190,19 +190,19 @@ def rename(fileitem: schemas.FileItem,
|
||||
meta = MetaInfoPath(sub_path)
|
||||
mediainfo = transferchain.recognize_media(meta)
|
||||
if not mediainfo:
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到媒体信息")
|
||||
new_path = transferchain.recommend_name(meta=meta, mediainfo=mediainfo)
|
||||
if not new_path:
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到新名称")
|
||||
ret: schemas.Response = rename(fileitem=sub_file,
|
||||
new_name=Path(new_path).name,
|
||||
recursive=False)
|
||||
if not ret.success:
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
return schemas.Response(success=False, message=f"{sub_path.name} 重命名失败!")
|
||||
progress.end(ProgressKey.BatchRename)
|
||||
progress.end()
|
||||
# 重命名自己
|
||||
result = StorageChain().rename_file(fileitem, new_name)
|
||||
if result:
|
||||
|
||||
@@ -79,7 +79,7 @@ async def create_subscribe(
|
||||
# 订阅用户
|
||||
subscribe_in.username = current_user.name
|
||||
# 转化为字典
|
||||
subscribe_dict = subscribe_in.dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if subscribe_in.id:
|
||||
subscribe_dict.pop("id", None)
|
||||
sid, message = await SubscribeChain().async_add(mtype=mtype,
|
||||
@@ -106,7 +106,7 @@ async def update_subscribe(
|
||||
return schemas.Response(success=False, message="订阅不存在")
|
||||
# 避免更新缺失集数
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
subscribe_dict = subscribe_in.dict()
|
||||
subscribe_dict = subscribe_in.model_dump()
|
||||
if not subscribe_in.lack_episode:
|
||||
# 没有缺失集数时,缺失集数清空,避免更新为0
|
||||
subscribe_dict.pop("lack_episode")
|
||||
@@ -421,11 +421,23 @@ async def popular_subscribes(
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
min_sub: Optional[int] = None,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询热门订阅
|
||||
"""
|
||||
subscribes = await SubscribeHelper().async_get_statistic(stype=stype, page=page, count=count)
|
||||
subscribes = await SubscribeHelper().async_get_statistic(
|
||||
stype=stype,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
if subscribes:
|
||||
ret_medias = []
|
||||
for sub in subscribes:
|
||||
@@ -517,7 +529,7 @@ async def subscribe_fork(
|
||||
"""
|
||||
复用订阅
|
||||
"""
|
||||
sub_dict = sub.dict()
|
||||
sub_dict = sub.model_dump()
|
||||
sub_dict.pop("id")
|
||||
for key in list(sub_dict.keys()):
|
||||
if not hasattr(schemas.Subscribe(), key):
|
||||
@@ -570,11 +582,23 @@ async def popular_subscribes(
|
||||
name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
genre_id: Optional[int] = None,
|
||||
min_rating: Optional[float] = None,
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询分享的订阅
|
||||
"""
|
||||
return await SubscribeHelper().async_get_shares(name=name, page=page, count=count)
|
||||
return await SubscribeHelper().async_get_shares(
|
||||
name=name,
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
min_rating=min_rating,
|
||||
max_rating=max_rating,
|
||||
sort_type=sort_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics])
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from collections import deque
|
||||
@@ -8,13 +7,13 @@ from typing import Optional, Union, Annotated
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.config import global_vars, settings
|
||||
@@ -26,12 +25,14 @@ from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.helper.system import SystemHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas import ConfigChangeEventData
|
||||
@@ -47,101 +48,43 @@ router = APIRouter()
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: bool = False,
|
||||
use_disk_cache: bool = False,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Response:
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
"""
|
||||
if not url:
|
||||
raise HTTPException(status_code=404, detail="URL not provided")
|
||||
return None
|
||||
|
||||
if allowed_domains is None:
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS)
|
||||
|
||||
# 验证URL安全性
|
||||
if not SecurityUtils.is_safe_url(url, allowed_domains):
|
||||
raise HTTPException(status_code=404, detail="Unsafe URL")
|
||||
logger.warn(f"Blocked unsafe image URL: {url}")
|
||||
return None
|
||||
|
||||
# 后续观察系统性能表现,如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求,可以考虑重新引入内存缓存
|
||||
cache_path: Optional[AsyncPath] = None
|
||||
if use_disk_cache:
|
||||
# 生成缓存路径
|
||||
base_path = AsyncPath(settings.CACHE_PATH)
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = base_path / "images" / sanitized_path
|
||||
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 确保缓存路径和文件类型合法
|
||||
if not await SecurityUtils.async_is_safe_path(base_path=base_path,
|
||||
user_path=cache_path,
|
||||
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
|
||||
raise HTTPException(status_code=400, detail="Invalid cache path or file type")
|
||||
|
||||
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
|
||||
if cache_path and await cache_path.exists():
|
||||
try:
|
||||
async with aiofiles.open(cache_path, 'rb') as f:
|
||||
content = await f.read()
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
return Response(content=content, media_type="image/jpeg", headers=headers)
|
||||
except Exception as e:
|
||||
# 如果读取磁盘缓存发生异常,这里仅记录日志,尝试再次请求远端进行处理
|
||||
logger.debug(f"Failed to read cache file {cache_path}: {e}")
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if proxy else None
|
||||
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer,
|
||||
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
|
||||
if not response:
|
||||
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
content = response.content
|
||||
Image.open(io.BytesIO(content)).verify()
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
raise HTTPException(status_code=502, detail="Invalid image format")
|
||||
|
||||
response_headers = response.headers
|
||||
|
||||
cache_control_header = response_headers.get("Cache-Control", "")
|
||||
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
|
||||
|
||||
# 如果需要使用磁盘缓存,则保存到磁盘
|
||||
if use_disk_cache and cache_path:
|
||||
try:
|
||||
if not await cache_path.parent.exists():
|
||||
await cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
|
||||
await tmp_file.write(content)
|
||||
temp_path = AsyncPath(tmp_file.name)
|
||||
await temp_path.replace(cache_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to write cache file {cache_path}: {e}")
|
||||
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
if if_none_match == etag:
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
return Response(status_code=304, headers=headers)
|
||||
|
||||
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
content = await ImageHelper().async_fetch_image(
|
||||
url=url,
|
||||
proxy=proxy,
|
||||
use_cache=use_cache,
|
||||
cookies=cookies,
|
||||
)
|
||||
if content:
|
||||
# 检查 If-None-Match
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
if if_none_match == etag:
|
||||
return Response(status_code=304, headers=headers)
|
||||
# 返回缓存图片
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
@@ -149,6 +92,7 @@ async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
) -> Response:
|
||||
@@ -159,7 +103,12 @@ async def proxy_img(
|
||||
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
|
||||
config and config.config and config.config.get("host")]
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_disk_cache=cache,
|
||||
cookies = (
|
||||
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
|
||||
if use_cookies
|
||||
else None
|
||||
)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
|
||||
|
||||
@@ -173,8 +122,7 @@ async def cache_img(
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
proxy = "doubanio.com" not in url
|
||||
return await fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
|
||||
|
||||
@@ -186,18 +134,24 @@ def get_global_setting(token: str):
|
||||
if token != "moviepilot":
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
info = settings.dict(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
# 白名单模式,仅包含前端业务逻辑必需的字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"TMDB_IMAGE_DOMAIN",
|
||||
"GLOBAL_IMAGE_CACHE",
|
||||
"ADVANCED_MODE",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE"
|
||||
}
|
||||
)
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
@@ -208,7 +162,7 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统环境变量,包括当前版本号(仅管理员)
|
||||
"""
|
||||
info = settings.dict(
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
|
||||
)
|
||||
info.update({
|
||||
@@ -243,13 +197,11 @@ async def set_env_setting(env: dict,
|
||||
)
|
||||
|
||||
if success_updates:
|
||||
for key in success_updates.keys():
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=getattr(settings, key, None),
|
||||
change_type="update"
|
||||
))
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=success_updates.keys(),
|
||||
change_type="update"
|
||||
))
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
@@ -265,14 +217,14 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
|
||||
"""
|
||||
实时获取处理进度,返回格式为SSE
|
||||
"""
|
||||
progress = ProgressHelper()
|
||||
progress = ProgressHelper(process_type)
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
while not global_vars.is_system_stopped:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
detail = progress.get(process_type)
|
||||
detail = progress.get()
|
||||
yield f"data: {json.dumps(detail)}\n\n"
|
||||
await asyncio.sleep(0.5)
|
||||
except asyncio.CancelledError:
|
||||
@@ -334,6 +286,18 @@ async def set_setting(
|
||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(provider: str, api_key: str, base_url: Optional[str] = None, _: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
"""
|
||||
try:
|
||||
models = LLMHelper().get_models(provider, api_key, base_url)
|
||||
return schemas.Response(success=True, data=models)
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(request: Request, role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
|
||||
@@ -135,8 +135,8 @@ def refresh_cache(_: User = Depends(get_current_active_superuser)):
|
||||
|
||||
@router.post("/cache/reidentify/{domain}/{torrent_hash}", summary="重新识别种子", response_model=schemas.Response)
|
||||
async def reidentify_cache(domain: str, torrent_hash: str,
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_superuser_async)):
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
重新识别指定的种子
|
||||
:param domain: 站点域名
|
||||
|
||||
@@ -8,7 +8,7 @@ from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.storage import StorageChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.config import settings
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.core.security import verify_token, verify_apitoken
|
||||
from app.db import get_db
|
||||
@@ -75,6 +75,8 @@ async def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Dep
|
||||
:param _: Token校验
|
||||
"""
|
||||
TransferChain().remove_from_queue(fileitem)
|
||||
# 取消整理
|
||||
global_vars.stop_transfer(fileitem.path)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -109,7 +111,7 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
if history.dest_fileitem:
|
||||
# 删除旧的已整理文件
|
||||
dest_fileitem = FileItem(**history.dest_fileitem)
|
||||
state = StorageChain().delete_media_file(dest_fileitem, mtype=MediaType(history.type))
|
||||
state = StorageChain().delete_media_file(dest_fileitem)
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=f"{dest_fileitem.path} 删除失败")
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ async def create_user(
|
||||
user = await current_user.async_get_by_name(db, name=user_in.name)
|
||||
if user:
|
||||
return schemas.Response(success=False, message="用户已存在")
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
user_info["hashed_password"] = get_password_hash(user_info["password"])
|
||||
user_info.pop("password")
|
||||
@@ -59,7 +59,7 @@ async def update_user(
|
||||
"""
|
||||
更新用户
|
||||
"""
|
||||
user_info = user_in.dict()
|
||||
user_info = user_in.model_dump()
|
||||
if user_info.get("password"):
|
||||
# 正则表达式匹配密码包含字母、数字、特殊字符中的至少两项
|
||||
pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{6,50}$'
|
||||
@@ -111,45 +111,6 @@ async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db),
|
||||
return schemas.Response(success=True, message=file.filename)
|
||||
|
||||
|
||||
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||
async def otp_judge(
|
||||
data: dict,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
|
||||
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
user: User = await User.async_get_by_name(db, userid)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
return schemas.Response(success=user.is_otp)
|
||||
|
||||
|
||||
@router.get("/config/{key}", summary="查询用户配置", response_model=schemas.Response)
|
||||
def get_config(key: str,
|
||||
current_user: User = Depends(get_current_active_user)):
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.chain.workflow import WorkflowChain
|
||||
from app.core.config import global_vars
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_token
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import Workflow
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -47,7 +47,7 @@ async def create_workflow(workflow: schemas.Workflow,
|
||||
workflow.state = "P"
|
||||
if not workflow.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
workflow_obj = Workflow(**workflow.dict())
|
||||
workflow_obj = Workflow(**workflow.model_dump())
|
||||
await workflow_obj.async_create(db)
|
||||
return schemas.Response(success=True, message="创建工作流成功")
|
||||
|
||||
@@ -277,7 +277,7 @@ def update_workflow(workflow: schemas.Workflow,
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not wf.trigger_type:
|
||||
workflow.trigger_type = "timer"
|
||||
wf.update(db, workflow.dict())
|
||||
wf.update(db, workflow.model_dump())
|
||||
# 更新后的工作流对象
|
||||
updated_workflow = workflow_oper.get(workflow.id)
|
||||
# 更新定时任务
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Callable, Any, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -128,9 +128,12 @@ async def get_cookie(
|
||||
@cookie_router.post("/get/{uuid}")
|
||||
async def post_cookie(
|
||||
uuid: Annotated[str, Path(min_length=5, pattern="^[a-zA-Z0-9]+$")],
|
||||
request: schemas.CookiePassword):
|
||||
request: Optional[schemas.CookiePassword] = Body(None)):
|
||||
"""
|
||||
POST 下载加密数据
|
||||
"""
|
||||
data = await load_encrypt_data(uuid)
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
if request is not None:
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
else:
|
||||
return data
|
||||
|
||||
@@ -4,16 +4,15 @@ import pickle
|
||||
import traceback
|
||||
from abc import ABCMeta
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Tuple, List, Set, Union, Dict
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from qbittorrentapi import TorrentFilesList
|
||||
from transmission_rpc import File
|
||||
|
||||
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
from app.core.event import EventManager
|
||||
@@ -48,78 +47,66 @@ class ChainBase(metaclass=ABCMeta):
|
||||
send_callback=self.run_module
|
||||
)
|
||||
self.pluginmanager = PluginManager()
|
||||
self.filecache = FileCache()
|
||||
self.async_filecache = AsyncFileCache()
|
||||
|
||||
@staticmethod
|
||||
def load_cache(filename: str) -> Any:
|
||||
def load_cache(self, filename: str) -> Any:
|
||||
"""
|
||||
从本地加载缓存
|
||||
加载缓存
|
||||
"""
|
||||
cache_path = settings.TEMP_PATH / filename
|
||||
if cache_path.exists():
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
except Exception as err:
|
||||
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
|
||||
return None
|
||||
content = self.filecache.get(filename)
|
||||
if not content:
|
||||
return None
|
||||
try:
|
||||
return pickle.loads(content)
|
||||
except Exception as err:
|
||||
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def async_load_cache(filename: str) -> Any:
|
||||
async def async_load_cache(self, filename: str) -> Any:
|
||||
"""
|
||||
异步从本地加载缓存
|
||||
异步加载缓存
|
||||
"""
|
||||
cache_path = settings.TEMP_PATH / filename
|
||||
if cache_path.exists():
|
||||
try:
|
||||
async with aiofiles.open(cache_path, 'rb') as f:
|
||||
content = await f.read()
|
||||
return pickle.loads(content)
|
||||
except Exception as err:
|
||||
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
|
||||
return None
|
||||
content = await self.async_filecache.get(filename)
|
||||
if not content:
|
||||
return None
|
||||
try:
|
||||
return pickle.loads(content)
|
||||
except Exception as err:
|
||||
logger.error(f"异步加载缓存 {filename} 出错:{str(err)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def async_save_cache(cache: Any, filename: str) -> None:
|
||||
async def async_save_cache(self, cache: Any, filename: str) -> None:
|
||||
"""
|
||||
异步保存缓存到本地
|
||||
异步保存缓存
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(settings.TEMP_PATH / filename, 'wb') as f:
|
||||
await f.write(pickle.dumps(cache))
|
||||
await self.async_filecache.set(filename, pickle.dumps(cache))
|
||||
except Exception as err:
|
||||
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
|
||||
logger.error(f"异步保存缓存 {filename} 出错:{str(err)}")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def save_cache(cache: Any, filename: str) -> None:
|
||||
def save_cache(self, cache: Any, filename: str) -> None:
|
||||
"""
|
||||
保存缓存到本地
|
||||
保存缓存
|
||||
"""
|
||||
try:
|
||||
with open(settings.TEMP_PATH / filename, 'wb') as f:
|
||||
pickle.dump(cache, f) # noqa
|
||||
self.filecache.set(filename, pickle.dumps(cache))
|
||||
except Exception as err:
|
||||
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def remove_cache(filename: str) -> None:
|
||||
def remove_cache(self, filename: str) -> None:
|
||||
"""
|
||||
删除本地缓存
|
||||
删除缓存,同时删除Redis和本地缓存
|
||||
"""
|
||||
cache_path = settings.TEMP_PATH / filename
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
self.filecache.delete(filename)
|
||||
|
||||
@staticmethod
|
||||
async def async_remove_cache(filename: str) -> None:
|
||||
async def async_remove_cache(self, filename: str) -> None:
|
||||
"""
|
||||
异步删除本地缓存
|
||||
异步删除缓存,同时删除Redis和本地缓存
|
||||
"""
|
||||
cache_path = AsyncPath(settings.TEMP_PATH) / filename
|
||||
if await cache_path.exists():
|
||||
try:
|
||||
await cache_path.unlink()
|
||||
except Exception as err:
|
||||
logger.error(f"异步删除缓存 {filename} 出错:{str(err)}")
|
||||
await self.async_filecache.delete(filename)
|
||||
|
||||
@staticmethod
|
||||
def __is_valid_empty(ret):
|
||||
@@ -372,9 +359,10 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
with fresh(not cache):
|
||||
return self.run_module("recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
async def async_recognize_media(self, meta: MetaBase = None,
|
||||
mtype: Optional[MediaType] = None,
|
||||
@@ -405,9 +393,10 @@ class ChainBase(metaclass=ABCMeta):
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
async with async_fresh(not cache):
|
||||
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
|
||||
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
|
||||
episode_group=episode_group, cache=cache)
|
||||
|
||||
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
|
||||
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
|
||||
@@ -700,13 +689,13 @@ class ChainBase(metaclass=ABCMeta):
|
||||
return self.run_module("filter_torrents", rule_groups=rule_groups,
|
||||
torrent_list=torrent_list, mediainfo=mediainfo)
|
||||
|
||||
def download(self, content: Union[Path, str], download_dir: Path, cookie: str,
|
||||
def download(self, content: Union[Path, str, bytes], download_dir: Path, cookie: str,
|
||||
episodes: Set[int] = None, category: Optional[str] = None, label: Optional[str] = None,
|
||||
downloader: Optional[str] = None
|
||||
) -> Optional[Tuple[Optional[str], Optional[str], Optional[str], str]]:
|
||||
"""
|
||||
根据种子文件,选择并添加下载任务
|
||||
:param content: 种子文件地址或者磁力链接
|
||||
:param content: 种子文件地址或者磁力链接或者种子内容
|
||||
:param download_dir: 下载目录
|
||||
:param cookie: cookie
|
||||
:param episodes: 需要下载的集数
|
||||
@@ -719,15 +708,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
cookie=cookie, episodes=episodes, category=category, label=label,
|
||||
downloader=downloader)
|
||||
|
||||
def download_added(self, context: Context, download_dir: Path, torrent_path: Path = None) -> None:
|
||||
def download_added(self, context: Context, download_dir: Path, torrent_content: Union[str, bytes] = None) -> None:
|
||||
"""
|
||||
添加下载任务成功后,从站点下载字幕,保存到下载目录
|
||||
:param context: 上下文,包括识别信息、媒体信息、种子信息
|
||||
:param download_dir: 下载目录
|
||||
:param torrent_path: 种子文件地址
|
||||
:param torrent_content: 种子内容,如果有则直接使用该内容,否则从context中获取种子文件路径
|
||||
:return: None,该方法可被多个模块同时处理
|
||||
"""
|
||||
return self.run_module("download_added", context=context, torrent_path=torrent_path,
|
||||
return self.run_module("download_added", context=context,
|
||||
torrent_content=torrent_content,
|
||||
download_dir=download_dir)
|
||||
|
||||
def list_torrents(self, status: TorrentStatus = None,
|
||||
@@ -860,12 +850,18 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
self.messageoper.add(**message.dict())
|
||||
self.messageoper.add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
@@ -912,23 +908,23 @@ class ChainBase(metaclass=ABCMeta):
|
||||
break
|
||||
# 按设定发送
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.dict(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message)
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
self.messagequeue.send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
|
||||
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
self.messagequeue.send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False)
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
async def async_post_message(self,
|
||||
message: Optional[Notification] = None,
|
||||
meta: Optional[MetaBase] = None,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
torrentinfo: Optional[TorrentInfo] = None,
|
||||
transferinfo: Optional[TransferInfo] = None,
|
||||
**kwargs) -> None:
|
||||
message: Optional[Notification] = None,
|
||||
meta: Optional[MetaBase] = None,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
torrentinfo: Optional[TorrentInfo] = None,
|
||||
transferinfo: Optional[TransferInfo] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
异步发送消息
|
||||
:param message: Notification实例
|
||||
@@ -939,12 +935,18 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param kwargs: 其他参数(覆盖业务对象属性值)
|
||||
:return: 成功或失败
|
||||
"""
|
||||
# 添加格式化的时间参数
|
||||
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 渲染消息
|
||||
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
|
||||
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
|
||||
# 检查消息是否有效
|
||||
if not message:
|
||||
logger.warning("消息为空,跳过发送")
|
||||
return
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
await self.messageoper.async_add(**message.dict())
|
||||
await self.messageoper.async_add(**message.model_dump())
|
||||
# 发送消息按设置隔离
|
||||
if not message.userid and message.mtype:
|
||||
# 消息隔离设置
|
||||
@@ -991,15 +993,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
break
|
||||
# 按设定发送
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**send_message.dict(), "type": send_message.mtype})
|
||||
await self.messagequeue.async_send_message("post_message", message=send_message)
|
||||
data={**send_message.model_dump(), "type": send_message.mtype})
|
||||
await self.messagequeue.async_send_message("post_message", message=send_message, **kwargs)
|
||||
if not send_orignal:
|
||||
return
|
||||
# 发送消息事件
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
|
||||
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
|
||||
data={**message.model_dump(), "type": message.mtype})
|
||||
# 按原消息发送
|
||||
await self.messagequeue.async_send_message("post_message", message=message,
|
||||
immediately=True if message.userid else False)
|
||||
immediately=True if message.userid else False, **kwargs)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
@@ -1010,7 +1013,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [media.to_dict() for media in medias]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
@@ -1023,7 +1026,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
|
||||
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
|
||||
self.messageoper.add(**message.dict(), note=note_list)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents,
|
||||
immediately=True if message.userid else False)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, Optional, Tuple, Set, Dict, Union
|
||||
|
||||
from app import schemas
|
||||
from app.chain import ChainBase
|
||||
from app.core.cache import FileCache
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo, TorrentInfo, Context
|
||||
from app.core.event import eventmanager, Event
|
||||
@@ -18,7 +19,7 @@ from app.db.mediaserver_oper import MediaServerOper
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
from app.schemas import ExistMediaInfo, FileURI, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
|
||||
ResourceDownloadEventData
|
||||
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ContentType, \
|
||||
ChainEventType
|
||||
@@ -35,10 +36,10 @@ class DownloadChain(ChainBase):
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Union[str, int] = None
|
||||
) -> Tuple[Optional[Union[Path, str]], str, list]:
|
||||
) -> Tuple[Optional[Union[str, bytes]], str, list]:
|
||||
"""
|
||||
下载种子文件,如果是磁力链,会返回磁力链接本身
|
||||
:return: 种子路径,种子目录名,种子文件清单
|
||||
:return: 种子内容,种子目录名,种子文件清单
|
||||
"""
|
||||
|
||||
def __get_redict_url(url: str, ua: Optional[str] = None, cookie: Optional[str] = None) -> Optional[str]:
|
||||
@@ -117,7 +118,7 @@ class DownloadChain(ChainBase):
|
||||
logger.error(f"{torrent.title} 无法获取下载地址:{torrent.enclosure}!")
|
||||
return None, "", []
|
||||
# 下载种子文件
|
||||
torrent_file, content, download_folder, files, error_msg = TorrentHelper().download_torrent(
|
||||
_, content, download_folder, files, error_msg = TorrentHelper().download_torrent(
|
||||
url=torrent_url,
|
||||
cookie=site_cookie,
|
||||
ua=torrent.site_ua or settings.USER_AGENT,
|
||||
@@ -127,7 +128,7 @@ class DownloadChain(ChainBase):
|
||||
# 磁力链
|
||||
return content, "", []
|
||||
|
||||
if not torrent_file:
|
||||
if not content:
|
||||
logger.error(f"下载种子文件失败:{torrent.title} - {torrent_url}")
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
@@ -139,9 +140,11 @@ class DownloadChain(ChainBase):
|
||||
return None, "", []
|
||||
|
||||
# 返回 种子文件路径,种子目录名,种子文件清单
|
||||
return torrent_file, download_folder, files
|
||||
return content, download_folder, files
|
||||
|
||||
def download_single(self, context: Context, torrent_file: Path = None,
|
||||
def download_single(self, context: Context,
|
||||
torrent_file: Path = None,
|
||||
torrent_content: Optional[Union[str, bytes]] = None,
|
||||
episodes: Set[int] = None,
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
@@ -154,11 +157,12 @@ class DownloadChain(ChainBase):
|
||||
下载及发送通知
|
||||
:param context: 资源上下文
|
||||
:param torrent_file: 种子文件路径
|
||||
:param torrent_content: 种子内容(磁力链或种子文件内容)
|
||||
:param episodes: 需要下载的集数
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、Subscribe、Manual等)
|
||||
:param downloader: 下载器
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param userid: 用户ID
|
||||
:param username: 调用下载的用户名/插件名
|
||||
:param label: 自定义标签
|
||||
@@ -207,26 +211,35 @@ class DownloadChain(ChainBase):
|
||||
# 实际下载的集数
|
||||
download_episodes = StringUtils.format_ep(list(episodes)) if episodes else None
|
||||
_folder_name = ""
|
||||
if not torrent_file:
|
||||
if not torrent_file and not torrent_content:
|
||||
# 下载种子文件,得到的可能是文件也可能是磁力链
|
||||
content, _folder_name, _file_list = self.download_torrent(_torrent,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid)
|
||||
if not content:
|
||||
return None
|
||||
else:
|
||||
content = torrent_file
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_torrent_info(torrent_file)
|
||||
torrent_content, _folder_name, _file_list = self.download_torrent(_torrent,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid)
|
||||
elif torrent_file:
|
||||
if torrent_file.exists():
|
||||
torrent_content = torrent_file.read_bytes()
|
||||
else:
|
||||
# 缓存处理器
|
||||
cache_backend = FileCache()
|
||||
# 读取缓存的种子文件
|
||||
torrent_content = cache_backend.get(torrent_file.as_posix(), region="torrents")
|
||||
|
||||
if not torrent_content:
|
||||
return None
|
||||
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
|
||||
storage = 'local'
|
||||
# 下载目录
|
||||
if save_path:
|
||||
# 下载目录使用自定义的
|
||||
download_dir = Path(save_path)
|
||||
else:
|
||||
# 根据媒体信息查询下载目录配置
|
||||
dir_info = DirectoryHelper().get_dir(_media, storage="local", include_unsorted=True)
|
||||
dir_info = DirectoryHelper().get_dir(_media, include_unsorted=True)
|
||||
storage = dir_info.storage if dir_info else storage
|
||||
# 拼装子目录
|
||||
if dir_info:
|
||||
# 一级目录
|
||||
@@ -247,9 +260,11 @@ class DownloadChain(ChainBase):
|
||||
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
|
||||
title="下载失败", role="system")
|
||||
return None
|
||||
fileURI = FileURI(storage=storage, path=download_dir.as_posix())
|
||||
download_dir = Path(fileURI.uri)
|
||||
|
||||
# 添加下载
|
||||
result: Optional[tuple] = self.download(content=content,
|
||||
result: Optional[tuple] = self.download(content=torrent_content,
|
||||
cookie=_torrent.site_cookie,
|
||||
episodes=episodes,
|
||||
download_dir=download_dir,
|
||||
@@ -278,7 +293,7 @@ class DownloadChain(ChainBase):
|
||||
# 登记下载记录
|
||||
downloadhis = DownloadHistoryOper()
|
||||
downloadhis.add(
|
||||
path=str(download_path),
|
||||
path=download_path.as_posix(),
|
||||
type=_media.type.value,
|
||||
title=_media.title,
|
||||
year=_media.year,
|
||||
@@ -319,8 +334,8 @@ class DownloadChain(ChainBase):
|
||||
files_to_add.append({
|
||||
"download_hash": _hash,
|
||||
"downloader": _downloader,
|
||||
"fullpath": str(_save_path / file),
|
||||
"savepath": str(_save_path),
|
||||
"fullpath": (_save_path / file).as_posix(),
|
||||
"savepath": _save_path.as_posix(),
|
||||
"filepath": file,
|
||||
"torrentname": _meta.org_string,
|
||||
})
|
||||
@@ -346,7 +361,7 @@ class DownloadChain(ChainBase):
|
||||
username=username,
|
||||
)
|
||||
# 下载成功后处理
|
||||
self.download_added(context=context, download_dir=download_dir, torrent_path=torrent_file)
|
||||
self.download_added(context=context, download_dir=download_dir, torrent_content=torrent_content)
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(EventType.DownloadAdded, {
|
||||
"hash": _hash,
|
||||
@@ -388,7 +403,7 @@ class DownloadChain(ChainBase):
|
||||
根据缺失数据,自动种子列表中组合择优下载
|
||||
:param contexts: 资源上下文列表
|
||||
:param no_exists: 缺失的剧集信息
|
||||
:param save_path: 保存路径
|
||||
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
|
||||
:param channel: 通知渠道
|
||||
:param source: 来源(消息通知、订阅、手工下载等)
|
||||
:param userid: 用户ID
|
||||
@@ -560,7 +575,7 @@ class DownloadChain(ChainBase):
|
||||
logger.info(f"开始下载 {torrent.title} ...")
|
||||
download_id = self.download_single(
|
||||
context=context,
|
||||
torrent_file=content if isinstance(content, Path) else None,
|
||||
torrent_content=content,
|
||||
save_path=save_path,
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -727,7 +742,7 @@ class DownloadChain(ChainBase):
|
||||
logger.info(f"开始下载 {torrent.title} ...")
|
||||
download_id = self.download_single(
|
||||
context=context,
|
||||
torrent_file=content if isinstance(content, Path) else None,
|
||||
torrent_content=content,
|
||||
episodes=selected_episodes,
|
||||
save_path=save_path,
|
||||
channel=channel,
|
||||
@@ -982,7 +997,7 @@ class DownloadChain(ChainBase):
|
||||
# 发出下载任务删除事件,如需处理辅种,可监听该事件
|
||||
self.eventmanager.send_event(EventType.DownloadDeleted, {
|
||||
"hash": hash_str,
|
||||
"torrents": [torrent.dict() for torrent in torrents]
|
||||
"torrents": [torrent.model_dump() for torrent in torrents]
|
||||
})
|
||||
else:
|
||||
logger.info(f"没有在下载器中查询到 {hash_str} 对应的下载任务")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from threading import Lock
|
||||
from typing import Optional, List, Tuple, Union
|
||||
|
||||
@@ -20,6 +22,9 @@ from app.utils.string import StringUtils
|
||||
recognize_lock = Lock()
|
||||
scraping_lock = Lock()
|
||||
|
||||
current_umask = os.umask(0)
|
||||
os.umask(current_umask)
|
||||
|
||||
|
||||
class MediaChain(ChainBase):
|
||||
"""
|
||||
@@ -310,6 +315,21 @@ class MediaChain(ChainBase):
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_bluray_folder(fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断是否为原盘目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
# 蓝光原盘目录必备的文件或文件夹
|
||||
required_files = ['BDMV', 'CERTIFICATE']
|
||||
# 检查目录下是否存在所需文件或文件夹
|
||||
for item in StorageChain().list_files(fileitem):
|
||||
if item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
@eventmanager.register(EventType.MetadataScrape)
|
||||
def scrape_metadata_event(self, event: Event):
|
||||
"""
|
||||
@@ -349,51 +369,60 @@ class MediaChain(ChainBase):
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
if file_list:
|
||||
# 1. 收集fileitem和file_list中每个文件之间所有子目录
|
||||
all_dirs = set()
|
||||
root_path = Path(fileitem.path)
|
||||
# 如果是BDMV原盘目录,只对根目录进行刮削,不处理子目录
|
||||
if self.is_bluray_folder(fileitem):
|
||||
logger.info(f"检测到BDMV原盘目录,只对根目录进行刮削:{fileitem.path}")
|
||||
self.scrape_metadata(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=True,
|
||||
recursive=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
# 1. 收集fileitem和file_list中每个文件之间所有子目录
|
||||
all_dirs = set()
|
||||
root_path = Path(fileitem.path)
|
||||
|
||||
logger.debug(f"开始收集目录,根目录:{root_path}")
|
||||
# 收集根目录
|
||||
all_dirs.add(root_path)
|
||||
logger.debug(f"开始收集目录,根目录:{root_path}")
|
||||
# 收集根目录
|
||||
all_dirs.add(root_path)
|
||||
|
||||
# 收集所有目录(包括所有层级)
|
||||
for sub_file in file_list:
|
||||
sub_path = Path(sub_file)
|
||||
# 收集从根目录到文件的所有父目录
|
||||
current_path = sub_path.parent
|
||||
while current_path != root_path and current_path.is_relative_to(root_path):
|
||||
all_dirs.add(current_path)
|
||||
current_path = current_path.parent
|
||||
# 收集所有目录(包括所有层级)
|
||||
for sub_file in file_list:
|
||||
sub_path = Path(sub_file)
|
||||
# 收集从根目录到文件的所有父目录
|
||||
current_path = sub_path.parent
|
||||
while current_path != root_path and current_path.is_relative_to(root_path):
|
||||
all_dirs.add(current_path)
|
||||
current_path = current_path.parent
|
||||
|
||||
logger.debug(f"共收集到 {len(all_dirs)} 个目录")
|
||||
logger.debug(f"共收集到 {len(all_dirs)} 个目录")
|
||||
|
||||
# 2. 初始化一遍子目录,但不处理文件
|
||||
for sub_dir in all_dirs:
|
||||
sub_dir_item = storagechain.get_file_item(storage=fileitem.storage, path=sub_dir)
|
||||
if sub_dir_item:
|
||||
logger.info(f"为目录生成海报和nfo:{sub_dir}")
|
||||
# 初始化目录元数据,但不处理文件
|
||||
self.scrape_metadata(fileitem=sub_dir_item,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=True,
|
||||
recursive=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
logger.warn(f"无法获取目录项:{sub_dir}")
|
||||
# 2. 初始化一遍子目录,但不处理文件
|
||||
for sub_dir in all_dirs:
|
||||
sub_dir_item = storagechain.get_file_item(storage=fileitem.storage, path=sub_dir)
|
||||
if sub_dir_item:
|
||||
logger.info(f"为目录生成海报和nfo:{sub_dir}")
|
||||
# 初始化目录元数据,但不处理文件
|
||||
self.scrape_metadata(fileitem=sub_dir_item,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=True,
|
||||
recursive=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
logger.warn(f"无法获取目录项:{sub_dir}")
|
||||
|
||||
# 3. 刮削每个文件
|
||||
logger.info(f"开始刮削 {len(file_list)} 个文件")
|
||||
for sub_file_path in file_list:
|
||||
sub_file_item = storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=Path(sub_file_path))
|
||||
if sub_file_item:
|
||||
self.scrape_metadata(fileitem=sub_file_item,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
logger.warn(f"无法获取文件项:{sub_file_path}")
|
||||
# 3. 刮削每个文件
|
||||
logger.info(f"开始刮削 {len(file_list)} 个文件")
|
||||
for sub_file_path in file_list:
|
||||
sub_file_item = storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=Path(sub_file_path))
|
||||
if sub_file_item:
|
||||
self.scrape_metadata(fileitem=sub_file_item,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
overwrite=overwrite)
|
||||
else:
|
||||
logger.warn(f"无法获取文件项:{sub_file_path}")
|
||||
else:
|
||||
# 执行全量刮削
|
||||
logger.info(f"开始刮削目录 {fileitem.path} ...")
|
||||
@@ -417,20 +446,6 @@ class MediaChain(ChainBase):
|
||||
|
||||
storagechain = StorageChain()
|
||||
|
||||
def is_bluray_folder(_fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断是否为原盘目录
|
||||
"""
|
||||
if not _fileitem or _fileitem.type != "dir":
|
||||
return False
|
||||
# 蓝光原盘目录必备的文件或文件夹
|
||||
required_files = ['BDMV', 'CERTIFICATE']
|
||||
# 检查目录下是否存在所需文件或文件夹
|
||||
for item in storagechain.list_files(_fileitem):
|
||||
if item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __list_files(_fileitem: schemas.FileItem):
|
||||
"""
|
||||
列出下级文件
|
||||
@@ -446,36 +461,65 @@ class MediaChain(ChainBase):
|
||||
"""
|
||||
if not _fileitem or not _content or not _path:
|
||||
return
|
||||
# 保存文件到临时目录
|
||||
tmp_dir = settings.TEMP_PATH / StringUtils.generate_random_str(10)
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp_file = tmp_dir / _path.name
|
||||
tmp_file.write_bytes(_content)
|
||||
# 获取文件的父目录
|
||||
try:
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file, new_name=_path.name)
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 写入内容
|
||||
if isinstance(_content, bytes):
|
||||
tmp_file.write(_content)
|
||||
else:
|
||||
tmp_file.write(_content.encode('utf-8'))
|
||||
tmp_file.flush()
|
||||
tmp_file.close() # 关闭文件句柄
|
||||
|
||||
# 刮削文件只需要读写权限
|
||||
tmp_file_path.chmod(0o666 & ~current_umask)
|
||||
|
||||
# 上传文件
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path, new_name=_path.name)
|
||||
if item:
|
||||
logger.info(f"已保存文件:{item.path}")
|
||||
else:
|
||||
logger.warn(f"文件保存失败:{_path}")
|
||||
finally:
|
||||
if tmp_file.exists():
|
||||
tmp_file.unlink()
|
||||
|
||||
def __download_image(_url: str) -> Optional[bytes]:
|
||||
def __download_and_save_image(_fileitem: schemas.FileItem, _path: Path, _url: str):
|
||||
"""
|
||||
下载图片并保存
|
||||
流式下载图片并直接保存到文件(减少内存占用)
|
||||
:param _fileitem: 关联的媒体文件项
|
||||
:param _path: 图片文件路径
|
||||
:param _url: 图片下载URL
|
||||
"""
|
||||
if not _fileitem or not _url or not _path:
|
||||
return
|
||||
try:
|
||||
logger.info(f"正在下载图片:{_url} ...")
|
||||
r = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(url=_url)
|
||||
if r:
|
||||
return r.content
|
||||
else:
|
||||
logger.info(f"{_url} 图片下载失败,请检查网络连通性!")
|
||||
request_utils = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT)
|
||||
with request_utils.get_stream(url=_url) as r:
|
||||
if r and r.status_code == 200:
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(delete=True, delete_on_close=False, suffix=_path.suffix) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 流式写入文件
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
tmp_file.write(chunk)
|
||||
tmp_file.flush()
|
||||
tmp_file.close() # 关闭文件句柄
|
||||
|
||||
# 刮削的图片只需要读写权限
|
||||
tmp_file_path.chmod(0o666 & ~current_umask)
|
||||
|
||||
# 上传文件
|
||||
item = storagechain.upload_file(fileitem=_fileitem, path=tmp_file_path,
|
||||
new_name=_path.name)
|
||||
if item:
|
||||
logger.info(f"已保存图片:{item.path}")
|
||||
else:
|
||||
logger.warn(f"图片保存失败:{_path}")
|
||||
else:
|
||||
logger.info(f"{_url} 图片下载失败")
|
||||
except Exception as err:
|
||||
logger.error(f"{_url} 图片下载失败:{str(err)}!")
|
||||
return None
|
||||
|
||||
if not fileitem:
|
||||
return
|
||||
@@ -521,7 +565,7 @@ class MediaChain(ChainBase):
|
||||
# 电影目录
|
||||
if recursive:
|
||||
# 处理文件
|
||||
if is_bluray_folder(fileitem):
|
||||
if self.is_bluray_folder(fileitem):
|
||||
# 原盘目录
|
||||
if scraping_switchs.get('movie_nfo', True):
|
||||
nfo_path = filepath / (filepath.name + ".nfo")
|
||||
@@ -541,6 +585,9 @@ class MediaChain(ChainBase):
|
||||
# 处理目录内的文件
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
@@ -571,14 +618,11 @@ class MediaChain(ChainBase):
|
||||
should_scrape = True # 未知类型默认刮削
|
||||
|
||||
if should_scrape:
|
||||
image_path = filepath.with_name(image_name)
|
||||
image_path = filepath / image_name
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 写入图片到当前目录
|
||||
if content:
|
||||
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -624,13 +668,10 @@ class MediaChain(ChainBase):
|
||||
for episode, image_url in image_dict.items():
|
||||
image_path = filepath.with_suffix(Path(image_url).suffix)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -640,6 +681,9 @@ class MediaChain(ChainBase):
|
||||
if recursive:
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir" and not file.name.lower().startswith("season"):
|
||||
# 电视剧不处理非季子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
parent=fileitem if file.type == "file" else None,
|
||||
@@ -678,13 +722,10 @@ class MediaChain(ChainBase):
|
||||
image_path = filepath.with_name(image_name)
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到剧集目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -714,13 +755,11 @@ class MediaChain(ChainBase):
|
||||
continue
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__save_file(_fileitem=parent, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
if not parent:
|
||||
parent = storagechain.get_parent_item(fileitem)
|
||||
__download_and_save_image(_fileitem=parent, _path=image_path,
|
||||
_url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
@@ -770,11 +809,8 @@ class MediaChain(ChainBase):
|
||||
image_path = filepath / image_name
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
|
||||
path=image_path):
|
||||
# 下载图片
|
||||
content = __download_image(image_url)
|
||||
# 保存图片文件到当前目录
|
||||
if content:
|
||||
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
|
||||
# 流式下载图片并直接保存
|
||||
__download_and_save_image(_fileitem=fileitem, _path=image_path, _url=image_url)
|
||||
else:
|
||||
logger.info(f"已存在图片文件:{image_path}")
|
||||
else:
|
||||
|
||||
@@ -113,6 +113,16 @@ class MediaServerChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("mediaserver_play_url", server=server, item_id=item_id)
|
||||
|
||||
def get_image_cookies(
|
||||
self, server: Optional[str], image_url: str
|
||||
) -> Optional[str | dict]:
|
||||
"""
|
||||
获取图片的Cookies
|
||||
"""
|
||||
return self.run_module(
|
||||
"mediaserver_image_cookies", server=server, image_url=image_url
|
||||
)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
同步媒体库所有数据到本地数据库
|
||||
@@ -167,7 +177,7 @@ class MediaServerChain(ChainBase):
|
||||
for episode in espisodes_info:
|
||||
seasoninfo[episode.season] = episode.episodes
|
||||
# 插入数据
|
||||
item_dict = item.dict()
|
||||
item_dict = item.model_dump()
|
||||
item_dict["seasoninfo"] = seasoninfo
|
||||
item_dict["item_type"] = item_type
|
||||
dboper.add(**item_dict)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import settings
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.meta import MetaBase
|
||||
from app.db.user_oper import UserOper
|
||||
@@ -33,6 +37,10 @@ class MessageChain(ChainBase):
|
||||
_cache_file = "__user_messages__"
|
||||
# 每页数据量
|
||||
_page_size: int = 8
|
||||
# 用户会话信息 {userid: (session_id, last_time)}
|
||||
_user_sessions: Dict[Union[str, int], tuple] = {}
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 15
|
||||
|
||||
@staticmethod
|
||||
def __get_noexits_info(
|
||||
@@ -156,15 +164,15 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
# 处理消息
|
||||
if text.startswith('CALLBACK:'):
|
||||
# 处理按钮回调(适配支持回调的渠道)
|
||||
# 处理按钮回调(适配支持回调的渠),优先级最高
|
||||
if ChannelCapabilityManager.supports_callbacks(channel):
|
||||
self._handle_callback(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username,
|
||||
original_message_id=original_message_id, original_chat_id=original_chat_id)
|
||||
else:
|
||||
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
elif text.startswith('/') and not text.lower().startswith('/ai'):
|
||||
# 执行特定命令命令(但不是/ai)
|
||||
self.eventmanager.send_event(
|
||||
EventType.CommandExcute,
|
||||
{
|
||||
@@ -174,265 +182,231 @@ class MessageChain(ChainBase):
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
elif text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
# 选择项目
|
||||
if not cache_data \
|
||||
or not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
# 选择的序号
|
||||
_choice = int(text) + _current_page * self._page_size - 1
|
||||
# 缓存类型
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 缓存列表
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
# 选择
|
||||
elif text.lower().startswith('/ai'):
|
||||
# 用户指定AI智能体消息响应
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
|
||||
# 普通消息,全局智能体响应
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
else:
|
||||
# 非智能体普通消息响应
|
||||
if text.isdigit():
|
||||
# 用户选择了具体的条目
|
||||
# 缓存
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
# 选择项目
|
||||
if not cache_data.get('items') \
|
||||
or len(cache_data.get('items')) < int(text):
|
||||
# 发送消息
|
||||
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
if cache_type in ["Search", "ReSearch"]:
|
||||
# 当前媒体信息
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
_current_media = mediainfo
|
||||
# 查询缺失的媒体信息
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=_current_media)
|
||||
if exist_flag and cache_type == "Search":
|
||||
# 媒体库中已存在
|
||||
# 选择的序号
|
||||
_choice = int(text) + _current_page * self._page_size - 1
|
||||
# 缓存类型
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 缓存列表
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
# 选择
|
||||
try:
|
||||
if cache_type in ["Search", "ReSearch"]:
|
||||
# 当前媒体信息
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
_current_media = mediainfo
|
||||
# 查询缺失的媒体信息
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=_current_media)
|
||||
if exist_flag and cache_type == "Search":
|
||||
# 媒体库中已存在
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"【{_current_media.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
|
||||
userid=userid))
|
||||
return
|
||||
elif exist_flag:
|
||||
# 没有缺失,但要全量重新搜索和下载
|
||||
no_exists = self.__get_noexits_info(_current_meta, _current_media)
|
||||
# 发送缺失的媒体信息
|
||||
messages = []
|
||||
if no_exists and cache_type == "Search":
|
||||
# 发送缺失消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
elif no_exists:
|
||||
# 发送总集数的消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季总 {no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
if messages:
|
||||
self.post_message(Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
|
||||
userid=userid))
|
||||
# 搜索种子,过滤掉不需要的剧集,以便选择
|
||||
logger.info(f"开始搜索 {mediainfo.title_year} ...")
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"【{_current_media.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
|
||||
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
|
||||
userid=userid))
|
||||
return
|
||||
elif exist_flag:
|
||||
# 没有缺失,但要全量重新搜索和下载
|
||||
no_exists = self.__get_noexits_info(_current_meta, _current_media)
|
||||
# 发送缺失的媒体信息
|
||||
messages = []
|
||||
if no_exists and cache_type == "Search":
|
||||
# 发送缺失消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
elif no_exists:
|
||||
# 发送总集数的消息
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
messages = [
|
||||
f"第 {sea} 季总 {no_exist.total_episode} 集"
|
||||
for sea, no_exist in no_exists.get(mediakey).items()]
|
||||
if messages:
|
||||
self.post_message(Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title_year}:\n" + "\n".join(messages),
|
||||
userid=userid))
|
||||
# 搜索种子,过滤掉不需要的剧集,以便选择
|
||||
logger.info(f"开始搜索 {mediainfo.title_year} ...")
|
||||
self.post_message(
|
||||
Notification(channel=channel,
|
||||
source=source,
|
||||
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
|
||||
userid=userid))
|
||||
# 开始搜索
|
||||
contexts = SearchChain().process(mediainfo=mediainfo,
|
||||
no_exists=no_exists)
|
||||
if not contexts:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"{mediainfo.title}"
|
||||
f"{_current_meta.sea} 未搜索到需要的资源!",
|
||||
userid=userid))
|
||||
return
|
||||
# 搜索结果排序
|
||||
contexts = TorrentHelper().sort_torrents(contexts)
|
||||
try:
|
||||
# 判断是否设置自动下载
|
||||
auto_download_user = settings.AUTO_DOWNLOAD_USER
|
||||
# 匹配到自动下载用户
|
||||
if auto_download_user \
|
||||
and (auto_download_user == "all"
|
||||
or any(userid == user for user in auto_download_user.split(","))):
|
||||
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
|
||||
# 自动选择下载
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=contexts,
|
||||
userid=userid,
|
||||
username=username,
|
||||
no_exists=no_exists)
|
||||
else:
|
||||
# 更新缓存
|
||||
user_cache[userid] = {
|
||||
"type": "Torrent",
|
||||
"items": contexts
|
||||
}
|
||||
_current_page = 0
|
||||
# 保存缓存
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
# 删除原消息
|
||||
if (original_message_id and original_chat_id and
|
||||
ChannelCapabilityManager.supports_deletion(channel)):
|
||||
self.delete_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id
|
||||
)
|
||||
# 发送种子数据
|
||||
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=mediainfo.title,
|
||||
items=contexts[:self._page_size],
|
||||
userid=userid,
|
||||
total=len(contexts))
|
||||
finally:
|
||||
contexts.clear()
|
||||
del contexts
|
||||
elif cache_type in ["Subscribe", "ReSubscribe"]:
|
||||
# 订阅或洗版媒体
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
# 洗版标识
|
||||
best_version = False
|
||||
# 查询缺失的媒体信息
|
||||
if cache_type == "Subscribe":
|
||||
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=mediainfo)
|
||||
if exist_flag:
|
||||
# 开始搜索
|
||||
contexts = SearchChain().process(mediainfo=mediainfo,
|
||||
no_exists=no_exists)
|
||||
if not contexts:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"【{mediainfo.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
|
||||
title=f"{mediainfo.title}"
|
||||
f"{_current_meta.sea} 未搜索到需要的资源!",
|
||||
userid=userid))
|
||||
return
|
||||
else:
|
||||
best_version = True
|
||||
# 转换用户名
|
||||
mp_name = UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
# 添加订阅,状态为N
|
||||
SubscribeChain().add(title=mediainfo.title,
|
||||
year=mediainfo.year,
|
||||
mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
season=_current_meta.begin_season,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=mp_name or username,
|
||||
best_version=best_version)
|
||||
elif cache_type == "Torrent":
|
||||
if int(text) == 0:
|
||||
# 自动选择下载,强制下载模式
|
||||
self.__auto_download(channel=channel,
|
||||
# 搜索结果排序
|
||||
contexts = TorrentHelper().sort_torrents(contexts)
|
||||
try:
|
||||
# 判断是否设置自动下载
|
||||
auto_download_user = settings.AUTO_DOWNLOAD_USER
|
||||
# 匹配到自动下载用户
|
||||
if auto_download_user \
|
||||
and (auto_download_user == "all"
|
||||
or any(userid == user for user in auto_download_user.split(","))):
|
||||
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
|
||||
# 自动选择下载
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=contexts,
|
||||
userid=userid,
|
||||
username=username,
|
||||
no_exists=no_exists)
|
||||
else:
|
||||
# 更新缓存
|
||||
user_cache[userid] = {
|
||||
"type": "Torrent",
|
||||
"items": contexts
|
||||
}
|
||||
_current_page = 0
|
||||
# 保存缓存
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
# 删除原消息
|
||||
if (original_message_id and original_chat_id and
|
||||
ChannelCapabilityManager.supports_deletion(channel)):
|
||||
self.delete_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id
|
||||
)
|
||||
# 发送种子数据
|
||||
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=mediainfo.title,
|
||||
items=contexts[:self._page_size],
|
||||
userid=userid,
|
||||
total=len(contexts))
|
||||
finally:
|
||||
contexts.clear()
|
||||
del contexts
|
||||
elif cache_type in ["Subscribe", "ReSubscribe"]:
|
||||
# 订阅或洗版媒体
|
||||
mediainfo: MediaInfo = cache_list[_choice]
|
||||
# 洗版标识
|
||||
best_version = False
|
||||
# 查询缺失的媒体信息
|
||||
if cache_type == "Subscribe":
|
||||
exist_flag, _ = DownloadChain().get_no_exists_info(meta=_current_meta,
|
||||
mediainfo=mediainfo)
|
||||
if exist_flag:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"【{mediainfo.title_year}"
|
||||
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
|
||||
userid=userid))
|
||||
return
|
||||
else:
|
||||
best_version = True
|
||||
# 转换用户名
|
||||
mp_name = UserOper().get_name(
|
||||
**{f"{channel.name.lower()}_userid": userid}) if channel else None
|
||||
# 添加订阅,状态为N
|
||||
SubscribeChain().add(title=mediainfo.title,
|
||||
year=mediainfo.year,
|
||||
mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
season=_current_meta.begin_season,
|
||||
channel=channel,
|
||||
source=source,
|
||||
cache_list=cache_list,
|
||||
userid=userid,
|
||||
username=username)
|
||||
else:
|
||||
# 下载种子
|
||||
context: Context = cache_list[_choice]
|
||||
# 下载
|
||||
DownloadChain().download_single(context, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
username=mp_name or username,
|
||||
best_version=best_version)
|
||||
elif cache_type == "Torrent":
|
||||
if int(text) == 0:
|
||||
# 自动选择下载,强制下载模式
|
||||
self.__auto_download(channel=channel,
|
||||
source=source,
|
||||
cache_list=cache_list,
|
||||
userid=userid,
|
||||
username=username)
|
||||
else:
|
||||
# 下载种子
|
||||
context: Context = cache_list[_choice]
|
||||
# 下载
|
||||
DownloadChain().download_single(context, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
if _current_page == 0:
|
||||
# 第一页
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "p":
|
||||
# 上一页
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是第一页了!", userid=userid))
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
# 减一页
|
||||
_current_page -= 1
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
if _current_page == 0:
|
||||
start = 0
|
||||
end = self._page_size
|
||||
else:
|
||||
start = _current_page * self._page_size
|
||||
end = start + self._page_size
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
# 发送媒体数据
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid).copy()
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
total = len(cache_list)
|
||||
# 加一页
|
||||
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
|
||||
if not cache_list:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
|
||||
return
|
||||
else:
|
||||
# 第一页
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是第一页了!", userid=userid))
|
||||
return
|
||||
# 减一页
|
||||
_current_page -= 1
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
try:
|
||||
# 加一页
|
||||
_current_page += 1
|
||||
if _current_page == 0:
|
||||
start = 0
|
||||
end = self._page_size
|
||||
else:
|
||||
start = _current_page * self._page_size
|
||||
end = start + self._page_size
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=total,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
@@ -440,93 +414,145 @@ class MessageChain(ChainBase):
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list,
|
||||
items=cache_list[start:end],
|
||||
userid=userid,
|
||||
total=total,
|
||||
total=len(cache_list),
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
else:
|
||||
# 搜索或订阅
|
||||
if text.startswith("订阅"):
|
||||
# 订阅
|
||||
content = re.sub(r"订阅[::\s]*", "", text)
|
||||
action = "Subscribe"
|
||||
elif text.startswith("洗版"):
|
||||
# 洗版
|
||||
content = re.sub(r"洗版[::\s]*", "", text)
|
||||
action = "ReSubscribe"
|
||||
elif text.startswith("搜索") or text.startswith("下载"):
|
||||
# 重新搜索/下载
|
||||
content = re.sub(r"(搜索|下载)[::\s]*", "", text)
|
||||
action = "ReSearch"
|
||||
elif text.startswith("#") \
|
||||
or re.search(r"^请[问帮你]", text) \
|
||||
or re.search(r"[??]$", text) \
|
||||
or StringUtils.count_words(text) > 10 \
|
||||
or text.find("继续") != -1:
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
elif StringUtils.is_link(text):
|
||||
# 链接
|
||||
content = text
|
||||
action = "Link"
|
||||
else:
|
||||
# 搜索
|
||||
content = text
|
||||
action = "Search"
|
||||
|
||||
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
|
||||
# 搜索
|
||||
meta, medias = MediaChain().search(content)
|
||||
# 识别
|
||||
if not meta.name:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
|
||||
return
|
||||
# 开始搜索
|
||||
if not medias:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid))
|
||||
return
|
||||
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
|
||||
try:
|
||||
# 记录当前状态
|
||||
_current_meta = meta
|
||||
# 保存缓存
|
||||
user_cache[userid] = {
|
||||
'type': action,
|
||||
'items': medias
|
||||
}
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
_current_page = 0
|
||||
_current_media = None
|
||||
# 发送媒体列表
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=meta.name,
|
||||
items=medias[:self._page_size],
|
||||
userid=userid, total=len(medias))
|
||||
finally:
|
||||
medias.clear()
|
||||
del medias
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
elif text.lower() == "n":
|
||||
# 下一页
|
||||
cache_data: dict = user_cache.get(userid)
|
||||
if not cache_data:
|
||||
# 没有缓存
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="输入有误!", userid=userid))
|
||||
return
|
||||
cache_data = cache_data.copy()
|
||||
try:
|
||||
cache_type: str = cache_data.get('type')
|
||||
# 产生副本,避免修改原值
|
||||
cache_list: list = cache_data.get('items').copy()
|
||||
total = len(cache_list)
|
||||
# 加一页
|
||||
cache_list = cache_list[(_current_page + 1) * self._page_size:(_current_page + 2) * self._page_size]
|
||||
if not cache_list:
|
||||
# 没有数据
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
|
||||
return
|
||||
else:
|
||||
try:
|
||||
# 加一页
|
||||
_current_page += 1
|
||||
if cache_type == "Torrent":
|
||||
# 发送种子数据
|
||||
self.__post_torrents_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_media.title,
|
||||
items=cache_list,
|
||||
userid=userid,
|
||||
total=total,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
else:
|
||||
# 发送媒体数据
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=_current_meta.name,
|
||||
items=cache_list,
|
||||
userid=userid,
|
||||
total=total,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id)
|
||||
finally:
|
||||
cache_list.clear()
|
||||
del cache_list
|
||||
finally:
|
||||
cache_data.clear()
|
||||
del cache_data
|
||||
else:
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(
|
||||
EventType.UserMessage,
|
||||
{
|
||||
"text": content,
|
||||
"userid": userid,
|
||||
"channel": channel,
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
# 搜索或订阅
|
||||
if text.startswith("订阅"):
|
||||
# 订阅
|
||||
content = re.sub(r"订阅[::\s]*", "", text)
|
||||
action = "Subscribe"
|
||||
elif text.startswith("洗版"):
|
||||
# 洗版
|
||||
content = re.sub(r"洗版[::\s]*", "", text)
|
||||
action = "ReSubscribe"
|
||||
elif text.startswith("搜索") or text.startswith("下载"):
|
||||
# 重新搜索/下载
|
||||
content = re.sub(r"(搜索|下载)[::\s]*", "", text)
|
||||
action = "ReSearch"
|
||||
elif text.startswith("#") \
|
||||
or re.search(r"^请[问帮你]", text) \
|
||||
or re.search(r"[??]$", text) \
|
||||
or StringUtils.count_words(text) > 10 \
|
||||
or text.find("继续") != -1:
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
elif StringUtils.is_link(text):
|
||||
# 链接
|
||||
content = text
|
||||
action = "Link"
|
||||
else:
|
||||
# 搜索
|
||||
content = text
|
||||
action = "Search"
|
||||
|
||||
if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]:
|
||||
# 搜索
|
||||
meta, medias = MediaChain().search(content)
|
||||
# 识别
|
||||
if not meta.name:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
|
||||
return
|
||||
# 开始搜索
|
||||
if not medias:
|
||||
self.post_message(Notification(
|
||||
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!",
|
||||
userid=userid))
|
||||
return
|
||||
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
|
||||
try:
|
||||
# 记录当前状态
|
||||
_current_meta = meta
|
||||
# 保存缓存
|
||||
user_cache[userid] = {
|
||||
'type': action,
|
||||
'items': medias
|
||||
}
|
||||
self.save_cache(user_cache, self._cache_file)
|
||||
_current_page = 0
|
||||
_current_media = None
|
||||
# 发送媒体列表
|
||||
self.__post_medias_message(channel=channel,
|
||||
source=source,
|
||||
title=meta.name,
|
||||
items=medias[:self._page_size],
|
||||
userid=userid, total=len(medias))
|
||||
finally:
|
||||
medias.clear()
|
||||
del medias
|
||||
else:
|
||||
# 广播事件
|
||||
self.eventmanager.send_event(
|
||||
EventType.UserMessage,
|
||||
{
|
||||
"text": content,
|
||||
"userid": userid,
|
||||
"channel": channel,
|
||||
"source": source
|
||||
}
|
||||
)
|
||||
finally:
|
||||
user_cache.clear()
|
||||
del user_cache
|
||||
@@ -815,3 +841,133 @@ class MessageChain(ChainBase):
|
||||
buttons.append(page_buttons)
|
||||
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _get_or_create_session_id(userid: Union[str, int]) -> str:
|
||||
"""
|
||||
获取或创建会话ID
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, last_time = MessageChain._user_sessions[userid]
|
||||
|
||||
# 计算时间差
|
||||
time_diff = current_time - last_time
|
||||
|
||||
# 如果时间差小于等于15分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
|
||||
# 更新最后使用时间
|
||||
MessageChain._user_sessions[userid] = (session_id, current_time)
|
||||
logger.info(
|
||||
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
return session_id
|
||||
|
||||
# 创建新的会话ID
|
||||
new_session_id = f"user_{userid}_{int(time.time())}"
|
||||
MessageChain._user_sessions[userid] = (new_session_id, current_time)
|
||||
logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}")
|
||||
return new_session_id
|
||||
|
||||
@staticmethod
|
||||
def clear_user_session(userid: Union[str, int]) -> bool:
|
||||
"""
|
||||
清除指定用户的会话信息
|
||||
返回是否成功清除
|
||||
"""
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def remote_clear_session(self, channel: MessageChannel, userid: Union[str, int], source: Optional[str] = None):
|
||||
"""
|
||||
清除用户会话(远程命令接口)
|
||||
"""
|
||||
# 获取并清除会话信息
|
||||
session_id = None
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
if session_id:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.clear_session(
|
||||
session_id=session_id,
|
||||
user_id=str(userid)
|
||||
),
|
||||
global_vars.loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"清除智能体会话记忆失败: {e}")
|
||||
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="智能体会话已清除,下次将创建新的会话",
|
||||
userid=userid
|
||||
))
|
||||
else:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="您当前没有活跃的智能体会话",
|
||||
userid=userid
|
||||
))
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
try:
|
||||
# 检查AI智能体是否启用
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="MoviePilot智能助手未启用,请在系统设置中启用"
|
||||
))
|
||||
return
|
||||
|
||||
# 提取用户消息
|
||||
if text.lower().startswith("/ai"):
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感)
|
||||
else:
|
||||
user_message = text.strip() # 按原消息处理
|
||||
if not user_message:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="请输入您的问题或需求"
|
||||
))
|
||||
return
|
||||
|
||||
# 生成或复用会话ID
|
||||
session_id = self._get_or_create_session_id(userid)
|
||||
|
||||
# 在事件循环中处理
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username
|
||||
),
|
||||
global_vars.loop
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")
|
||||
|
||||
@@ -1,48 +1,105 @@
|
||||
import asyncio
|
||||
import io
|
||||
from typing import List, Optional
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from anyio import Path as AsyncPath
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cache_backend, cached
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings, global_vars
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.asyncio import AsyncUtils
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.http import AsyncRequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
# 推荐相关的专用缓存
|
||||
recommend_ttl = 24 * 3600
|
||||
recommend_cache_region = "recommend"
|
||||
|
||||
|
||||
class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
推荐处理链,单例运行
|
||||
"""
|
||||
|
||||
# 推荐数据的缓存页数
|
||||
# 推荐缓存时间
|
||||
recommend_ttl = 24 * 3600
|
||||
# 推荐缓存页数
|
||||
cache_max_pages = 5
|
||||
# 推荐缓存区域
|
||||
recommend_cache_region = "recommend"
|
||||
|
||||
def refresh_recommend(self):
|
||||
"""
|
||||
刷新推荐数据 - 同步包装器
|
||||
刷新推荐
|
||||
"""
|
||||
try:
|
||||
AsyncUtils.run_async(self.async_refresh_recommend())
|
||||
except Exception as e:
|
||||
logger.error(f"刷新推荐数据失败:{str(e)}")
|
||||
raise
|
||||
logger.debug("Starting to refresh Recommend data.")
|
||||
|
||||
# 推荐来源方法
|
||||
recommend_methods = [
|
||||
self.tmdb_movies,
|
||||
self.tmdb_tvs,
|
||||
self.tmdb_trending,
|
||||
self.bangumi_calendar,
|
||||
self.douban_movie_showing,
|
||||
self.douban_movies,
|
||||
self.douban_tvs,
|
||||
self.douban_movie_top250,
|
||||
self.douban_tv_weekly_chinese,
|
||||
self.douban_tv_weekly_global,
|
||||
self.douban_tv_animation,
|
||||
self.douban_movie_hot,
|
||||
self.douban_tv_hot,
|
||||
]
|
||||
|
||||
# 缓存并刷新所有推荐数据
|
||||
recommends = []
|
||||
# 记录哪些方法已完成
|
||||
methods_finished = set()
|
||||
# 这里避免区间内连续调用相同来源,因此遍历方案为每页遍历所有推荐来源,再进行页数遍历
|
||||
for page in range(1, self.cache_max_pages + 1):
|
||||
for method in recommend_methods:
|
||||
if global_vars.is_system_stopped:
|
||||
return
|
||||
if method in methods_finished:
|
||||
continue
|
||||
logger.debug(f"Fetch {method.__name__} data for page {page}.")
|
||||
data = method(page=page)
|
||||
if not data:
|
||||
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
|
||||
methods_finished.add(method)
|
||||
continue
|
||||
recommends.extend(data)
|
||||
# 如果所有方法都已经完成,提前结束循环
|
||||
if len(methods_finished) == len(recommend_methods):
|
||||
break
|
||||
|
||||
# 缓存收集到的海报
|
||||
self.__cache_posters(recommends)
|
||||
logger.debug("Recommend data refresh completed.")
|
||||
|
||||
def __cache_posters(self, datas: List[dict]):
|
||||
"""
|
||||
提取 poster_path 并缓存图片
|
||||
:param datas: 数据列表
|
||||
"""
|
||||
if not settings.GLOBAL_IMAGE_CACHE:
|
||||
return
|
||||
|
||||
for data in datas:
|
||||
if global_vars.is_system_stopped:
|
||||
return
|
||||
poster_path = data.get("poster_path")
|
||||
if poster_path:
|
||||
poster_url = poster_path.replace("original", "w500")
|
||||
logger.debug(f"Caching poster image: {poster_url}")
|
||||
self.__fetch_and_save_image(poster_url)
|
||||
|
||||
@staticmethod
|
||||
def __fetch_and_save_image(url: str):
|
||||
"""
|
||||
请求并保存图片
|
||||
:param url: 图片路径
|
||||
"""
|
||||
ImageHelper().fetch_image(url=url)
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
@@ -199,162 +256,6 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
tvs = DoubanChain().tv_hot(page=page, count=count)
|
||||
return [media.to_dict() for media in tvs] if tvs else []
|
||||
|
||||
# 异步版本的方法
|
||||
async def async_refresh_recommend(self):
|
||||
"""
|
||||
异步刷新推荐
|
||||
"""
|
||||
logger.debug("Starting to async refresh Recommend data.")
|
||||
cache_backend.clear(region=recommend_cache_region)
|
||||
logger.debug("Recommend Cache has been cleared.")
|
||||
|
||||
# 推荐来源方法
|
||||
recommend_methods = [
|
||||
self.async_tmdb_movies,
|
||||
self.async_tmdb_tvs,
|
||||
self.async_tmdb_trending,
|
||||
self.async_bangumi_calendar,
|
||||
self.async_douban_movie_showing,
|
||||
self.async_douban_movies,
|
||||
self.async_douban_tvs,
|
||||
self.async_douban_movie_top250,
|
||||
self.async_douban_tv_weekly_chinese,
|
||||
self.async_douban_tv_weekly_global,
|
||||
self.async_douban_tv_animation,
|
||||
self.async_douban_movie_hot,
|
||||
self.async_douban_tv_hot,
|
||||
]
|
||||
|
||||
# 缓存并刷新所有推荐数据
|
||||
recommends = []
|
||||
# 记录哪些方法已完成
|
||||
methods_finished = set()
|
||||
# 这里避免区间内连续调用相同来源,因此遍历方案为每页遍历所有推荐来源,再进行页数遍历
|
||||
for page in range(1, self.cache_max_pages + 1):
|
||||
# 为每个页面并发执行所有方法
|
||||
tasks = []
|
||||
for method in recommend_methods:
|
||||
if global_vars.is_system_stopped:
|
||||
return
|
||||
if method in methods_finished:
|
||||
continue
|
||||
tasks.append(self._async_fetch_method_data(method, page, methods_finished))
|
||||
|
||||
# 并发执行所有任务
|
||||
if tasks:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, list) and result:
|
||||
recommends.extend(result)
|
||||
|
||||
# 如果所有方法都已经完成,提前结束循环
|
||||
if len(methods_finished) == len(recommend_methods):
|
||||
break
|
||||
|
||||
# 缓存收集到的海报
|
||||
await self.__async_cache_posters(recommends)
|
||||
logger.debug("Async recommend data refresh completed.")
|
||||
|
||||
@staticmethod
|
||||
async def _async_fetch_method_data(method, page: int, methods_finished: set):
|
||||
"""
|
||||
异步获取方法数据的辅助函数
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Async fetch {method.__name__} data for page {page}.")
|
||||
data = await method(page=page)
|
||||
if not data:
|
||||
logger.debug(f"Method {method.__name__} finished fetching data. Ending pagination early.")
|
||||
methods_finished.add(method)
|
||||
return []
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching data from {method.__name__}: {e}")
|
||||
methods_finished.add(method)
|
||||
return []
|
||||
|
||||
async def __async_cache_posters(self, datas: List[dict]):
|
||||
"""
|
||||
异步提取 poster_path 并缓存图片
|
||||
:param datas: 数据列表
|
||||
"""
|
||||
if not settings.GLOBAL_IMAGE_CACHE:
|
||||
return
|
||||
|
||||
tasks = []
|
||||
for data in datas:
|
||||
if global_vars.is_system_stopped:
|
||||
return
|
||||
poster_path = data.get("poster_path")
|
||||
if poster_path:
|
||||
poster_url = poster_path.replace("original", "w500")
|
||||
logger.debug(f"Async caching poster image: {poster_url}")
|
||||
tasks.append(self.__async_fetch_and_save_image(poster_url))
|
||||
|
||||
# 并发缓存图片
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@staticmethod
|
||||
async def __async_fetch_and_save_image(url: str):
|
||||
"""
|
||||
异步请求并保存图片
|
||||
:param url: 图片路径
|
||||
"""
|
||||
if not settings.GLOBAL_IMAGE_CACHE or not url:
|
||||
return
|
||||
|
||||
# 生成缓存路径
|
||||
base_path = AsyncPath(settings.CACHE_PATH)
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = base_path / "images" / sanitized_path
|
||||
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 确保缓存路径和文件类型合法
|
||||
if not await SecurityUtils.async_is_safe_path(base_path=base_path,
|
||||
user_path=cache_path,
|
||||
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
|
||||
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
|
||||
return
|
||||
|
||||
# 本地存在缓存图片,则直接跳过
|
||||
if await cache_path.exists():
|
||||
logger.debug(f"Cache hit: Image already exists at {cache_path}")
|
||||
return
|
||||
|
||||
# 请求远程图片
|
||||
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
|
||||
proxies = settings.PROXY if not referer else None
|
||||
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT,
|
||||
proxies=proxies, referer=referer).get_res(url=url)
|
||||
if not response:
|
||||
logger.debug(f"Empty response for URL: {url}")
|
||||
return
|
||||
|
||||
# 验证下载的内容是否为有效图片
|
||||
try:
|
||||
Image.open(io.BytesIO(response.content)).verify()
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid image format for URL {url}: {e}")
|
||||
return
|
||||
|
||||
if not cache_path:
|
||||
return
|
||||
|
||||
try:
|
||||
if not await cache_path.parent.exists():
|
||||
await cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
|
||||
await tmp_file.write(response.content)
|
||||
temp_path = AsyncPath(tmp_file.name)
|
||||
await temp_path.replace(cache_path)
|
||||
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
@cached(ttl=recommend_ttl, region=recommend_cache_region)
|
||||
async def async_tmdb_movies(self, sort_by: Optional[str] = "popularity.desc",
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import asyncio
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
from typing import Dict, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.chain import ChainBase
|
||||
@@ -18,7 +17,6 @@ from app.core.event import eventmanager, Event
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import NotExistMediaInfo
|
||||
@@ -59,7 +57,7 @@ class SearchChain(ChainBase):
|
||||
results = self.process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
self.save_cache(pickle.dumps(results), self.__result_temp_file)
|
||||
self.save_cache(results, self.__result_temp_file)
|
||||
return results
|
||||
|
||||
def search_by_title(self, title: str, page: Optional[int] = 0,
|
||||
@@ -85,36 +83,20 @@ class SearchChain(ChainBase):
|
||||
torrent_info=torrent) for torrent in torrents]
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
self.save_cache(pickle.dumps(contexts), self.__result_temp_file)
|
||||
self.save_cache(contexts, self.__result_temp_file)
|
||||
return contexts
|
||||
|
||||
def last_search_results(self) -> List[Context]:
|
||||
def last_search_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
获取上次搜索结果
|
||||
"""
|
||||
# 读取本地文件缓存
|
||||
content = self.load_cache(self.__result_temp_file)
|
||||
if not content:
|
||||
return []
|
||||
try:
|
||||
return pickle.loads(content)
|
||||
except Exception as e:
|
||||
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
|
||||
return []
|
||||
return self.load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_search_results(self) -> List[Context]:
|
||||
async def async_last_search_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次搜索结果
|
||||
"""
|
||||
# 读取本地文件缓存
|
||||
content = await self.async_load_cache(self.__result_temp_file)
|
||||
if not content:
|
||||
return []
|
||||
try:
|
||||
return pickle.loads(content)
|
||||
except Exception as e:
|
||||
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
|
||||
return []
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
@@ -143,7 +125,7 @@ class SearchChain(ChainBase):
|
||||
results = await self.async_process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
await self.async_save_cache(pickle.dumps(results), self.__result_temp_file)
|
||||
await self.async_save_cache(results, self.__result_temp_file)
|
||||
return results
|
||||
|
||||
async def async_search_by_title(self, title: str, page: Optional[int] = 0,
|
||||
@@ -169,7 +151,7 @@ class SearchChain(ChainBase):
|
||||
torrent_info=torrent) for torrent in torrents]
|
||||
# 保存到本地文件
|
||||
if cache_local:
|
||||
await self.async_save_cache(pickle.dumps(contexts), self.__result_temp_file)
|
||||
await self.async_save_cache(contexts, self.__result_temp_file)
|
||||
return contexts
|
||||
|
||||
@staticmethod
|
||||
@@ -233,12 +215,11 @@ class SearchChain(ChainBase):
|
||||
return []
|
||||
|
||||
# 开始新进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.Search)
|
||||
progress = ProgressHelper(ProgressKey.Search)
|
||||
progress.start()
|
||||
|
||||
# 开始过滤
|
||||
progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...',
|
||||
key=ProgressKey.Search)
|
||||
progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...')
|
||||
# 匹配订阅附加参数
|
||||
if filter_params:
|
||||
logger.info(f'开始附加参数过滤,附加参数:{filter_params} ...')
|
||||
@@ -256,7 +237,7 @@ class SearchChain(ChainBase):
|
||||
logger.info(f"过滤规则/剧集过滤完成,剩余 {len(torrents)} 个资源")
|
||||
|
||||
# 过滤完成
|
||||
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search)
|
||||
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源')
|
||||
|
||||
# 总数
|
||||
_total = len(torrents)
|
||||
@@ -269,14 +250,13 @@ class SearchChain(ChainBase):
|
||||
try:
|
||||
# 英文标题应该在别名/原标题中,不需要再匹配
|
||||
logger.info(f"开始匹配结果 标题:{mediainfo.title},原标题:{mediainfo.original_title},别名:{mediainfo.names}")
|
||||
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
|
||||
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...')
|
||||
for torrent in torrents:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
_count += 1
|
||||
progress.update(value=(_count / _total) * 96,
|
||||
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...',
|
||||
key=ProgressKey.Search)
|
||||
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...')
|
||||
if not torrent.title:
|
||||
continue
|
||||
|
||||
@@ -309,8 +289,7 @@ class SearchChain(ChainBase):
|
||||
# 匹配完成
|
||||
logger.info(f"匹配完成,共匹配到 {len(_match_torrents)} 个资源")
|
||||
progress.update(value=97,
|
||||
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源',
|
||||
key=ProgressKey.Search)
|
||||
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源')
|
||||
|
||||
# 去掉mediainfo中多余的数据
|
||||
mediainfo.clear()
|
||||
@@ -326,16 +305,14 @@ class SearchChain(ChainBase):
|
||||
|
||||
# 排序
|
||||
progress.update(value=99,
|
||||
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...',
|
||||
key=ProgressKey.Search)
|
||||
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...')
|
||||
contexts = torrenthelper.sort_torrents(contexts)
|
||||
|
||||
# 结束进度
|
||||
logger.info(f'搜索完成,共 {len(contexts)} 个资源')
|
||||
progress.update(value=100,
|
||||
text=f'搜索完成,共 {len(contexts)} 个资源',
|
||||
key=ProgressKey.Search)
|
||||
progress.end(ProgressKey.Search)
|
||||
text=f'搜索完成,共 {len(contexts)} 个资源')
|
||||
progress.end()
|
||||
|
||||
# 去重后返回
|
||||
return self.__remove_duplicate(contexts)
|
||||
@@ -347,9 +324,6 @@ class SearchChain(ChainBase):
|
||||
:param _torrents: 种子列表
|
||||
:return: 去重后的种子列表
|
||||
"""
|
||||
if not settings.SEARCH_MULTIPLE_NAME:
|
||||
return _torrents
|
||||
# 通过encosure去重
|
||||
return list({f"{t.torrent_info.site_name}_{t.torrent_info.title}_{t.torrent_info.description}": t
|
||||
for t in _torrents}.values())
|
||||
|
||||
@@ -407,16 +381,23 @@ class SearchChain(ChainBase):
|
||||
if search_count > 0:
|
||||
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
|
||||
time.sleep(random.randint(1, 10))
|
||||
|
||||
# 搜索站点
|
||||
torrents.extend(
|
||||
self.__search_all_sites(
|
||||
mediainfo=mediainfo,
|
||||
keyword=search_word,
|
||||
sites=sites,
|
||||
area=area
|
||||
) or []
|
||||
)
|
||||
results = self.__search_all_sites(
|
||||
mediainfo=mediainfo,
|
||||
keyword=search_word,
|
||||
sites=sites,
|
||||
area=area
|
||||
) or []
|
||||
# 合并结果
|
||||
|
||||
search_count += 1
|
||||
torrents.extend(results)
|
||||
|
||||
# 有结果则停止
|
||||
if not settings.SEARCH_MULTIPLE_NAME and torrents:
|
||||
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
|
||||
break
|
||||
|
||||
# 处理结果
|
||||
return self.__parse_result(
|
||||
@@ -539,8 +520,8 @@ class SearchChain(ChainBase):
|
||||
return []
|
||||
|
||||
# 开始进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.Search)
|
||||
progress = ProgressHelper(ProgressKey.Search)
|
||||
progress.start()
|
||||
# 开始计时
|
||||
start_time = datetime.now()
|
||||
# 总数
|
||||
@@ -549,8 +530,7 @@ class SearchChain(ChainBase):
|
||||
finish_count = 0
|
||||
# 更新进度
|
||||
progress.update(value=0,
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...",
|
||||
key=ProgressKey.Search)
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...")
|
||||
# 结果集
|
||||
results = []
|
||||
# 多线程
|
||||
@@ -579,17 +559,15 @@ class SearchChain(ChainBase):
|
||||
results.extend(result)
|
||||
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
|
||||
progress.update(value=finish_count / total_num * 100,
|
||||
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...",
|
||||
key=ProgressKey.Search)
|
||||
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...")
|
||||
# 计算耗时
|
||||
end_time = datetime.now()
|
||||
# 更新进度
|
||||
progress.update(value=100,
|
||||
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒",
|
||||
key=ProgressKey.Search)
|
||||
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
# 结束进度
|
||||
progress.end(ProgressKey.Search)
|
||||
progress.end()
|
||||
|
||||
# 返回
|
||||
return results
|
||||
@@ -624,8 +602,8 @@ class SearchChain(ChainBase):
|
||||
return []
|
||||
|
||||
# 开始进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.Search)
|
||||
progress = ProgressHelper(ProgressKey.Search)
|
||||
progress.start()
|
||||
# 开始计时
|
||||
start_time = datetime.now()
|
||||
# 总数
|
||||
@@ -634,8 +612,7 @@ class SearchChain(ChainBase):
|
||||
finish_count = 0
|
||||
# 更新进度
|
||||
progress.update(value=0,
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...",
|
||||
key=ProgressKey.Search)
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...")
|
||||
# 结果集
|
||||
results = []
|
||||
|
||||
@@ -666,18 +643,16 @@ class SearchChain(ChainBase):
|
||||
results.extend(result)
|
||||
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
|
||||
progress.update(value=finish_count / total_num * 100,
|
||||
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...",
|
||||
key=ProgressKey.Search)
|
||||
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...")
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.now()
|
||||
# 更新进度
|
||||
progress.update(value=100,
|
||||
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒",
|
||||
key=ProgressKey.Search)
|
||||
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
# 结束进度
|
||||
progress.end(ProgressKey.Search)
|
||||
progress.end()
|
||||
|
||||
# 返回
|
||||
return results
|
||||
|
||||
@@ -44,6 +44,7 @@ class SiteChain(ChainBase):
|
||||
"star-space.net": self.__indexphp_test,
|
||||
"yemapt.org": self.__yema_test,
|
||||
"hddolby.com": self.__hddolby_test,
|
||||
"rousi.pro": self.__rousi_test,
|
||||
}
|
||||
|
||||
def refresh_userdata(self, site: dict = None) -> Optional[SiteUserData]:
|
||||
@@ -56,7 +57,7 @@ class SiteChain(ChainBase):
|
||||
if userdata:
|
||||
SiteOper().update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
|
||||
name=site.get("name"),
|
||||
payload=userdata.dict())
|
||||
payload=userdata.model_dump())
|
||||
# 发送事件
|
||||
eventmanager.send_event(EventType.SiteRefreshed, {
|
||||
"site_id": site.get("id")
|
||||
@@ -249,6 +250,32 @@ class SiteChain(ChainBase):
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __rousi_test(site: Site) -> Tuple[bool, str]:
|
||||
"""
|
||||
判断站点是否已经登陆:rousi
|
||||
"""
|
||||
url = f"https://{StringUtils.get_url_domain(site.url)}/api/v1/profile"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {site.apikey}",
|
||||
}
|
||||
res = RequestUtils(
|
||||
headers=headers,
|
||||
proxies=settings.PROXY if site.proxy else None,
|
||||
timeout=site.timeout or 15
|
||||
).get_res(url=url)
|
||||
if res is None:
|
||||
return False, "无法打开网站!"
|
||||
if res.status_code == 200:
|
||||
user_info = res.json()
|
||||
if user_info and user_info.get("code") == 0:
|
||||
return True, "连接成功"
|
||||
return False, "APIKEY已过期"
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __parse_favicon(url: str, cookie: str, ua: str) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
@@ -313,6 +340,11 @@ class SiteChain(ChainBase):
|
||||
siteoper = SiteOper()
|
||||
rsshelper = RssHelper()
|
||||
for domain, cookie in cookies.items():
|
||||
# 检查系统是否停止
|
||||
if global_vars.is_system_stopped:
|
||||
logger.info("系统正在停止,中断CookieCloud同步")
|
||||
return False, "系统正在停止,同步被中断"
|
||||
|
||||
# 索引器信息
|
||||
indexer = siteshelper.get_indexer(domain)
|
||||
# 数据库的站点信息
|
||||
@@ -331,7 +363,7 @@ class SiteChain(ChainBase):
|
||||
cookie=cookie,
|
||||
ua=site_info.ua or settings.USER_AGENT,
|
||||
proxy=True if site_info.proxy else False,
|
||||
timeout=site_info.timeout
|
||||
timeout=site_info.timeout or 15
|
||||
)
|
||||
if rss_url:
|
||||
logger.info(f"更新站点 {domain} RSS地址 ...")
|
||||
|
||||
@@ -6,7 +6,6 @@ from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class StorageChain(ChainBase):
|
||||
@@ -134,8 +133,7 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("support_transtype", storage=storage)
|
||||
|
||||
def delete_media_file(self, fileitem: schemas.FileItem,
|
||||
mtype: MediaType = None, delete_self: bool = True) -> bool:
|
||||
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
|
||||
"""
|
||||
删除媒体文件,以及不含媒体文件的目录
|
||||
"""
|
||||
@@ -152,7 +150,8 @@ class StorageChain(ChainBase):
|
||||
return False
|
||||
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
|
||||
if fileitem.path == "/" or len(Path(fileitem.path).parts) <= 2:
|
||||
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
|
||||
if len(fileitem_path.parts) <= 2:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 根目录或一级目录不允许删除")
|
||||
return False
|
||||
if fileitem.type == "dir":
|
||||
@@ -162,13 +161,7 @@ class StorageChain(ChainBase):
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
return False
|
||||
elif self.any_files(fileitem, extensions=media_exts) is False:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 不存在其它媒体文件,正在删除空目录")
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
return False
|
||||
# 不处理父目录
|
||||
return True
|
||||
|
||||
elif delete_self:
|
||||
# 本身是文件,需要删除文件
|
||||
logger.warn(f"正在删除文件【{fileitem.storage}】{fileitem.path}")
|
||||
@@ -176,35 +169,43 @@ class StorageChain(ChainBase):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
return False
|
||||
|
||||
if mtype:
|
||||
# 重命名格式
|
||||
rename_format = settings.RENAME_FORMAT(mtype)
|
||||
media_path = DirectoryHelper.get_media_root_path(
|
||||
rename_format, rename_path=Path(fileitem.path)
|
||||
)
|
||||
if not media_path:
|
||||
return True
|
||||
# 处理媒体文件根目录
|
||||
dir_item = self.get_file_item(storage=fileitem.storage, path=media_path)
|
||||
else:
|
||||
# 处理上级目录
|
||||
dir_item = self.get_parent_item(fileitem)
|
||||
# 检查和删除上级空目录
|
||||
dir_item = fileitem if fileitem.type == "dir" else self.get_parent_item(fileitem)
|
||||
if not dir_item:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 上级目录不存在")
|
||||
return True
|
||||
|
||||
# 检查和删除上级目录
|
||||
if dir_item and len(Path(dir_item.path).parts) > 2:
|
||||
# 如何目录是所有下载目录、媒体库目录的上级,则不处理
|
||||
for d in DirectoryHelper().get_dirs():
|
||||
if d.download_path and Path(d.download_path).is_relative_to(Path(dir_item.path)):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 是下载目录本级或上级目录,不删除")
|
||||
return True
|
||||
if d.library_path and Path(d.library_path).is_relative_to(Path(dir_item.path)):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 是媒体库目录本级或上级目录,不删除")
|
||||
return True
|
||||
# 不存在其他媒体文件,删除空目录
|
||||
if self.any_files(dir_item, extensions=media_exts) is False:
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 不存在其它媒体文件,正在删除空目录")
|
||||
if not self.delete_file(dir_item):
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 删除失败")
|
||||
return False
|
||||
# 查找操作文件项匹配的配置目录(资源目录、媒体库目录)
|
||||
associated_dir = max(
|
||||
(
|
||||
Path(p)
|
||||
for d in DirectoryHelper().get_dirs()
|
||||
for p in (d.download_path, d.library_path)
|
||||
if p and fileitem_path.is_relative_to(p)
|
||||
),
|
||||
key=lambda path: len(path.parts),
|
||||
default=None,
|
||||
)
|
||||
|
||||
while dir_item and len(Path(dir_item.path).parts) > 2:
|
||||
# 目录是资源目录、媒体库目录的上级,则不处理
|
||||
if associated_dir and associated_dir.is_relative_to(Path(dir_item.path)):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 位于资源或媒体库目录结构中,不删除")
|
||||
break
|
||||
|
||||
elif not associated_dir and self.list_files(dir_item, recursion=False):
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 不是空目录,不删除")
|
||||
break
|
||||
|
||||
if self.any_files(dir_item, extensions=media_exts) is not False:
|
||||
logger.debug(f"【{dir_item.storage}】{dir_item.path} 存在媒体文件,不删除")
|
||||
break
|
||||
|
||||
# 删除空目录并继续处理父目录
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 不存在其它媒体文件,正在删除空目录")
|
||||
if not self.delete_file(dir_item):
|
||||
logger.warn(f"【{dir_item.storage}】{dir_item.path} 删除失败")
|
||||
return False
|
||||
dir_item = self.get_parent_item(dir_item)
|
||||
|
||||
return True
|
||||
|
||||
@@ -42,7 +42,7 @@ class SubscribeChain(ChainBase):
|
||||
_LOCK_TIMOUT = 3600 * 2
|
||||
|
||||
@staticmethod
|
||||
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
@@ -158,7 +158,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
else:
|
||||
# 使用TMDBID识别
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
|
||||
@@ -169,7 +169,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
if mediainfo:
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
@@ -949,7 +949,7 @@ class SubscribeChain(ChainBase):
|
||||
and torrent_mediainfo.douban_id != mediainfo.douban_id:
|
||||
continue
|
||||
logger.info(
|
||||
f'{mediainfo.title_year} 通过媒体信ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
f'{mediainfo.title_year} 通过媒体ID匹配到可选资源:{torrent_info.site_name} - {torrent_info.title}')
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -1184,6 +1184,42 @@ class SubscribeChain(ChainBase):
|
||||
logger.error(f'follow用户分享订阅 {title} 添加失败:{message}')
|
||||
logger.info(f'follow用户分享订阅刷新完成,共添加 {success_count} 个订阅')
|
||||
|
||||
async def cache_calendar(self):
|
||||
"""
|
||||
预缓存订阅日历,实际上就是查询一遍所有订阅的媒体信息
|
||||
前端请示是异常的,所以需要使用异步缓存方法
|
||||
"""
|
||||
logger.info(f'开始预缓存订阅日历 ...')
|
||||
for subscribe in await SubscribeOper().async_list():
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
try:
|
||||
mtype = MediaType(subscribe.type)
|
||||
except ValueError:
|
||||
logger.error(f'订阅 {subscribe.name} 类型错误:{subscribe.type}')
|
||||
continue
|
||||
# 识别媒体信息
|
||||
if mtype == MediaType.MOVIE:
|
||||
mediainfo: MediaInfo = await self.async_recognize_media(mtype=mtype,
|
||||
tmdbid=subscribe.tmdbid,
|
||||
doubanid=subscribe.doubanid,
|
||||
bangumiid=subscribe.bangumiid,
|
||||
episode_group=subscribe.episode_group,
|
||||
cache=False)
|
||||
if not mediainfo:
|
||||
logger.warn(
|
||||
f'未识别到媒体信息,标题:{subscribe.name},tmdbid:{subscribe.tmdbid},doubanid:{subscribe.doubanid}')
|
||||
continue
|
||||
else:
|
||||
episodes = await TmdbChain().async_tmdb_episodes(tmdbid=subscribe.tmdbid,
|
||||
season=subscribe.season,
|
||||
episode_group=subscribe.episode_group)
|
||||
if not episodes:
|
||||
logger.warn(
|
||||
f'未识别到季集信息,标题:{subscribe.name},tmdbid:{subscribe.tmdbid},豆瓣ID:{subscribe.doubanid},季:{subscribe.season}')
|
||||
continue
|
||||
logger.info(f'订阅日历预缓存完成')
|
||||
|
||||
@staticmethod
|
||||
def __update_subscribe_note(subscribe: Subscribe, downloads: Optional[List[Context]]):
|
||||
"""
|
||||
|
||||
@@ -150,7 +150,7 @@ class TorrentsChain(ChainBase):
|
||||
return []
|
||||
# 解析RSS
|
||||
rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False,
|
||||
timeout=int(site.get("timeout") or 30))
|
||||
timeout=int(site.get("timeout") or 30), ua=site.get("ua") if site.get("ua") else None)
|
||||
if rss_items is None:
|
||||
# rss过期,尝试保留原配置生成新的rss
|
||||
self.__renew_rss_url(domain=domain, site=site)
|
||||
|
||||
@@ -33,6 +33,7 @@ from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey,
|
||||
SystemConfigKey, ChainEventType, ContentType
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
downloader_lock = threading.Lock()
|
||||
job_lock = threading.Lock()
|
||||
@@ -329,8 +330,12 @@ class JobManager:
|
||||
# 计算状态为完成的任务数
|
||||
if __mediaid__ not in self._job_view:
|
||||
return 0
|
||||
return sum([task.fileitem.size for task in self._job_view[__mediaid__].tasks if
|
||||
task.state == "completed" and task.fileitem.size is not None])
|
||||
return sum([
|
||||
task.fileitem.size if task.fileitem.size is not None
|
||||
else (SystemUtils.get_directory_size(Path(task.fileitem.path)) if task.fileitem.storage == "local" else 0)
|
||||
for task in self._job_view[__mediaid__].tasks
|
||||
if task.state == "completed"
|
||||
])
|
||||
|
||||
def total(self) -> int:
|
||||
"""
|
||||
@@ -501,7 +506,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
# 获取整理屏蔽词
|
||||
transfer_exclude_words = SystemConfigOper().get(SystemConfigKey.TransferExcludeWords)
|
||||
for t in tasks:
|
||||
if t.download_hash and self._can_delete_torrent(t.download_hash, t.downloader, transfer_exclude_words):
|
||||
if t.download_hash and self._can_delete_torrent(t.download_hash, t.downloader,
|
||||
transfer_exclude_words):
|
||||
if self.remove_torrents(t.download_hash, downloader=t.downloader):
|
||||
logger.info(f"移动模式删除种子成功:{t.download_hash}")
|
||||
if t.fileitem:
|
||||
@@ -554,8 +560,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
processed_num = 0
|
||||
# 失败数量
|
||||
fail_num = 0
|
||||
# 已完成文件
|
||||
finished_files = []
|
||||
|
||||
progress = ProgressHelper()
|
||||
progress = ProgressHelper(ProgressKey.FileTransfer)
|
||||
|
||||
while not global_vars.is_system_stopped:
|
||||
try:
|
||||
@@ -570,7 +578,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
if __queue_start:
|
||||
logger.info("开始整理队列处理...")
|
||||
# 启动进度
|
||||
progress.start(ProgressKey.FileTransfer)
|
||||
progress.start()
|
||||
# 重置计数
|
||||
processed_num = 0
|
||||
fail_num = 0
|
||||
@@ -578,8 +586,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
__process_msg = f"开始整理队列处理,当前共 {total_num} 个文件 ..."
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=0,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
text=__process_msg)
|
||||
# 队列已开始
|
||||
__queue_start = False
|
||||
# 更新进度
|
||||
@@ -587,7 +594,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=processed_num / total_num * 100,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
data={
|
||||
"current": Path(fileitem.path).as_posix(),
|
||||
"finished": finished_files
|
||||
})
|
||||
# 整理
|
||||
state, err_msg = self.__handle_transfer(task=task, callback=item.callback)
|
||||
if not state:
|
||||
@@ -595,20 +605,20 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
fail_num += 1
|
||||
# 更新进度
|
||||
processed_num += 1
|
||||
finished_files.append(Path(fileitem.path).as_posix())
|
||||
__process_msg = f"{fileitem.name} 整理完成"
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=processed_num / total_num * 100,
|
||||
progress.update(value=(processed_num / total_num) * 100,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
data={})
|
||||
except queue.Empty:
|
||||
if not __queue_start:
|
||||
# 结束进度
|
||||
__end_msg = f"整理队列处理完成,共整理 {processed_num} 个文件,失败 {fail_num} 个"
|
||||
logger.info(__end_msg)
|
||||
progress.update(value=100,
|
||||
text=__end_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
progress.end(ProgressKey.FileTransfer)
|
||||
text=__end_msg)
|
||||
progress.end()
|
||||
# 重置计数
|
||||
processed_num = 0
|
||||
fail_num = 0
|
||||
@@ -1106,6 +1116,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
file_meta=file_meta)
|
||||
if begin_ep is not None:
|
||||
file_meta.begin_episode = begin_ep
|
||||
if part is not None:
|
||||
file_meta.part = part
|
||||
if end_ep is not None:
|
||||
file_meta.end_episode = end_ep
|
||||
@@ -1115,10 +1126,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
downloadhis = DownloadHistoryOper()
|
||||
if bluray_dir:
|
||||
# 蓝光原盘,按目录名查询
|
||||
download_history = downloadhis.get_by_path(str(file_path))
|
||||
download_history = downloadhis.get_by_path(file_path.as_posix())
|
||||
else:
|
||||
# 按文件全路径查询
|
||||
download_file = downloadhis.get_file_by_fullpath(str(file_path))
|
||||
download_file = downloadhis.get_file_by_fullpath(file_path.as_posix())
|
||||
if download_file:
|
||||
download_history = downloadhis.get_by_hash(download_file.download_hash)
|
||||
|
||||
@@ -1164,15 +1175,16 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
processed_num = 0
|
||||
# 失败数量
|
||||
fail_num = 0
|
||||
# 已完成文件
|
||||
finished_files = []
|
||||
|
||||
# 启动进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.FileTransfer)
|
||||
progress = ProgressHelper(ProgressKey.FileTransfer)
|
||||
progress.start()
|
||||
__process_msg = f"开始整理,共 {total_num} 个文件 ..."
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=0,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
text=__process_msg)
|
||||
try:
|
||||
for transfer_task in transfer_tasks:
|
||||
if global_vars.is_system_stopped:
|
||||
@@ -1184,7 +1196,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
logger.info(__process_msg)
|
||||
progress.update(value=(processed_num + fail_num) / total_num * 100,
|
||||
text=__process_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
data={
|
||||
"current": Path(transfer_task.fileitem.path).as_posix(),
|
||||
"finished": finished_files,
|
||||
})
|
||||
state, err_msg = self.__handle_transfer(
|
||||
task=transfer_task,
|
||||
callback=self.__default_callback
|
||||
@@ -1196,6 +1211,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
fail_num += 1
|
||||
else:
|
||||
processed_num += 1
|
||||
# 记录已完成
|
||||
finished_files.append(Path(transfer_task.fileitem.path).as_posix())
|
||||
finally:
|
||||
transfer_tasks.clear()
|
||||
del transfer_tasks
|
||||
@@ -1205,8 +1222,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
logger.info(__end_msg)
|
||||
progress.update(value=100,
|
||||
text=__end_msg,
|
||||
key=ProgressKey.FileTransfer)
|
||||
progress.end(ProgressKey.FileTransfer)
|
||||
data={})
|
||||
progress.end()
|
||||
|
||||
error_msg = "、".join(err_msgs[:2]) + (f",等{len(err_msgs)}个文件错误!" if len(err_msgs) > 2 else "")
|
||||
return all_success, error_msg
|
||||
@@ -1351,12 +1368,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
else:
|
||||
# 更新媒体图片
|
||||
self.obtain_images(mediainfo=mediainfo)
|
||||
# 开始进度
|
||||
progress = ProgressHelper()
|
||||
progress.start(ProgressKey.FileTransfer)
|
||||
progress.update(value=0,
|
||||
text=f"开始整理 {fileitem.path} ...",
|
||||
key=ProgressKey.FileTransfer)
|
||||
|
||||
# 开始整理
|
||||
state, errmsg = self.do_transfer(
|
||||
fileitem=fileitem,
|
||||
@@ -1377,7 +1389,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
if not state:
|
||||
return False, errmsg
|
||||
|
||||
progress.end(ProgressKey.FileTransfer)
|
||||
logger.info(f"{fileitem.path} 整理完成")
|
||||
return True, ""
|
||||
else:
|
||||
@@ -1431,11 +1442,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
|
||||
for keyword in exclude_words:
|
||||
if keyword and re.search(r"%s" % keyword, file_path, re.IGNORECASE):
|
||||
logger.debug(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
logger.warn(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _can_delete_torrent(self, download_hash: str, downloader: str, transfer_exclude_words) -> bool:
|
||||
"""
|
||||
检查是否可以删除种子文件
|
||||
@@ -1467,13 +1477,9 @@ class TransferChain(ChainBase, metaclass=Singleton):
|
||||
for file in torrent_files:
|
||||
file_path = save_path / file.name
|
||||
# 如果存在未被屏蔽的媒体文件,则不删除种子
|
||||
if (
|
||||
file_path.suffix in self.all_exts
|
||||
and not self._is_blocked_by_exclude_words(
|
||||
str(file_path), transfer_exclude_words
|
||||
)
|
||||
and file_path.exists()
|
||||
):
|
||||
if (file_path.suffix in self.all_exts
|
||||
and not self._is_blocked_by_exclude_words(file_path.as_posix(), transfer_exclude_words)
|
||||
and file_path.exists()):
|
||||
return False
|
||||
|
||||
# 所有媒体文件都被屏蔽或不存在,可以删除种子
|
||||
|
||||
@@ -52,7 +52,10 @@ class UserChain(ChainBase):
|
||||
success, user_or_message = self.password_authenticate(credentials=credentials)
|
||||
if success:
|
||||
# 如果用户启用了二次验证码,则进一步验证
|
||||
if not self._verify_mfa(user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
logger.info(f"用户 {username} 通过密码认证成功")
|
||||
return True, user_or_message
|
||||
@@ -63,7 +66,10 @@ class UserChain(ChainBase):
|
||||
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
|
||||
if aux_success:
|
||||
# 辅助认证成功后再验证二次验证码
|
||||
if not self._verify_mfa(aux_user_or_message, credentials.mfa_code):
|
||||
mfa_result = self._verify_mfa(aux_user_or_message, credentials.mfa_code)
|
||||
if mfa_result == "MFA_REQUIRED":
|
||||
return False, "MFA_REQUIRED"
|
||||
elif not mfa_result:
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
return True, aux_user_or_message
|
||||
else:
|
||||
@@ -159,22 +165,46 @@ class UserChain(ChainBase):
|
||||
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
|
||||
|
||||
@staticmethod
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool:
|
||||
def _verify_mfa(user: User, mfa_code: Optional[str]) -> Union[bool, str]:
|
||||
"""
|
||||
验证 MFA(二次验证码)
|
||||
检查用户是否启用了 OTP 或 PassKey,如果启用了任何一种,都需要提供验证
|
||||
|
||||
:param user: 用户对象
|
||||
:param mfa_code: 二次验证码
|
||||
:return: 如果验证成功返回 True,否则返回 False
|
||||
:param mfa_code: 二次验证码(如果提供了则验证OTP)
|
||||
:return:
|
||||
- 如果验证成功返回 True
|
||||
- 如果需要MFA但未提供,返回 "MFA_REQUIRED"
|
||||
- 如果MFA验证失败,返回 False
|
||||
"""
|
||||
if not user.is_otp:
|
||||
# 检查用户是否有PassKey
|
||||
from app.db.models.passkey import PassKey
|
||||
has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id))
|
||||
|
||||
# 如果用户既没有启用OTP也没有PassKey,直接通过
|
||||
if not user.is_otp and not has_passkey:
|
||||
return True
|
||||
|
||||
# 如果用户启用了OTP或PassKey,但没有提供验证码,需要进行二次验证
|
||||
if not mfa_code:
|
||||
logger.info(f"用户 {user.name} 缺少 MFA 认证码")
|
||||
return False
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码")
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
# 如果提供了验证码,且用户启用了 OTP,则验证 OTP
|
||||
if user.is_otp:
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
# OTP 验证成功
|
||||
return True
|
||||
|
||||
# 用户未启用 OTP,此时提供的 mfa_code 无效;如果启用了 PassKey,则仍需通过 PassKey 验证
|
||||
if has_passkey:
|
||||
logger.info(
|
||||
f"用户 {user.name} 未启用 OTP,但已启用 PassKey,提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证"
|
||||
)
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
return True
|
||||
|
||||
def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool:
|
||||
|
||||
@@ -11,7 +11,7 @@ from pydantic.fields import Callable
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.core.workflow import WorkFlowManager
|
||||
from app.workflow import WorkFlowManager
|
||||
from app.db.models import Workflow
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
@@ -180,7 +180,7 @@ class WorkflowExecutor:
|
||||
"""
|
||||
合并上下文
|
||||
"""
|
||||
for key, value in context.dict().items():
|
||||
for key, value in context.model_dump().items():
|
||||
if not getattr(self.context, key, None):
|
||||
setattr(self.context, key, value)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Union, Dict, Optional
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.message import MessageChain
|
||||
from app.chain.site import SiteChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.chain.system import SystemChain
|
||||
@@ -140,6 +141,12 @@ class Command(metaclass=Singleton):
|
||||
"description": "当前版本",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
},
|
||||
"/clear_session": {
|
||||
"func": MessageChain().remote_clear_session,
|
||||
"description": "清除会话",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
}
|
||||
}
|
||||
# 插件命令集合
|
||||
@@ -215,7 +222,7 @@ class Command(metaclass=Singleton):
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
|
||||
|
||||
def __trigger_register_commands_event(self) -> (Optional[Event], dict):
|
||||
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
|
||||
"""
|
||||
触发事件,允许调整命令数据
|
||||
"""
|
||||
|
||||
1429
app/core/cache.py
1429
app/core/cache.py
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
@@ -6,12 +7,14 @@ import re
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
from asyncio import AbstractEventLoop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from pydantic import BaseModel, BaseSettings, validator, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from app.log import logger, log_settings, LogConfigModel
|
||||
from app.schemas import MediaType
|
||||
@@ -49,8 +52,7 @@ class ConfigModel(BaseModel):
|
||||
Pydantic 配置模型,描述所有配置项及其类型和默认值
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # 忽略未定义的配置项
|
||||
model_config = ConfigDict(extra="ignore") # 忽略未定义的配置项
|
||||
|
||||
# ==================== 基础应用配置 ====================
|
||||
# 项目名称
|
||||
@@ -75,6 +77,8 @@ class ConfigModel(BaseModel):
|
||||
DEBUG: bool = False
|
||||
# 是否开发模式
|
||||
DEV: bool = False
|
||||
# 高级设置模式
|
||||
ADVANCED_MODE: bool = True
|
||||
|
||||
# ==================== 安全认证配置 ====================
|
||||
# 密钥
|
||||
@@ -87,8 +91,10 @@ class ConfigModel(BaseModel):
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
|
||||
# RESOURCE_TOKEN过期时间
|
||||
RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 30
|
||||
# 超级管理员
|
||||
# 超级管理员初始用户名
|
||||
SUPERUSER: str = "admin"
|
||||
# 超级管理员初始密码
|
||||
SUPERUSER_PASSWORD: Optional[str] = None
|
||||
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
|
||||
AUXILIARY_AUTH_ENABLE: bool = False
|
||||
# API密钥,需要更换
|
||||
@@ -114,7 +120,7 @@ class ConfigModel(BaseModel):
|
||||
# 数据库连接池获取连接的超时时间(秒)
|
||||
DB_POOL_TIMEOUT: int = 30
|
||||
# SQLite 连接池大小
|
||||
DB_SQLITE_POOL_SIZE: int = 30
|
||||
DB_SQLITE_POOL_SIZE: int = 10
|
||||
# SQLite 连接池溢出数量
|
||||
DB_SQLITE_MAX_OVERFLOW: int = 50
|
||||
# PostgreSQL 主机地址
|
||||
@@ -128,7 +134,7 @@ class ConfigModel(BaseModel):
|
||||
# PostgreSQL 密码
|
||||
DB_POSTGRESQL_PASSWORD: str = "moviepilot"
|
||||
# PostgreSQL 连接池大小
|
||||
DB_POSTGRESQL_POOL_SIZE: int = 30
|
||||
DB_POSTGRESQL_POOL_SIZE: int = 10
|
||||
# PostgreSQL 连接池溢出数量
|
||||
DB_POSTGRESQL_MAX_OVERFLOW: int = 50
|
||||
|
||||
@@ -136,9 +142,17 @@ class ConfigModel(BaseModel):
|
||||
# 缓存类型,支持 cachetools 和 redis,默认使用 cachetools
|
||||
CACHE_BACKEND_TYPE: str = "cachetools"
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要
|
||||
CACHE_BACKEND_URL: Optional[str] = None
|
||||
CACHE_BACKEND_URL: Optional[str] = "redis://localhost:6379"
|
||||
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
|
||||
CACHE_REDIS_MAXMEMORY: Optional[str] = None
|
||||
# 全局图片缓存,将媒体图片缓存到本地
|
||||
GLOBAL_IMAGE_CACHE: bool = False
|
||||
# 全局图片缓存保留天数
|
||||
GLOBAL_IMAGE_CACHE_DAYS: int = 7
|
||||
# 临时文件保留天数
|
||||
TEMP_FILE_DAYS: int = 3
|
||||
# 元数据识别缓存过期时间(小时),0为自动
|
||||
META_CACHE_EXPIRE: int = 0
|
||||
|
||||
# ==================== 网络代理配置 ====================
|
||||
# 网络代理服务器地址
|
||||
@@ -159,23 +173,13 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 媒体元数据配置 ====================
|
||||
# 媒体搜索来源 themoviedb/douban/bangumi,多个用,分隔
|
||||
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
|
||||
SEARCH_SOURCE: str = "themoviedb"
|
||||
# 媒体识别来源 themoviedb/douban
|
||||
RECOGNIZE_SOURCE: str = "themoviedb"
|
||||
# 元数据识别缓存过期时间(小时)
|
||||
META_CACHE_EXPIRE: int = 0
|
||||
# 电视剧动漫的分类genre_ids
|
||||
ANIME_GENREIDS: List[int] = Field(default=[16])
|
||||
# 刮削来源 themoviedb/douban
|
||||
SCRAP_SOURCE: str = "themoviedb"
|
||||
# 新增已入库媒体是否跟随TMDB信息变化
|
||||
SCRAP_FOLLOW_TMDB: bool = True
|
||||
# 检查本地媒体库是否存在资源开关
|
||||
LOCAL_EXISTS_SEARCH: bool = False
|
||||
# 搜索多个名称
|
||||
SEARCH_MULTIPLE_NAME: bool = False
|
||||
# 最大搜索名称数量
|
||||
MAX_SEARCH_NAME_LIMIT: int = 2
|
||||
# 电视剧动漫的分类genre_ids
|
||||
ANIME_GENREIDS: List[int] = Field(default=[16])
|
||||
|
||||
# ==================== TMDB配置 ====================
|
||||
# TMDB图片地址
|
||||
@@ -237,8 +241,6 @@ class ConfigModel(BaseModel):
|
||||
'.aifc', '.aiff', '.alac', '.adif', '.adts',
|
||||
'.flac', '.midi', '.opus', '.sfalc']
|
||||
)
|
||||
# 下载器临时文件后缀
|
||||
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
|
||||
|
||||
# ==================== 媒体服务器配置 ====================
|
||||
# 媒体服务器同步间隔(小时)
|
||||
@@ -253,6 +255,10 @@ class ConfigModel(BaseModel):
|
||||
SUBSCRIBE_STATISTIC_SHARE: bool = True
|
||||
# 订阅搜索开关
|
||||
SUBSCRIBE_SEARCH: bool = False
|
||||
# 订阅搜索时间间隔(小时)
|
||||
SUBSCRIBE_SEARCH_INTERVAL: int = 24
|
||||
# 检查本地媒体库是否存在资源开关
|
||||
LOCAL_EXISTS_SEARCH: bool = True
|
||||
|
||||
# ==================== 站点配置 ====================
|
||||
# 站点数据刷新间隔(小时)
|
||||
@@ -261,6 +267,18 @@ class ConfigModel(BaseModel):
|
||||
SITE_MESSAGE: bool = True
|
||||
# 不能缓存站点资源的站点域名,多个使用,分隔
|
||||
NO_CACHE_SITE_KEY: str = "m-team"
|
||||
# OCR服务器地址,用于识别站点验证码
|
||||
OCR_HOST: str = "https://movie-pilot.org"
|
||||
# 仿真类型:playwright 或 flaresolverr
|
||||
BROWSER_EMULATION: str = "playwright"
|
||||
# FlareSolverr 服务地址,例如 http://127.0.0.1:8191
|
||||
FLARESOLVERR_URL: Optional[str] = None
|
||||
|
||||
# ==================== 搜索配置 ====================
|
||||
# 搜索多个名称
|
||||
SEARCH_MULTIPLE_NAME: bool = False
|
||||
# 最大搜索名称数量
|
||||
MAX_SEARCH_NAME_LIMIT: int = 3
|
||||
|
||||
# ==================== 下载配置 ====================
|
||||
# 种子标签
|
||||
@@ -269,6 +287,8 @@ class ConfigModel(BaseModel):
|
||||
DOWNLOAD_SUBTITLE: bool = True
|
||||
# 交互搜索自动下载用户ID,使用,分割
|
||||
AUTO_DOWNLOAD_USER: Optional[str] = None
|
||||
# 下载器临时文件后缀
|
||||
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
|
||||
|
||||
# ==================== CookieCloud配置 ====================
|
||||
# CookieCloud是否启动本地服务
|
||||
@@ -284,7 +304,7 @@ class ConfigModel(BaseModel):
|
||||
# CookieCloud同步黑名单,多个域名,分割
|
||||
COOKIECLOUD_BLACKLIST: Optional[str] = None
|
||||
|
||||
# ==================== 重命名配置 ====================
|
||||
# ==================== 整理配置 ====================
|
||||
# 电影重命名格式
|
||||
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
|
||||
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
|
||||
@@ -298,12 +318,14 @@ class ConfigModel(BaseModel):
|
||||
RENAME_FORMAT_S0_NAMES: list = Field(default=["Specials", "SPs"])
|
||||
# 为指定默认字幕添加.default后缀
|
||||
DEFAULT_SUB: Optional[str] = "zh-cn"
|
||||
# 新增已入库媒体是否跟随TMDB信息变化
|
||||
SCRAP_FOLLOW_TMDB: bool = True
|
||||
|
||||
# ==================== 服务地址配置 ====================
|
||||
# OCR服务器地址
|
||||
OCR_HOST: str = "https://movie-pilot.org"
|
||||
# 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目
|
||||
MP_SERVER_HOST: str = "https://movie-pilot.org"
|
||||
|
||||
# ==================== 个性化 ====================
|
||||
# 登录页面电影海报,tmdb/bing/mediaserver
|
||||
WALLPAPER: str = "tmdb"
|
||||
# 自定义壁纸api地址
|
||||
@@ -331,7 +353,7 @@ class ConfigModel(BaseModel):
|
||||
# 是否开启插件热加载
|
||||
PLUGIN_AUTO_RELOAD: bool = False
|
||||
|
||||
# ==================== GitHub配置 ====================
|
||||
# ==================== Github & PIP ====================
|
||||
# Github token,提高请求api限流阈值 ghp_****
|
||||
GITHUB_TOKEN: Optional[str] = None
|
||||
# Github代理服务器,格式:https://mirror.ghproxy.com/
|
||||
@@ -344,14 +366,12 @@ class ConfigModel(BaseModel):
|
||||
# ==================== 性能配置 ====================
|
||||
# 大内存模式
|
||||
BIG_MEMORY_MODE: bool = False
|
||||
# FastApi性能监控
|
||||
PERFORMANCE_MONITOR_ENABLE: bool = False
|
||||
# 全局图片缓存,将媒体图片缓存到本地
|
||||
GLOBAL_IMAGE_CACHE: bool = False
|
||||
# 是否启用编码探测的性能模式
|
||||
ENCODING_DETECTION_PERFORMANCE_MODE: bool = True
|
||||
# 编码探测的最低置信度阈值
|
||||
ENCODING_DETECTION_MIN_CONFIDENCE: float = 0.8
|
||||
# 主动内存回收时间间隔(分钟),0为不启用
|
||||
MEMORY_GC_INTERVAL: int = 30
|
||||
|
||||
# ==================== 安全配置 ====================
|
||||
# 允许的图片缓存域名
|
||||
@@ -373,6 +393,8 @@ class ConfigModel(BaseModel):
|
||||
])
|
||||
# 允许的图片文件后缀格式
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
|
||||
# PassKey 是否强制用户验证(生物识别等)
|
||||
PASSKEY_REQUIRE_UV: bool = True
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
# 工作流数据共享
|
||||
@@ -380,19 +402,43 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 存储配置 ====================
|
||||
# 对rclone进行快照对比时,是否检查文件夹的修改时间
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME = True
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
# 对OpenList进行快照对比时,是否检查文件夹的修改时间
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME = True
|
||||
|
||||
# ==================== 浏览器仿真配置 ====================
|
||||
# 仿真类型:playwright 或 flaresolverr
|
||||
BROWSER_EMULATION: str = "playwright"
|
||||
# FlareSolverr 服务地址,例如 http://127.0.0.1:8191
|
||||
FLARESOLVERR_URL: Optional[str] = None
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
# Playwright浏览器类型,chromium/firefox
|
||||
PLAYWRIGHT_BROWSER_TYPE: str = "chromium"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# 合局AI智能体
|
||||
AI_AGENT_GLOBAL: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.1
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 15
|
||||
# LLM工具调用超时时间(秒)
|
||||
LLM_TOOL_TIMEOUT: int = 300
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 30
|
||||
# 内存记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 1
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@@ -400,10 +446,11 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
系统配置类
|
||||
"""
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = SystemUtils.get_env_path()
|
||||
env_file_encoding = "utf-8"
|
||||
model_config = SettingsConfigDict(
|
||||
case_sensitive=True,
|
||||
env_file=SystemUtils.get_env_path(),
|
||||
env_file_encoding="utf-8",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -500,33 +547,54 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}")
|
||||
return default, True
|
||||
|
||||
@validator('*', pre=True, always=True)
|
||||
def generic_type_validator(cls, value: Any, field): # noqa
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def generic_type_validator(cls, data: Any): # noqa
|
||||
"""
|
||||
通用校验器,尝试将配置值转换为期望的类型
|
||||
"""
|
||||
if field.name == "API_TOKEN":
|
||||
converted_value, needs_update = cls.validate_api_token(value, value)
|
||||
else:
|
||||
converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default,
|
||||
field.name)
|
||||
if needs_update:
|
||||
cls.update_env_config(field, value, converted_value)
|
||||
return converted_value
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
# 处理 API_TOKEN 特殊验证
|
||||
if 'API_TOKEN' in data:
|
||||
converted_value, needs_update = cls.validate_api_token(data['API_TOKEN'], data['API_TOKEN'])
|
||||
if needs_update:
|
||||
cls.update_env_config("API_TOKEN", data["API_TOKEN"], converted_value)
|
||||
data['API_TOKEN'] = converted_value
|
||||
|
||||
# 对其他字段进行类型转换
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
if field_name not in data:
|
||||
continue
|
||||
value = data[field_name]
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
field = cls.model_fields.get(field_name)
|
||||
if field:
|
||||
converted_value, needs_update = cls.generic_type_converter(
|
||||
value, value, field.annotation, field.default, field_name
|
||||
)
|
||||
if needs_update:
|
||||
cls.update_env_config(field_name, value, converted_value)
|
||||
data[field_name] = converted_value
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def update_env_config(field: Any, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
def update_env_config(field_name: str, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
"""
|
||||
更新 env 配置
|
||||
"""
|
||||
message = None
|
||||
is_converted = original_value is not None and str(original_value) != str(converted_value)
|
||||
if is_converted:
|
||||
message = f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
message = f"配置项 '{field_name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
logger.warning(message)
|
||||
|
||||
if field.name in os.environ:
|
||||
message = f"配置项 '{field.name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
if field_name in os.environ:
|
||||
message = f"配置项 '{field_name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
logger.warning(message)
|
||||
return False, message
|
||||
else:
|
||||
@@ -536,10 +604,10 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
else:
|
||||
value_to_write = str(converted_value) if converted_value is not None else ""
|
||||
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field.name, value_to_set=value_to_write,
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field_name, value_to_set=value_to_write,
|
||||
quote_mode="always")
|
||||
if is_converted:
|
||||
logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件")
|
||||
logger.info(f"配置项 '{field_name}' 已自动修正并写入到 'app.env' 文件")
|
||||
return True, message
|
||||
|
||||
def update_setting(self, key: str, value: Any) -> Tuple[Optional[bool], str]:
|
||||
@@ -553,19 +621,17 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return False, f"配置项 '{key}' 不存在"
|
||||
|
||||
try:
|
||||
field = self.__fields__[key]
|
||||
field = Settings.model_fields[key]
|
||||
original_value = getattr(self, key)
|
||||
if field.name == "API_TOKEN":
|
||||
if key == "API_TOKEN":
|
||||
converted_value, needs_update = self.validate_api_token(value, original_value)
|
||||
else:
|
||||
converted_value, needs_update = self.generic_type_converter(value,
|
||||
original_value,
|
||||
field.type_,
|
||||
field.default,
|
||||
key)
|
||||
converted_value, needs_update = self.generic_type_converter(
|
||||
value, original_value, field.annotation, field.default, key
|
||||
)
|
||||
# 如果没有抛出异常,则统一使用 converted_value 进行更新
|
||||
if needs_update or str(value) != str(converted_value):
|
||||
success, message = self.update_env_config(field, value, converted_value)
|
||||
success, message = self.update_env_config(key, value, converted_value)
|
||||
# 仅成功更新配置时,才更新内存
|
||||
if success:
|
||||
setattr(self, key, converted_value)
|
||||
@@ -657,7 +723,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
douban=512,
|
||||
bangumi=512,
|
||||
fanart=512,
|
||||
meta=(self.META_CACHE_EXPIRE or 24) * 3600,
|
||||
meta=(self.META_CACHE_EXPIRE or 72) * 3600,
|
||||
scheduler=100,
|
||||
threadpool=100
|
||||
)
|
||||
@@ -668,7 +734,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
douban=256,
|
||||
bangumi=256,
|
||||
fanart=128,
|
||||
meta=(self.META_CACHE_EXPIRE or 2) * 3600,
|
||||
meta=(self.META_CACHE_EXPIRE or 24) * 3600,
|
||||
scheduler=50,
|
||||
threadpool=50
|
||||
)
|
||||
@@ -792,6 +858,10 @@ class GlobalVar(object):
|
||||
SUBSCRIPTIONS: List[dict] = []
|
||||
# 需应急停止的工作流
|
||||
EMERGENCY_STOP_WORKFLOWS: List[int] = []
|
||||
# 需应急停止文件整理
|
||||
EMERGENCY_STOP_TRANSFER: List[str] = []
|
||||
# 当前事件循环
|
||||
CURRENT_EVENT_LOOP: AbstractEventLoop = asyncio.get_event_loop()
|
||||
|
||||
def stop_system(self):
|
||||
"""
|
||||
@@ -832,12 +902,43 @@ class GlobalVar(object):
|
||||
if workflow_id in self.EMERGENCY_STOP_WORKFLOWS:
|
||||
self.EMERGENCY_STOP_WORKFLOWS.remove(workflow_id)
|
||||
|
||||
def is_workflow_stopped(self, workflow_id: int):
|
||||
def is_workflow_stopped(self, workflow_id: int) -> bool:
|
||||
"""
|
||||
是否停止工作流
|
||||
"""
|
||||
return self.is_system_stopped or workflow_id in self.EMERGENCY_STOP_WORKFLOWS
|
||||
|
||||
def stop_transfer(self, path: str):
|
||||
"""
|
||||
停止文件整理
|
||||
"""
|
||||
if path not in self.EMERGENCY_STOP_TRANSFER:
|
||||
self.EMERGENCY_STOP_TRANSFER.append(path)
|
||||
|
||||
def is_transfer_stopped(self, path: str) -> bool:
|
||||
"""
|
||||
是否停止文件整理
|
||||
"""
|
||||
if self.is_system_stopped:
|
||||
return True
|
||||
if path in self.EMERGENCY_STOP_TRANSFER:
|
||||
self.EMERGENCY_STOP_TRANSFER.remove(path)
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def loop(self) -> AbstractEventLoop:
|
||||
"""
|
||||
当前循环
|
||||
"""
|
||||
return self.CURRENT_EVENT_LOOP
|
||||
|
||||
def set_loop(self, loop: AbstractEventLoop):
|
||||
"""
|
||||
设置循环
|
||||
"""
|
||||
self.CURRENT_EVENT_LOOP = loop
|
||||
|
||||
|
||||
# 全局标识
|
||||
global_vars = GlobalVar()
|
||||
|
||||
@@ -95,18 +95,20 @@ class TorrentInfo:
|
||||
if upload_volume_factor is None or download_volume_factor is None:
|
||||
return "未知"
|
||||
free_strs = {
|
||||
"1.0 1.0": "普通",
|
||||
"1.0 0.0": "免费",
|
||||
"2.0 1.0": "2X",
|
||||
"4.0 1.0": "4X",
|
||||
"2.0 0.0": "2X免费",
|
||||
"4.0 0.0": "4X免费",
|
||||
"1.0 0.5": "50%",
|
||||
"2.0 0.5": "2X 50%",
|
||||
"1.0 0.7": "70%",
|
||||
"1.0 0.3": "30%"
|
||||
"1.00 1.00": "普通",
|
||||
"1.00 0.00": "免费",
|
||||
"2.00 1.00": "2X",
|
||||
"4.00 1.00": "4X",
|
||||
"2.00 0.00": "2X免费",
|
||||
"4.00 0.00": "4X免费",
|
||||
"1.00 0.50": "50%",
|
||||
"2.00 0.50": "2X 50%",
|
||||
"1.00 0.70": "70%",
|
||||
"1.00 0.30": "30%",
|
||||
"1.00 0.75": "75%",
|
||||
"1.00 0.25": "25%"
|
||||
}
|
||||
return free_strs.get('%.1f %.1f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
return free_strs.get('%.2f %.2f' % (upload_volume_factor, download_volume_factor), "未知")
|
||||
|
||||
@property
|
||||
def volume_factor(self):
|
||||
@@ -250,6 +252,8 @@ class MediaInfo:
|
||||
production_countries: list = field(default_factory=list)
|
||||
# 语种
|
||||
spoken_languages: list = field(default_factory=list)
|
||||
# 所有发行日期
|
||||
release_dates: list = field(default_factory=list)
|
||||
# 状态
|
||||
status: str = None
|
||||
# 标签
|
||||
@@ -257,7 +261,7 @@ class MediaInfo:
|
||||
# 评价数量
|
||||
vote_count: int = None
|
||||
# 流行度
|
||||
popularity: int = None
|
||||
popularity: float = None
|
||||
# 时长
|
||||
runtime: int = None
|
||||
# 下一集
|
||||
@@ -433,6 +437,18 @@ class MediaInfo:
|
||||
if self.release_date:
|
||||
# 年份
|
||||
self.year = self.release_date[:4]
|
||||
# 所有发行日期
|
||||
self.release_dates = [
|
||||
{
|
||||
"date": release_date.get("release_date"),
|
||||
"iso_code": result.get("iso_3166_1"),
|
||||
"note": release_date.get("note"),
|
||||
"type": release_date.get("type"),
|
||||
}
|
||||
for result in info.get("release_dates", {}).get("results", [])
|
||||
for release_date in result.get("release_dates", [])
|
||||
if release_date.get("release_date")
|
||||
]
|
||||
else:
|
||||
# 电视剧
|
||||
self.title = info.get('name')
|
||||
@@ -483,7 +499,7 @@ class MediaInfo:
|
||||
continue
|
||||
if current_value is None:
|
||||
setattr(self, key, value)
|
||||
elif type(current_value) == type(value):
|
||||
elif type(current_value) is type(value):
|
||||
setattr(self, key, value)
|
||||
|
||||
def set_douban_info(self, info: dict):
|
||||
@@ -624,7 +640,7 @@ class MediaInfo:
|
||||
continue
|
||||
if current_value is None:
|
||||
setattr(self, key, value)
|
||||
elif type(current_value) == type(value):
|
||||
elif type(current_value) is type(value):
|
||||
setattr(self, key, value)
|
||||
|
||||
def set_bangumi_info(self, info: dict):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import random
|
||||
@@ -10,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union, Any
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.core.config import global_vars
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ChainEventData
|
||||
@@ -71,15 +73,24 @@ class EventManager(metaclass=Singleton):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__executor = ThreadHelper() # 动态线程池,用于消费事件
|
||||
self.__consumer_threads = [] # 用于保存启动的事件消费者线程
|
||||
self.__event_queue = PriorityQueue() # 优先级队列
|
||||
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {} # 广播事件的订阅者
|
||||
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {} # 链式事件的订阅者
|
||||
self.__disabled_handlers = set() # 禁用的事件处理器集合
|
||||
self.__disabled_classes = set() # 禁用的事件处理器类集合
|
||||
self.__lock = threading.Lock() # 线程锁
|
||||
self.__event = threading.Event() # 退出事件
|
||||
# 动态线程池,用于消费事件
|
||||
self.__executor = ThreadHelper()
|
||||
# 用于保存启动的事件消费者线程
|
||||
self.__consumer_threads = []
|
||||
# 优先级队列
|
||||
self.__event_queue = PriorityQueue()
|
||||
# 广播事件的订阅者
|
||||
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {}
|
||||
# 链式事件的订阅者
|
||||
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {}
|
||||
# 禁用的事件处理器集合
|
||||
self.__disabled_handlers = set()
|
||||
# 禁用的事件处理器类集合
|
||||
self.__disabled_classes = set()
|
||||
# 线程锁
|
||||
self.__lock = threading.Lock()
|
||||
# 退出事件
|
||||
self.__event = threading.Event()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -438,7 +449,15 @@ class EventManager(metaclass=Singleton):
|
||||
isolated_event = Event(event_type=event.event_type,
|
||||
event_data=event_data_copy,
|
||||
priority=event.priority)
|
||||
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
# 对于异步函数,直接在事件循环中运行
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.__safe_invoke_handler_async(handler, isolated_event),
|
||||
global_vars.loop
|
||||
)
|
||||
else:
|
||||
# 对于同步函数,在线程池中运行
|
||||
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
|
||||
|
||||
def __safe_invoke_handler(self, handler: Callable, event: Event):
|
||||
"""
|
||||
@@ -450,10 +469,7 @@ class EventManager(metaclass=Singleton):
|
||||
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
|
||||
return
|
||||
|
||||
try:
|
||||
self.__invoke_handler_by_type_sync(handler, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event, handler, e)
|
||||
self.__invoke_handler_by_type_sync(handler, event)
|
||||
|
||||
async def __safe_invoke_handler_async(self, handler: Callable, event: Event):
|
||||
"""
|
||||
@@ -465,10 +481,7 @@ class EventManager(metaclass=Singleton):
|
||||
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.__invoke_handler_by_type_async(handler, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event, handler, e)
|
||||
await self.__invoke_handler_by_type_async(handler, event)
|
||||
|
||||
def __invoke_handler_by_type_sync(self, handler: Callable, event: Event):
|
||||
"""
|
||||
@@ -486,7 +499,17 @@ class EventManager(metaclass=Singleton):
|
||||
|
||||
if class_name in plugin_manager.get_plugin_ids():
|
||||
# 插件处理器
|
||||
plugin_manager.run_plugin_method(class_name, method_name, event)
|
||||
plugin = plugin_manager.running_plugins.get(class_name)
|
||||
if not plugin:
|
||||
return
|
||||
method = getattr(plugin, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=plugin.name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
elif class_name in module_manager.get_module_ids():
|
||||
# 模块处理器
|
||||
module = module_manager.get_running_module(class_name)
|
||||
@@ -495,16 +518,24 @@ class EventManager(metaclass=Singleton):
|
||||
method = getattr(module, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
method(event)
|
||||
try:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=module.get_name(),
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
else:
|
||||
# 全局处理器
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if not class_obj or not hasattr(class_obj, method_name):
|
||||
return
|
||||
method = getattr(class_obj, method_name)
|
||||
method = getattr(class_obj, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
method(event)
|
||||
try:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=class_name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
async def __invoke_handler_by_type_async(self, handler: Callable, event: Event):
|
||||
"""
|
||||
@@ -537,52 +568,63 @@ class EventManager(metaclass=Singleton):
|
||||
names = handler.__qualname__.split(".")
|
||||
return names[0], names[1]
|
||||
|
||||
@staticmethod
|
||||
async def __invoke_plugin_method_async(handler: Any, class_name: str, method_name: str, event: Event):
|
||||
async def __invoke_plugin_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用插件方法
|
||||
"""
|
||||
plugin = handler.running_plugins.get(class_name)
|
||||
if plugin and hasattr(plugin, method_name):
|
||||
method = getattr(plugin, method_name)
|
||||
if not plugin:
|
||||
return
|
||||
method = getattr(plugin, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
# 插件同步函数在异步环境中运行,避免阻塞
|
||||
await run_in_threadpool(method, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=plugin.name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
@staticmethod
|
||||
async def __invoke_module_method_async(handler: Any, class_name: str, method_name: str, event: Event):
|
||||
async def __invoke_module_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用模块方法
|
||||
"""
|
||||
module = handler.get_running_module(class_name)
|
||||
if not module:
|
||||
return
|
||||
|
||||
method = getattr(module, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
method(event)
|
||||
try:
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=module.get_name(),
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
async def __invoke_global_method_async(self, class_name: str, method_name: str, event: Event):
|
||||
"""
|
||||
异步调用全局对象方法
|
||||
"""
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if not class_obj or not hasattr(class_obj, method_name):
|
||||
if not class_obj:
|
||||
return
|
||||
|
||||
method = getattr(class_obj, method_name)
|
||||
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
method(event)
|
||||
method = getattr(class_obj, method_name, None)
|
||||
if not method:
|
||||
return
|
||||
try:
|
||||
if inspect.iscoroutinefunction(method):
|
||||
await method(event)
|
||||
else:
|
||||
method(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=class_name,
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
|
||||
@staticmethod
|
||||
def __get_class_instance(class_name: str):
|
||||
@@ -609,7 +651,11 @@ class EventManager(metaclass=Singleton):
|
||||
module_name = f"app.chain.{class_name[:-5].lower()}"
|
||||
module = importlib.import_module(module_name)
|
||||
elif class_name.endswith("Helper"):
|
||||
module_name = f"app.helper.{class_name[:-6].lower()}"
|
||||
# 特殊处理 Async 类
|
||||
if class_name.startswith("Async"):
|
||||
module_name = f"app.helper.{class_name[5:-6].lower()}"
|
||||
else:
|
||||
module_name = f"app.helper.{class_name[:-6].lower()}"
|
||||
module = importlib.import_module(module_name)
|
||||
else:
|
||||
module_name = f"app.{class_name.lower()}"
|
||||
@@ -649,18 +695,16 @@ class EventManager(metaclass=Singleton):
|
||||
"""
|
||||
logger.debug(f"{stage} - {event}")
|
||||
|
||||
def __handle_event_error(self, event: Event, handler: Callable, e: Exception):
|
||||
def __handle_event_error(self, event: Event, module_name: str,
|
||||
class_name: str, method_name: str, e: Exception):
|
||||
"""
|
||||
全局错误处理器,用于处理事件处理中的异常
|
||||
"""
|
||||
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
|
||||
|
||||
names = handler.__qualname__.split(".")
|
||||
class_name, method_name = names[0], names[1]
|
||||
logger.error(f"{module_name} 事件处理出错:{str(e)} - {traceback.format_exc()}")
|
||||
|
||||
# 发送系统错误通知
|
||||
from app.helper.message import MessageHelper
|
||||
MessageHelper().put(title=f"{event.event_type} 事件处理出错",
|
||||
MessageHelper().put(title=f"{module_name} 处理事件 {event.event_type} 时出错",
|
||||
message=f"{class_name}.{method_name}:{str(e)}",
|
||||
role="system")
|
||||
self.send_event(
|
||||
|
||||
@@ -94,7 +94,6 @@ class MetaVideo(MetaBase):
|
||||
title = re.sub(r'\d{4}[\s._-]\d{1,2}[\s._-]\d{1,2}', "", title)
|
||||
# 拆分tokens
|
||||
tokens = Tokens(title)
|
||||
self.tokens = tokens
|
||||
# 实例化StreamingPlatforms对象
|
||||
streaming_platforms = StreamingPlatforms()
|
||||
# 解析名称、年份、季、集、资源类型、分辨率等
|
||||
@@ -102,7 +101,7 @@ class MetaVideo(MetaBase):
|
||||
while token:
|
||||
self._index += 1 # 更新当前处理的token索引
|
||||
# Part
|
||||
self.__init_part(token)
|
||||
self.__init_part(token, tokens)
|
||||
# 标题
|
||||
if self._continue_flag:
|
||||
self.__init_name(token)
|
||||
@@ -123,7 +122,7 @@ class MetaVideo(MetaBase):
|
||||
self.__init_resource_type(token)
|
||||
# 流媒体平台
|
||||
if self._continue_flag:
|
||||
self.__init_web_source(token, streaming_platforms)
|
||||
self.__init_web_source(token, tokens, streaming_platforms)
|
||||
# 视频编码
|
||||
if self._continue_flag:
|
||||
self.__init_video_encode(token)
|
||||
@@ -311,7 +310,7 @@ class MetaVideo(MetaBase):
|
||||
self.en_name = token
|
||||
self._last_token_type = "enname"
|
||||
|
||||
def __init_part(self, token: str):
|
||||
def __init_part(self, token: str, tokens: Tokens):
|
||||
"""
|
||||
识别Part
|
||||
"""
|
||||
@@ -327,12 +326,12 @@ class MetaVideo(MetaBase):
|
||||
if re_res:
|
||||
if not self.part:
|
||||
self.part = re_res.group(1)
|
||||
nextv = self.tokens.cur()
|
||||
nextv = tokens.cur()
|
||||
if nextv \
|
||||
and ((nextv.isdigit() and (len(nextv) == 1 or len(nextv) == 2 and nextv.startswith('0')))
|
||||
or nextv.upper() in ['A', 'B', 'C', 'I', 'II', 'III']):
|
||||
self.part = "%s%s" % (self.part, nextv)
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
self._last_token_type = "part"
|
||||
self._continue_flag = False
|
||||
# self._stop_name_flag = False
|
||||
@@ -582,7 +581,7 @@ class MetaVideo(MetaBase):
|
||||
self._effect.append(effect)
|
||||
self._last_token = effect.upper()
|
||||
|
||||
def __init_web_source(self, token: str, streaming_platforms: StreamingPlatforms):
|
||||
def __init_web_source(self, token: str, tokens: Tokens, streaming_platforms: StreamingPlatforms):
|
||||
"""
|
||||
识别流媒体平台
|
||||
"""
|
||||
@@ -594,10 +593,10 @@ class MetaVideo(MetaBase):
|
||||
|
||||
prev_token = None
|
||||
prev_idx = self._index - 2
|
||||
if 0 <= prev_idx < len(self.tokens.tokens):
|
||||
prev_token = self.tokens.tokens[prev_idx]
|
||||
if 0 <= prev_idx < len(tokens.tokens):
|
||||
prev_token = tokens.tokens[prev_idx]
|
||||
|
||||
next_token = self.tokens.peek()
|
||||
next_token = tokens.peek()
|
||||
|
||||
if streaming_platforms.is_streaming_platform(token):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(token)
|
||||
@@ -616,7 +615,7 @@ class MetaVideo(MetaBase):
|
||||
platform_name = streaming_platforms.get_streaming_platform_name(combined_token)
|
||||
query_range = 2
|
||||
if is_next:
|
||||
self.tokens.get_next()
|
||||
tokens.get_next()
|
||||
break
|
||||
|
||||
if not platform_name:
|
||||
@@ -626,8 +625,8 @@ class MetaVideo(MetaBase):
|
||||
match_start_idx = self._index - query_range
|
||||
match_end_idx = self._index - 1
|
||||
start_index = max(0, match_start_idx - query_range)
|
||||
end_index = min(len(self.tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = self.tokens.tokens[start_index:end_index]
|
||||
end_index = min(len(tokens.tokens), match_end_idx + 1 + query_range)
|
||||
tokens_to_check = tokens.tokens[start_index:end_index]
|
||||
|
||||
if any(tok and tok.upper() in web_tokens for tok in tokens_to_check):
|
||||
self.web_source = platform_name
|
||||
|
||||
@@ -71,12 +71,14 @@ def MetaInfoPath(path: Path) -> MetaBase:
|
||||
file_meta = MetaInfo(title=path.name)
|
||||
# 上级目录元数据
|
||||
dir_meta = MetaInfo(title=path.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
if file_meta.type == MediaType.TV or dir_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
# 上上级目录元数据
|
||||
root_meta = MetaInfo(title=path.parent.parent.name)
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
if file_meta.type == MediaType.TV or root_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
return file_meta
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class ModuleManager(metaclass=Singleton):
|
||||
# 通过模板开关控制加载
|
||||
_module.init_module()
|
||||
self._running_modules[module_id] = _module
|
||||
logger.info(f"Moudle Loaded:{module_id}")
|
||||
logger.debug(f"Moudle Loaded:{module_id}")
|
||||
except Exception as err:
|
||||
logger.error(f"Load Moudle Error:{module_id},{str(err)} - {traceback.format_exc()}", exc_info=True)
|
||||
|
||||
@@ -61,7 +61,7 @@ class ModuleManager(metaclass=Singleton):
|
||||
if hasattr(module, "stop"):
|
||||
try:
|
||||
module.stop()
|
||||
logger.info(f"Moudle Stoped:{module_id}")
|
||||
logger.debug(f"Moudle Stoped:{module_id}")
|
||||
except Exception as err:
|
||||
logger.error(f"Stop Moudle Error:{module_id},{str(err)} - {traceback.format_exc()}", exc_info=True)
|
||||
logger.info("所有模块停止完成")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import ast
|
||||
import asyncio
|
||||
import concurrent
|
||||
import concurrent.futures
|
||||
@@ -5,6 +6,7 @@ import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
@@ -13,12 +15,12 @@ from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
from watchfiles import watch
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager, Event
|
||||
from app.core.event import eventmanager
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.plugin import PluginHelper
|
||||
@@ -26,68 +28,16 @@ from app.helper.sites import SitesHelper # noqa
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
from app.utils.crypto import RSAUtils
|
||||
from app.utils.limit import rate_limit_window
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PluginMonitorHandler(FileSystemEventHandler):
|
||||
|
||||
def on_modified(self, event):
|
||||
"""
|
||||
插件文件修改后重载
|
||||
"""
|
||||
if event.is_directory:
|
||||
return
|
||||
# 使用 pathlib 处理文件路径,跳过非 .py 文件以及 pycache 目录中的文件
|
||||
event_path = Path(event.src_path)
|
||||
if not event_path.name.endswith(".py") or "pycache" in event_path.parts:
|
||||
return
|
||||
|
||||
# 读取插件根目录下的__init__.py文件,读取class XXXX(_PluginBase)的类名
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if plugins_root not in event_path.parents:
|
||||
return
|
||||
# 获取插件目录路径,没有找到__init__.py时,说明不是有效包,跳过插件重载
|
||||
# 插件重载目前没有支持app/plugins/plugin/package/__init__.py的场景,这里也不做支持
|
||||
plugin_dir = event_path.parent
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
logger.debug(f"{plugin_dir} 下没有找到 __init__.py,跳过插件重载")
|
||||
return
|
||||
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
pid = None
|
||||
for line in lines:
|
||||
if line.startswith("class") and "(_PluginBase)" in line:
|
||||
pid = line.split("class ")[1].split("(_PluginBase)")[0].strip()
|
||||
if pid:
|
||||
self.__reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
@staticmethod
|
||||
@rate_limit_window(max_calls=1, window_seconds=2, source="PluginMonitor", enable_logging=False)
|
||||
def __reload_plugin(pid):
|
||||
"""
|
||||
重新加载插件
|
||||
"""
|
||||
try:
|
||||
logger.info(f"插件 {pid} 文件修改,重新加载...")
|
||||
PluginManager().reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件文件修改后重载出错:{str(e)}")
|
||||
|
||||
|
||||
class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
插件管理器
|
||||
"""
|
||||
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""插件管理器"""
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
|
||||
|
||||
def __init__(self):
|
||||
# 插件列表
|
||||
@@ -96,8 +46,10 @@ class PluginManager(metaclass=Singleton):
|
||||
self._running_plugins: dict = {}
|
||||
# 配置Key
|
||||
self._config_key: str = "plugin.%s"
|
||||
# 监听器
|
||||
self._observer: Observer = None
|
||||
# 监控线程
|
||||
self._monitor_thread: Optional[threading.Thread] = None
|
||||
# 监控停止事件
|
||||
self._stop_monitor_event = threading.Event()
|
||||
# 开发者模式监测插件修改
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
self.__start_monitor()
|
||||
@@ -264,7 +216,6 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(module_name)
|
||||
importlib.reload(module)
|
||||
|
||||
# 检查模块中的类
|
||||
for name, obj in module.__dict__.items():
|
||||
@@ -299,29 +250,20 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
return self._plugins
|
||||
|
||||
@eventmanager.register(EventType.ConfigChanged)
|
||||
def handle_config_changed(self, event: Event):
|
||||
"""
|
||||
处理配置变更事件
|
||||
:param event: 事件对象
|
||||
"""
|
||||
if not event:
|
||||
return
|
||||
event_data: schemas.ConfigChangeEventData = event.event_data
|
||||
if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']:
|
||||
return
|
||||
logger.info("配置变更,重新加载插件文件修改监测...")
|
||||
def on_config_changed(self):
|
||||
self.reload_monitor()
|
||||
|
||||
def get_reload_name(self) -> str:
|
||||
return "插件文件修改监测"
|
||||
|
||||
def reload_monitor(self):
|
||||
"""
|
||||
重新加载插件文件修改监测
|
||||
"""
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
if self._observer and self._observer.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
else:
|
||||
self.__start_monitor()
|
||||
# 先关闭已有监测,再重新启动
|
||||
self.stop_monitor()
|
||||
self.__start_monitor()
|
||||
else:
|
||||
self.stop_monitor()
|
||||
|
||||
@@ -329,25 +271,123 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
启用监测插件文件修改监测
|
||||
"""
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("插件文件修改监测已经在运行中...")
|
||||
return
|
||||
|
||||
logger.info("开始监测插件文件修改...")
|
||||
monitor_handler = PluginMonitorHandler()
|
||||
self._observer = Observer()
|
||||
self._observer.schedule(monitor_handler, str(settings.ROOT_PATH / "app" / "plugins"), recursive=True)
|
||||
self._observer.start()
|
||||
|
||||
# 在启动新线程之前,确保停止事件是清除状态
|
||||
self._stop_monitor_event.clear()
|
||||
|
||||
# 创建并启动监控线程
|
||||
self._monitor_thread = threading.Thread(
|
||||
target=self._run_file_watcher,
|
||||
daemon=True
|
||||
)
|
||||
self._monitor_thread.start()
|
||||
|
||||
def stop_monitor(self):
|
||||
"""
|
||||
停止监测插件文件修改监测
|
||||
"""
|
||||
# 停止监测
|
||||
if self._observer and self._observer.is_alive():
|
||||
if self._monitor_thread and self._monitor_thread.is_alive():
|
||||
logger.info("正在停止插件文件修改监测...")
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
self._stop_monitor_event.set()
|
||||
self._monitor_thread.join(timeout=5)
|
||||
if self._monitor_thread.is_alive():
|
||||
logger.warning("插件文件修改监测线程在5秒内未能正常停止。")
|
||||
self._monitor_thread = None
|
||||
logger.info("插件文件修改监测停止完成")
|
||||
else:
|
||||
logger.info("未启用插件文件修改监测,无需停止")
|
||||
|
||||
def _run_file_watcher(self):
|
||||
"""
|
||||
运行 watchfiles 监视器的主循环。
|
||||
"""
|
||||
# 监视插件目录
|
||||
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
|
||||
logger.info(">>> 监控线程已启动,准备进入watch循环...")
|
||||
# 使用 watchfiles 监视目录变化,并响应变化事件
|
||||
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
|
||||
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
yield_on_timeout=True):
|
||||
# 如果收到停止事件,退出循环
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
# 处理变化事件
|
||||
plugins_to_reload = set()
|
||||
for _change_type, path_str in changes:
|
||||
event_path = Path(path_str)
|
||||
|
||||
# 跳过非 .py 文件以及 pycache 目录中的文件
|
||||
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
|
||||
continue
|
||||
|
||||
# 解析插件ID
|
||||
pid = self._get_plugin_id_from_path(event_path)
|
||||
# 跳过无效插件文件
|
||||
if pid:
|
||||
# 收集需要重载的插件ID,自动去重,避免重复重载
|
||||
plugins_to_reload.add(pid)
|
||||
|
||||
# 触发重载
|
||||
if plugins_to_reload:
|
||||
logger.info(f"检测到插件文件变化,准备重载: {list(plugins_to_reload)}")
|
||||
for pid in plugins_to_reload:
|
||||
try:
|
||||
self.reload_plugin(pid)
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {pid} 热重载失败: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_id_from_path(event_path: Path) -> Optional[str]:
|
||||
"""
|
||||
根据文件路径解析出插件的ID。
|
||||
:param event_path: 被修改文件的 Path 对象。
|
||||
:return: 插件ID字符串,如果不是有效插件文件则返回 None。
|
||||
"""
|
||||
try:
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if not event_path.is_relative_to(plugins_root):
|
||||
return None
|
||||
|
||||
try:
|
||||
plugin_dir_name = event_path.relative_to(plugins_root).parts[0]
|
||||
plugin_dir = plugins_root / plugin_dir_name
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
return None
|
||||
|
||||
# 读取 __init__.py 文件,查找插件主类名
|
||||
with open(init_file, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
|
||||
# 遍历AST,查找继承自 _PluginBase 的类
|
||||
for node in ast.walk(tree):
|
||||
# 检查节点是否为类定义
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# 遍历该类的所有基类
|
||||
for base in node.bases:
|
||||
# 检查基类是否是我们寻找的 _PluginBase
|
||||
# ast.Name 用于处理简单的基类名
|
||||
if isinstance(base, ast.Name) and base.id == '_PluginBase':
|
||||
# 返回这个类的名字
|
||||
return node.name
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"从路径解析插件ID时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __stop_plugin(plugin: Any):
|
||||
"""
|
||||
@@ -410,6 +450,10 @@ class PluginManager(metaclass=Singleton):
|
||||
except KeyError:
|
||||
# 模块可能已经被删除
|
||||
pass
|
||||
|
||||
importlib.invalidate_caches()
|
||||
logger.debug("已清除查找器的缓存")
|
||||
|
||||
if plugin_id:
|
||||
if modules_to_remove:
|
||||
logger.info(f"插件 {plugin_id} 共清除 {len(modules_to_remove)} 个模块缓存:{modules_to_remove}")
|
||||
@@ -693,6 +737,36 @@ class PluginManager(metaclass=Singleton):
|
||||
logger.error(f"获取插件 {plugin_id} 动作出错:{str(e)}")
|
||||
return ret_actions
|
||||
|
||||
def get_plugin_agent_tools(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件智能体工具
|
||||
[{
|
||||
"plugin_id": "插件ID",
|
||||
"plugin_name": "插件名称",
|
||||
"tools": [ToolClass1, ToolClass2, ...]
|
||||
}]
|
||||
"""
|
||||
ret_tools = []
|
||||
# 创建字典快照避免并发修改
|
||||
running_plugins_snapshot = dict(self._running_plugins)
|
||||
for plugin_id, plugin in running_plugins_snapshot.items():
|
||||
if pid and pid != plugin_id:
|
||||
continue
|
||||
if hasattr(plugin, "get_agent_tools") and ObjectUtils.check_method(plugin.get_agent_tools):
|
||||
try:
|
||||
if not plugin.get_state():
|
||||
continue
|
||||
tools = plugin.get_agent_tools()
|
||||
if tools:
|
||||
ret_tools.append({
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tools": tools
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {plugin_id} 智能体工具出错:{str(e)}")
|
||||
return ret_tools
|
||||
|
||||
@staticmethod
|
||||
def get_plugin_remote_entry(plugin_id: str, dist_path: str) -> str:
|
||||
"""
|
||||
@@ -1024,7 +1098,8 @@ class PluginManager(metaclass=Singleton):
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version, force)
|
||||
with fresh(force):
|
||||
online_plugins = PluginHelper().get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
@@ -1165,6 +1240,7 @@ class PluginManager(metaclass=Singleton):
|
||||
async def async_get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
|
||||
"""
|
||||
异步获取所有在线插件信息
|
||||
:param force: 是否强制刷新(忽略缓存)
|
||||
"""
|
||||
if not settings.PLUGIN_MARKET:
|
||||
return []
|
||||
@@ -1230,7 +1306,8 @@ class PluginManager(metaclass=Singleton):
|
||||
# 已安装插件
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件
|
||||
online_plugins = await PluginHelper().async_get_plugins(market, package_version, force)
|
||||
async with async_fresh(force):
|
||||
online_plugins = await PluginHelper().async_get_plugins(market, package_version)
|
||||
if online_plugins is None:
|
||||
logger.warning(
|
||||
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
|
||||
|
||||
@@ -252,19 +252,19 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
"""
|
||||
使用 API Token 进行身份认证
|
||||
:param token: API Token,从 URL 查询参数中获取
|
||||
:param token: API Token,从 URL 查询参数中获取 token=xxx
|
||||
:return: 返回校验通过的 API Token
|
||||
"""
|
||||
return __verify_key(token, settings.API_TOKEN, "API_TOKEN")
|
||||
return __verify_key(token, settings.API_TOKEN, "token")
|
||||
|
||||
|
||||
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
|
||||
"""
|
||||
使用 API Key 进行身份认证
|
||||
:param apikey: API Key,从 URL 查询参数或请求头中获取
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx,或请求头中获取 X-API-KEY=xxx
|
||||
:return: 返回校验通过的 API Key
|
||||
"""
|
||||
return __verify_key(apikey, settings.API_TOKEN, "API_KEY")
|
||||
return __verify_key(apikey, settings.API_TOKEN, "apikey")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
|
||||
@@ -454,7 +454,6 @@ class Base:
|
||||
|
||||
@db_update
|
||||
def update(self, db: Session, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
@@ -462,7 +461,6 @@ class Base:
|
||||
|
||||
@async_db_update
|
||||
async def async_update(self, db: AsyncSession, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .downloadhistory import DownloadHistory, DownloadFiles
|
||||
from .mediaserver import MediaServerItem
|
||||
from .passkey import PassKey
|
||||
from .plugindata import PluginData
|
||||
from .site import Site
|
||||
from .siteicon import SiteIcon
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user