mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-08 21:02:44 +08:00
Compare commits
638 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75fca971d4 | ||
|
|
22f3244bf5 | ||
|
|
aafc4b3a39 | ||
|
|
18906e5ab2 | ||
|
|
9675d199f9 | ||
|
|
78e8faa203 | ||
|
|
d5ed9bc654 | ||
|
|
770065d9ed | ||
|
|
abc4154e2c | ||
|
|
fd6c9d5d34 | ||
|
|
dc428e7de0 | ||
|
|
0c51d79be7 | ||
|
|
1b489ba581 | ||
|
|
4d9f17b083 | ||
|
|
3c7cd2186f | ||
|
|
5acfd683b9 | ||
|
|
6b01901a4a | ||
|
|
1ca54afd6c | ||
|
|
9c75c2d22e | ||
|
|
79ec3ed2c3 | ||
|
|
7072d2cfe8 | ||
|
|
c0c08b0b84 | ||
|
|
01329195ee | ||
|
|
ad40b99313 | ||
|
|
1e338e48ab | ||
|
|
ac9c9598f4 | ||
|
|
02cb5dfc31 | ||
|
|
8109ffb445 | ||
|
|
0ecbcb89fa | ||
|
|
8f38c06424 | ||
|
|
902394f86e | ||
|
|
9fefd807f9 | ||
|
|
a8fb4a6d84 | ||
|
|
7806267e92 | ||
|
|
eb5e17a115 | ||
|
|
2ae98d628d | ||
|
|
8b9dc0e77f | ||
|
|
2f151cea64 | ||
|
|
b777e8cab1 | ||
|
|
663e37bd03 | ||
|
|
8960620883 | ||
|
|
5b892b3a63 | ||
|
|
974d5f2f49 | ||
|
|
f70881bb4f | ||
|
|
376c65335f | ||
|
|
d7a5c32b08 | ||
|
|
4cda182ccd | ||
|
|
60ac901c6c | ||
|
|
388afa8d3c | ||
|
|
ec0915e488 | ||
|
|
244112be5c | ||
|
|
1f526adbe7 | ||
|
|
c4cfd70f7c | ||
|
|
c9149d1761 | ||
|
|
c68450fc7f | ||
|
|
d9eb3295b0 | ||
|
|
5440dbae51 | ||
|
|
321bf94de8 | ||
|
|
84b938c0d2 | ||
|
|
fc47382938 | ||
|
|
2e034f7990 | ||
|
|
e61299f748 | ||
|
|
cbff2fed17 | ||
|
|
9c51f73a72 | ||
|
|
70109635c7 | ||
|
|
8999c3a855 | ||
|
|
7bd775130e | ||
|
|
4bba7dbe76 | ||
|
|
0cab21b83c | ||
|
|
ca9cbc1160 | ||
|
|
02439f55a9 | ||
|
|
2d358e376c | ||
|
|
b349aa2693 | ||
|
|
e3fee39043 | ||
|
|
a1a72df6c6 | ||
|
|
cdf40a7046 | ||
|
|
b9b19c9acc | ||
|
|
8c603baa43 | ||
|
|
a977948f2b | ||
|
|
f70eaf9363 | ||
|
|
bfea0174dd | ||
|
|
296d815e3e | ||
|
|
c3b7a50642 | ||
|
|
8e0a9f94f6 | ||
|
|
6806900436 | ||
|
|
a8ecdc8206 | ||
|
|
60e1e3c173 | ||
|
|
f859d99d91 | ||
|
|
31640b780c | ||
|
|
aaeb4d2634 | ||
|
|
75d4c0153c | ||
|
|
8d7ff2bd1d | ||
|
|
c3e96ae73f | ||
|
|
d8c86069f2 | ||
|
|
a25c709927 | ||
|
|
d7c62fb55a | ||
|
|
27cc559c86 | ||
|
|
e7d14691df | ||
|
|
20387a0085 | ||
|
|
740b0a1396 | ||
|
|
7d0c790185 | ||
|
|
a12147d0f5 | ||
|
|
213a298813 | ||
|
|
1acf78342c | ||
|
|
c85d3adb34 | ||
|
|
83bf59dd4d | ||
|
|
d5d6442e1d | ||
|
|
a1fa469026 | ||
|
|
4b4b808b76 | ||
|
|
a6f16dcf8f | ||
|
|
c822782910 | ||
|
|
e598d5edc4 | ||
|
|
d38b6dfc0a | ||
|
|
0a4091d93c | ||
|
|
0399ab73cf | ||
|
|
940cececf4 | ||
|
|
94c75eb1c7 | ||
|
|
de4dbf283b | ||
|
|
10807a6fb7 | ||
|
|
04b8475761 | ||
|
|
e6e50d7f0a | ||
|
|
94ed065344 | ||
|
|
d94b5962b4 | ||
|
|
dcca318733 | ||
|
|
4a789297fe | ||
|
|
1249929b6a | ||
|
|
864af45f85 | ||
|
|
bd68bcfd27 | ||
|
|
17373bc0fe | ||
|
|
4612d3cdde | ||
|
|
517300afe9 | ||
|
|
3c7fdfec3c | ||
|
|
cfc8d26558 | ||
|
|
1c16b8bfec | ||
|
|
aae50004b1 | ||
|
|
4fbd2a7612 | ||
|
|
cede1a1100 | ||
|
|
5d3511cbc2 | ||
|
|
a66e082a8c | ||
|
|
2406438d1b | ||
|
|
be42c78aca | ||
|
|
78b8b30351 | ||
|
|
80e35fa938 | ||
|
|
e82494c444 | ||
|
|
309b7b8a77 | ||
|
|
f2daa633b6 | ||
|
|
630d13ac52 | ||
|
|
40c79b249b | ||
|
|
6f4df912d8 | ||
|
|
5744228a9d | ||
|
|
8c46ece44a | ||
|
|
4cbf1a886e | ||
|
|
17519d5a96 | ||
|
|
faa046eea4 | ||
|
|
873e3832b6 | ||
|
|
d4a15d3b53 | ||
|
|
6ca6a94631 | ||
|
|
61fced0df3 | ||
|
|
b2f6ffddee | ||
|
|
c85805b15d | ||
|
|
a0838ed9cd | ||
|
|
63bbec5db4 | ||
|
|
4bc67dc816 | ||
|
|
9620a06552 | ||
|
|
9b00a5f3f1 | ||
|
|
faa77be843 | ||
|
|
28f158c479 | ||
|
|
90c3afcfa4 | ||
|
|
565e10b6a5 | ||
|
|
773ed5e6f7 | ||
|
|
8351312b2b | ||
|
|
41f53d39a0 | ||
|
|
4873ffda84 | ||
|
|
b79609bb8b | ||
|
|
bdcbb5cce6 | ||
|
|
d1503f9df3 | ||
|
|
210c3234d2 | ||
|
|
c13abfdd0d | ||
|
|
30b332ac7e | ||
|
|
7e9c489aeb | ||
|
|
5739ca7f97 | ||
|
|
e4451c7e6a | ||
|
|
5cded77387 | ||
|
|
ea4e0dd764 | ||
|
|
f105357f96 | ||
|
|
bc2302baeb | ||
|
|
afcdefbbf3 | ||
|
|
3ad8557065 | ||
|
|
e68d607c9b | ||
|
|
8e9cf67190 | ||
|
|
0cb6cd8761 | ||
|
|
17aa795b3e | ||
|
|
7d47096e6e | ||
|
|
48b59df11b | ||
|
|
a90a3b2445 | ||
|
|
d18b68d24a | ||
|
|
78c4ec8bfe | ||
|
|
b50a3b9aae | ||
|
|
4f3eaa12d5 | ||
|
|
cedb0f565c | ||
|
|
226432ec7f | ||
|
|
d93ab0143c | ||
|
|
3d32d66ab1 | ||
|
|
e814eed047 | ||
|
|
96395c1469 | ||
|
|
6065c29891 | ||
|
|
f38cb274e4 | ||
|
|
7bfee87cbf | ||
|
|
2ce2a3754c | ||
|
|
510476c214 | ||
|
|
6cd071c84b | ||
|
|
406e17b3fa | ||
|
|
dd184255ad | ||
|
|
77a0b38081 | ||
|
|
14c3d66ce6 | ||
|
|
858da38680 | ||
|
|
9f381b3c73 | ||
|
|
b8fc20b981 | ||
|
|
b89825525a | ||
|
|
e09cfc6704 | ||
|
|
0c9c303c60 | ||
|
|
3156b43739 | ||
|
|
591aa990a6 | ||
|
|
3be29f36a7 | ||
|
|
7638db4c3b | ||
|
|
0312a500a6 | ||
|
|
1a88b5355a | ||
|
|
3374773de5 | ||
|
|
872b5fe3da | ||
|
|
be15e9871c | ||
|
|
024a6a253b | ||
|
|
1af662df7b | ||
|
|
b4f64eb593 | ||
|
|
86aa86208c | ||
|
|
018e814615 | ||
|
|
e4d6e5cfc7 | ||
|
|
770cd77632 | ||
|
|
9f1692b33d | ||
|
|
6f63e0a5d7 | ||
|
|
6a90e2c796 | ||
|
|
23b90ff0f9 | ||
|
|
dc86af2fa4 | ||
|
|
425b822046 | ||
|
|
65c18b1d52 | ||
|
|
1bddf3daa7 | ||
|
|
600b6af876 | ||
|
|
4bdf16331d | ||
|
|
87cbda0528 | ||
|
|
9897941bf9 | ||
|
|
31938812d0 | ||
|
|
19d879d3f6 | ||
|
|
cc41036c63 | ||
|
|
a9f2b40529 | ||
|
|
86000ea19a | ||
|
|
0422c3b9e7 | ||
|
|
64c8bd5b5a | ||
|
|
a7eba2c5fc | ||
|
|
2b7753e43e | ||
|
|
47c1e5b5b8 | ||
|
|
14ee97def0 | ||
|
|
92e262f732 | ||
|
|
c46880b701 | ||
|
|
473e9b9300 | ||
|
|
28945ef153 | ||
|
|
b6b5d9f9c4 | ||
|
|
ba5de1ab31 | ||
|
|
002ebeaade | ||
|
|
894756000c | ||
|
|
cdb178c503 | ||
|
|
7c48cafc71 | ||
|
|
74d4592238 | ||
|
|
0044dd104e | ||
|
|
05041e2eae | ||
|
|
78908f216d | ||
|
|
efc68ae701 | ||
|
|
e9340a8b4b | ||
|
|
66e199d516 | ||
|
|
6151d8a787 | ||
|
|
296261da8a | ||
|
|
383371dd6f | ||
|
|
bb8c026bda | ||
|
|
344993dd6f | ||
|
|
ffb048c314 | ||
|
|
3eef9b8faa | ||
|
|
5704bb646b | ||
|
|
fbc684b3a7 | ||
|
|
6529b2a9c3 | ||
|
|
a1701e2edf | ||
|
|
eba6391de7 | ||
|
|
9f2c3c9688 | ||
|
|
57f5a19d0c | ||
|
|
c8d53c6964 | ||
|
|
643cda1abe | ||
|
|
03d118a73a | ||
|
|
51dd7f5c17 | ||
|
|
af7e1e7a3c | ||
|
|
ea5d855bc3 | ||
|
|
5f74367cd6 | ||
|
|
26e41e1c14 | ||
|
|
1bb2b50043 | ||
|
|
7bdb629f03 | ||
|
|
fd92f986da | ||
|
|
69a1207102 | ||
|
|
def652c768 | ||
|
|
c35faf5356 | ||
|
|
0615a33206 | ||
|
|
e77530bdc5 | ||
|
|
8c62df63cc | ||
|
|
bd36eade77 | ||
|
|
d2c023081a | ||
|
|
63d0850b38 | ||
|
|
c86659428f | ||
|
|
bf7cc6caf0 | ||
|
|
26b8be6041 | ||
|
|
f978f9196f | ||
|
|
75cb8d2a3c | ||
|
|
17a21ed707 | ||
|
|
f390647139 | ||
|
|
aacd91e196 | ||
|
|
258171c9c4 | ||
|
|
812c5873aa | ||
|
|
4c3d47f1f0 | ||
|
|
ba7b6ba869 | ||
|
|
d0471ae512 | ||
|
|
636c4be9fb | ||
|
|
6bec765a9d | ||
|
|
d61d16ccc4 | ||
|
|
f2a5715b24 | ||
|
|
c064c3781f | ||
|
|
bb4dffe2a4 | ||
|
|
37cf3eeef3 | ||
|
|
40395b2999 | ||
|
|
32afe6445f | ||
|
|
793a991913 | ||
|
|
d278224ff1 | ||
|
|
9b4d0ce6a8 | ||
|
|
a1829fe590 | ||
|
|
2b2b39365c | ||
|
|
1147930f3f | ||
|
|
636f338ed7 | ||
|
|
72365d00b4 | ||
|
|
19d8086732 | ||
|
|
30488418e5 | ||
|
|
2f0badd74a | ||
|
|
6045b0579b | ||
|
|
498f1fec74 | ||
|
|
f6a541f2b9 | ||
|
|
8ce78eabca | ||
|
|
2c34c5309f | ||
|
|
77e680168a | ||
|
|
8a7e59742f | ||
|
|
42bac14770 | ||
|
|
8323834483 | ||
|
|
1751caef62 | ||
|
|
d622d1474d | ||
|
|
f28be2e7de | ||
|
|
17773913ae | ||
|
|
d469c2d3f9 | ||
|
|
4e74d32882 | ||
|
|
7b8cd37a9b | ||
|
|
eda306d726 | ||
|
|
94f3b1fe84 | ||
|
|
c50e3ba293 | ||
|
|
eff7818912 | ||
|
|
270bcff8f3 | ||
|
|
e04963c2dc | ||
|
|
f369967c91 | ||
|
|
cd982c5526 | ||
|
|
16e03c9d37 | ||
|
|
d38b1f5364 | ||
|
|
f57ba4d05e | ||
|
|
172eeaafcf | ||
|
|
3115ed28b2 | ||
|
|
d8dc53805c | ||
|
|
7218d10e1b | ||
|
|
89bf85f501 | ||
|
|
8334a468d0 | ||
|
|
3da80ed077 | ||
|
|
2883ccbe87 | ||
|
|
5d3443fee4 | ||
|
|
27756a53db | ||
|
|
71cde6661d | ||
|
|
a857337b31 | ||
|
|
4ee21ffae4 | ||
|
|
d8399f7e85 | ||
|
|
574ac8d32f | ||
|
|
a2611bfa7d | ||
|
|
853badb76f | ||
|
|
5d69e1d2a5 | ||
|
|
6494f28bdb | ||
|
|
f55916bda2 | ||
|
|
04691ee197 | ||
|
|
2ac0e564e1 | ||
|
|
6072a29a20 | ||
|
|
8658942385 | ||
|
|
cc4859950c | ||
|
|
23b81ad6f1 | ||
|
|
e3b9dca5c0 | ||
|
|
a2359a1ad2 | ||
|
|
cb875b1b34 | ||
|
|
b92a85b4bc | ||
|
|
8c7dd6bab2 | ||
|
|
aad7df64d7 | ||
|
|
8474342007 | ||
|
|
61ccb4be65 | ||
|
|
1c6f69707c | ||
|
|
e08e8c482a | ||
|
|
548c1d2cab | ||
|
|
5a071bf3d1 | ||
|
|
1bffcbd947 | ||
|
|
274a36a83a | ||
|
|
ec40f36114 | ||
|
|
af19f274a7 | ||
|
|
2316004194 | ||
|
|
98762198ef | ||
|
|
1469de22a4 | ||
|
|
1e687f960a | ||
|
|
7f01b835fd | ||
|
|
e46b6c5c01 | ||
|
|
74226ad8df | ||
|
|
f8ae7be539 | ||
|
|
37b16e380d | ||
|
|
9ea3e9f652 | ||
|
|
54422b5181 | ||
|
|
712995dcf3 | ||
|
|
c2767b0fd6 | ||
|
|
179cc61f65 | ||
|
|
f3b910d55a | ||
|
|
f4157b52ea | ||
|
|
79710310ce | ||
|
|
3412498438 | ||
|
|
b896b07a08 | ||
|
|
379bff0622 | ||
|
|
474f47aa9f | ||
|
|
f1e26a4133 | ||
|
|
e37f881207 | ||
|
|
306c0b707b | ||
|
|
08c448ee30 | ||
|
|
1532014067 | ||
|
|
fa9f604af9 | ||
|
|
3b3d0d6539 | ||
|
|
9641d33040 | ||
|
|
eca339d107 | ||
|
|
ca18705d88 | ||
|
|
8f17b52466 | ||
|
|
8cf84e722b | ||
|
|
7c4d736b54 | ||
|
|
1b3ae6ab25 | ||
|
|
a4ad08136e | ||
|
|
df5e7997c5 | ||
|
|
b2cb3768c1 | ||
|
|
fa169c5cd3 | ||
|
|
bbb3975b67 | ||
|
|
4502a9c4fa | ||
|
|
86905a2670 | ||
|
|
b1e60a4867 | ||
|
|
1efe3324fb | ||
|
|
55c1e37d39 | ||
|
|
7fa700317c | ||
|
|
bbe831a57c | ||
|
|
90c86c056c | ||
|
|
36f22a28df | ||
|
|
ac03c51e2c | ||
|
|
bd9e92f705 | ||
|
|
281eff5eb2 | ||
|
|
abbd2253ad | ||
|
|
46466624ae | ||
|
|
0ba8d51b2a | ||
|
|
a1408ee18f | ||
|
|
58030bbcff | ||
|
|
e1b3e6ef01 | ||
|
|
298a6ba8ab | ||
|
|
e5bf47629f | ||
|
|
ea29ee9f66 | ||
|
|
868c2254de | ||
|
|
567522c87a | ||
|
|
25fd47f57b | ||
|
|
f89d6342d1 | ||
|
|
b02affdea3 | ||
|
|
6e5ade943b | ||
|
|
a6ed0c0d00 | ||
|
|
68402aadd7 | ||
|
|
85cacd447b | ||
|
|
11262b321a | ||
|
|
bf290f063d | ||
|
|
7ac0fbaf76 | ||
|
|
7489c76722 | ||
|
|
bcdf1b6efe | ||
|
|
8a9dbe212c | ||
|
|
16bd71a6cb | ||
|
|
71caad0655 | ||
|
|
2c62ffe34a | ||
|
|
3450a89880 | ||
|
|
a081a69bbe | ||
|
|
271d1d23d5 | ||
|
|
605aba1a3c | ||
|
|
be3c2b4c7c | ||
|
|
08eb32d7bd | ||
|
|
2b9cda15e4 | ||
|
|
f6055b290a | ||
|
|
ec665e05e4 | ||
|
|
2b6d7205ec | ||
|
|
41381a920c | ||
|
|
f1b3fc2254 | ||
|
|
a677ed307d | ||
|
|
0ab23ee972 | ||
|
|
43f56d39be | ||
|
|
a39caee5f5 | ||
|
|
2edfdf47c8 | ||
|
|
3819461db5 | ||
|
|
85654dd7dd | ||
|
|
619a70416b | ||
|
|
16d996fe70 | ||
|
|
1baeb6da19 | ||
|
|
1641d432dd | ||
|
|
1bf9862e47 | ||
|
|
602a394043 | ||
|
|
22a2415ca5 | ||
|
|
feb034352d | ||
|
|
a7c8942c78 | ||
|
|
95f2ac3811 | ||
|
|
91354295f2 | ||
|
|
c9c4ab5911 | ||
|
|
a26c5e40dd | ||
|
|
80f5c7bc44 | ||
|
|
4833b39c52 | ||
|
|
f478958943 | ||
|
|
0469ad46d6 | ||
|
|
5fe5deb9df | ||
|
|
ce83bc24bd | ||
|
|
dce729c8cb | ||
|
|
a9d17cd96f | ||
|
|
294bb3d4a1 | ||
|
|
b31b9261f2 | ||
|
|
2211f8d9e4 | ||
|
|
b9b7b00a7f | ||
|
|
843faf6103 | ||
|
|
4af5dad9a8 | ||
|
|
52437c9d18 | ||
|
|
c6cb4c8479 | ||
|
|
c3714ec251 | ||
|
|
dbe2f94af1 | ||
|
|
07fd5f8a9e | ||
|
|
9e64b4cd7f | ||
|
|
f08a7b9eb3 | ||
|
|
a6fa764e2a | ||
|
|
01676668f1 | ||
|
|
8e5e4f460d | ||
|
|
f907b8a84d | ||
|
|
a3a4285f90 | ||
|
|
0979163b79 | ||
|
|
248a25eaee | ||
|
|
f95b1fa68a | ||
|
|
d2b5d69051 | ||
|
|
3ca419b735 | ||
|
|
50e275a2f9 | ||
|
|
aeccf78957 | ||
|
|
cb3cef70e5 | ||
|
|
b9bd303bf8 | ||
|
|
57d4786a7f | ||
|
|
df031455b2 | ||
|
|
30059eff4f | ||
|
|
bc289b48c8 | ||
|
|
067d8b99b8 | ||
|
|
00a6a9c42d | ||
|
|
070425d446 | ||
|
|
7405883444 | ||
|
|
66959937ed | ||
|
|
e431efbcba | ||
|
|
ba00baa5a0 | ||
|
|
0fb5d4a164 | ||
|
|
1ac717b67f | ||
|
|
273cbd447e | ||
|
|
cee41567a2 | ||
|
|
1aae5eb1a6 | ||
|
|
28a4c81aff | ||
|
|
5e077cd64d | ||
|
|
e3f957a59b | ||
|
|
55c62a3ab5 | ||
|
|
22e7eef1bd | ||
|
|
d6524907f3 | ||
|
|
357db334cd | ||
|
|
f8bed3909b | ||
|
|
182bbdde91 | ||
|
|
2c70f990c2 | ||
|
|
0b01a6aa91 | ||
|
|
e557dffbc6 | ||
|
|
7f33b0b1b8 | ||
|
|
41ddf77a5b | ||
|
|
8c657ce41d | ||
|
|
3ff3b9ed4a | ||
|
|
ef43419ecd | ||
|
|
2ca375c214 | ||
|
|
cbd45c1d0f | ||
|
|
2592ea3464 | ||
|
|
73ac97cd96 | ||
|
|
e014663e97 | ||
|
|
58592e961f | ||
|
|
9a99b9ce82 | ||
|
|
8c6dca1751 | ||
|
|
cf488d5f5f | ||
|
|
515584d34c | ||
|
|
fb2becc7f2 | ||
|
|
0f8ceb0fac | ||
|
|
a70bf18770 | ||
|
|
2de83c44ab | ||
|
|
7b99f09810 | ||
|
|
6b4ba8bfad | ||
|
|
0c6cfc5020 | ||
|
|
abd9733e7f | ||
|
|
98c3ae5e76 | ||
|
|
bb5a657469 | ||
|
|
7797532350 | ||
|
|
c3a5106adc | ||
|
|
c5fd935dd0 | ||
|
|
ec375a19ae | ||
|
|
51e940617c | ||
|
|
58ec8bd437 | ||
|
|
a096395086 | ||
|
|
4bd08bd915 | ||
|
|
2c849cfa7a | ||
|
|
501d530d1d | ||
|
|
91fc4327f4 | ||
|
|
8d56c67079 | ||
|
|
e52d43458e | ||
|
|
9b125bf9b0 | ||
|
|
0716c65269 | ||
|
|
ba3ce4f1b5 | ||
|
|
07f72b0cdc | ||
|
|
bda19df87f | ||
|
|
5d82fae2b0 | ||
|
|
0813b87221 | ||
|
|
961ecfc720 | ||
|
|
81f30ef25a | ||
|
|
140b0d3df2 | ||
|
|
b3d69d7de4 | ||
|
|
8e65564fb8 | ||
|
|
06ce9bd4de |
109
.github/workflows/build.yml
vendored
109
.github/workflows/build.yml
vendored
@@ -14,6 +14,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Release version
|
||||
id: release_version
|
||||
@@ -66,6 +69,98 @@ jobs:
|
||||
cache-from: type=gha, scope=${{ github.workflow }}-docker
|
||||
cache-to: type=gha, scope=${{ github.workflow }}-docker
|
||||
|
||||
- name: Generate Changelog
|
||||
id: changelog
|
||||
run: |
|
||||
# 获取上一个 tag(排除当前版本的 tag)
|
||||
PREVIOUS_TAG=$(git tag -l 'v*' --sort=-v:refname | grep -v "^v${{ env.app_version }}$" | head -n 1)
|
||||
echo "Previous tag: $PREVIOUS_TAG"
|
||||
|
||||
# 使用 || 作为分隔符,同时获取 commit 消息和作者 GitHub 用户名
|
||||
if [ -z "$PREVIOUS_TAG" ]; then
|
||||
COMMITS=$(git log --pretty=format:"%s||%an" HEAD)
|
||||
else
|
||||
COMMITS=$(git log --pretty=format:"%s||%an" ${PREVIOUS_TAG}..HEAD)
|
||||
fi
|
||||
|
||||
# 分类收集 commit 消息(使用关联数组去重)
|
||||
declare -A SEEN
|
||||
FEATURES=""
|
||||
FIXES=""
|
||||
OTHERS=""
|
||||
|
||||
while IFS= read -r line; do
|
||||
# 跳过空行
|
||||
if [ -z "$line" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# 分离 commit 消息和作者
|
||||
msg=$(echo "$line" | sed 's/||[^|]*$//')
|
||||
author=$(echo "$line" | sed 's/.*||//')
|
||||
|
||||
# 跳过 Merge commit 和版本更新 commit
|
||||
if echo "$msg" | grep -qE "^Merge pull request|^Merge branch|^更新 version"; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# 按 Conventional Commits 前缀分类
|
||||
if echo "$msg" | grep -qiE "^feat(\(.+\))?:"; then
|
||||
desc=$(echo "$msg" | sed -E 's/^feat(\([^)]*\))?:\s*//')
|
||||
category="FEATURES"
|
||||
elif echo "$msg" | grep -qiE "^fix(\(.+\))?:"; then
|
||||
desc=$(echo "$msg" | sed -E 's/^fix(\([^)]*\))?:\s*//')
|
||||
category="FIXES"
|
||||
elif echo "$msg" | grep -qiE "^(docs|style|refactor|perf|test|build|ci|chore|revert)(\(.+\))?:"; then
|
||||
desc=$(echo "$msg" | sed -E 's/^(docs|style|refactor|perf|test|build|ci|chore|revert)(\([^)]*\))?:\s*//')
|
||||
category="OTHERS"
|
||||
else
|
||||
desc="$msg"
|
||||
category="OTHERS"
|
||||
fi
|
||||
|
||||
# 使用 "分类+描述" 作为去重的 key,跳过重复内容
|
||||
dedup_key="${category}::${desc}"
|
||||
if [ -n "${SEEN[$dedup_key]+x}" ]; then
|
||||
continue
|
||||
fi
|
||||
SEEN[$dedup_key]=1
|
||||
|
||||
# 添加 by @author 引用
|
||||
entry="- ${desc} by @${author}"
|
||||
|
||||
case "$category" in
|
||||
FEATURES) FEATURES="${FEATURES}${entry}\n" ;;
|
||||
FIXES) FIXES="${FIXES}${entry}\n" ;;
|
||||
OTHERS) OTHERS="${OTHERS}${entry}\n" ;;
|
||||
esac
|
||||
done <<< "$COMMITS"
|
||||
|
||||
# 组装 changelog
|
||||
CHANGELOG=""
|
||||
|
||||
if [ -n "$FEATURES" ]; then
|
||||
CHANGELOG="${CHANGELOG}### ✨ 新功能\n\n${FEATURES}\n"
|
||||
fi
|
||||
|
||||
if [ -n "$FIXES" ]; then
|
||||
CHANGELOG="${CHANGELOG}### 🐛 修复\n\n${FIXES}\n"
|
||||
fi
|
||||
|
||||
if [ -n "$OTHERS" ]; then
|
||||
CHANGELOG="${CHANGELOG}### 🔧 其他\n\n${OTHERS}\n"
|
||||
fi
|
||||
|
||||
# 添加版本对比链接
|
||||
if [ -n "$PREVIOUS_TAG" ]; then
|
||||
CHANGELOG="${CHANGELOG}**完整更新记录**: https://github.com/${{ github.repository }}/compare/${PREVIOUS_TAG}...v${{ env.app_version }}"
|
||||
fi
|
||||
|
||||
# 写入环境变量
|
||||
echo "CHANGELOG<<EOF" >> $GITHUB_ENV
|
||||
echo -e "$CHANGELOG" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
- name: Get existing release body
|
||||
id: get_release_body
|
||||
continue-on-error: true
|
||||
@@ -73,9 +168,17 @@ jobs:
|
||||
release_body=$(curl -s -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
|
||||
"https://api.github.com/repos/${{ github.repository }}/releases/tags/v${{ env.app_version }}" | \
|
||||
jq -r '.body // ""')
|
||||
echo "RELEASE_BODY<<EOF" >> $GITHUB_ENV
|
||||
echo "$release_body" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
# 如果已有手动编写的 release body,则保留;否则使用自动生成的 changelog
|
||||
if [ -n "$release_body" ] && [ "$release_body" != "null" ] && [ "$release_body" != "" ]; then
|
||||
echo "RELEASE_BODY<<EOF" >> $GITHUB_ENV
|
||||
echo "$release_body" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
else
|
||||
echo "RELEASE_BODY<<EOF" >> $GITHUB_ENV
|
||||
echo "${{ env.CHANGELOG }}" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Delete Release
|
||||
uses: dev-drprasad/delete-tag-and-release@v1.1
|
||||
|
||||
1
.github/workflows/issues.yml
vendored
1
.github/workflows/issues.yml
vendored
@@ -29,4 +29,5 @@ jobs:
|
||||
days-before-pr-close: -1
|
||||
# 排除带有RFC标签的issue
|
||||
exempt-issue-labels: "RFC"
|
||||
operations-per-run: 500
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -27,4 +27,7 @@ venv
|
||||
|
||||
# Pylint
|
||||
pylint-report.json
|
||||
.pylint.d/
|
||||
.pylint.d/
|
||||
|
||||
# AI
|
||||
.claude/
|
||||
|
||||
@@ -26,6 +26,11 @@
|
||||
|
||||
官方Wiki:https://wiki.movie-pilot.org
|
||||
|
||||
### 为 AI Agent 添加 Skills
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
|
||||
## 参与开发
|
||||
|
||||
API文档:https://api.movie-pilot.org
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,33 +1,362 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.message import (
|
||||
MessageResponse,
|
||||
ChannelCapabilityManager,
|
||||
ChannelCapability,
|
||||
)
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
class _StreamChain(ChainBase):
|
||||
pass
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
|
||||
class StreamingHandler:
|
||||
"""
|
||||
流式Token缓冲管理器
|
||||
|
||||
负责从 LLM 流式 token 中积累文本,并在支持消息编辑的渠道上实时推送给用户。
|
||||
|
||||
工作流程:
|
||||
1. Agent开始处理时调用 start_streaming(),检查渠道能力并启动定时刷新
|
||||
2. LLM 产生 token 时调用 emit() 积累到缓冲区
|
||||
3. 定时器周期性调用 _flush():
|
||||
- 第一次有内容时发送新消息(通过 send_direct_message 获取 message_id)
|
||||
- 后续有新内容时编辑同一条消息(通过 edit_message)
|
||||
- 当消息长度接近渠道限制时,冻结当前消息并发送新消息继续输出
|
||||
4. 工具调用时:
|
||||
- 流式渠道:工具消息直接 emit() 追加到 buffer,与 Agent 文字合并为同一条流式消息
|
||||
- 非流式渠道:调用 take() 取出已积累的文字,与工具消息合并独立发送
|
||||
5. Agent最终完成时调用 stop_streaming():执行最后一次刷新,
|
||||
返回是否已通过流式发送完所有内容(调用方据此决定是否还需额外发送)
|
||||
"""
|
||||
|
||||
# 流式输出的刷新间隔(秒)
|
||||
FLUSH_INTERVAL = 0.3
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
self._buffer = ""
|
||||
# 流式输出相关状态
|
||||
self._streaming_enabled = False
|
||||
self._flush_task: Optional[asyncio.Task] = None
|
||||
# 当前消息的发送信息(用于编辑消息)
|
||||
self._message_response: Optional[MessageResponse] = None
|
||||
# 已发送给用户的文本(用于追踪增量)
|
||||
self._sent_text = ""
|
||||
# 当前消息的起始偏移量(buffer 中属于当前消息的起始位置)
|
||||
self._msg_start_offset = 0
|
||||
# 当前渠道的单条消息最大长度(0 表示不限制)
|
||||
self._max_message_length = 0
|
||||
# 消息发送所需的上下文信息
|
||||
self._channel: Optional[str] = None
|
||||
self._source: Optional[str] = None
|
||||
self._user_id: Optional[str] = None
|
||||
self._username: Optional[str] = None
|
||||
self._title: str = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
def emit(self, token: str):
|
||||
"""
|
||||
接收 LLM 流式 token,积累到缓冲区。
|
||||
"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
# 如果存量消息结束是两个换行,则去掉新消息前面的换行,避免过多空行
|
||||
if self._buffer.endswith("\n\n") and token.startswith("\n"):
|
||||
token = token.lstrip("\n")
|
||||
self._buffer += token
|
||||
|
||||
async def take(self) -> str:
|
||||
"""
|
||||
获取当前已积累的消息内容,获取后清空缓冲区。
|
||||
|
||||
用于非流式渠道:工具调用前取出 Agent 已产出的文字,
|
||||
与工具提示合并后独立发送。
|
||||
|
||||
注意:流式渠道不调用此方法,工具消息直接 emit 到 buffer 中。
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._buffer:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
message = self._buffer
|
||||
logger.info(f"Agent消息: {message}")
|
||||
self._buffer = ""
|
||||
return message
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
if not token:
|
||||
return
|
||||
def clear(self):
|
||||
"""
|
||||
清空缓冲区(不返回内容)
|
||||
"""
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置缓冲区,清空已发送的文本从头更新,但保持消息编辑能力。
|
||||
|
||||
与 clear 的区别:
|
||||
- clear:完全重置所有状态,后续会开新消息
|
||||
- reset:只清空buffer,保留消息编辑状态,后续继续编辑同一条消息
|
||||
"""
|
||||
with self._lock:
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._msg_start_offset = 0
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
channel: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
title: str = "",
|
||||
):
|
||||
"""
|
||||
启动流式输出。
|
||||
始终标记为流式状态(用于 buffer 收集 token),
|
||||
但只有渠道支持消息编辑时才启动定时刷新任务(实时推送给用户)。
|
||||
:param channel: 消息渠道
|
||||
:param source: 消息来源
|
||||
:param user_id: 用户ID
|
||||
:param username: 用户名
|
||||
:param title: 消息标题
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._user_id = user_id
|
||||
self._username = username
|
||||
self._title = title
|
||||
|
||||
self._streaming_enabled = True
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
|
||||
# 检查渠道是否支持消息编辑,不支持则仅收集 token 到 buffer,不实时推送
|
||||
if not self._can_stream():
|
||||
logger.debug(f"渠道 {channel} 不支持消息编辑,仅启用 buffer 收集模式")
|
||||
return
|
||||
|
||||
# 从渠道能力中获取单条消息最大长度
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
self._max_message_length = ChannelCapabilityManager.get_max_message_length(
|
||||
channel_enum
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
self._max_message_length = 0
|
||||
|
||||
# 启动异步定时刷新任务
|
||||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||||
logger.debug("流式输出已启动")
|
||||
|
||||
async def stop_streaming(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
停止流式输出。执行最后一次刷新确保所有内容都已发送。
|
||||
:return: (all_sent, final_text)
|
||||
all_sent: 是否已经通过流式编辑将最终完整内容发送给了用户
|
||||
(True 表示调用方无需再额外发送消息)
|
||||
final_text: 流式发送的完整文本内容(用于调用方保存消息记录)
|
||||
"""
|
||||
if not self._streaming_enabled:
|
||||
return False, ""
|
||||
|
||||
self._streaming_enabled = False
|
||||
|
||||
# 取消定时任务
|
||||
await self._cancel_flush_task()
|
||||
|
||||
# 执行最后一次刷新
|
||||
await self._flush()
|
||||
|
||||
# 检查是否所有缓冲内容都已发送
|
||||
with self._lock:
|
||||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||||
current_msg_text = self._buffer[self._msg_start_offset :]
|
||||
all_sent = (
|
||||
self._message_response is not None
|
||||
and self._sent_text
|
||||
and current_msg_text == self._sent_text
|
||||
)
|
||||
# 保留最终文本用于返回(返回完整 buffer 内容,包含所有分段消息)
|
||||
final_text = self._buffer if all_sent else ""
|
||||
# 重置状态
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
if all_sent:
|
||||
# 所有内容已通过流式发送,清空缓冲区
|
||||
self._buffer = ""
|
||||
return all_sent, final_text
|
||||
|
||||
def _can_stream(self) -> bool:
|
||||
"""
|
||||
检查当前渠道是否支持流式输出(消息编辑)
|
||||
"""
|
||||
if not self._channel:
|
||||
return False
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
return ChannelCapabilityManager.supports_capability(
|
||||
channel_enum, ChannelCapability.MESSAGE_EDITING
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
return False
|
||||
|
||||
async def _flush_loop(self):
|
||||
"""
|
||||
定时刷新循环,定期将缓冲区内容发送/编辑到用户
|
||||
"""
|
||||
try:
|
||||
while self._streaming_enabled:
|
||||
await asyncio.sleep(self.FLUSH_INTERVAL)
|
||||
if self._streaming_enabled:
|
||||
await self._flush()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"流式刷新异常: {e}")
|
||||
|
||||
async def _cancel_flush_task(self):
|
||||
"""
|
||||
取消当前的定时刷新任务
|
||||
"""
|
||||
if self._flush_task and not self._flush_task.done():
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._flush_task = None
|
||||
|
||||
async def _flush(self):
|
||||
"""
|
||||
将当前缓冲区内容刷新到用户消息
|
||||
- 如果还没有发送过消息,先发送一条新消息并记录message_id
|
||||
- 如果已经发送过消息,编辑该消息为最新的完整内容
|
||||
- 如果当前消息内容超过长度限制,冻结当前消息并发送新消息继续输出
|
||||
"""
|
||||
with self._lock:
|
||||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||||
current_text = self._buffer[self._msg_start_offset :]
|
||||
if not current_text or current_text == self._sent_text:
|
||||
# 没有新内容需要刷新
|
||||
return
|
||||
|
||||
chain = _StreamChain()
|
||||
|
||||
try:
|
||||
if self._message_response is None:
|
||||
# 第一次发送:发送新消息并获取 message_id
|
||||
response = chain.send_direct_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
logger.debug(
|
||||
f"流式输出初始消息已发送: message_id={response.message_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"流式输出初始消息发送失败或未返回message_id,降级为非流式输出"
|
||||
)
|
||||
self._streaming_enabled = False
|
||||
else:
|
||||
# 检查当前消息内容是否超过长度限制
|
||||
if (
|
||||
self._max_message_length
|
||||
and len(current_text) > self._max_message_length
|
||||
):
|
||||
# 消息过长,冻结当前消息(保持最后一次成功编辑的内容)
|
||||
# 将 offset 移动到已发送文本之后,开启新消息
|
||||
logger.debug(
|
||||
f"流式消息长度 {len(current_text)} 超过限制 {self._max_message_length},启用新消息"
|
||||
)
|
||||
with self._lock:
|
||||
self._msg_start_offset += len(self._sent_text)
|
||||
current_text = self._buffer[self._msg_start_offset :]
|
||||
self._message_response = None
|
||||
self._sent_text = ""
|
||||
|
||||
# 如果偏移后还有新内容,立即发送为新消息
|
||||
if current_text:
|
||||
response = chain.send_direct_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
logger.debug(
|
||||
f"流式输出新消息已发送: message_id={response.message_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug("流式输出新消息发送失败,降级为非流式输出")
|
||||
self._streaming_enabled = False
|
||||
else:
|
||||
# 后续更新:编辑已有消息
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
|
||||
success = chain.edit_message(
|
||||
channel=channel_enum,
|
||||
source=self._message_response.source,
|
||||
message_id=self._message_response.message_id,
|
||||
chat_id=self._message_response.chat_id,
|
||||
text=current_text,
|
||||
title=self._title,
|
||||
)
|
||||
if success:
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
else:
|
||||
logger.debug("流式输出消息编辑失败")
|
||||
except Exception as e:
|
||||
logger.error(f"流式输出刷新失败: {e}")
|
||||
|
||||
@property
|
||||
def is_streaming(self) -> bool:
|
||||
"""
|
||||
是否正在流式输出
|
||||
"""
|
||||
return self._streaming_enabled
|
||||
|
||||
@property
|
||||
def is_auto_flushing(self) -> bool:
|
||||
"""
|
||||
是否正在定时刷新(渠道支持消息编辑时自动推送 buffer 内容)
|
||||
"""
|
||||
return self._flush_task is not None
|
||||
|
||||
@property
|
||||
def has_sent_message(self) -> bool:
|
||||
"""
|
||||
是否已经通过流式输出发送过消息(当前轮次)
|
||||
"""
|
||||
return self._message_response is not None
|
||||
|
||||
@@ -1,39 +1,45 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
class MemoryManager:
|
||||
"""
|
||||
对话记忆管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
# 内存缓存清理任务
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
def initialize(self):
|
||||
"""
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
self.cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_memories()
|
||||
)
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
"""
|
||||
关闭记忆管理器
|
||||
"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
@@ -41,214 +47,77 @@ class ConversationMemoryManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
@staticmethod
|
||||
def _get_memory_key(session_id: str, user_id: str):
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
计算内存Key
|
||||
"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
def get_memory(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
获取内存中的记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
return self.memory_cache.get(cache_key)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
def get_agent_messages(
|
||||
self, session_id: str, user_id: str
|
||||
) -> List[BaseMessage]:
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
return memory.messages
|
||||
|
||||
return messages
|
||||
def save_agent_messages(
|
||||
self, session_id: str, user_id: str, messages: List[BaseMessage]
|
||||
):
|
||||
"""
|
||||
保存Agent消息(仅内存缓存)
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||||
"""
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
memory.messages = messages
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
# 更新内存缓存
|
||||
self.save_memory(memory)
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
def save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到内存缓存
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||||
"""
|
||||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
def clear_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
清空会话记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
"""
|
||||
清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
@@ -263,7 +132,9 @@ class ConversationMemoryManager:
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
if (
|
||||
current_time - memory.updated_at
|
||||
).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
@@ -278,3 +149,6 @@ class ConversationMemoryManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
|
||||
|
||||
memory_manager = MemoryManager()
|
||||
|
||||
0
app/agent/middleware/__init__.py
Normal file
0
app/agent/middleware/__init__.py
Normal file
406
app/agent/middleware/activity_log.py
Normal file
406
app/agent/middleware/activity_log.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
活动日志中间件 - 自动记录 Agent 每次交互的操作摘要。
|
||||
|
||||
按日期存储在 CONFIG_PATH/agent/activity/YYYY-MM-DD.md 中,
|
||||
每次 Agent 执行完毕后自动调用 LLM 对本轮对话生成简洁的活动摘要,
|
||||
并在每次 Agent 启动时加载近几天的活动日志注入系统提示词。
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Any, NotRequired, TypedDict
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# 活动日志保留天数
|
||||
DEFAULT_RETENTION_DAYS = 7
|
||||
|
||||
# 注入系统提示词时加载的天数
|
||||
PROMPT_LOAD_DAYS = 3
|
||||
|
||||
# 每日日志文件最大大小 (256KB)
|
||||
MAX_LOG_FILE_SIZE = 256 * 1024
|
||||
|
||||
# 提取本轮对话上下文的最大字符数(避免过长的对话消耗太多 token)
|
||||
MAX_CONTEXT_FOR_SUMMARY = 4000
|
||||
|
||||
# LLM 总结的提示词
|
||||
SUMMARY_PROMPT = """请根据以下 AI 助手与用户的对话记录,生成一条简洁的活动摘要(中文,一句话,不超过80字)。
|
||||
摘要应包含:用户的需求是什么、助手做了什么、结果如何。
|
||||
只输出摘要内容,不要加任何前缀、标点序号或解释。
|
||||
|
||||
对话记录:
|
||||
{conversation}"""
|
||||
|
||||
|
||||
class ActivityLogState(AgentState):
|
||||
"""ActivityLogMiddleware 的状态模型。"""
|
||||
|
||||
activity_log_contents: NotRequired[Annotated[dict[str, str], PrivateStateAttr]]
|
||||
"""将日期字符串映射到日志内容的字典。标记为私有,不包含在最终代理状态中。"""
|
||||
|
||||
|
||||
class ActivityLogStateUpdate(TypedDict):
|
||||
"""ActivityLogMiddleware 的状态更新。"""
|
||||
|
||||
activity_log_contents: dict[str, str]
|
||||
|
||||
|
||||
def _extract_last_round(messages: list) -> list | None:
|
||||
"""从完整消息列表中提取最后一轮交互。
|
||||
|
||||
从最后一条 HumanMessage 到消息末尾即为本轮交互。
|
||||
|
||||
参数:
|
||||
messages: Agent 执行后的完整消息列表。
|
||||
|
||||
返回:
|
||||
本轮交互的消息子列表,如果无有效交互则返回 None。
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 找到最后一条用户消息的索引
|
||||
last_human_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if isinstance(messages[i], HumanMessage) and messages[i].content:
|
||||
last_human_idx = i
|
||||
break
|
||||
|
||||
if last_human_idx is None:
|
||||
return None
|
||||
|
||||
round_messages = messages[last_human_idx:]
|
||||
|
||||
# 检查是否为系统心跳消息
|
||||
user_msg = round_messages[0]
|
||||
user_content = (
|
||||
user_msg.content if isinstance(user_msg.content, str) else str(user_msg.content)
|
||||
)
|
||||
if user_content.strip().startswith("[System Heartbeat]"):
|
||||
return None
|
||||
|
||||
return round_messages
|
||||
|
||||
|
||||
def _format_conversation_for_summary(round_messages: list) -> str:
|
||||
"""将本轮对话消息格式化为文本,供 LLM 总结。
|
||||
|
||||
参数:
|
||||
round_messages: 本轮交互的消息列表。
|
||||
|
||||
返回:
|
||||
格式化后的对话文本。
|
||||
"""
|
||||
lines = []
|
||||
total_len = 0
|
||||
|
||||
for msg in round_messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
line = f"用户: {content}"
|
||||
elif isinstance(msg, AIMessage):
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
tool_names = [
|
||||
tc["name"]
|
||||
for tc in msg.tool_calls
|
||||
if isinstance(tc, dict) and "name" in tc
|
||||
]
|
||||
line = f"助手调用工具: {', '.join(tool_names)}"
|
||||
elif msg.content:
|
||||
content = (
|
||||
msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
)
|
||||
line = f"助手: {content}"
|
||||
else:
|
||||
continue
|
||||
elif isinstance(msg, ToolMessage):
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
# 工具返回可能很长,截断
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
line = f"工具返回: {content}"
|
||||
else:
|
||||
continue
|
||||
|
||||
# 控制总长度
|
||||
if total_len + len(line) > MAX_CONTEXT_FOR_SUMMARY:
|
||||
lines.append("...(后续对话省略)")
|
||||
break
|
||||
lines.append(line)
|
||||
total_len += len(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _summarize_with_llm(conversation_text: str) -> str | None:
|
||||
"""调用 LLM 对对话文本生成活动摘要。
|
||||
|
||||
参数:
|
||||
conversation_text: 格式化后的对话文本。
|
||||
|
||||
返回:
|
||||
LLM 生成的摘要字符串,失败时返回 None。
|
||||
"""
|
||||
try:
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
llm = LLMHelper.get_llm(streaming=False)
|
||||
prompt = SUMMARY_PROMPT.format(conversation=conversation_text)
|
||||
response = await llm.ainvoke(prompt)
|
||||
summary = response.content.strip()
|
||||
# 清理模型可能输出的前缀(如 "摘要:" "总结:")
|
||||
summary = re.sub(r"^(摘要|总结|活动记录)[::]\s*", "", summary)
|
||||
return summary if summary else None
|
||||
except Exception as e:
|
||||
logger.debug("LLM summarization failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
ACTIVITY_LOG_SYSTEM_PROMPT = """<activity_log>
|
||||
{activity_log}
|
||||
</activity_log>
|
||||
|
||||
<activity_log_guidelines>
|
||||
The above <activity_log> contains a record of your recent interactions with the user, automatically maintained by the system.
|
||||
|
||||
**How to use this information:**
|
||||
- Reference past activities when relevant to provide continuity (e.g., "之前帮你订阅了《XXX》,现在有更新了")
|
||||
- Use activity history to understand ongoing tasks and user patterns
|
||||
- When the user asks "你之前帮我做了什么" or similar questions, refer to this log
|
||||
- Activity logs are automatically recorded after each interaction - you do NOT need to manually update them
|
||||
|
||||
**What is automatically logged:**
|
||||
- Each user interaction: what was asked, which tools were used, and the outcome
|
||||
- Timestamps for all activities
|
||||
- The log is organized by date for easy reference
|
||||
|
||||
**Important:**
|
||||
- Activity logs are READ-ONLY from your perspective - the system manages them automatically
|
||||
- Do not attempt to edit or write to activity log files
|
||||
- For long-term preferences and knowledge, continue to use MEMORY.md
|
||||
- Activity logs are retained for {retention_days} days and then automatically cleaned up
|
||||
</activity_log_guidelines>
|
||||
"""
|
||||
|
||||
|
||||
class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, ResponseT]): # noqa
|
||||
"""自动记录和加载 Agent 活动日志的中间件。
|
||||
|
||||
- abefore_agent: 加载近几天的活动日志
|
||||
- awrap_model_call: 将活动日志注入系统提示词
|
||||
- aafter_agent: 从本次对话中提取摘要并追加到当日日志文件
|
||||
|
||||
参数:
|
||||
activity_dir: 活动日志存储目录路径。
|
||||
retention_days: 日志保留天数(默认 7 天)。
|
||||
prompt_load_days: 注入系统提示词时加载的天数(默认 3 天)。
|
||||
"""
|
||||
|
||||
state_schema = ActivityLogState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
activity_dir: str,
|
||||
retention_days: int = DEFAULT_RETENTION_DAYS,
|
||||
prompt_load_days: int = PROMPT_LOAD_DAYS,
|
||||
) -> None:
|
||||
self.activity_dir = activity_dir
|
||||
self.retention_days = retention_days
|
||||
self.prompt_load_days = prompt_load_days
|
||||
|
||||
def _get_log_path(self, date_str: str) -> AsyncPath:
|
||||
"""获取指定日期的日志文件路径。"""
|
||||
return AsyncPath(self.activity_dir) / f"{date_str}.md"
|
||||
|
||||
def _format_activity_log(self, contents: dict[str, str]) -> str:
|
||||
"""格式化活动日志用于系统提示词注入。"""
|
||||
if not contents:
|
||||
return ACTIVITY_LOG_SYSTEM_PROMPT.format(
|
||||
activity_log="(暂无活动记录)",
|
||||
retention_days=self.retention_days,
|
||||
)
|
||||
|
||||
# 按日期排序(最近的在前)
|
||||
sorted_dates = sorted(contents.keys(), reverse=True)
|
||||
sections = []
|
||||
for date_str in sorted_dates:
|
||||
content = contents[date_str].strip()
|
||||
if content:
|
||||
sections.append(f"### {date_str}\n{content}")
|
||||
|
||||
if not sections:
|
||||
return ACTIVITY_LOG_SYSTEM_PROMPT.format(
|
||||
activity_log="(暂无活动记录)",
|
||||
retention_days=self.retention_days,
|
||||
)
|
||||
|
||||
log_body = "\n\n".join(sections)
|
||||
return ACTIVITY_LOG_SYSTEM_PROMPT.format(
|
||||
activity_log=log_body,
|
||||
retention_days=self.retention_days,
|
||||
)
|
||||
|
||||
async def _load_recent_logs(self) -> dict[str, str]:
|
||||
"""加载近几天的活动日志。"""
|
||||
contents: dict[str, str] = {}
|
||||
today = datetime.now().date()
|
||||
|
||||
for i in range(self.prompt_load_days):
|
||||
date = today - timedelta(days=i)
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
log_path = self._get_log_path(date_str)
|
||||
|
||||
if await log_path.exists():
|
||||
try:
|
||||
content = await log_path.read_text(encoding="utf-8")
|
||||
contents[date_str] = content
|
||||
logger.debug("Loaded activity log for %s", date_str)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load activity log %s: %s", date_str, e)
|
||||
|
||||
return contents
|
||||
|
||||
async def _append_activity(self, summary: str) -> None:
|
||||
"""将一条活动记录追加到当日日志文件。"""
|
||||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||||
now_str = datetime.now().strftime("%H:%M")
|
||||
log_path = self._get_log_path(today_str)
|
||||
|
||||
# 确保目录存在
|
||||
dir_path = AsyncPath(self.activity_dir)
|
||||
if not await dir_path.exists():
|
||||
await dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 检查文件大小
|
||||
if await log_path.exists():
|
||||
stat = await log_path.stat()
|
||||
if stat.st_size >= MAX_LOG_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Activity log %s exceeds size limit (%d bytes), skipping append",
|
||||
today_str,
|
||||
stat.st_size,
|
||||
)
|
||||
return
|
||||
|
||||
# 追加记录
|
||||
entry = f"- **{now_str}** {summary}\n"
|
||||
try:
|
||||
if await log_path.exists():
|
||||
existing = await log_path.read_text(encoding="utf-8")
|
||||
await log_path.write_text(existing + entry, encoding="utf-8")
|
||||
else:
|
||||
header = f"# {today_str} 活动日志\n\n"
|
||||
await log_path.write_text(header + entry, encoding="utf-8")
|
||||
logger.debug("Activity logged: %s", summary[:80])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to append activity log: %s", e)
|
||||
|
||||
async def _cleanup_old_logs(self) -> None:
|
||||
"""清理超过保留天数的旧日志文件。"""
|
||||
dir_path = AsyncPath(self.activity_dir)
|
||||
if not await dir_path.exists():
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now().date() - timedelta(days=self.retention_days)
|
||||
date_pattern = re.compile(r"^(\d{4}-\d{2}-\d{2})\.md$")
|
||||
|
||||
try:
|
||||
async for path in dir_path.iterdir():
|
||||
if not await path.is_file():
|
||||
continue
|
||||
match = date_pattern.match(path.name)
|
||||
if not match:
|
||||
continue
|
||||
try:
|
||||
file_date = datetime.strptime(match.group(1), "%Y-%m-%d").date()
|
||||
if file_date < cutoff_date:
|
||||
await path.unlink()
|
||||
logger.debug("Cleaned up old activity log: %s", path.name)
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup old activity logs: %s", e)
|
||||
|
||||
async def abefore_agent(
|
||||
self, state: ActivityLogState, runtime: Runtime
|
||||
) -> ActivityLogStateUpdate | None:
|
||||
"""在 Agent 执行前加载近期活动日志。"""
|
||||
# 如果已经加载则跳过
|
||||
if "activity_log_contents" in state:
|
||||
return None
|
||||
|
||||
contents = await self._load_recent_logs()
|
||||
|
||||
# 趁机清理旧日志(低频操作,不影响性能)
|
||||
await self._cleanup_old_logs()
|
||||
|
||||
return ActivityLogStateUpdate(activity_log_contents=contents)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将活动日志注入系统消息。"""
|
||||
contents = request.state.get("activity_log_contents", {})
|
||||
activity_log_prompt = self._format_activity_log(contents)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, activity_log_prompt
|
||||
)
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""异步包装模型调用,注入活动日志到系统提示词。"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
|
||||
async def aafter_agent(
|
||||
self, state: ActivityLogState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
"""Agent 执行完毕后,调用 LLM 对本轮对话生成摘要并追加到当日活动日志。"""
|
||||
try:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 提取本轮交互
|
||||
round_messages = _extract_last_round(messages)
|
||||
if not round_messages:
|
||||
return None
|
||||
|
||||
# 格式化对话文本
|
||||
conversation_text = _format_conversation_for_summary(round_messages)
|
||||
if not conversation_text:
|
||||
return None
|
||||
|
||||
# 调用 LLM 生成摘要
|
||||
summary = await _summarize_with_llm(conversation_text)
|
||||
if summary:
|
||||
await self._append_activity(summary)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record activity: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["ActivityLogMiddleware"]
|
||||
350
app/agent/middleware/jobs.py
Normal file
350
app/agent/middleware/jobs.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
import yaml # noqa
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# JOB.md 文件最大限制为 1MB
|
||||
MAX_JOB_FILE_SIZE = 1 * 1024 * 1024
|
||||
|
||||
|
||||
class JobMetadata(TypedDict):
|
||||
"""Job 元数据。"""
|
||||
|
||||
path: str
|
||||
"""JOB.md 文件路径。"""
|
||||
|
||||
id: str
|
||||
"""Job 标识符(目录名)。"""
|
||||
|
||||
name: str
|
||||
"""Job 名称。"""
|
||||
|
||||
description: str
|
||||
"""Job 描述。"""
|
||||
|
||||
schedule: str
|
||||
"""调度类型: once(一次性)/ recurring(重复性)。"""
|
||||
|
||||
status: str
|
||||
"""当前状态: pending / in_progress / completed / cancelled。"""
|
||||
|
||||
last_run: str | None
|
||||
"""上次执行时间。"""
|
||||
|
||||
|
||||
class JobsState(AgentState):
|
||||
"""jobs 中间件状态。"""
|
||||
|
||||
jobs_metadata: NotRequired[Annotated[list[JobMetadata], PrivateStateAttr]]
|
||||
"""已加载的 job 元数据列表,不传播给父 agent。"""
|
||||
|
||||
|
||||
class JobsStateUpdate(TypedDict):
|
||||
"""jobs 中间件状态更新项。"""
|
||||
|
||||
jobs_metadata: list[JobMetadata]
|
||||
"""待合并的 job 元数据列表。"""
|
||||
|
||||
|
||||
def _parse_job_metadata(
|
||||
content: str,
|
||||
job_path: str,
|
||||
job_id: str,
|
||||
) -> JobMetadata | None:
|
||||
"""从 JOB.md 内容中解析 YAML 前言并验证元数据。"""
|
||||
if len(content) > MAX_JOB_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Skipping %s: content too large (%d bytes)", job_path, len(content)
|
||||
)
|
||||
return None
|
||||
|
||||
# 匹配 --- 分隔的 YAML 前言
|
||||
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n"
|
||||
match = re.match(frontmatter_pattern, content, re.DOTALL)
|
||||
if not match:
|
||||
logger.warning("Skipping %s: no valid YAML frontmatter found", job_path)
|
||||
return None
|
||||
frontmatter_str = match.group(1)
|
||||
|
||||
# 解析 YAML
|
||||
try:
|
||||
frontmatter_data = yaml.safe_load(frontmatter_str)
|
||||
except yaml.YAMLError as e:
|
||||
logger.warning("Invalid YAML in %s: %s", job_path, e)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter_data, dict):
|
||||
logger.warning("Skipping %s: frontmatter is not a mapping", job_path)
|
||||
return None
|
||||
|
||||
# Job 名称和描述
|
||||
name = str(frontmatter_data.get("name", "")).strip()
|
||||
description = str(frontmatter_data.get("description", "")).strip()
|
||||
if not name:
|
||||
logger.warning("Skipping %s: missing required 'name'", job_path)
|
||||
return None
|
||||
|
||||
# 调度类型
|
||||
schedule = str(frontmatter_data.get("schedule", "once")).strip().lower()
|
||||
if schedule not in ("once", "recurring"):
|
||||
schedule = "once"
|
||||
|
||||
# 状态
|
||||
status = str(frontmatter_data.get("status", "pending")).strip().lower()
|
||||
if status not in ("pending", "in_progress", "completed", "cancelled"):
|
||||
status = "pending"
|
||||
|
||||
# 上次执行时间
|
||||
last_run = str(frontmatter_data.get("last_run", "")).strip() or None
|
||||
|
||||
return JobMetadata(
|
||||
id=job_id,
|
||||
name=name,
|
||||
description=description,
|
||||
path=job_path,
|
||||
schedule=schedule,
|
||||
status=status,
|
||||
last_run=last_run,
|
||||
)
|
||||
|
||||
|
||||
async def _alist_jobs(source_path: AsyncPath) -> list[JobMetadata]:
|
||||
"""异步列出指定路径下的所有任务。
|
||||
|
||||
扫描包含 JOB.md 的目录并解析其元数据。
|
||||
"""
|
||||
jobs: list[JobMetadata] = []
|
||||
|
||||
if not await source_path.exists():
|
||||
return []
|
||||
|
||||
# 查找所有任务目录(包含 JOB.md 的目录)
|
||||
job_dirs: list[AsyncPath] = []
|
||||
async for path in source_path.iterdir():
|
||||
if await path.is_dir() and await (path / "JOB.md").is_file():
|
||||
job_dirs.append(path)
|
||||
|
||||
if not job_dirs:
|
||||
return []
|
||||
|
||||
# 解析 JOB.md
|
||||
for job_path in job_dirs:
|
||||
job_md_path = job_path / "JOB.md"
|
||||
|
||||
job_content = await job_md_path.read_text(encoding="utf-8")
|
||||
|
||||
# 解析元数据
|
||||
job_metadata = _parse_job_metadata(
|
||||
content=job_content,
|
||||
job_path=str(job_md_path),
|
||||
job_id=job_path.name,
|
||||
)
|
||||
if job_metadata:
|
||||
jobs.append(job_metadata)
|
||||
|
||||
return jobs
|
||||
|
||||
|
||||
JOBS_SYSTEM_PROMPT = """
|
||||
<jobs_system>
|
||||
You have a **scheduled jobs** system that allows you to track and execute long-running or recurring tasks.
|
||||
|
||||
**Jobs Location:** `{jobs_location}`
|
||||
|
||||
**Current Jobs:**
|
||||
|
||||
{jobs_list}
|
||||
|
||||
**Job File Format:**
|
||||
|
||||
Each job is a directory containing a `JOB.md` file with YAML frontmatter followed by task details:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: 任务名称(简短中文描述)
|
||||
description: 任务的详细描述,说明要做什么
|
||||
schedule: once 或 recurring
|
||||
status: pending / in_progress / completed / cancelled
|
||||
last_run: "YYYY-MM-DD HH:MM"(上次执行时间,可选)
|
||||
---
|
||||
# 任务详情
|
||||
|
||||
## 目标
|
||||
详细描述这个任务要完成的目标。
|
||||
|
||||
## 执行日志
|
||||
记录每次执行的情况和结果。
|
||||
|
||||
- **2024-01-15 10:00** - 执行了XXX操作,结果:成功/失败
|
||||
- **2024-01-16 10:00** - 继续执行XXX...
|
||||
```
|
||||
|
||||
**Job Lifecycle Rules:**
|
||||
|
||||
1. **Creating a Job**: When a user asks you to do something periodically or at a later time:
|
||||
- Create a new directory under the jobs location, directory name is the `job-id` (lowercase, hyphens, 1-64 chars)
|
||||
- Write a `JOB.md` file with proper frontmatter and detailed task description
|
||||
- Set `schedule: once` for one-time tasks, `schedule: recurring` for repeating tasks (e.g., daily sign-in, weekly checks)
|
||||
- Set initial `status: pending`
|
||||
|
||||
2. **Executing a Job**: When you work on a job:
|
||||
- Update `status: in_progress` in the frontmatter
|
||||
- Execute the required actions using your tools
|
||||
- Log the execution result in the "执行日志" section with timestamp
|
||||
- Update `last_run` in frontmatter to current time
|
||||
|
||||
3. **Completing a Job**:
|
||||
- For `schedule: once` tasks: set `status: completed` after successful execution
|
||||
- For `schedule: recurring` tasks: keep `status: pending` after execution, only update `last_run` time. The job stays active for the next scheduled run.
|
||||
- Set `status: cancelled` if the user explicitly asks to cancel/stop a task
|
||||
|
||||
4. **Heartbeat Check**: You will be periodically woken up to check pending jobs. When woken up:
|
||||
- Read the jobs directory to find all active jobs (status: pending or in_progress)
|
||||
- Skip jobs with `status: completed` or `status: cancelled`
|
||||
- For `schedule: recurring` jobs, check `last_run` to determine if it's time to run again
|
||||
- Execute pending jobs and update their status/logs accordingly
|
||||
|
||||
**Important Notes:**
|
||||
- Each job MUST have its own separate directory and JOB.md file to avoid conflicts
|
||||
- Always update the frontmatter fields (status, last_run) when executing a job
|
||||
- Keep execution logs concise but informative
|
||||
- For recurring jobs, maintain a rolling log (keep recent entries, you can summarize/remove old entries to keep the file manageable)
|
||||
- When creating jobs, make the description detailed enough that you can understand and execute the task in future sessions without additional context
|
||||
|
||||
**When to Create Jobs:**
|
||||
- User says "每天帮我..." / "定期..." / "定时..." / "提醒我..." / "以后每次..."
|
||||
- User requests a task that should be done repeatedly
|
||||
- User asks for monitoring or periodic checking of something
|
||||
|
||||
**When NOT to Create Jobs:**
|
||||
- User asks for an immediate one-time action (just do it now)
|
||||
- Simple questions or conversations
|
||||
- Tasks that are already handled by MoviePilot's built-in scheduler services
|
||||
</jobs_system>
|
||||
"""
|
||||
|
||||
|
||||
class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa
|
||||
"""加载并向系统提示词注入 Agent Jobs 的中间件。
|
||||
|
||||
扫描 jobs 目录下的 JOB.md 文件,解析元数据并注入到系统提示词中,
|
||||
使智能体了解当前的长期任务及其状态。
|
||||
"""
|
||||
|
||||
state_schema = JobsState
|
||||
|
||||
def __init__(self, *, sources: list[str]) -> None:
|
||||
"""初始化 Jobs 中间件。"""
|
||||
self.sources = sources
|
||||
self.system_prompt_template = JOBS_SYSTEM_PROMPT
|
||||
|
||||
@staticmethod
|
||||
def _format_jobs_list(jobs: list[JobMetadata]) -> str:
|
||||
"""格式化任务元数据列表用于系统提示词。"""
|
||||
if not jobs:
|
||||
return "(No active jobs. You can create jobs when users request periodic or scheduled tasks.)"
|
||||
|
||||
lines = []
|
||||
for job in jobs:
|
||||
status_emoji = {
|
||||
"pending": "⏳",
|
||||
"in_progress": "🔄",
|
||||
"completed": "✅",
|
||||
"cancelled": "❌",
|
||||
}.get(job["status"], "❓")
|
||||
|
||||
schedule_label = (
|
||||
"recurring (重复)"
|
||||
if job["schedule"] == "recurring"
|
||||
else "once (一次性)"
|
||||
)
|
||||
desc_line = (
|
||||
f"- {status_emoji} **{job['id']}**: {job['name']}"
|
||||
f" [{schedule_label}] - {job['description']}"
|
||||
)
|
||||
if job.get("last_run"):
|
||||
desc_line += f" (上次执行: {job['last_run']})"
|
||||
lines.append(desc_line)
|
||||
lines.append(f" -> Read `{job['path']}` for full details")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将任务文档注入模型请求的系统消息中。"""
|
||||
jobs_metadata = request.state.get("jobs_metadata", []) # noqa
|
||||
|
||||
# 过滤:只展示活跃任务(pending / in_progress / recurring)
|
||||
active_jobs = [
|
||||
j
|
||||
for j in jobs_metadata
|
||||
if j["status"] in ("pending", "in_progress")
|
||||
or (j["schedule"] == "recurring" and j["status"] not in ("cancelled",))
|
||||
]
|
||||
|
||||
jobs_list = self._format_jobs_list(active_jobs)
|
||||
jobs_location = self.sources[0] if self.sources else ""
|
||||
|
||||
jobs_section = self.system_prompt_template.format(
|
||||
jobs_location=jobs_location,
|
||||
jobs_list=jobs_list,
|
||||
)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, jobs_section
|
||||
)
|
||||
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self, state: JobsState, runtime: Runtime, config: RunnableConfig
|
||||
) -> JobsStateUpdate | None:
|
||||
"""在 Agent 执行前异步加载任务元数据。
|
||||
|
||||
每个会话仅加载一次。若 state 中已有则跳过。
|
||||
"""
|
||||
# 如果 state 中已存在元数据则跳过
|
||||
if "jobs_metadata" in state:
|
||||
return None
|
||||
|
||||
all_jobs: list[JobMetadata] = []
|
||||
|
||||
# 遍历源加载任务
|
||||
for source_path_str in self.sources:
|
||||
source_path = AsyncPath(source_path_str)
|
||||
if not await source_path.exists():
|
||||
await source_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
source_jobs = await _alist_jobs(source_path)
|
||||
all_jobs.extend(source_jobs)
|
||||
|
||||
return JobsStateUpdate(jobs_metadata=all_jobs)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""在模型调用时注入任务文档。"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
|
||||
|
||||
__all__ = ["JobMetadata", "JobsMiddleware"]
|
||||
396
app/agent/middleware/memory.py
Normal file
396
app/agent/middleware/memory.py
Normal file
@@ -0,0 +1,396 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, NotRequired, TypedDict, Dict
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# 记忆文件最大限制为 100KB,防止单文件过大导致上下文溢出
|
||||
MAX_MEMORY_FILE_SIZE = 100 * 1024
|
||||
|
||||
# 默认记忆文件名(用户主记忆)
|
||||
DEFAULT_MEMORY_FILE = "MEMORY.md"
|
||||
|
||||
|
||||
class MemoryState(AgentState):
|
||||
"""`MemoryMiddleware` 的状态模型。
|
||||
|
||||
属性:
|
||||
memory_contents: 将源路径映射到其加载内容的字典。
|
||||
标记为私有,因此不包含在最终的代理状态中。
|
||||
memory_empty: 记忆文件是否为空或不存在。
|
||||
标记为私有,用于判断是否需要触发初始化引导流程。
|
||||
"""
|
||||
|
||||
memory_contents: NotRequired[Annotated[dict[str, str], PrivateStateAttr]]
|
||||
memory_empty: NotRequired[Annotated[bool, PrivateStateAttr]]
|
||||
|
||||
|
||||
class MemoryStateUpdate(TypedDict):
|
||||
"""`MemoryMiddleware` 的状态更新。"""
|
||||
|
||||
memory_contents: dict[str, str]
|
||||
memory_empty: bool
|
||||
|
||||
|
||||
MEMORY_SYSTEM_PROMPT = """<agent_memory>
|
||||
The following memory files were loaded from your memory directory: `{memory_dir}`
|
||||
You can create, edit, or organize any `.md` files in this directory to manage your knowledge.
|
||||
|
||||
{agent_memory}
|
||||
</agent_memory>
|
||||
|
||||
<memory_guidelines>
|
||||
The above <agent_memory> was loaded from `.md` files in your memory directory (`{memory_dir}`). As you learn from your interactions with the user, you can save new knowledge by calling the `edit_file` or `write_file` tool on files in this directory.
|
||||
|
||||
**Memory file organization:**
|
||||
- All `.md` files in `{memory_dir}` are automatically loaded as memory.
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- You may create additional `.md` files to organize knowledge by topic (e.g., `MEDIA_RULES.md`, `DOWNLOAD_PREFERENCES.md`, `SITE_CONFIGS.md`, etc.).
|
||||
- Keep each file focused on a specific domain or topic for better organization.
|
||||
- Subdirectories are NOT scanned — only `.md` files directly in `{memory_dir}`.
|
||||
|
||||
**Learning from feedback:**
|
||||
- One of your MAIN PRIORITIES is to learn from your interactions with the user. These learnings can be implicit or explicit. This means that in the future, you will remember this important information.
|
||||
- When you need to remember something, updating memory must be your FIRST, IMMEDIATE action - before responding to the user, before calling other tools, before doing anything else. Just update memory immediately.
|
||||
- When user says something is better/worse, capture WHY and encode it as a pattern.
|
||||
- Each correction is a chance to improve permanently - don't just fix the immediate issue, update your instructions.
|
||||
- A great opportunity to update your memories is when the user interrupts a tool call and provides feedback. You should update your memories immediately before revising the tool call.
|
||||
- Look for the underlying principle behind corrections, not just the specific mistake.
|
||||
- The user might not explicitly ask you to remember something, but if they provide information that is useful for future use, you should update your memories immediately.
|
||||
|
||||
**Asking for information:**
|
||||
- If you lack context to perform an action (e.g. send a Slack DM, requires a user ID/email) you should explicitly ask the user for this information.
|
||||
- It is preferred for you to ask for information, don't assume anything that you do not know!
|
||||
- When the user provides information that is useful for future use, you should update your memories immediately.
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something (e.g., "remember my email", "save this preference")
|
||||
- When the user describes your role or how you should behave (e.g., "you are a web researcher", "always do X")
|
||||
- When the user gives feedback on your work - capture what was wrong and how to improve
|
||||
- When the user provides information required for tool use (e.g., slack channel ID, email addresses)
|
||||
- When the user provides context useful for future tasks, such as how to use tools, or which actions to take in a particular situation
|
||||
- When you discover new patterns or preferences (coding styles, conventions, workflows)
|
||||
|
||||
**When to NOT update memories:**
|
||||
- When the information is temporary or transient (e.g., "I'm running late", "I'm on my phone right now")
|
||||
- When the information is a one-time task request (e.g., "Find me a recipe", "What's 25 * 4?")
|
||||
- When the information is a simple question that doesn't reveal lasting preferences (e.g., "What day is it?", "Can you explain X?")
|
||||
- When the information is an acknowledgment or small talk (e.g., "Sounds good!", "Hello", "Thanks for that")
|
||||
- When the information is stale or irrelevant in future conversations
|
||||
- Never store API keys, access tokens, passwords, or any other credentials in any file, memory, or system prompt.
|
||||
- If the user asks where to put API keys or provides an API key, do NOT echo or save it.
|
||||
- Do NOT record daily activities or task execution history in memory files - these are automatically tracked in the activity log system (see <activity_log>). Memory files are only for long-term knowledge, preferences, and patterns.
|
||||
|
||||
**Examples:**
|
||||
Example 1 (remembering user information):
|
||||
User: Can you connect to my google account?
|
||||
Agent: Sure, I'll connect to your google account, what's your google account email?
|
||||
User: john@example.com
|
||||
Agent: Let me save this to my memory.
|
||||
Tool Call: edit_file(...) -> remembers that the user's google account email is john@example.com
|
||||
|
||||
Example 2 (remembering implicit user preferences):
|
||||
User: Can you write me an example for creating a deep agent in LangChain?
|
||||
Agent: Sure, I'll write you an example for creating a deep agent in LangChain <example code in Python>
|
||||
User: Can you do this in JavaScript
|
||||
Agent: Let me save this to my memory.
|
||||
Tool Call: edit_file(...) -> remembers that the user prefers to get LangChain code examples in JavaScript
|
||||
Agent: Sure, here is the JavaScript example<example code in JavaScript>
|
||||
|
||||
Example 3 (do not remember transient information):
|
||||
User: I'm going to play basketball tonight so I will be offline for a few hours.
|
||||
Agent: Okay I'll add a block to your calendar.
|
||||
Tool Call: create_calendar_event(...) -> just calls a tool, does not commit anything to memory, as it is transient information
|
||||
</memory_guidelines>
|
||||
"""
|
||||
|
||||
MEMORY_ONBOARDING_PROMPT = """<agent_memory>
|
||||
(No memory loaded — this is a brand new user with no saved preferences.)
|
||||
Memory directory: {memory_dir}
|
||||
Default memory file: {memory_file}
|
||||
</agent_memory>
|
||||
|
||||
<memory_onboarding>
|
||||
**IMPORTANT — First-time user detected!**
|
||||
|
||||
The memory directory is currently empty. This means this is likely the user's first interaction, or their preferences have been reset.
|
||||
|
||||
**Your MANDATORY first action in this conversation:**
|
||||
Before doing ANYTHING else (before answering questions, before calling tools, before performing any task), you MUST proactively greet the user warmly and ask them about their preferences so you can provide personalized service going forward. Specifically, ask about:
|
||||
|
||||
1. **How to address the user** — Ask what name or nickname they'd like you to call them (e.g., a real name, a nickname, or a fun title). This is the top priority for building a personal connection.
|
||||
2. **Communication style preference** — Do they prefer a cute/playful tone (with emojis), a formal/professional tone, a concise/minimalist style, or something else?
|
||||
3. **Media preferences** — What types of media do they primarily care about? (e.g., movies, TV shows, anime, documentaries, etc.)
|
||||
4. **Quality preferences** — Do they have preferred video quality (4K, 1080p), codecs (H.265, H.264), or subtitle language preferences?
|
||||
5. **Any other special requests** — Anything else they'd like you to always keep in mind?
|
||||
|
||||
**After the user replies**, you MUST immediately:
|
||||
1. Use the `write_file` tool to save ALL their preferences to the memory file at: `{memory_file}`
|
||||
2. Format the memory file in clean Markdown with clear sections (e.g., `## User Profile`, `## Communication Style`, `## Media Preferences`, etc.)
|
||||
3. The `## User Profile` section MUST include the user's preferred name/nickname at the top
|
||||
4. Only AFTER saving the preferences, proceed to help with whatever the user originally asked about (if anything)
|
||||
5. From this point on, always address the user by their preferred name/nickname in conversations
|
||||
6. You may also create additional `.md` files in the memory directory (`{memory_dir}`) for different topics as needed.
|
||||
|
||||
**If the user skips the preference questions** and directly asks you to do something:
|
||||
- Go ahead and help them with their request first
|
||||
- But still ask about their preferences naturally at the end of the interaction
|
||||
- Save whatever you learn about them (implicit or explicit) to the memory file
|
||||
|
||||
**Example onboarding flow:**
|
||||
The greeting should introduce yourself, explain this is the first meeting, and ask the above questions in a numbered list. Adapt the tone to your persona defined in the base system prompt.
|
||||
</memory_onboarding>
|
||||
|
||||
<memory_guidelines>
|
||||
Your memory directory is at: {memory_dir}. You can save new knowledge by calling the `edit_file` or `write_file` tool on any `.md` file in this directory.
|
||||
|
||||
**Memory file organization:**
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- You may create additional `.md` files to organize knowledge by topic.
|
||||
- All `.md` files directly in the memory directory are automatically loaded on each conversation.
|
||||
|
||||
**Learning from feedback:**
|
||||
- One of your MAIN PRIORITIES is to learn from your interactions with the user. These learnings can be implicit or explicit. This means that in the future, you will remember this important information.
|
||||
- When you need to remember something, updating memory must be your FIRST, IMMEDIATE action - before responding to the user, before calling other tools, before doing anything else. Just update memory immediately.
|
||||
- When user says something is better/worse, capture WHY and encode it as a pattern.
|
||||
- Each correction is a chance to improve permanently - don't just fix the immediate issue, update your instructions.
|
||||
- The user might not explicitly ask you to remember something, but if they provide information that is useful for future use, you should update your memories immediately.
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something
|
||||
- When the user describes your role or how you should behave
|
||||
- When the user gives feedback on your work
|
||||
- When the user provides information required for tool use
|
||||
- When you discover new patterns or preferences
|
||||
|
||||
**When to NOT update memories:**
|
||||
- Temporary/transient information
|
||||
- One-time task requests
|
||||
- Simple questions, acknowledgments, or small talk
|
||||
- Never store API keys, access tokens, passwords, or credentials
|
||||
- Do NOT record daily activities in memory files — those go to the activity log
|
||||
</memory_guidelines>
|
||||
"""
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # noqa
|
||||
"""从代理记忆目录加载所有 MD 文件作为记忆的中间件。
|
||||
|
||||
自动扫描指定目录下的所有 `.md` 文件,加载其内容并注入到系统提示词中。
|
||||
支持多文件记忆组织:用户可以创建多个 `.md` 文件来按主题组织知识。
|
||||
|
||||
参数:
|
||||
memory_dir: 记忆文件目录路径。
|
||||
"""
|
||||
|
||||
state_schema = MemoryState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
memory_dir: str,
|
||||
) -> None:
|
||||
"""初始化记忆中间件。
|
||||
|
||||
参数:
|
||||
memory_dir: 记忆文件目录路径(例如,`"/config/agent"`)。
|
||||
该目录下所有 `.md` 文件都会被自动加载为记忆。
|
||||
"""
|
||||
self.memory_dir = memory_dir
|
||||
self.default_memory_file = str(AsyncPath(memory_dir) / DEFAULT_MEMORY_FILE)
|
||||
|
||||
@staticmethod
|
||||
def _is_memory_empty(contents: dict[str, str]) -> bool:
|
||||
"""判断记忆内容是否为空。
|
||||
|
||||
检查所有源文件的内容,如果全部为空或仅包含空白字符则返回 True。
|
||||
|
||||
参数:
|
||||
contents: 将源路径映射到内容的字典。
|
||||
|
||||
返回:
|
||||
如果记忆为空则返回 True,否则返回 False。
|
||||
"""
|
||||
if not contents:
|
||||
return True
|
||||
return all(not content.strip() for content in contents.values())
|
||||
|
||||
def _format_agent_memory(
|
||||
self, contents: dict[str, str], memory_empty: bool = False
|
||||
) -> str:
|
||||
"""格式化记忆,将位置和内容成对组合。
|
||||
|
||||
当记忆为空时,返回初始化引导提示词,引导智能体主动询问用户偏好。
|
||||
当记忆非空时,返回标准记忆系统提示词,包含所有加载的文件内容。
|
||||
|
||||
参数:
|
||||
contents: 将源路径映射到内容的字典。
|
||||
memory_empty: 记忆是否为空的标志位。
|
||||
|
||||
返回:
|
||||
在 <agent_memory> 标签中包装了位置+内容对的格式化字符串。
|
||||
"""
|
||||
# 记忆为空时返回初始化引导提示词
|
||||
if memory_empty or self._is_memory_empty(contents):
|
||||
return MEMORY_ONBOARDING_PROMPT.format(
|
||||
memory_dir=self.memory_dir,
|
||||
memory_file=self.default_memory_file,
|
||||
)
|
||||
|
||||
# 按文件名排序,确保 MEMORY.md 排在最前面
|
||||
sorted_paths = sorted(
|
||||
[p for p in contents if contents[p].strip()],
|
||||
key=lambda p: (0 if AsyncPath(p).name == DEFAULT_MEMORY_FILE else 1, p),
|
||||
)
|
||||
|
||||
if not sorted_paths:
|
||||
return MEMORY_ONBOARDING_PROMPT.format(
|
||||
memory_dir=self.memory_dir,
|
||||
memory_file=self.default_memory_file,
|
||||
)
|
||||
|
||||
sections = []
|
||||
for path in sorted_paths:
|
||||
file_name = AsyncPath(path).name
|
||||
sections.append(f"### {file_name}\n**Path:** `{path}`\n\n{contents[path]}")
|
||||
|
||||
memory_body = "\n\n---\n\n".join(sections)
|
||||
return MEMORY_SYSTEM_PROMPT.format(
|
||||
agent_memory=memory_body,
|
||||
memory_dir=self.memory_dir,
|
||||
)
|
||||
|
||||
async def _scan_memory_files(self) -> list[str]:
|
||||
"""扫描记忆目录下的所有 .md 文件。
|
||||
|
||||
仅扫描目录下直接存在的 `.md` 文件(不递归子目录)。
|
||||
文件大小超过限制的将被跳过。
|
||||
|
||||
返回:
|
||||
发现的 .md 文件路径列表。
|
||||
"""
|
||||
dir_path = AsyncPath(self.memory_dir)
|
||||
if not await dir_path.exists():
|
||||
return []
|
||||
|
||||
md_files: list[str] = []
|
||||
async for entry in dir_path.iterdir():
|
||||
if await entry.is_file() and entry.name.lower().endswith(".md"):
|
||||
md_files.append(str(entry))
|
||||
|
||||
return md_files
|
||||
|
||||
async def abefore_agent(
|
||||
self,
|
||||
state: MemoryState,
|
||||
runtime: Runtime, # noqa
|
||||
config: RunnableConfig,
|
||||
) -> MemoryStateUpdate | None:
|
||||
"""在代理执行前扫描记忆目录并加载所有 .md 文件的内容。
|
||||
|
||||
自动发现目录下所有 `.md` 文件并加载其内容到状态中。
|
||||
如果状态中尚未存在则进行加载。
|
||||
同时检测记忆文件是否为空,设置 memory_empty 标志位,
|
||||
以便在系统提示词中触发初始化引导流程。
|
||||
|
||||
参数:
|
||||
state: 当前代理状态。
|
||||
runtime: 运行时上下文。
|
||||
config: Runnable 配置。
|
||||
|
||||
返回:
|
||||
填充了 memory_contents 和 memory_empty 的状态更新。
|
||||
"""
|
||||
# 如果已经加载则跳过
|
||||
if "memory_contents" in state:
|
||||
return None
|
||||
|
||||
# 扫描目录下所有 .md 文件
|
||||
md_files = await self._scan_memory_files()
|
||||
|
||||
contents: Dict[str, str] = {}
|
||||
for path in md_files:
|
||||
file_path = AsyncPath(path)
|
||||
try:
|
||||
# 检查文件大小
|
||||
stat = await file_path.stat()
|
||||
if stat.st_size > MAX_MEMORY_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Skipping memory file %s: too large (%d bytes, max %d)",
|
||||
path,
|
||||
stat.st_size,
|
||||
MAX_MEMORY_FILE_SIZE,
|
||||
)
|
||||
continue
|
||||
contents[path] = await file_path.read_text(encoding="utf-8")
|
||||
logger.debug("Loaded memory from: %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read memory file %s: %s", path, e)
|
||||
|
||||
if contents:
|
||||
logger.info(
|
||||
"Loaded %d memory file(s) from %s: %s",
|
||||
len(contents),
|
||||
self.memory_dir,
|
||||
[AsyncPath(p).name for p in contents],
|
||||
)
|
||||
|
||||
# 检测记忆是否为空(文件不存在、文件内容为空白)
|
||||
is_empty = self._is_memory_empty(contents)
|
||||
if is_empty:
|
||||
logger.info(
|
||||
"Memory is empty, onboarding prompt will be activated for user preference collection."
|
||||
)
|
||||
|
||||
return MemoryStateUpdate(memory_contents=contents, memory_empty=is_empty)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将记忆内容注入系统消息。
|
||||
|
||||
参数:
|
||||
request: 要修改的模型请求。
|
||||
|
||||
返回:
|
||||
将记忆注入系统消息后的修改后请求。
|
||||
"""
|
||||
contents = request.state.get("memory_contents", {}) # noqa
|
||||
memory_empty = request.state.get("memory_empty", False) # noqa
|
||||
agent_memory = self._format_agent_memory(contents, memory_empty=memory_empty)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, agent_memory
|
||||
)
|
||||
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""异步包装模型调用,将记忆注入系统提示词。
|
||||
|
||||
参数:
|
||||
request: 正在处理的模型请求。
|
||||
handler: 使用修改后的请求进行调用的异步处理函数。
|
||||
|
||||
返回:
|
||||
来自处理函数的模型响应。
|
||||
"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
43
app/agent/middleware/patch_tool_calls.py
Normal file
43
app/agent/middleware/patch_tool_calls.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Overwrite
|
||||
|
||||
|
||||
class PatchToolCallsMiddleware(AgentMiddleware):
|
||||
"""修复消息历史中悬空工具调用的中间件。"""
|
||||
|
||||
def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""在代理运行之前,处理任何 AIMessage 中悬空的工具调用。"""
|
||||
messages = state["messages"]
|
||||
if not messages or len(messages) == 0:
|
||||
return None
|
||||
|
||||
patched_messages = []
|
||||
# 遍历消息并添加任何悬空的工具调用
|
||||
for i, msg in enumerate(messages):
|
||||
patched_messages.append(msg)
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
corresponding_tool_msg = next(
|
||||
(msg for msg in messages[i:] if msg.type == "tool" and msg.tool_call_id == tool_call["id"]),
|
||||
# ty: ignore[unresolved-attribute]
|
||||
None,
|
||||
)
|
||||
if corresponding_tool_msg is None:
|
||||
# 我们有一个悬空的工具调用,需要一个 ToolMessage
|
||||
tool_msg = (
|
||||
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
|
||||
"cancelled - another message came in before it could be completed."
|
||||
)
|
||||
patched_messages.append(
|
||||
ToolMessage(
|
||||
content=tool_msg,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
|
||||
return {"messages": Overwrite(patched_messages)}
|
||||
524
app/agent/middleware/skills.py
Normal file
524
app/agent/middleware/skills.py
Normal file
@@ -0,0 +1,524 @@
|
||||
import re
|
||||
import shutil
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
import yaml # noqa
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.agents.middleware.types import PrivateStateAttr # noqa
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# 安全提示: SKILL.md 文件最大限制为 10MB,防止 DoS 攻击
|
||||
MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# Agent Skills 规范约束 (https://agentskills.io/specification)
|
||||
MAX_SKILL_NAME_LENGTH = 64
|
||||
MAX_SKILL_DESCRIPTION_LENGTH = 1024
|
||||
MAX_SKILL_COMPATIBILITY_LENGTH = 500
|
||||
|
||||
|
||||
class SkillMetadata(TypedDict):
|
||||
"""Skill 元数据,符合 Agent Skills 规范。"""
|
||||
|
||||
path: str
|
||||
"""SKILL.md 文件路径。"""
|
||||
|
||||
id: str
|
||||
"""Skill 标识符。
|
||||
约束: 1-64 字符,仅限小写字母/数字/连字符,不能以连字符开头或结尾,无连续连字符,需与父目录名一致。
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""Skill 名称。
|
||||
约束: Skill中文描述。
|
||||
"""
|
||||
|
||||
version: int
|
||||
"""Skill 版本号。
|
||||
用于内置技能的版本管理,同步时比较版本号决定是否覆盖用户目录中的旧版本。
|
||||
"""
|
||||
|
||||
description: str
|
||||
"""Skill 功能描述。
|
||||
约束: 1-1024 字符,应说明功能及适用场景。
|
||||
"""
|
||||
|
||||
license: str | None
|
||||
"""许可证信息。"""
|
||||
|
||||
compatibility: str | None
|
||||
"""环境依赖或兼容性要求 (最多 500 字符)。"""
|
||||
|
||||
metadata: dict[str, str]
|
||||
"""附加元数据。"""
|
||||
|
||||
allowed_tools: list[str]
|
||||
"""(实验性) Skill 建议使用的工具列表。"""
|
||||
|
||||
|
||||
class SkillsState(AgentState):
|
||||
"""skills 中间件状态。"""
|
||||
|
||||
skills_metadata: NotRequired[Annotated[list[SkillMetadata], PrivateStateAttr]]
|
||||
"""已加载的 skill 元数据列表,不传播给父 agent。"""
|
||||
|
||||
|
||||
class SkillsStateUpdate(TypedDict):
|
||||
"""skills 中间件状态更新项。"""
|
||||
|
||||
skills_metadata: list[SkillMetadata]
|
||||
"""待合并的 skill 元数据列表。"""
|
||||
|
||||
|
||||
def _parse_skill_metadata( # noqa: C901
|
||||
content: str,
|
||||
skill_path: str,
|
||||
skill_id: str,
|
||||
) -> SkillMetadata | None:
|
||||
"""从 SKILL.md 内容中解析 YAML 前言并验证元数据。"""
|
||||
if len(content) > MAX_SKILL_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Skipping %s: content too large (%d bytes)", skill_path, len(content)
|
||||
)
|
||||
return None
|
||||
|
||||
# 匹配 --- 分隔的 YAML 前言
|
||||
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n"
|
||||
match = re.match(frontmatter_pattern, content, re.DOTALL)
|
||||
if not match:
|
||||
logger.warning("Skipping %s: no valid YAML frontmatter found", skill_path)
|
||||
return None
|
||||
frontmatter_str = match.group(1)
|
||||
|
||||
# 解析 YAML
|
||||
try:
|
||||
frontmatter_data = yaml.safe_load(frontmatter_str)
|
||||
except yaml.YAMLError as e:
|
||||
logger.warning("Invalid YAML in %s: %s", skill_path, e)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter_data, dict):
|
||||
logger.warning("Skipping %s: frontmatter is not a mapping", skill_path)
|
||||
return None
|
||||
|
||||
# SKill名称和描述
|
||||
name = str(frontmatter_data.get("name", "")).strip()
|
||||
description = str(frontmatter_data.get("description", "")).strip()
|
||||
if not name or not description:
|
||||
logger.warning(
|
||||
"Skipping %s: missing required 'name' or 'description'", skill_path
|
||||
)
|
||||
return None
|
||||
description_str = description
|
||||
if len(description_str) > MAX_SKILL_DESCRIPTION_LENGTH:
|
||||
logger.warning(
|
||||
"Description exceeds %d characters in %s, truncating",
|
||||
MAX_SKILL_DESCRIPTION_LENGTH,
|
||||
skill_path,
|
||||
)
|
||||
description_str = description_str[:MAX_SKILL_DESCRIPTION_LENGTH]
|
||||
|
||||
# 可选的工具列表,支持空格或逗号分隔
|
||||
raw_tools = frontmatter_data.get("allowed-tools")
|
||||
if isinstance(raw_tools, str):
|
||||
allowed_tools = [
|
||||
t.strip(",") # 兼容 Claude Code 风格的逗号分隔
|
||||
for t in raw_tools.split()
|
||||
if t.strip(",")
|
||||
]
|
||||
else:
|
||||
if raw_tools is not None:
|
||||
logger.warning(
|
||||
"Ignoring non-string 'allowed-tools' in %s (got %s)",
|
||||
skill_path,
|
||||
type(raw_tools).__name__,
|
||||
)
|
||||
allowed_tools = []
|
||||
|
||||
# 能力或环境兼容性说明,最多 500 字符
|
||||
compatibility_str = str(frontmatter_data.get("compatibility", "")).strip() or None
|
||||
if compatibility_str and len(compatibility_str) > MAX_SKILL_COMPATIBILITY_LENGTH:
|
||||
logger.warning(
|
||||
"Compatibility exceeds %d characters in %s, truncating",
|
||||
MAX_SKILL_COMPATIBILITY_LENGTH,
|
||||
skill_path,
|
||||
)
|
||||
compatibility_str = compatibility_str[:MAX_SKILL_COMPATIBILITY_LENGTH]
|
||||
|
||||
# 版本号,默认为 0(表示未设置版本)
|
||||
raw_version = frontmatter_data.get("version")
|
||||
version = 0
|
||||
if raw_version is not None:
|
||||
try:
|
||||
version = int(raw_version)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Invalid 'version' in %s (got %r), defaulting to 0",
|
||||
skill_path,
|
||||
raw_version,
|
||||
)
|
||||
|
||||
return SkillMetadata(
|
||||
id=skill_id,
|
||||
name=name,
|
||||
version=version,
|
||||
description=description_str,
|
||||
path=skill_path,
|
||||
metadata=_validate_metadata(frontmatter_data.get("metadata", {}), skill_path),
|
||||
license=str(frontmatter_data.get("license", "")).strip() or None,
|
||||
compatibility=compatibility_str,
|
||||
allowed_tools=allowed_tools,
|
||||
)
|
||||
|
||||
|
||||
def _validate_metadata(
|
||||
raw: object,
|
||||
skill_path: str,
|
||||
) -> dict[str, str]:
|
||||
"""验证并规范化 YAML 前言中的元数据字段,确保为 dict[str, str] 类型。"""
|
||||
if not isinstance(raw, dict):
|
||||
if raw:
|
||||
logger.warning(
|
||||
"Ignoring non-dict metadata in %s (got %s)",
|
||||
skill_path,
|
||||
type(raw).__name__,
|
||||
)
|
||||
return {}
|
||||
return {str(k): str(v) for k, v in raw.items()}
|
||||
|
||||
|
||||
def _format_skill_annotations(skill: SkillMetadata) -> str:
|
||||
"""构建许可证和兼容性说明字符串。"""
|
||||
parts: list[str] = []
|
||||
if skill.get("license"):
|
||||
parts.append(f"License: {skill['license']}")
|
||||
if skill.get("compatibility"):
|
||||
parts.append(f"Compatibility: {skill['compatibility']}")
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
async def _alist_skills(source_path: AsyncPath) -> list[SkillMetadata]:
|
||||
"""异步列出指定路径下的所有技能。
|
||||
|
||||
扫描包含 SKILL.md 的目录并解析其元数据。
|
||||
"""
|
||||
skills: list[SkillMetadata] = []
|
||||
|
||||
# 查找所有技能目录 (包含 SKILL.md 的目录)
|
||||
skill_dirs: List[AsyncPath] = []
|
||||
async for path in source_path.iterdir():
|
||||
if await path.is_dir() and await (path / "SKILL.md").is_file():
|
||||
skill_dirs.append(path)
|
||||
|
||||
if not skill_dirs:
|
||||
return []
|
||||
|
||||
# 解析已下载的 SKILL.md
|
||||
for skill_path in skill_dirs:
|
||||
skill_md_path = skill_path / "SKILL.md"
|
||||
|
||||
skill_content = await skill_md_path.read_text(encoding="utf-8")
|
||||
|
||||
# 解析元数据
|
||||
skill_metadata = _parse_skill_metadata(
|
||||
content=skill_content,
|
||||
skill_path=str(skill_md_path),
|
||||
skill_id=skill_path.name,
|
||||
)
|
||||
if skill_metadata:
|
||||
skills.append(skill_metadata)
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
SKILLS_SYSTEM_PROMPT = """
|
||||
<skills_system>
|
||||
You have access to a skills library that provides specialized capabilities and domain knowledge.
|
||||
|
||||
{skills_locations}
|
||||
|
||||
**Available Skills:**
|
||||
|
||||
{skills_list}
|
||||
|
||||
**How to Use Skills (Progressive Disclosure):**
|
||||
|
||||
Skills follow a **progressive disclosure** pattern - you see their name and description above, but only read full instructions when needed:
|
||||
|
||||
1. **Recognize when a skill applies**: Check if the user's task matches a skill's description
|
||||
2. **Read the skill's full instructions**: Use the path shown in the skill list above
|
||||
3. **Follow the skill's instructions**: SKILL.md contains step-by-step workflows, best practices, and examples
|
||||
4. **Access supporting files**: Skills may include helper scripts, configs, or reference docs - use absolute paths
|
||||
|
||||
**Creating New Skills:**
|
||||
|
||||
When you identify a repetitive complex workflow or specialized task that would benefit from being a skill, you can create one:
|
||||
|
||||
1. **Directory Structure**: Create a new directory in one of the skills locations. The directory name is the `skill-id`.
|
||||
- Path format: `<skills_location>/<skill-id>/SKILL.md`
|
||||
- `skill-id` constraints: 1-64 characters, lowercase letters, numbers, and hyphens only.
|
||||
2. **SKILL.md Format**: Must start with a YAML frontmatter followed by markdown instructions.
|
||||
```markdown
|
||||
---
|
||||
name: Brief tool name (Chinese)
|
||||
description: Detailed functional description and use cases (1-1024 chars)
|
||||
allowed-tools: "tool1 tool2" (optional, space-separated list of recommended tools)
|
||||
compatibility: "Environment requirements" (optional, max 500 chars)
|
||||
---
|
||||
# Skill Instructions
|
||||
Step-by-step workflows, best practices, and examples go here.
|
||||
```
|
||||
3. **Supporting Files**: You can add `.py` scripts, `.yaml` configs, or other files within the same skill directory. Reference them using absolute paths in `SKILL.md`.
|
||||
|
||||
**When to Use Skills:**
|
||||
- User's request matches a skill's domain (e.g., "research X" -> web-research skill)
|
||||
- You need specialized knowledge or structured workflows
|
||||
- A skill provides proven patterns for complex tasks
|
||||
|
||||
**Executing Skill Scripts:**
|
||||
Skills may contain Python scripts or other executable files. Always use absolute paths from the skill list.
|
||||
|
||||
**Example Workflow:**
|
||||
|
||||
User: "Can you research the latest developments in quantum computing?"
|
||||
|
||||
1. Check available skills -> See "web-research" skill with its path
|
||||
2. Read the skill using the path shown
|
||||
3. Follow the skill's research workflow (search -> organize -> synthesize)
|
||||
4. Use any helper scripts with absolute paths
|
||||
|
||||
Remember: Skills make you more capable and consistent. When in doubt, check if a skill exists for the task!
|
||||
</skills_system>
|
||||
"""
|
||||
|
||||
|
||||
def _extract_version(skill_md: Path) -> int:
|
||||
"""从 SKILL.md 文件中快速提取 version 字段,无法提取时返回 0。"""
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return 0
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
return 0
|
||||
try:
|
||||
frontmatter = yaml.safe_load(match.group(1))
|
||||
except yaml.YAMLError:
|
||||
return 0
|
||||
if not isinstance(frontmatter, dict):
|
||||
return 0
|
||||
raw = frontmatter.get("version")
|
||||
if raw is None:
|
||||
return 0
|
||||
try:
|
||||
return int(raw)
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
|
||||
def _sync_bundled_skills(bundled_dir: Path, target_dir: Path) -> None:
|
||||
"""将项目自带的技能同步到用户目录。
|
||||
|
||||
- 目标目录中不存在对应技能子目录时,直接复制。
|
||||
- 目标目录中已存在时,比较内置与用户目录中 SKILL.md 的 version 字段:
|
||||
- 内置版本更高时,直接覆盖用户目录中的旧版本。
|
||||
- 版本相同或用户版本更高时,跳过。
|
||||
- 内置 SKILL.md 无 version 字段(视为 0)时,不覆盖。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bundled_dir : Path
|
||||
项目内置技能目录(如 ``ROOT_PATH / "skills"``)。
|
||||
target_dir : Path
|
||||
用户配置技能目录(如 ``CONFIG_PATH / "agent" / "skills"``)。
|
||||
"""
|
||||
if not bundled_dir.is_dir():
|
||||
return
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for skill_src in bundled_dir.iterdir():
|
||||
if not skill_src.is_dir():
|
||||
continue
|
||||
skill_md = skill_src / "SKILL.md"
|
||||
if not skill_md.is_file():
|
||||
continue
|
||||
|
||||
skill_dst = target_dir / skill_src.name
|
||||
|
||||
if not skill_dst.exists():
|
||||
# 目标不存在,直接复制
|
||||
try:
|
||||
shutil.copytree(str(skill_src), str(skill_dst))
|
||||
logger.info(
|
||||
"已自动复制内置技能 '%s' -> '%s'", skill_src.name, skill_dst
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("复制内置技能 '%s' 失败: %s", skill_src.name, e)
|
||||
continue
|
||||
|
||||
# 目标已存在,比较版本号
|
||||
bundled_version = _extract_version(skill_md)
|
||||
if bundled_version <= 0:
|
||||
# 内置技能无版本号,保持旧逻辑不覆盖
|
||||
continue
|
||||
|
||||
user_skill_md = skill_dst / "SKILL.md"
|
||||
user_version = _extract_version(user_skill_md) if user_skill_md.is_file() else 0
|
||||
|
||||
if bundled_version <= user_version:
|
||||
# 用户版本 >= 内置版本,跳过
|
||||
continue
|
||||
|
||||
# 内置版本更高,删除旧版本后覆盖
|
||||
try:
|
||||
shutil.rmtree(str(skill_dst))
|
||||
shutil.copytree(str(skill_src), str(skill_dst))
|
||||
logger.info(
|
||||
"已更新内置技能 '%s' (v%d -> v%d)",
|
||||
skill_src.name,
|
||||
user_version,
|
||||
bundled_version,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("更新内置技能 '%s' 失败: %s", skill_src.name, e)
|
||||
|
||||
|
||||
class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # noqa
|
||||
"""加载并向系统提示词注入 Agent Skill 的中间件。
|
||||
|
||||
按源顺序加载 Skill,后加载的会覆盖重名的。
|
||||
启动时自动将项目内置技能(bundled_skills_dir)同步到用户技能目录。
|
||||
"""
|
||||
|
||||
state_schema = SkillsState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sources: list[str],
|
||||
bundled_skills_dir: str | None = None,
|
||||
) -> None:
|
||||
"""初始化 Skill 中间件。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sources : list[str]
|
||||
用户技能目录列表。
|
||||
bundled_skills_dir : str | None
|
||||
项目内置技能目录路径。若提供,在首次加载前会将其中不存在于
|
||||
sources 首个目录的技能自动复制过去。
|
||||
"""
|
||||
self.sources = sources
|
||||
self.bundled_skills_dir = bundled_skills_dir
|
||||
self.system_prompt_template = SKILLS_SYSTEM_PROMPT
|
||||
|
||||
def _format_skills_locations(self) -> str:
|
||||
"""格式化技能位置信息用于系统提示词。"""
|
||||
locations = []
|
||||
|
||||
for i, source_path in enumerate(self.sources):
|
||||
suffix = " (higher priority)" if i == len(self.sources) - 1 else ""
|
||||
locations.append(f"**MoviePilot Skills**: `{source_path}`{suffix}")
|
||||
|
||||
return "\n".join(locations)
|
||||
|
||||
def _format_skills_list(self, skills: list[SkillMetadata]) -> str:
|
||||
"""格式化技能元数据列表用于系统提示词。"""
|
||||
if not skills:
|
||||
paths = [f"{source_path}" for source_path in self.sources]
|
||||
return f"(No skills available yet. You can create skills in {' or '.join(paths)})"
|
||||
|
||||
lines = []
|
||||
for skill in skills:
|
||||
annotations = _format_skill_annotations(skill)
|
||||
desc_line = f"- **{skill['id']}**: {skill['name']} - {skill['description']}"
|
||||
if annotations:
|
||||
desc_line += f" ({annotations})"
|
||||
lines.append(desc_line)
|
||||
if skill["allowed_tools"]:
|
||||
lines.append(f" -> Allowed tools: {', '.join(skill['allowed_tools'])}")
|
||||
lines.append(f" -> Read `{skill['path']}` for full instructions")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将技能文档注入模型请求的系统消息中。"""
|
||||
skills_metadata = request.state.get("skills_metadata", []) # noqa
|
||||
skills_locations = self._format_skills_locations()
|
||||
skills_list = self._format_skills_list(skills_metadata)
|
||||
|
||||
skills_section = self.system_prompt_template.format(
|
||||
skills_locations=skills_locations,
|
||||
skills_list=skills_list,
|
||||
)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, skills_section
|
||||
)
|
||||
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self, state: SkillsState, runtime: Runtime, config: RunnableConfig
|
||||
) -> SkillsStateUpdate | None: # ty: ignore[invalid-method-override]
|
||||
"""在 Agent 执行前异步加载技能元数据。
|
||||
|
||||
每个会话仅加载一次。若 state 中已有则跳过。
|
||||
首次加载时,会先将内置技能同步到用户目录(如不存在)。
|
||||
"""
|
||||
# 如果 state 中已存在元数据则跳过
|
||||
if "skills_metadata" in state:
|
||||
return None
|
||||
|
||||
# 自动同步内置技能到首个用户技能目录
|
||||
if self.bundled_skills_dir and self.sources:
|
||||
bundled = Path(self.bundled_skills_dir)
|
||||
target = Path(self.sources[0])
|
||||
try:
|
||||
_sync_bundled_skills(bundled, target)
|
||||
except Exception as e:
|
||||
logger.warning("同步内置技能失败: %s", e)
|
||||
|
||||
all_skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
# 遍历源按顺序加载技能,重名时后者覆盖前者
|
||||
for source_path in self.sources:
|
||||
skill_source_path = AsyncPath(source_path)
|
||||
if not await skill_source_path.exists():
|
||||
await skill_source_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
source_skills = await _alist_skills(skill_source_path)
|
||||
for skill in source_skills:
|
||||
all_skills[skill["name"]] = skill
|
||||
|
||||
skills = list(all_skills.values())
|
||||
return SkillsStateUpdate(skills_metadata=skills)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""在模型调用时注入技能文档。"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
|
||||
|
||||
__all__ = ["SkillMetadata", "SkillsMiddleware"]
|
||||
21
app/agent/middleware/utils.py
Normal file
21
app/agent/middleware/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from langchain_core.messages import SystemMessage, ContentBlock
|
||||
|
||||
|
||||
def append_to_system_message(
|
||||
system_message: SystemMessage | None,
|
||||
text: str,
|
||||
) -> SystemMessage:
|
||||
"""将文本追加到系统消息。
|
||||
|
||||
参数:
|
||||
system_message: 现有的系统消息或 None。
|
||||
text: 要添加到系统消息的文本。
|
||||
|
||||
返回:
|
||||
追加了文本的新 SystemMessage。
|
||||
"""
|
||||
new_content: list[ContentBlock] = list(system_message.content_blocks) if system_message else [] # noqa
|
||||
if new_content:
|
||||
text = f"\n\n{text}"
|
||||
new_content.append({"type": "text", "text": text})
|
||||
return SystemMessage(content_blocks=new_content)
|
||||
@@ -1,70 +1,62 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
You are an AI media assistant powered by MoviePilot. You specialize in managing home media ecosystems: searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
Core Capabilities:
|
||||
1. Media Search & Recognition — Identify movies, TV shows, and anime; recognize media from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management — Create rules for automated downloading; monitor trending content.
|
||||
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
|
||||
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
<communication>
|
||||
{verbose_spec}
|
||||
|
||||
## Working Principles
|
||||
- Tone: friendly, concise. Like a knowledgeable friend, not a corporate bot.
|
||||
- Use emojis sparingly (1-3 per response): greetings, completions, errors.
|
||||
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles/paths.
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
</communication>
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
<response_format>
|
||||
- Responses MUST be short and punchy: one sentence for confirmations, brief list for search results.
|
||||
- NO filler phrases like "Let me help you", "Here are the results", "I found..." — skip all unnecessary preamble.
|
||||
- NO repeating what user said.
|
||||
- NO narrating your internal reasoning.
|
||||
- After task completion: one line summary only.
|
||||
- When error occurs: brief acknowledgment + suggestion, then move on.
|
||||
</response_format>
|
||||
|
||||
## Common Operation Workflows
|
||||
<flow>
|
||||
1. Media Discovery: Identify exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. Context Checking: Verify current status (already in library? already subscribed?).
|
||||
3. Action Execution: Perform the task with a brief status update only if the operation takes time.
|
||||
4. Final Confirmation: State the result concisely.
|
||||
</flow>
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
<tool_calling_strategy>
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when all automated methods are exhausted.
|
||||
</tool_calling_strategy>
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
<media_management_rules>
|
||||
1. Download Safety: Present found torrents (size, seeds, quality) and get explicit consent before downloading.
|
||||
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
|
||||
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
|
||||
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
|
||||
</media_management_rules>
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
<system_info>
|
||||
{moviepilot_info}
|
||||
</system_info>
|
||||
|
||||
@@ -1,13 +1,25 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from time import strftime
|
||||
from typing import Dict
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import (
|
||||
ChannelCapability,
|
||||
ChannelCapabilities,
|
||||
MessageChannel,
|
||||
ChannelCapabilityManager,
|
||||
)
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
"""
|
||||
提示词管理器
|
||||
"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
@@ -17,22 +29,20 @@ class PromptManager:
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
"""
|
||||
加载指定的提示词
|
||||
"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
@@ -46,73 +56,122 @@ class PromptManager:
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
|
||||
# 识别渠道
|
||||
markdown_spec = ""
|
||||
msg_channel = (
|
||||
next(
|
||||
(c for c in MessageChannel if c.value.lower() == channel.lower()), None
|
||||
)
|
||||
if channel
|
||||
else None
|
||||
)
|
||||
# 获取渠道能力说明
|
||||
if msg_channel:
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
markdown_spec = self._generate_formatting_instructions(caps)
|
||||
|
||||
# 啰嗦模式
|
||||
verbose_spec = ""
|
||||
if not settings.AI_AGENT_VERBOSE:
|
||||
verbose_spec = (
|
||||
"\n\n[Important Instruction] STRICTLY ENFORCED: DO NOT output any conversational "
|
||||
"text, thinking processes, or explanations before or during tool calls. Call tools "
|
||||
"directly without any transitional phrases. "
|
||||
"You MUST remain completely silent until the task is completely finished. "
|
||||
"DO NOT output any content whatsoever until your final summary reply."
|
||||
)
|
||||
|
||||
# MoviePilot系统信息
|
||||
moviepilot_info = self._get_moviepilot_info()
|
||||
|
||||
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
|
||||
base_prompt = base_prompt.format(
|
||||
markdown_spec=markdown_spec,
|
||||
verbose_spec=verbose_spec,
|
||||
moviepilot_info=moviepilot_info,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
def _get_moviepilot_info() -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
获取MoviePilot系统信息,用于注入到系统提示词中
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
# 获取主机名和IP地址
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
ip_address = socket.gethostbyname(hostname)
|
||||
except Exception: # noqa
|
||||
hostname = "localhost"
|
||||
ip_address = "127.0.0.1"
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
# 配置文件和日志文件目录
|
||||
config_path = str(settings.CONFIG_PATH)
|
||||
log_path = str(settings.LOG_PATH)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
# API地址构建
|
||||
api_port = settings.PORT
|
||||
api_path = settings.API_V1_STR
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
# API令牌
|
||||
api_token = settings.API_TOKEN or "未设置"
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
# 数据库信息
|
||||
db_type = settings.DB_TYPE
|
||||
if db_type == "sqlite":
|
||||
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
|
||||
else:
|
||||
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
|
||||
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
info_lines = [
|
||||
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
f"- 运行环境: {SystemUtils.platform} {'docker' if SystemUtils.is_docker() else ''}",
|
||||
f"- 主机名: {hostname}",
|
||||
f"- IP地址: {ip_address}",
|
||||
f"- API端口: {api_port}",
|
||||
f"- API路径: {api_path}",
|
||||
f"- API令牌: {api_token}",
|
||||
f"- 外网域名: {settings.APP_DOMAIN or '未设置'}",
|
||||
f"- 数据库类型: {db_type}",
|
||||
f"- 数据库: {db_info}",
|
||||
f"- 配置文件目录: {config_path}",
|
||||
f"- 日志文件目录: {log_path}",
|
||||
f"- 系统安装目录: {settings.ROOT_PATH}",
|
||||
]
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
return "\n".join(info_lines)
|
||||
|
||||
@staticmethod
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
根据渠道能力动态生成格式指令
|
||||
"""
|
||||
instructions = []
|
||||
if ChannelCapability.RICH_TEXT not in caps.capabilities:
|
||||
instructions.append(
|
||||
"- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown."
|
||||
)
|
||||
instructions.append(
|
||||
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators)."
|
||||
)
|
||||
instructions.append(
|
||||
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks."
|
||||
)
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
"""
|
||||
清空缓存
|
||||
"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler
|
||||
from app.agent import StreamingHandler
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
@@ -16,55 +20,104 @@ class ToolChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
"""
|
||||
MoviePilot专用工具基类(LangChain v1 / langchain_core)
|
||||
"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: str = PrivateAttr(default=None)
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_channel: Optional[str] = PrivateAttr(default=None)
|
||||
_source: Optional[str] = PrivateAttr(default=None)
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
_require_admin: bool = PrivateAttr(default=False)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
self._require_admin = getattr(self.__class__, "require_admin", False)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送运行工具前的消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
# 发送执行工具说明
|
||||
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
异步运行工具,负责:
|
||||
1. 在工具调用前将流式消息推送给用户
|
||||
2. 持久化工具调用记录到会话记忆
|
||||
3. 调用具体工具逻辑(子类实现的 execute 方法)
|
||||
4. 持久化工具结果到会话记忆
|
||||
5. 权限检查
|
||||
"""
|
||||
|
||||
permission_result = await self._check_permission()
|
||||
if permission_result:
|
||||
return permission_result
|
||||
|
||||
# 获取工具执行提示消息
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
return result
|
||||
|
||||
# 发送工具执行过程消息
|
||||
if self._stream_handler and self._stream_handler.is_streaming:
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
if self._stream_handler.is_auto_flushing:
|
||||
# 渠道支持编辑:工具消息追加到 buffer,由定时刷新推送
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
else:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
else:
|
||||
# 非VERBOSE,重置缓冲区从头更新,保持消息编辑能力
|
||||
self._stream_handler.reset()
|
||||
else:
|
||||
# 未启用流式传输,不发送任何工具消息内容
|
||||
pass
|
||||
|
||||
logger.debug(f"Executing tool {self.name} with args: {kwargs}")
|
||||
|
||||
# 执行具体工具逻辑
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f"Tool {self.name} executed with result: {result}")
|
||||
except Exception as e:
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
|
||||
result = error_message
|
||||
|
||||
# 格式化结果
|
||||
if isinstance(result, str):
|
||||
formatted_result = result
|
||||
elif isinstance(result, (int, float)):
|
||||
formatted_result = str(result)
|
||||
else:
|
||||
formatted_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
return formatted_result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
获取工具执行时的友好提示消息
|
||||
|
||||
获取工具执行时的友好提示消息。
|
||||
|
||||
子类可以重写此方法,根据实际参数生成个性化的提示消息。
|
||||
如果返回 None 或空字符串,将回退使用 explanation 参数。
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: 工具的所有参数(包括 explanation)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation
|
||||
"""
|
||||
@@ -72,20 +125,134 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
"""子类实现具体的工具执行逻辑"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
"""
|
||||
设置消息属性
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
self._callback_handler = callback_handler
|
||||
def set_stream_handler(self, stream_handler: StreamingHandler):
|
||||
"""
|
||||
设置回调处理器
|
||||
"""
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
检查用户权限:
|
||||
1. 首先检查工具是否需要管理员权限
|
||||
2. 如果需要管理员权限,则检查用户是否是渠道管理员
|
||||
3. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员
|
||||
4. 如果都不是系统管理员,检查用户ID是否等于渠道配置的用户ID
|
||||
5. 如果都不是,返回权限拒绝消息
|
||||
"""
|
||||
if not self._require_admin:
|
||||
return None
|
||||
|
||||
if not self._channel or not self._source:
|
||||
return None
|
||||
|
||||
user_id_str = str(self._user_id) if self._user_id else None
|
||||
|
||||
channel_type_map = {
|
||||
MessageChannel.Telegram: "telegram",
|
||||
MessageChannel.Discord: "discord",
|
||||
MessageChannel.Wechat: "wechat",
|
||||
MessageChannel.Slack: "slack",
|
||||
MessageChannel.VoceChat: "vocechat",
|
||||
MessageChannel.SynologyChat: "synologychat",
|
||||
MessageChannel.QQ: "qqbot",
|
||||
}
|
||||
|
||||
channel_type = None
|
||||
for key, value in channel_type_map.items():
|
||||
if self._channel == key.value:
|
||||
channel_type = value
|
||||
break
|
||||
|
||||
if not channel_type:
|
||||
return None
|
||||
|
||||
admin_key_map = {
|
||||
"telegram": "TELEGRAM_ADMINS",
|
||||
"discord": "DISCORD_ADMINS",
|
||||
"wechat": "WECHAT_ADMINS",
|
||||
"slack": "SLACK_ADMINS",
|
||||
"vocechat": "VOCECHAT_ADMINS",
|
||||
"synologychat": "SYNOLOGYCHAT_ADMINS",
|
||||
"qqbot": "QQBOT_ADMINS",
|
||||
}
|
||||
|
||||
user_id_key_map = {
|
||||
"telegram": "TELEGRAM_CHAT_ID",
|
||||
"vocechat": "VOCECHAT_CHANNEL_ID",
|
||||
"wechat": "WECHAT_BOT_CHAT_ID",
|
||||
}
|
||||
|
||||
admin_key = admin_key_map.get(channel_type)
|
||||
user_id_key = user_id_key_map.get(channel_type)
|
||||
|
||||
try:
|
||||
configs = ServiceConfigHelper.get_notification_configs()
|
||||
for config in configs:
|
||||
if config.name == self._source and config.config:
|
||||
channel_admins = config.config.get(admin_key) if admin_key else None
|
||||
if channel_admins:
|
||||
admin_list = [
|
||||
aid.strip()
|
||||
for aid in str(channel_admins).split(",")
|
||||
if aid.strip()
|
||||
]
|
||||
if user_id_str and user_id_str in admin_list:
|
||||
return None
|
||||
|
||||
user = (
|
||||
UserOper().get_by_name(self._username)
|
||||
if self._username
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有渠道管理员或系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系渠道管理员将您的用户ID添加到渠道管理员列表中,"
|
||||
"或联系系统管理员为您设置权限。"
|
||||
)
|
||||
else:
|
||||
user = (
|
||||
UserOper().get_by_name(self._username)
|
||||
if self._username
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
|
||||
if user_id_key:
|
||||
config_user_id = config.config.get(user_id_key)
|
||||
if config_user_id and str(config_user_id) == user_id_str:
|
||||
return None
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系系统管理员为您设置权限。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"检查权限失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
@@ -93,6 +260,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
text=message,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
@@ -27,7 +25,9 @@ from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool
|
||||
from app.agent.tools.impl.recognize_media import RecognizeMediaTool
|
||||
from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
|
||||
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
|
||||
from app.agent.tools.impl.query_media_detail import QueryMediaDetailTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
@@ -36,23 +36,46 @@ from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.delete_download_history import DeleteDownloadHistoryTool
|
||||
from app.agent.tools.impl.delete_transfer_history import DeleteTransferHistoryTool
|
||||
from app.agent.tools.impl.modify_download import ModifyDownloadTool
|
||||
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool
|
||||
from app.agent.tools.impl.transfer_file import TransferFileTool
|
||||
from app.agent.tools.impl.execute_command import ExecuteCommandTool
|
||||
from app.agent.tools.impl.edit_file import EditFileTool
|
||||
from app.agent.tools.impl.write_file import WriteFileTool
|
||||
from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.browse_webpage import BrowseWebpageTool
|
||||
from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool
|
||||
from app.agent.tools.impl.query_plugin_capabilities import QueryPluginCapabilitiesTool
|
||||
from app.agent.tools.impl.run_slash_command import RunSlashCommandTool
|
||||
from app.agent.tools.impl.list_slash_commands import ListSlashCommandsTool
|
||||
from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiersTool
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
"""
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
@@ -61,10 +84,12 @@ class MoviePilotToolFactory:
|
||||
RecognizeMediaTool,
|
||||
ScrapeMetadataTool,
|
||||
QueryEpisodeScheduleTool,
|
||||
QueryMediaDetailTool,
|
||||
AddSubscribeTool,
|
||||
UpdateSubscribeTool,
|
||||
SearchSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
GetSearchResultsTool,
|
||||
SearchWebTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
@@ -75,6 +100,9 @@ class MoviePilotToolFactory:
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadTasksTool,
|
||||
DeleteDownloadTool,
|
||||
DeleteDownloadHistoryTool,
|
||||
DeleteTransferHistoryTool,
|
||||
ModifyDownloadTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
UpdateSiteTool,
|
||||
@@ -92,18 +120,26 @@ class MoviePilotToolFactory:
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool
|
||||
RunWorkflowTool,
|
||||
ExecuteCommandTool,
|
||||
EditFileTool,
|
||||
WriteFileTool,
|
||||
ReadFileTool,
|
||||
BrowseWebpageTool,
|
||||
QueryInstalledPluginsTool,
|
||||
QueryPluginCapabilitiesTool,
|
||||
RunSlashCommandTool,
|
||||
ListSlashCommandsTool,
|
||||
QueryCustomIdentifiersTool,
|
||||
UpdateCustomIdentifiersTool,
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tools.append(tool)
|
||||
|
||||
|
||||
# 加载插件提供的工具
|
||||
plugin_tools_count = 0
|
||||
plugin_tools_info = PluginManager().get_plugin_agent_tools()
|
||||
@@ -115,24 +151,31 @@ class MoviePilotToolFactory:
|
||||
try:
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过")
|
||||
logger.warning(
|
||||
f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过"
|
||||
)
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
logger.debug(
|
||||
f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}")
|
||||
|
||||
logger.error(
|
||||
f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}"
|
||||
)
|
||||
|
||||
builtin_tools_count = len(tool_definitions)
|
||||
if plugin_tools_count > 0:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)")
|
||||
logger.info(
|
||||
f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
|
||||
return tools
|
||||
|
||||
176
app/agent/tools/impl/_torrent_search_utils.py
Normal file
176
app/agent/tools/impl/_torrent_search_utils.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""种子搜索工具辅助函数"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.context import Context
|
||||
from app.utils.crypto import HashUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
SEARCH_RESULT_CACHE_FILE = "__search_result__"
|
||||
TORRENT_RESULT_LIMIT = 50
|
||||
|
||||
|
||||
def build_torrent_ref(context: Optional[Context]) -> str:
|
||||
"""生成用于下载校验的短引用"""
|
||||
if not context or not context.torrent_info:
|
||||
return ""
|
||||
return HashUtils.sha1(context.torrent_info.enclosure or "")[:7]
|
||||
|
||||
|
||||
def sort_season_options(options: List[str]) -> List[str]:
|
||||
"""按前端逻辑排序季集选项"""
|
||||
if len(options) <= 1:
|
||||
return options
|
||||
|
||||
parsed_options = []
|
||||
for index, option in enumerate(options):
|
||||
match = re.match(r"^S(\d+)(?:-S(\d+))?\s*(?:E(\d+)(?:-E(\d+))?)?$", option or "")
|
||||
if not match:
|
||||
parsed_options.append({
|
||||
"original": option,
|
||||
"season_num": 0,
|
||||
"episode_num": 0,
|
||||
"max_episode_num": 0,
|
||||
"is_whole_season": False,
|
||||
"index": index,
|
||||
})
|
||||
continue
|
||||
|
||||
episode_num = int(match.group(3)) if match.group(3) else 0
|
||||
max_episode_num = int(match.group(4)) if match.group(4) else episode_num
|
||||
parsed_options.append({
|
||||
"original": option,
|
||||
"season_num": int(match.group(1)),
|
||||
"episode_num": episode_num,
|
||||
"max_episode_num": max_episode_num,
|
||||
"is_whole_season": not match.group(3),
|
||||
"index": index,
|
||||
})
|
||||
|
||||
whole_seasons = [item for item in parsed_options if item["is_whole_season"]]
|
||||
episodes = [item for item in parsed_options if not item["is_whole_season"]]
|
||||
|
||||
whole_seasons.sort(key=lambda item: (-item["season_num"], item["index"]))
|
||||
episodes.sort(
|
||||
key=lambda item: (
|
||||
-item["season_num"],
|
||||
-(item["max_episode_num"] or item["episode_num"]),
|
||||
-item["episode_num"],
|
||||
item["index"],
|
||||
)
|
||||
)
|
||||
return [item["original"] for item in whole_seasons + episodes]
|
||||
|
||||
|
||||
def append_option(options: List[str], value: Optional[str]) -> None:
|
||||
"""按前端逻辑收集去重后的筛选项"""
|
||||
if value and value not in options:
|
||||
options.append(value)
|
||||
|
||||
|
||||
def build_filter_options(items: List[Context]) -> dict:
|
||||
"""从搜索结果中构建筛选项汇总"""
|
||||
filter_options = {
|
||||
"site": [],
|
||||
"season": [],
|
||||
"freeState": [],
|
||||
"edition": [],
|
||||
"resolution": [],
|
||||
"videoCode": [],
|
||||
"releaseGroup": [],
|
||||
}
|
||||
|
||||
for item in items:
|
||||
torrent_info = item.torrent_info
|
||||
meta_info = item.meta_info
|
||||
append_option(filter_options["site"], getattr(torrent_info, "site_name", None))
|
||||
append_option(filter_options["season"], getattr(meta_info, "season_episode", None))
|
||||
append_option(filter_options["freeState"], getattr(torrent_info, "volume_factor", None))
|
||||
append_option(filter_options["edition"], getattr(meta_info, "edition", None))
|
||||
append_option(filter_options["resolution"], getattr(meta_info, "resource_pix", None))
|
||||
append_option(filter_options["videoCode"], getattr(meta_info, "video_encode", None))
|
||||
append_option(filter_options["releaseGroup"], getattr(meta_info, "resource_team", None))
|
||||
|
||||
filter_options["season"] = sort_season_options(filter_options["season"])
|
||||
return filter_options
|
||||
|
||||
|
||||
def match_filter(filter_values: Optional[List[str]], value: Optional[str]) -> bool:
|
||||
"""匹配前端同款多选筛选规则"""
|
||||
return not filter_values or bool(value and value in filter_values)
|
||||
|
||||
|
||||
def filter_contexts(items: List[Context],
|
||||
site: Optional[List[str]] = None,
|
||||
season: Optional[List[str]] = None,
|
||||
free_state: Optional[List[str]] = None,
|
||||
video_code: Optional[List[str]] = None,
|
||||
edition: Optional[List[str]] = None,
|
||||
resolution: Optional[List[str]] = None,
|
||||
release_group: Optional[List[str]] = None) -> List[Context]:
|
||||
"""按前端同款维度筛选结果"""
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
torrent_info = item.torrent_info
|
||||
meta_info = item.meta_info
|
||||
if (
|
||||
match_filter(site, getattr(torrent_info, "site_name", None))
|
||||
and match_filter(free_state, getattr(torrent_info, "volume_factor", None))
|
||||
and match_filter(season, getattr(meta_info, "season_episode", None))
|
||||
and match_filter(release_group, getattr(meta_info, "resource_team", None))
|
||||
and match_filter(video_code, getattr(meta_info, "video_encode", None))
|
||||
and match_filter(resolution, getattr(meta_info, "resource_pix", None))
|
||||
and match_filter(edition, getattr(meta_info, "edition", None))
|
||||
):
|
||||
filtered_items.append(item)
|
||||
return filtered_items
|
||||
|
||||
|
||||
def simplify_search_result(context: Context, index: int) -> dict:
|
||||
"""精简单条搜索结果"""
|
||||
simplified = {}
|
||||
torrent_info = context.torrent_info
|
||||
meta_info = context.meta_info
|
||||
media_info = context.media_info
|
||||
|
||||
if torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": torrent_info.title,
|
||||
"size": StringUtils.format_size(torrent_info.size),
|
||||
"seeders": torrent_info.seeders,
|
||||
"peers": torrent_info.peers,
|
||||
"site_name": torrent_info.site_name,
|
||||
"torrent_url": f"{build_torrent_ref(context)}:{index}",
|
||||
"page_url": torrent_info.page_url,
|
||||
"volume_factor": torrent_info.volume_factor,
|
||||
"freedate_diff": torrent_info.freedate_diff,
|
||||
"pubdate": torrent_info.pubdate,
|
||||
}
|
||||
|
||||
if media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": media_info.title,
|
||||
"en_title": media_info.en_title,
|
||||
"year": media_info.year,
|
||||
"type": media_info.type.value if media_info.type else None,
|
||||
"season": media_info.season,
|
||||
"tmdb_id": media_info.tmdb_id,
|
||||
}
|
||||
|
||||
if meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": meta_info.name,
|
||||
"cn_name": meta_info.cn_name,
|
||||
"en_name": meta_info.en_name,
|
||||
"year": meta_info.year,
|
||||
"type": meta_info.type.value if meta_info.type else None,
|
||||
"begin_season": meta_info.begin_season,
|
||||
"season_episode": meta_info.season_episode,
|
||||
"resource_team": meta_info.resource_team,
|
||||
"video_encode": meta_info.video_encode,
|
||||
"edition": meta_info.edition,
|
||||
"resource_pix": meta_info.resource_pix,
|
||||
}
|
||||
|
||||
return simplified
|
||||
@@ -1,27 +1,31 @@
|
||||
"""添加下载工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
from app.schemas import TorrentInfo
|
||||
from app.schemas import TorrentInfo, FileURI
|
||||
from app.utils.crypto import HashUtils
|
||||
|
||||
|
||||
class AddDownloadInput(BaseModel):
|
||||
"""添加下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
|
||||
torrent_title: str = Field(...,
|
||||
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
|
||||
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
|
||||
torrent_description: Optional[str] = Field(None,
|
||||
description="Brief description of the torrent content (optional)")
|
||||
torrent_url: List[str] = Field(
|
||||
...,
|
||||
description="One or more torrent_url values. Supports refs from get_search_results (`hash:id`) and magnet links."
|
||||
)
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
@@ -32,75 +36,242 @@ class AddDownloadInput(BaseModel):
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
|
||||
description: str = "Add torrent download tasks using refs from get_search_results or magnet links."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据下载参数生成友好的提示消息"""
|
||||
torrent_title = kwargs.get("torrent_title", "")
|
||||
site_name = kwargs.get("site_name", "")
|
||||
torrent_urls = self._normalize_torrent_urls(kwargs.get("torrent_url"))
|
||||
downloader = kwargs.get("downloader")
|
||||
|
||||
message = f"正在添加下载任务: {torrent_title}"
|
||||
if site_name:
|
||||
message += f" (来源: {site_name})"
|
||||
|
||||
if torrent_urls:
|
||||
if len(torrent_urls) == 1:
|
||||
if self._is_torrent_ref(torrent_urls[0]):
|
||||
message = f"正在添加下载任务: 资源 {torrent_urls[0]}"
|
||||
else:
|
||||
message = "正在添加下载任务: 磁力链接"
|
||||
else:
|
||||
message = f"正在批量添加下载任务: 共 {len(torrent_urls)} 个资源"
|
||||
else:
|
||||
message = "正在添加下载任务"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
|
||||
@staticmethod
|
||||
def _build_torrent_ref(context: Context) -> str:
|
||||
"""生成用于校验缓存项的短引用"""
|
||||
if not context or not context.torrent_info:
|
||||
return ""
|
||||
return HashUtils.sha1(context.torrent_info.enclosure or "")[:7]
|
||||
|
||||
@staticmethod
|
||||
def _is_torrent_ref(torrent_ref: Optional[str]) -> bool:
|
||||
"""判断是否为内部搜索结果引用"""
|
||||
if not torrent_ref:
|
||||
return False
|
||||
return bool(re.fullmatch(r"[0-9a-f]{7}:\d+", str(torrent_ref).strip()))
|
||||
|
||||
@staticmethod
|
||||
def _is_magnet_link_input(torrent_url: Optional[str]) -> bool:
|
||||
"""判断输入是否为允许直接添加的磁力链接"""
|
||||
if not torrent_url:
|
||||
return False
|
||||
value = str(torrent_url).strip()
|
||||
return value.startswith("magnet:")
|
||||
|
||||
@classmethod
|
||||
def _resolve_cached_context(cls, torrent_ref: str) -> Optional[Context]:
|
||||
"""从最近一次搜索缓存中解析种子上下文,仅支持 hash:id 格式"""
|
||||
ref = str(torrent_ref).strip()
|
||||
if ":" not in ref:
|
||||
return None
|
||||
try:
|
||||
ref_hash, ref_index = ref.split(":", 1)
|
||||
index = int(ref_index)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if index < 1:
|
||||
return None
|
||||
|
||||
results = SearchChain().last_search_results() or []
|
||||
if index > len(results):
|
||||
return None
|
||||
context = results[index - 1]
|
||||
if not ref_hash or cls._build_torrent_ref(context) != ref_hash:
|
||||
return None
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _merge_labels_with_system_tag(labels: Optional[str]) -> Optional[str]:
|
||||
"""合并用户标签与系统默认标签,确保任务可被系统管理"""
|
||||
system_tag = (settings.TORRENT_TAG or "").strip()
|
||||
user_labels = [item.strip() for item in (labels or "").split(",") if item.strip()]
|
||||
|
||||
if system_tag and system_tag not in user_labels:
|
||||
user_labels.append(system_tag)
|
||||
|
||||
return ",".join(user_labels) if user_labels else None
|
||||
|
||||
@staticmethod
|
||||
def _format_failed_result(failed_messages: List[str]) -> str:
|
||||
"""统一格式化失败结果"""
|
||||
return ", ".join([message for message in failed_messages if message])
|
||||
|
||||
@staticmethod
|
||||
def _build_failure_message(torrent_ref: str, error_msg: Optional[str] = None) -> str:
|
||||
"""构造失败提示"""
|
||||
normalized_error = (error_msg or "").strip()
|
||||
prefix = "添加种子任务失败:"
|
||||
if normalized_error.startswith(prefix):
|
||||
normalized_error = normalized_error[len(prefix):].lstrip()
|
||||
if AddDownloadTool._is_magnet_link_input(normalized_error):
|
||||
normalized_error = ""
|
||||
if normalized_error:
|
||||
return f"{torrent_ref} {normalized_error}"
|
||||
if AddDownloadTool._is_torrent_ref(torrent_ref):
|
||||
return torrent_ref
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _normalize_torrent_urls(cls, torrent_url: Optional[List[str] | str]) -> List[str]:
|
||||
"""统一规范 torrent_url 输入,保留所有非空值"""
|
||||
if torrent_url is None:
|
||||
return []
|
||||
|
||||
if isinstance(torrent_url, str):
|
||||
candidates = torrent_url.split(",")
|
||||
else:
|
||||
candidates = torrent_url
|
||||
|
||||
return [str(item).strip() for item in candidates if item and str(item).strip()]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_direct_download_dir(save_path: Optional[str]) -> Optional[Path]:
|
||||
"""解析直接下载使用的目录,优先使用 save_path,其次使用默认下载目录"""
|
||||
if save_path:
|
||||
return Path(save_path)
|
||||
|
||||
download_dirs = DirectoryHelper().get_download_dirs()
|
||||
if not download_dirs:
|
||||
return None
|
||||
|
||||
dir_conf = download_dirs[0]
|
||||
if not dir_conf.download_path:
|
||||
return None
|
||||
|
||||
return Path(FileURI(storage=dir_conf.storage or "local", path=dir_conf.download_path).uri)
|
||||
|
||||
async def run(self, torrent_url: Optional[List[str]] = None,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
f"执行工具: {self.name}, 参数: torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
|
||||
try:
|
||||
if not torrent_title or not torrent_url:
|
||||
return "错误:必须提供种子标题和下载链接"
|
||||
torrent_inputs = self._normalize_torrent_urls(torrent_url)
|
||||
if not torrent_inputs:
|
||||
return "错误:torrent_url 不能为空。"
|
||||
|
||||
# 使用DownloadChain添加下载
|
||||
download_chain = DownloadChain()
|
||||
merged_labels = self._merge_labels_with_system_tag(labels)
|
||||
success_count = 0
|
||||
failed_messages = []
|
||||
|
||||
# 根据站点名称查询站点cookie
|
||||
if not site_name:
|
||||
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
return f"错误:未找到站点信息:{site_name}"
|
||||
for torrent_input in torrent_inputs:
|
||||
if self._is_torrent_ref(torrent_input):
|
||||
cached_context = self._resolve_cached_context(torrent_input)
|
||||
if not cached_context or not cached_context.torrent_info:
|
||||
failed_messages.append(f"{torrent_input} 引用无效,请重新使用 get_search_results 查看搜索结果")
|
||||
continue
|
||||
|
||||
# 创建下载上下文
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=torrent_url,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
return "错误:无法识别媒体信息,无法添加下载任务"
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
cached_torrent = cached_context.torrent_info
|
||||
site_name = cached_torrent.site_name
|
||||
torrent_title = cached_torrent.title or torrent_input
|
||||
torrent_description = cached_torrent.description
|
||||
enclosure = cached_torrent.enclosure
|
||||
|
||||
did = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=labels
|
||||
)
|
||||
if did:
|
||||
return f"成功添加下载任务:{torrent_title}"
|
||||
else:
|
||||
return "添加下载任务失败"
|
||||
if not site_name:
|
||||
failed_messages.append(f"{torrent_input} 缺少站点名称")
|
||||
continue
|
||||
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
failed_messages.append(f"{torrent_input} 未找到站点信息 {site_name}")
|
||||
continue
|
||||
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=enclosure,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = cached_context.media_info if cached_context.media_info else None
|
||||
if not media_info:
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
failed_messages.append(f"{torrent_input} 无法识别媒体信息")
|
||||
continue
|
||||
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
else:
|
||||
if not self._is_magnet_link_input(torrent_input):
|
||||
failed_messages.append(
|
||||
f"{torrent_input} 不是有效的下载内容,非 hash:id 时仅支持 magnet: 开头"
|
||||
)
|
||||
continue
|
||||
download_dir = self._resolve_direct_download_dir(save_path)
|
||||
if not download_dir:
|
||||
failed_messages.append(f"{torrent_input} 缺少保存路径,且系统未配置可用下载目录")
|
||||
continue
|
||||
result = download_chain.download(
|
||||
content=torrent_input,
|
||||
download_dir=download_dir,
|
||||
cookie=None,
|
||||
label=merged_labels,
|
||||
downloader=downloader
|
||||
)
|
||||
if result:
|
||||
_, did, _, error_msg = result
|
||||
else:
|
||||
did, error_msg = None, "未找到下载器"
|
||||
if did:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_messages.append(self._build_failure_message(torrent_input, error_msg))
|
||||
continue
|
||||
|
||||
did, error_msg = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=merged_labels,
|
||||
return_detail=True
|
||||
)
|
||||
if did:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_messages.append(self._build_failure_message(torrent_input, error_msg))
|
||||
|
||||
if success_count and not failed_messages:
|
||||
return "任务添加成功"
|
||||
|
||||
if success_count:
|
||||
return f"部分任务添加失败:{self._format_failed_result(failed_messages)}"
|
||||
|
||||
return f"任务添加失败:{self._format_failed_result(failed_messages)}"
|
||||
except Exception as e:
|
||||
logger.error(f"添加下载任务失败: {e}", exc_info=True)
|
||||
return f"添加下载任务时发生错误: {str(e)}"
|
||||
|
||||
@@ -16,11 +16,13 @@ class AddSubscribeInput(BaseModel):
|
||||
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: str = Field(..., description="Release year of the media (required for accurate identification)")
|
||||
media_type: str = Field(...,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
description="Allowed values: movie, tv")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
tmdb_id: Optional[int] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)")
|
||||
douban_id: Optional[str] = Field(None,
|
||||
description="Douban ID for precise media identification (optional, alternative to tmdb_id)")
|
||||
start_episode: Optional[int] = Field(None,
|
||||
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)")
|
||||
total_episode: Optional[int] = Field(None,
|
||||
@@ -32,9 +34,9 @@ class AddSubscribeInput(BaseModel):
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)")
|
||||
description="List of filter rule group names to apply (optional, can be obtained from query_rule_groups tool)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)")
|
||||
description="List of site IDs to search from (optional, can be obtained from query_sites tool)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
@@ -60,26 +62,23 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None,
|
||||
season: Optional[int] = None, tmdb_id: Optional[int] = None,
|
||||
douban_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, douban_id={douban_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
|
||||
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
# 转换 tmdb_id 为整数
|
||||
tmdbid_int = None
|
||||
if tmdb_id:
|
||||
try:
|
||||
tmdbid_int = int(tmdb_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
@@ -99,15 +98,19 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
subscribe_kwargs['sites'] = sites
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
mtype=media_type_enum,
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
**subscribe_kwargs
|
||||
)
|
||||
if sid:
|
||||
if message and "已存在" in message:
|
||||
return f"订阅已存在:{title} ({year})。如需修改参数请先删除旧订阅。"
|
||||
|
||||
result_msg = f"成功添加订阅:{title} ({year})"
|
||||
if subscribe_kwargs:
|
||||
params = []
|
||||
|
||||
539
app/agent/tools/impl/browse_webpage.py
Normal file
539
app/agent/tools/impl/browse_webpage.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""浏览器操作工具 - 让Agent能够通过Playwright控制浏览器进行网页交互"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
# 页面内容最大长度
|
||||
MAX_CONTENT_LENGTH = 8000
|
||||
# 默认超时时间(秒)
|
||||
DEFAULT_TIMEOUT = 30
|
||||
# 截图最大宽度
|
||||
SCREENSHOT_MAX_WIDTH = 1280
|
||||
# 截图最大高度
|
||||
SCREENSHOT_MAX_HEIGHT = 720
|
||||
|
||||
|
||||
class BrowserAction(str, Enum):
|
||||
"""浏览器操作类型"""
|
||||
|
||||
GOTO = "goto"
|
||||
GET_CONTENT = "get_content"
|
||||
SCREENSHOT = "screenshot"
|
||||
CLICK = "click"
|
||||
FILL = "fill"
|
||||
SELECT = "select"
|
||||
EVALUATE = "evaluate"
|
||||
WAIT = "wait"
|
||||
|
||||
|
||||
class BrowseWebpageInput(BaseModel):
|
||||
"""浏览器操作工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this browser action is being performed",
|
||||
)
|
||||
action: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"The browser action to perform. Available actions:\n"
|
||||
"- 'goto': Navigate to a URL, returns page title and text summary\n"
|
||||
"- 'get_content': Get current page content (text or HTML)\n"
|
||||
"- 'screenshot': Take a screenshot of the current page, returns base64 image\n"
|
||||
"- 'click': Click on an element specified by selector\n"
|
||||
"- 'fill': Fill text into an input element specified by selector\n"
|
||||
"- 'select': Select an option from a dropdown element\n"
|
||||
"- 'evaluate': Execute JavaScript code on the page and return the result\n"
|
||||
"- 'wait': Wait for an element to appear on the page"
|
||||
),
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
None, description="URL to navigate to (required for 'goto' action)"
|
||||
)
|
||||
selector: Optional[str] = Field(
|
||||
None,
|
||||
description="CSS selector or text selector for the target element (for 'click', 'fill', 'select', 'wait' actions). "
|
||||
"Supports CSS selectors like '#id', '.class', 'tag', and Playwright text selectors like 'text=Click me'",
|
||||
)
|
||||
value: Optional[str] = Field(
|
||||
None,
|
||||
description="Value to fill into input or option value to select (for 'fill' and 'select' actions)",
|
||||
)
|
||||
script: Optional[str] = Field(
|
||||
None,
|
||||
description="JavaScript code to execute on the page (for 'evaluate' action). "
|
||||
"The script should return a value that can be serialized to JSON.",
|
||||
)
|
||||
content_type: Optional[str] = Field(
|
||||
"text",
|
||||
description="Content type for 'get_content' action: 'text' for readable text, 'html' for raw HTML",
|
||||
)
|
||||
timeout: Optional[int] = Field(
|
||||
DEFAULT_TIMEOUT, description="Timeout in seconds for the action (default: 30)"
|
||||
)
|
||||
cookies: Optional[str] = Field(
|
||||
None,
|
||||
description="Cookies to set for the browser context, format: 'name1=value1; name2=value2'",
|
||||
)
|
||||
user_agent: Optional[str] = Field(
|
||||
None, description="Custom User-Agent string for the browser context"
|
||||
)
|
||||
|
||||
|
||||
class BrowseWebpageTool(MoviePilotTool):
|
||||
name: str = "browse_webpage"
|
||||
description: str = (
|
||||
"Control a real browser (Playwright) to interact with web pages. "
|
||||
"Supports navigating to URLs, reading page content, taking screenshots, "
|
||||
"clicking elements, filling forms, selecting dropdown options, executing JavaScript, and waiting for elements. "
|
||||
"Use this tool when you need to interact with dynamic web pages, "
|
||||
"fill in forms, click buttons, or extract content from JavaScript-rendered pages. "
|
||||
"The browser session persists across multiple calls within the same conversation - "
|
||||
"first call 'goto' to open a page, then use other actions to interact with it."
|
||||
)
|
||||
args_schema: Type[BaseModel] = BrowseWebpageInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据操作类型生成友好的提示消息"""
|
||||
action = kwargs.get("action", "")
|
||||
url = kwargs.get("url", "")
|
||||
selector = kwargs.get("selector", "")
|
||||
action_messages = {
|
||||
"goto": f"正在打开网页: {url}",
|
||||
"get_content": "正在获取页面内容",
|
||||
"screenshot": "正在截取页面截图",
|
||||
"click": f"正在点击元素: {selector}",
|
||||
"fill": f"正在填写表单: {selector}",
|
||||
"select": f"正在选择选项: {selector}",
|
||||
"evaluate": "正在执行 JavaScript",
|
||||
"wait": f"正在等待元素: {selector}",
|
||||
}
|
||||
return action_messages.get(action, f"正在执行浏览器操作: {action}")
|
||||
|
||||
async def run(
|
||||
self,
|
||||
action: str,
|
||||
url: Optional[str] = None,
|
||||
selector: Optional[str] = None,
|
||||
value: Optional[str] = None,
|
||||
script: Optional[str] = None,
|
||||
content_type: Optional[str] = "text",
|
||||
timeout: Optional[int] = DEFAULT_TIMEOUT,
|
||||
cookies: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""执行浏览器操作"""
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 动作: {action}, URL: {url}, 选择器: {selector}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证操作类型
|
||||
try:
|
||||
browser_action = BrowserAction(action)
|
||||
except ValueError:
|
||||
valid_actions = ", ".join([a.value for a in BrowserAction])
|
||||
return f"错误: 不支持的操作类型 '{action}',支持的操作: {valid_actions}"
|
||||
|
||||
# 参数校验
|
||||
if browser_action == BrowserAction.GOTO and not url:
|
||||
return "错误: 'goto' 操作需要提供 url 参数"
|
||||
if (
|
||||
browser_action
|
||||
in (
|
||||
BrowserAction.CLICK,
|
||||
BrowserAction.FILL,
|
||||
BrowserAction.SELECT,
|
||||
BrowserAction.WAIT,
|
||||
)
|
||||
and not selector
|
||||
):
|
||||
return f"错误: '{action}' 操作需要提供 selector 参数"
|
||||
if browser_action == BrowserAction.FILL and value is None:
|
||||
return "错误: 'fill' 操作需要提供 value 参数"
|
||||
if browser_action == BrowserAction.EVALUATE and not script:
|
||||
return "错误: 'evaluate' 操作需要提供 script 参数"
|
||||
|
||||
# 在线程池中运行同步的 Playwright 操作
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._execute_browser_action(
|
||||
browser_action=browser_action,
|
||||
url=url,
|
||||
selector=selector,
|
||||
value=value,
|
||||
script=script,
|
||||
content_type=content_type,
|
||||
timeout=timeout,
|
||||
cookies=cookies,
|
||||
user_agent=user_agent,
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"浏览器操作失败: {e}", exc_info=True)
|
||||
return f"浏览器操作失败: {str(e)}"
|
||||
|
||||
def _execute_browser_action(
|
||||
self,
|
||||
browser_action: BrowserAction,
|
||||
url: Optional[str],
|
||||
selector: Optional[str],
|
||||
value: Optional[str],
|
||||
script: Optional[str],
|
||||
content_type: Optional[str],
|
||||
timeout: int,
|
||||
cookies: Optional[str],
|
||||
user_agent: Optional[str],
|
||||
) -> str:
|
||||
"""在同步上下文中执行 Playwright 浏览器操作"""
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
browser = None
|
||||
context = None
|
||||
page = None
|
||||
try:
|
||||
# 启动浏览器
|
||||
browser_type = settings.PLAYWRIGHT_BROWSER_TYPE or "chromium"
|
||||
browser = playwright[browser_type].launch(headless=True)
|
||||
|
||||
# 创建上下文
|
||||
context_kwargs = {}
|
||||
if user_agent:
|
||||
context_kwargs["user_agent"] = user_agent
|
||||
# 设置视口大小
|
||||
context_kwargs["viewport"] = {
|
||||
"width": SCREENSHOT_MAX_WIDTH,
|
||||
"height": SCREENSHOT_MAX_HEIGHT,
|
||||
}
|
||||
|
||||
context = browser.new_context(**context_kwargs)
|
||||
page = context.new_page()
|
||||
page.set_default_timeout(timeout * 1000)
|
||||
|
||||
# 设置 cookies
|
||||
if cookies:
|
||||
page.set_extra_http_headers({"cookie": cookies})
|
||||
|
||||
# 对于非 goto 操作,如果提供了 url 先导航
|
||||
if url and browser_action != BrowserAction.GOTO:
|
||||
page.goto(
|
||||
url, wait_until="domcontentloaded", timeout=timeout * 1000
|
||||
)
|
||||
page.wait_for_load_state("networkidle", timeout=timeout * 1000)
|
||||
|
||||
# 执行具体操作
|
||||
result = self._do_action(
|
||||
page,
|
||||
browser_action,
|
||||
url,
|
||||
selector,
|
||||
value,
|
||||
script,
|
||||
content_type,
|
||||
timeout,
|
||||
)
|
||||
return result
|
||||
|
||||
finally:
|
||||
if page:
|
||||
page.close()
|
||||
if context:
|
||||
context.close()
|
||||
if browser:
|
||||
browser.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Playwright 执行失败: {e}", exc_info=True)
|
||||
return f"Playwright 执行失败: {str(e)}"
|
||||
|
||||
def _do_action(
|
||||
self,
|
||||
page,
|
||||
browser_action: BrowserAction,
|
||||
url: Optional[str],
|
||||
selector: Optional[str],
|
||||
value: Optional[str],
|
||||
script: Optional[str],
|
||||
content_type: Optional[str],
|
||||
timeout: int,
|
||||
) -> str:
|
||||
"""执行具体的浏览器操作"""
|
||||
|
||||
if browser_action == BrowserAction.GOTO:
|
||||
return self._action_goto(page, url, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.GET_CONTENT:
|
||||
return self._action_get_content(page, content_type)
|
||||
|
||||
elif browser_action == BrowserAction.SCREENSHOT:
|
||||
return self._action_screenshot(page)
|
||||
|
||||
elif browser_action == BrowserAction.CLICK:
|
||||
return self._action_click(page, selector, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.FILL:
|
||||
return self._action_fill(page, selector, value, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.SELECT:
|
||||
return self._action_select(page, selector, value, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.EVALUATE:
|
||||
return self._action_evaluate(page, script)
|
||||
|
||||
elif browser_action == BrowserAction.WAIT:
|
||||
return self._action_wait(page, selector, timeout)
|
||||
|
||||
return f"未知操作: {browser_action}"
|
||||
|
||||
@staticmethod
|
||||
def _action_goto(page, url: str, timeout: int) -> str:
|
||||
"""导航到URL"""
|
||||
response = page.goto(url, wait_until="domcontentloaded", timeout=timeout * 1000)
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=min(timeout, 15) * 1000)
|
||||
except Exception:
|
||||
# networkidle 超时不是致命错误,页面可能已经可用
|
||||
pass
|
||||
|
||||
status = response.status if response else "unknown"
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
# 提取页面可读文本摘要
|
||||
text_content = page.inner_text("body")
|
||||
if text_content and len(text_content) > MAX_CONTENT_LENGTH:
|
||||
text_content = text_content[:MAX_CONTENT_LENGTH] + "\n\n...(内容已截断)"
|
||||
|
||||
# 提取页面链接
|
||||
links = page.evaluate("""
|
||||
() => {
|
||||
const links = [];
|
||||
document.querySelectorAll('a[href]').forEach(a => {
|
||||
const text = a.innerText.trim();
|
||||
const href = a.href;
|
||||
if (text && href && !href.startsWith('javascript:')) {
|
||||
links.push({text: text.substring(0, 80), href: href});
|
||||
}
|
||||
});
|
||||
return links.slice(0, 30);
|
||||
}
|
||||
""")
|
||||
|
||||
# 提取表单信息
|
||||
forms = page.evaluate("""
|
||||
() => {
|
||||
const forms = [];
|
||||
document.querySelectorAll('input, textarea, select, button').forEach(el => {
|
||||
const info = {
|
||||
tag: el.tagName.toLowerCase(),
|
||||
type: el.type || '',
|
||||
name: el.name || '',
|
||||
id: el.id || '',
|
||||
placeholder: el.placeholder || '',
|
||||
value: el.tagName.toLowerCase() === 'select' ? '' : (el.value || '').substring(0, 50),
|
||||
text: el.innerText ? el.innerText.trim().substring(0, 50) : ''
|
||||
};
|
||||
// 只保留有标识信息的元素
|
||||
if (info.name || info.id || info.placeholder || info.text) {
|
||||
forms.push(info);
|
||||
}
|
||||
});
|
||||
return forms.slice(0, 30);
|
||||
}
|
||||
""")
|
||||
|
||||
result = {
|
||||
"status": status,
|
||||
"url": page_url,
|
||||
"title": title,
|
||||
"text_content": text_content,
|
||||
}
|
||||
if links:
|
||||
result["links"] = links
|
||||
if forms:
|
||||
result["form_elements"] = forms
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _action_get_content(page, content_type: Optional[str]) -> str:
|
||||
"""获取页面内容"""
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
if content_type == "html":
|
||||
content = page.content()
|
||||
else:
|
||||
content = page.inner_text("body")
|
||||
|
||||
if content and len(content) > MAX_CONTENT_LENGTH:
|
||||
content = content[:MAX_CONTENT_LENGTH] + "\n\n...(内容已截断)"
|
||||
|
||||
result = {
|
||||
"url": page_url,
|
||||
"title": title,
|
||||
"content_type": content_type,
|
||||
"content": content,
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _action_screenshot(page) -> str:
|
||||
"""截取页面截图"""
|
||||
screenshot_bytes = page.screenshot(
|
||||
full_page=False,
|
||||
type="jpeg",
|
||||
quality=60,
|
||||
)
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
# 限制截图大小(base64编码后大约增大33%)
|
||||
max_b64_size = 200 * 1024 # ~150KB 原始图片
|
||||
if len(screenshot_b64) > max_b64_size:
|
||||
# 降低质量重新截图
|
||||
screenshot_bytes = page.screenshot(
|
||||
full_page=False,
|
||||
type="jpeg",
|
||||
quality=30,
|
||||
)
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
result = {
|
||||
"url": page_url,
|
||||
"title": title,
|
||||
"screenshot_base64": screenshot_b64,
|
||||
"format": "jpeg",
|
||||
"note": "截图已以 base64 编码返回",
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _action_click(page, selector: str, timeout: int) -> str:
|
||||
"""点击元素"""
|
||||
page.click(selector, timeout=timeout * 1000)
|
||||
|
||||
# 等待可能的页面变化
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=5000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功点击元素: {selector}",
|
||||
"current_url": page_url,
|
||||
"current_title": title,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_fill(page, selector: str, value: str, timeout: int) -> str:
|
||||
"""填写表单"""
|
||||
page.fill(selector, value, timeout=timeout * 1000)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功填写元素 '{selector}' 的值为 '{value}'",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_select(page, selector: str, value: Optional[str], timeout: int) -> str:
|
||||
"""选择下拉选项"""
|
||||
if value:
|
||||
page.select_option(selector, value=value, timeout=timeout * 1000)
|
||||
else:
|
||||
return "错误: 'select' 操作需要提供 value 参数"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功选择元素 '{selector}' 的选项 '{value}'",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_evaluate(page, script: str) -> str:
|
||||
"""执行 JavaScript"""
|
||||
result = page.evaluate(script)
|
||||
|
||||
# 格式化结果
|
||||
if result is None:
|
||||
formatted = "null"
|
||||
elif isinstance(result, (dict, list)):
|
||||
formatted = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
formatted = str(result)
|
||||
|
||||
# 限制结果长度
|
||||
if len(formatted) > MAX_CONTENT_LENGTH:
|
||||
formatted = formatted[:MAX_CONTENT_LENGTH] + "\n\n...(结果已截断)"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"result": formatted,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_wait(page, selector: str, timeout: int) -> str:
|
||||
"""等待元素出现"""
|
||||
element = page.wait_for_selector(selector, timeout=timeout * 1000)
|
||||
|
||||
if element:
|
||||
visible = element.is_visible()
|
||||
text = element.inner_text()
|
||||
if text and len(text) > 200:
|
||||
text = text[:200] + "..."
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"元素 '{selector}' 已出现",
|
||||
"visible": visible,
|
||||
"text": text,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
else:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"等待元素 '{selector}' 超时",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
@@ -11,66 +11,73 @@ from app.log import logger
|
||||
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
task_identifier: str = Field(..., description="Task identifier: can be task hash (unique identifier) or task title/name")
|
||||
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
|
||||
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of specific downloader (optional, if not provided will search all downloaders)",
|
||||
)
|
||||
delete_files: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
description: str = "Delete a download task from the downloader. Can delete by task hash (unique identifier) or task title/name. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
task_identifier = kwargs.get("task_identifier", "")
|
||||
hash_value = kwargs.get("hash", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
message = f"正在删除下载任务: {task_identifier}"
|
||||
|
||||
message = f"正在删除下载任务: {hash_value}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
message += " (包含文件)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, task_identifier: str, downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: task_identifier={task_identifier}, downloader={downloader}, delete_files={delete_files}")
|
||||
async def run(
|
||||
self,
|
||||
hash: str,
|
||||
downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}"
|
||||
)
|
||||
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 如果task_identifier看起来像hash(通常是40个字符的十六进制字符串)
|
||||
task_hash = None
|
||||
if len(task_identifier) == 40 and all(c in '0123456789abcdefABCDEF' for c in task_identifier):
|
||||
# 直接使用hash
|
||||
task_hash = task_identifier
|
||||
else:
|
||||
# 通过标题查找任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
for dl in downloads:
|
||||
# 检查标题或名称是否匹配
|
||||
if (task_identifier.lower() in (dl.title or "").lower()) or \
|
||||
(task_identifier.lower() in (dl.name or "").lower()):
|
||||
task_hash = dl.hash
|
||||
break
|
||||
|
||||
if not task_hash:
|
||||
return f"未找到匹配的下载任务:{task_identifier},请使用 query_downloads 工具查询可用的下载任务"
|
||||
|
||||
|
||||
# 仅支持通过hash删除任务
|
||||
if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash):
|
||||
return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
|
||||
|
||||
# 删除下载任务
|
||||
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
|
||||
result = download_chain.remove_torrents(hashs=[task_hash], downloader=downloader, delete_file=delete_files)
|
||||
|
||||
result = download_chain.remove_torrents(
|
||||
hashs=[hash], downloader=downloader, delete_file=delete_files
|
||||
)
|
||||
|
||||
if result:
|
||||
files_info = "(包含文件)" if delete_files else "(不包含文件)"
|
||||
return f"成功删除下载任务:{task_identifier} {files_info}"
|
||||
return f"成功删除下载任务:{hash} {files_info}"
|
||||
else:
|
||||
return f"删除下载任务失败:{task_identifier},请检查任务是否存在或下载器是否可用"
|
||||
return f"删除下载任务失败:{hash},请检查任务是否存在或下载器是否可用"
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载任务失败: {e}", exc_info=True)
|
||||
return f"删除下载任务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
44
app/agent/tools/impl/delete_download_history.py
Normal file
44
app/agent/tools/impl/delete_download_history.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""删除下载历史记录工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class DeleteDownloadHistoryInput(BaseModel):
|
||||
"""删除下载历史记录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
history_id: int = Field(
|
||||
..., description="The ID of the download history record to delete"
|
||||
)
|
||||
|
||||
|
||||
class DeleteDownloadHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_download_history"
|
||||
description: str = "Delete a download history record by ID. This only removes the record from the database, does not delete any actual files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
history_id = kwargs.get("history_id")
|
||||
return f"正在删除下载历史记录 ID: {history_id}"
|
||||
|
||||
async def run(self, history_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
try:
|
||||
async with AsyncSessionFactory() as db:
|
||||
await DownloadHistory.async_delete(db, history_id)
|
||||
return f"下载历史记录 ID: {history_id} 已成功删除"
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载历史记录失败: {e}", exc_info=True)
|
||||
return f"删除下载历史记录时发生错误: {str(e)}"
|
||||
@@ -14,14 +14,22 @@ from app.schemas.types import EventType
|
||||
|
||||
class DeleteSubscribeInput(BaseModel):
|
||||
"""删除订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to delete (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteSubscribeTool(MoviePilotTool):
|
||||
name: str = "delete_subscribe"
|
||||
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
|
||||
args_schema: Type[BaseModel] = DeleteSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
@@ -37,27 +45,25 @@ class DeleteSubscribeTool(MoviePilotTool):
|
||||
subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return f"订阅 ID {subscribe_id} 不存在"
|
||||
|
||||
|
||||
# 在删除之前获取订阅信息(用于事件)
|
||||
subscribe_info = subscribe.to_dict()
|
||||
|
||||
|
||||
# 删除订阅
|
||||
subscribe_oper.delete(subscribe_id)
|
||||
|
||||
|
||||
# 发送事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeDeleted,
|
||||
{"subscribe_id": subscribe_id, "subscribe_info": subscribe_info},
|
||||
)
|
||||
|
||||
# 统计订阅
|
||||
SubscribeHelper().sub_done_async({
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
})
|
||||
|
||||
SubscribeHelper().sub_done_async(
|
||||
{"tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid}
|
||||
)
|
||||
|
||||
return f"成功删除订阅:{subscribe.name} ({subscribe.year})"
|
||||
except Exception as e:
|
||||
logger.error(f"删除订阅失败: {e}", exc_info=True)
|
||||
return f"删除订阅时发生错误: {str(e)}"
|
||||
|
||||
|
||||
57
app/agent/tools/impl/delete_transfer_history.py
Normal file
57
app/agent/tools/impl/delete_transfer_history.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""删除整理历史记录工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class DeleteTransferHistoryInput(BaseModel):
|
||||
"""删除整理历史记录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
history_id: int = Field(
|
||||
..., description="The ID of the transfer history record to delete"
|
||||
)
|
||||
|
||||
|
||||
class DeleteTransferHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_transfer_history"
|
||||
description: str = "Delete a specific transfer history record by its ID. This is useful when you need to remove a failed transfer record before retrying the transfer, as the system skips files that already have transfer history."
|
||||
args_schema: Type[BaseModel] = DeleteTransferHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
history_id = kwargs.get("history_id")
|
||||
return f"正在删除整理历史记录: ID={history_id}"
|
||||
|
||||
async def run(self, history_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
try:
|
||||
transferhis = TransferHistoryOper()
|
||||
|
||||
# 查询历史记录是否存在
|
||||
history = transferhis.get(history_id)
|
||||
if not history:
|
||||
return f"错误:整理历史记录不存在,ID={history_id}"
|
||||
|
||||
# 保存信息用于返回
|
||||
title = history.title or "未知"
|
||||
src = history.src or "未知"
|
||||
status = "成功" if history.status else "失败"
|
||||
|
||||
# 删除记录
|
||||
transferhis.delete(history_id)
|
||||
|
||||
return f"已删除整理历史记录:ID={history_id},标题={title},源路径={src},状态={status}"
|
||||
except Exception as e:
|
||||
logger.error(f"删除整理历史记录失败: {e}", exc_info=True)
|
||||
return f"删除整理历史记录时发生错误: {str(e)}"
|
||||
74
app/agent/tools/impl/edit_file.py
Normal file
74
app/agent/tools/impl/edit_file.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""文件编辑工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class EditFileInput(BaseModel):
|
||||
"""Input parameters for edit file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to edit")
|
||||
old_text: str = Field(..., description="The exact old text to be replaced")
|
||||
new_text: str = Field(..., description="The new text to replace with")
|
||||
|
||||
|
||||
class EditFileTool(MoviePilotTool):
|
||||
name: str = "edit_file"
|
||||
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在编辑文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, old_text: str, new_text: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
# 校验逻辑:如果要替换特定文本,文件必须存在且包含该文本
|
||||
if not await path.exists():
|
||||
# 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容
|
||||
if old_text:
|
||||
return f"错误:文件 {file_path} 不存在,无法进行内容替换。"
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
|
||||
if await path.exists():
|
||||
content = await path.read_text(encoding="utf-8")
|
||||
if old_text not in content:
|
||||
logger.warning(f"编辑文件 {file_path} 失败:未找到指定的旧文本块")
|
||||
return f"错误:在文件 {file_path} 中未找到指定的旧文本。请确保包含所有的空格、缩进 and 换行符。"
|
||||
occurrences = content.count(old_text)
|
||||
new_content = content.replace(old_text, new_text)
|
||||
else:
|
||||
# 文件不存在且 old_text 为空的情形(初始化新文件)
|
||||
new_content = new_text
|
||||
occurrences = 1
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容")
|
||||
return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)"
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有访问/修改 {file_path} 的权限"
|
||||
except UnicodeDecodeError:
|
||||
return f"错误:{file_path} 不是文本文件,无法编辑"
|
||||
except Exception as e:
|
||||
logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
96
app/agent/tools/impl/execute_command.py
Normal file
96
app/agent/tools/impl/execute_command.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""执行Shell命令工具"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
..., description="Clear explanation of why this command is being executed"
|
||||
)
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(
|
||||
60, description="Max execution time in seconds (default: 60)"
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
name: str = "execute_command"
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行系统命令: {command}"
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
|
||||
)
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = [
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_str = stderr.decode("utf-8", errors="replace").strip()
|
||||
exit_code = process.returncode
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
if stdout_str:
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时处理
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return f"命令执行超时 (限制: {timeout}秒)"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return f"执行命令时发生错误: {str(e)}"
|
||||
@@ -8,44 +8,53 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
source: Optional[str] = Field("tmdb_trending",
|
||||
description="Recommendation source: "
|
||||
"'tmdb_trending' for TMDB trending content, "
|
||||
"'tmdb_movies' for TMDB popular movies, "
|
||||
"'tmdb_tvs' for TMDB popular TV shows, "
|
||||
"'douban_hot' for Douban popular content, "
|
||||
"'douban_movie_hot' for Douban hot movies, "
|
||||
"'douban_tv_hot' for Douban hot TV shows, "
|
||||
"'douban_movie_showing' for Douban movies currently showing, "
|
||||
"'douban_movies' for Douban latest movies, "
|
||||
"'douban_tvs' for Douban latest TV shows, "
|
||||
"'douban_movie_top250' for Douban movie TOP250, "
|
||||
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
|
||||
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
|
||||
"'douban_tv_animation' for Douban popular animation, "
|
||||
"'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
limit: Optional[int] = Field(20,
|
||||
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
source: Optional[str] = Field(
|
||||
"tmdb_trending",
|
||||
description="Recommendation source: "
|
||||
"'tmdb_trending' for TMDB trending content, "
|
||||
"'tmdb_movies' for TMDB popular movies, "
|
||||
"'tmdb_tvs' for TMDB popular TV shows, "
|
||||
"'douban_hot' for Douban popular content, "
|
||||
"'douban_movie_hot' for Douban hot movies, "
|
||||
"'douban_tv_hot' for Douban hot TV shows, "
|
||||
"'douban_movie_showing' for Douban movies currently showing, "
|
||||
"'douban_movies' for Douban latest movies, "
|
||||
"'douban_tvs' for Douban latest TV shows, "
|
||||
"'douban_movie_top250' for Douban movie TOP250, "
|
||||
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
|
||||
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
|
||||
"'douban_tv_animation' for Douban popular animation, "
|
||||
"'bangumi_calendar' for Bangumi anime calendar",
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
"all", description="Allowed values: movie, tv, all"
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1, description="Page number for pagination (default: 1, 20 items per page)"
|
||||
)
|
||||
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules. Supports pagination with 20 items per page."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据推荐参数生成友好的提示消息"""
|
||||
source = kwargs.get("source", "tmdb_trending")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
limit = kwargs.get("limit", 20)
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
source_map = {
|
||||
"tmdb_trending": "TMDB流行趋势",
|
||||
"tmdb_movies": "TMDB热门电影",
|
||||
@@ -60,110 +69,160 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
"douban_tv_weekly_chinese": "豆瓣国产剧集榜",
|
||||
"douban_tv_weekly_global": "豆瓣全球剧集榜",
|
||||
"douban_tv_animation": "豆瓣热门动漫",
|
||||
"bangumi_calendar": "番组计划"
|
||||
"bangumi_calendar": "番组计划",
|
||||
}
|
||||
source_desc = source_map.get(source, source)
|
||||
|
||||
|
||||
message = f"正在获取推荐: {source_desc}"
|
||||
if media_type != "all":
|
||||
message += f" [{media_type}]"
|
||||
message += f" (限制: {limit}条)"
|
||||
|
||||
message += f" (第{page}页)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
async def run(
|
||||
self,
|
||||
source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all",
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
page_size = 20
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, page={page}"
|
||||
)
|
||||
try:
|
||||
if media_type != "all":
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
media_type = media_type_enum.to_agent() # 归一化为 "movie"/"tv"
|
||||
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
# async_tmdb_trending 只接受 page 参数,返回固定数量的结果
|
||||
# 如果需要限制数量,需要在返回后截取
|
||||
results = await recommend_chain.async_tmdb_trending(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
results = await recommend_chain.async_tmdb_trending(page=page)
|
||||
elif source == "tmdb_movies":
|
||||
# async_tmdb_movies 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_movies(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
results = await recommend_chain.async_tmdb_movies(page=page)
|
||||
elif source == "tmdb_tvs":
|
||||
# async_tmdb_tvs 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_tvs(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
results = await recommend_chain.async_tmdb_tvs(page=page)
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif media_type == "tv":
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
else: # all
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit))
|
||||
results.extend(
|
||||
await recommend_chain.async_douban_movie_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
)
|
||||
results.extend(
|
||||
await recommend_chain.async_douban_tv_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
)
|
||||
elif source == "douban_movie_hot":
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_hot":
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_movie_showing":
|
||||
results = await recommend_chain.async_douban_movie_showing(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_showing(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_movies":
|
||||
results = await recommend_chain.async_douban_movies(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movies(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tvs":
|
||||
results = await recommend_chain.async_douban_tvs(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tvs(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_movie_top250":
|
||||
results = await recommend_chain.async_douban_movie_top250(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_top250(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_weekly_chinese":
|
||||
results = await recommend_chain.async_douban_tv_weekly_chinese(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_weekly_chinese(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_weekly_global":
|
||||
results = await recommend_chain.async_douban_tv_weekly_global(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_weekly_global(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_animation":
|
||||
results = await recommend_chain.async_douban_tv_animation(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_animation(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "bangumi_calendar":
|
||||
results = await recommend_chain.async_bangumi_calendar(page=1, count=limit)
|
||||
results = await recommend_chain.async_bangumi_calendar(
|
||||
page=page, count=page_size
|
||||
)
|
||||
else:
|
||||
# 不支持的推荐来源
|
||||
supported_sources = [
|
||||
"tmdb_trending", "tmdb_movies", "tmdb_tvs",
|
||||
"douban_hot", "douban_movie_hot", "douban_tv_hot",
|
||||
"douban_movie_showing", "douban_movies", "douban_tvs",
|
||||
"douban_movie_top250", "douban_tv_weekly_chinese",
|
||||
"douban_tv_weekly_global", "douban_tv_animation",
|
||||
"bangumi_calendar"
|
||||
"tmdb_trending",
|
||||
"tmdb_movies",
|
||||
"tmdb_tvs",
|
||||
"douban_hot",
|
||||
"douban_movie_hot",
|
||||
"douban_tv_hot",
|
||||
"douban_movie_showing",
|
||||
"douban_movies",
|
||||
"douban_tvs",
|
||||
"douban_movie_top250",
|
||||
"douban_tv_weekly_chinese",
|
||||
"douban_tv_weekly_global",
|
||||
"douban_tv_animation",
|
||||
"bangumi_calendar",
|
||||
]
|
||||
return f"不支持的推荐来源: {source}。支持的来源包括: {', '.join(supported_sources)}"
|
||||
|
||||
if results:
|
||||
# 限制最多20条结果
|
||||
# 对于TMDB来源,API自身按页返回,取前page_size条
|
||||
total_count = len(results)
|
||||
limited_results = results[:20]
|
||||
page_results = results[:page_size]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
for r in page_results:
|
||||
# r 应该是字典格式(to_dict的结果),但为了安全起见进行检查
|
||||
if not isinstance(r, dict):
|
||||
logger.warning(f"推荐结果格式异常,跳过: {type(r)}")
|
||||
continue
|
||||
|
||||
|
||||
simplified = {
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
"year": r.get("year"),
|
||||
"type": r.get("type"),
|
||||
"type": media_type_to_agent(r.get("type")),
|
||||
"season": r.get("season"),
|
||||
"tmdb_id": r.get("tmdb_id"),
|
||||
"imdb_id": r.get("imdb_id"),
|
||||
"douban_id": r.get("douban_id"),
|
||||
"vote_average": r.get("vote_average"),
|
||||
"poster_path": r.get("poster_path"),
|
||||
"detail_link": r.get("detail_link")
|
||||
"detail_link": r.get("detail_link"),
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
result_json = json.dumps(
|
||||
simplified_results, ensure_ascii=False, indent=2
|
||||
)
|
||||
has_more = total_count > page_size
|
||||
payload_msg = f"第 {page} 页,当前页 {len(simplified_results)} 条结果。"
|
||||
if has_more:
|
||||
payload_msg += (
|
||||
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
|
||||
)
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
return "未找到推荐内容。"
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐失败: {e}", exc_info=True)
|
||||
|
||||
155
app/agent/tools/impl/get_search_results.py
Normal file
155
app/agent/tools/impl/get_search_results.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""获取搜索结果工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from ._torrent_search_utils import (
|
||||
TORRENT_RESULT_LIMIT,
|
||||
build_filter_options,
|
||||
filter_contexts,
|
||||
simplify_search_result,
|
||||
)
|
||||
|
||||
|
||||
class GetSearchResultsInput(BaseModel):
|
||||
"""获取搜索结果工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site: Optional[List[str]] = Field(None, description="Site name filters")
|
||||
season: Optional[List[str]] = Field(None, description="Season or episode filters")
|
||||
free_state: Optional[List[str]] = Field(None, description="Promotion state filters")
|
||||
video_code: Optional[List[str]] = Field(None, description="Video codec filters")
|
||||
edition: Optional[List[str]] = Field(None, description="Edition filters")
|
||||
resolution: Optional[List[str]] = Field(None, description="Resolution filters")
|
||||
release_group: Optional[List[str]] = Field(
|
||||
None, description="Release group filters"
|
||||
)
|
||||
title_pattern: Optional[str] = Field(
|
||||
None,
|
||||
description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')",
|
||||
)
|
||||
show_filter_options: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to return only optional filter options for re-checking available conditions",
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1,
|
||||
description="Page number for pagination (default: 1, each page returns up to 50 results)",
|
||||
)
|
||||
|
||||
|
||||
class GetSearchResultsTool(MoviePilotTool):
|
||||
name: str = "get_search_results"
|
||||
description: str = "Get cached torrent search results from search_torrents with optional filters. Supports pagination with up to 50 results per page."
|
||||
args_schema: Type[BaseModel] = GetSearchResultsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return "正在获取搜索结果"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
site: Optional[List[str]] = None,
|
||||
season: Optional[List[str]] = None,
|
||||
free_state: Optional[List[str]] = None,
|
||||
video_code: Optional[List[str]] = None,
|
||||
edition: Optional[List[str]] = None,
|
||||
resolution: Optional[List[str]] = None,
|
||||
release_group: Optional[List[str]] = None,
|
||||
title_pattern: Optional[str] = None,
|
||||
show_filter_options: bool = False,
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
items = await SearchChain().async_last_search_results() or []
|
||||
if not items:
|
||||
return "没有可用的搜索结果,请先使用 search_torrents 搜索"
|
||||
|
||||
if show_filter_options:
|
||||
payload = {
|
||||
"total_count": len(items),
|
||||
"filter_options": build_filter_options(items),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
|
||||
regex_pattern = None
|
||||
if title_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(title_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {title_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
filtered_items = filter_contexts(
|
||||
items=items,
|
||||
site=site,
|
||||
season=season,
|
||||
free_state=free_state,
|
||||
video_code=video_code,
|
||||
edition=edition,
|
||||
resolution=resolution,
|
||||
release_group=release_group,
|
||||
)
|
||||
if regex_pattern:
|
||||
filtered_items = [
|
||||
item
|
||||
for item in filtered_items
|
||||
if item.torrent_info
|
||||
and item.torrent_info.title
|
||||
and regex_pattern.search(item.torrent_info.title)
|
||||
]
|
||||
if not filtered_items:
|
||||
return "没有符合筛选条件的搜索结果,请调整筛选条件"
|
||||
|
||||
total_count = len(filtered_items)
|
||||
filtered_ids = {id(item) for item in filtered_items}
|
||||
matched_indices = [
|
||||
index
|
||||
for index, item in enumerate(items, start=1)
|
||||
if id(item) in filtered_ids
|
||||
]
|
||||
|
||||
# 分页
|
||||
page_size = TORRENT_RESULT_LIMIT
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
page_items = filtered_items[start:end]
|
||||
page_indices = matched_indices[start:end]
|
||||
|
||||
if not page_items:
|
||||
return f"第 {page} 页没有数据,共 {total_count} 条结果,共 {(total_count + page_size - 1) // page_size} 页。"
|
||||
|
||||
results = [
|
||||
simplify_search_result(item, index)
|
||||
for item, index in zip(page_items, page_indices)
|
||||
]
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
payload = {
|
||||
"total_count": total_count,
|
||||
"page": page,
|
||||
"total_pages": total_pages,
|
||||
"results": results,
|
||||
}
|
||||
if page < total_pages:
|
||||
payload["message"] = (
|
||||
f"搜索结果共 {total_count} 条,当前第 {page}/{total_pages} 页,可使用 page={page + 1} 获取下一页。"
|
||||
)
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
error_message = f"获取搜索结果失败: {str(e)}"
|
||||
logger.error(f"获取搜索结果失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
@@ -24,7 +24,7 @@ class ListDirectoryInput(BaseModel):
|
||||
|
||||
class ListDirectoryTool(MoviePilotTool):
|
||||
name: str = "list_directory"
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directories' to query directory configuration settings."
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directory_settings' to query directory configuration settings."
|
||||
args_schema: Type[BaseModel] = ListDirectoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -120,8 +120,8 @@ class ListDirectoryTool(MoviePilotTool):
|
||||
result_json = json.dumps(simplified_items, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 20 个项目。\n\n{result_json}"
|
||||
if total_count > 100:
|
||||
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 100 个项目。\n\n{result_json}"
|
||||
else:
|
||||
return result_json
|
||||
except Exception as e:
|
||||
|
||||
79
app/agent/tools/impl/list_slash_commands.py
Normal file
79
app/agent/tools/impl/list_slash_commands.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""查询所有可用斜杠命令工具(系统命令 + 插件命令)"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ListSlashCommandsInput(BaseModel):
|
||||
"""查询所有可用斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class ListSlashCommandsTool(MoviePilotTool):
|
||||
name: str = "list_slash_commands"
|
||||
description: str = (
|
||||
"List all available slash commands in the system, including system preset commands "
|
||||
"(e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
|
||||
"and plugin-registered commands. "
|
||||
"Use this tool to discover what slash commands are available before executing them with run_slash_command. "
|
||||
"This is especially useful when the user describes an action in natural language and you need to "
|
||||
"find the matching command to fulfill their request."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ListSlashCommandsInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询所有可用命令"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
from app.command import Command
|
||||
|
||||
command_obj = Command()
|
||||
all_commands = command_obj.get_commands()
|
||||
|
||||
if not all_commands:
|
||||
return "当前没有可用的命令"
|
||||
|
||||
commands_list = []
|
||||
for cmd, info in all_commands.items():
|
||||
cmd_info = {
|
||||
"command": cmd,
|
||||
"description": info.get("description", ""),
|
||||
}
|
||||
if info.get("category"):
|
||||
cmd_info["category"] = info["category"]
|
||||
# 标识命令类型
|
||||
if info.get("type") == "scheduler":
|
||||
cmd_info["type"] = "scheduler"
|
||||
elif info.get("pid"):
|
||||
cmd_info["type"] = "plugin"
|
||||
cmd_info["plugin_id"] = info["pid"]
|
||||
else:
|
||||
cmd_info["type"] = "system"
|
||||
commands_list.append(cmd_info)
|
||||
|
||||
result = {
|
||||
"total": len(commands_list),
|
||||
"commands": commands_list,
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询可用命令失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询可用命令时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
124
app/agent/tools/impl/modify_download.py
Normal file
124
app/agent/tools/impl/modify_download.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""修改下载任务工具"""
|
||||
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ModifyDownloadInput(BaseModel):
|
||||
"""修改下载任务工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
action: Optional[str] = Field(
|
||||
None,
|
||||
description="Action to perform on the task: 'start' to resume downloading, 'stop' to pause downloading. "
|
||||
"If not provided, no start/stop action will be performed.",
|
||||
)
|
||||
tags: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of tags to set on the download task. If provided, these tags will be added to the task. "
|
||||
"Example: ['movie', 'hd']",
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of specific downloader (optional, if not provided will search all downloaders)",
|
||||
)
|
||||
|
||||
|
||||
class ModifyDownloadTool(MoviePilotTool):
|
||||
"""修改下载任务工具"""
|
||||
|
||||
name: str = "modify_download"
|
||||
description: str = (
|
||||
"Modify a download task in the downloader by task hash. "
|
||||
"Supports: 1) Setting tags on a download task, "
|
||||
"2) Starting (resuming) a paused download task, "
|
||||
"3) Stopping (pausing) a downloading task. "
|
||||
"Multiple operations can be performed in a single call."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ModifyDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
hash_value = kwargs.get("hash", "")
|
||||
action = kwargs.get("action")
|
||||
tags = kwargs.get("tags")
|
||||
downloader = kwargs.get("downloader")
|
||||
|
||||
parts = [f"正在修改下载任务: {hash_value}"]
|
||||
if action == "start":
|
||||
parts.append("操作: 开始下载")
|
||||
elif action == "stop":
|
||||
parts.append("操作: 暂停下载")
|
||||
if tags:
|
||||
parts.append(f"标签: {', '.join(tags)}")
|
||||
if downloader:
|
||||
parts.append(f"下载器: {downloader}")
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
hash: str,
|
||||
action: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: hash={hash}, action={action}, tags={tags}, downloader={downloader}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 校验 hash 格式
|
||||
if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash):
|
||||
return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
|
||||
|
||||
# 校验参数:至少需要一个操作
|
||||
if not action and not tags:
|
||||
return "参数错误:至少需要指定 action(start/stop)或 tags 中的一个。"
|
||||
|
||||
# 校验 action 参数
|
||||
if action and action not in ("start", "stop"):
|
||||
return f"参数错误:action 只支持 'start'(开始下载)或 'stop'(暂停下载),收到: '{action}'。"
|
||||
|
||||
download_chain = DownloadChain()
|
||||
results = []
|
||||
|
||||
# 设置标签
|
||||
if tags:
|
||||
tag_result = download_chain.set_torrents_tag(
|
||||
hashs=[hash], tags=tags, downloader=downloader
|
||||
)
|
||||
if tag_result:
|
||||
results.append(f"成功设置标签:{', '.join(tags)}")
|
||||
else:
|
||||
results.append(f"设置标签失败,请检查任务是否存在或下载器是否可用")
|
||||
|
||||
# 执行开始/暂停操作
|
||||
if action:
|
||||
action_result = download_chain.set_downloading(
|
||||
hash_str=hash, oper=action, name=downloader
|
||||
)
|
||||
action_desc = "开始" if action == "start" else "暂停"
|
||||
if action_result:
|
||||
results.append(f"成功{action_desc}下载任务")
|
||||
else:
|
||||
results.append(
|
||||
f"{action_desc}下载任务失败,请检查任务是否存在或下载器是否可用"
|
||||
)
|
||||
|
||||
return f"下载任务 {hash}:" + ";".join(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"修改下载任务失败: {e}", exc_info=True)
|
||||
return f"修改下载任务时发生错误: {str(e)}"
|
||||
66
app/agent/tools/impl/query_custom_identifiers.py
Normal file
66
app/agent/tools/impl/query_custom_identifiers.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""查询自定义识别词工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class QueryCustomIdentifiersInput(BaseModel):
|
||||
"""查询自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class QueryCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "query_custom_identifiers"
|
||||
description: str = (
|
||||
"Query all currently configured custom identifiers (自定义识别词). "
|
||||
"Returns the list of identifier rules used for preprocessing torrent/file names before media recognition. "
|
||||
"Use this tool to check existing rules before adding new ones to avoid duplicates."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询自定义识别词"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
identifiers = system_config_oper.get(SystemConfigKey.CustomIdentifiers)
|
||||
if identifiers:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": len(identifiers),
|
||||
"identifiers": identifiers,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": 0,
|
||||
"identifiers": [],
|
||||
"message": "当前没有配置自定义识别词",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询自定义识别词失败: {e}")
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询自定义识别词时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -9,6 +9,8 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus, media_type_to_agent
|
||||
|
||||
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
@@ -20,13 +22,48 @@ class QueryDownloadTasksInput(BaseModel):
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
hash: Optional[str] = Field(None, description="Query specific download task by hash (optional, if provided will search for this specific task regardless of status)")
|
||||
title: Optional[str] = Field(None, description="Query download tasks by title/name (optional, supports partial match, searches all tasks if provided)")
|
||||
tag: Optional[str] = Field(None, description="Filter download tasks by tag (optional, supports partial match, e.g. 'movie' will match tasks with tag 'movie' or 'movie_2024')")
|
||||
|
||||
|
||||
class QueryDownloadTasksTool(MoviePilotTool):
|
||||
name: str = "query_download_tasks"
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash, title, or tag. Shows download progress, completion status, tags, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
@staticmethod
|
||||
def _get_all_torrents(download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
|
||||
"""
|
||||
查询所有状态的任务(包括下载中和已完成的任务)
|
||||
"""
|
||||
all_torrents = []
|
||||
# 查询正在下载的任务
|
||||
downloading_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.DOWNLOADING
|
||||
) or []
|
||||
all_torrents.extend(downloading_torrents)
|
||||
|
||||
# 查询已完成的任务(可转移状态)
|
||||
transfer_torrents = download_chain.list_torrents(
|
||||
downloader=downloader,
|
||||
status=TorrentStatus.TRANSFER
|
||||
) or []
|
||||
all_torrents.extend(transfer_torrents)
|
||||
|
||||
return all_torrents
|
||||
|
||||
@staticmethod
|
||||
def _format_progress(progress: Optional[float]) -> Optional[str]:
|
||||
"""
|
||||
将下载进度格式化为保留一位小数的百分比字符串
|
||||
"""
|
||||
try:
|
||||
if progress is None:
|
||||
return None
|
||||
return f"{float(progress):.1f}%"
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
downloader = kwargs.get("downloader")
|
||||
@@ -47,20 +84,25 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
parts.append(f"Hash: {hash_value[:8]}...")
|
||||
elif title:
|
||||
parts.append(f"标题: {title}")
|
||||
|
||||
tag = kwargs.get("tag")
|
||||
if tag:
|
||||
parts.append(f"标签: {tag}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all",
|
||||
hash: Optional[str] = None,
|
||||
title: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}, hash={hash}, title={title}")
|
||||
title: Optional[str] = None,
|
||||
tag: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}, hash={hash}, title={title}, tag={tag}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 如果提供了hash,直接查询该hash的任务(不限制状态)
|
||||
if hash:
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash])
|
||||
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash]) or []
|
||||
if not torrents:
|
||||
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
|
||||
# 转换为DownloadingTorrent格式
|
||||
@@ -69,30 +111,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
downloads.append(torrent)
|
||||
filtered_downloads = downloads
|
||||
elif title:
|
||||
# 如果提供了title,查询所有任务并搜索匹配的标题
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
# 检查标题或名称是否匹配
|
||||
if (title.lower() in (torrent.title or "").lower()) or \
|
||||
(title.lower() in (torrent.name or "").lower()):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
if hasattr(torrent, "media"):
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
@@ -101,8 +120,46 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
torrent.userid = history.userid
|
||||
if hasattr(torrent, "username"):
|
||||
torrent.username = history.username
|
||||
torrent.userid = history.userid
|
||||
downloads.append(torrent)
|
||||
filtered_downloads = downloads
|
||||
elif title:
|
||||
# 如果提供了title,查询所有任务并搜索匹配的标题
|
||||
# 查询所有状态的任务
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
title_lower = title.lower()
|
||||
for torrent in all_torrents:
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
|
||||
# 检查标题或名称是否匹配(包括下载历史中的标题)
|
||||
matched = False
|
||||
# 检查torrent的title和name字段
|
||||
if (title_lower in (torrent.title or "").lower()) or \
|
||||
(title_lower in (getattr(torrent, "name", None) or "").lower()):
|
||||
matched = True
|
||||
# 检查下载历史中的标题
|
||||
if history and history.title:
|
||||
if title_lower in history.title.lower():
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
if history:
|
||||
if hasattr(torrent, "media"):
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "username"):
|
||||
torrent.username = history.username
|
||||
torrent.userid = history.userid
|
||||
filtered_downloads.append(torrent)
|
||||
if not filtered_downloads:
|
||||
return f"未找到标题包含 '{title}' 的下载任务"
|
||||
@@ -110,7 +167,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
# 根据status决定查询方式
|
||||
if status == "downloading":
|
||||
# 如果status为下载中,使用downloading方法
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
downloads = download_chain.downloading(name=downloader) or []
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
@@ -119,7 +176,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
else:
|
||||
# 其他状态(completed、paused、all),使用list_torrents查询所有任务
|
||||
# 查询所有状态的任务
|
||||
all_torrents = download_chain.list_torrents(downloader=downloader) or []
|
||||
all_torrents = self._get_all_torrents(download_chain, downloader)
|
||||
filtered_downloads = []
|
||||
for torrent in all_torrents:
|
||||
if downloader and torrent.downloader != downloader:
|
||||
@@ -137,17 +194,28 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "media"):
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "username"):
|
||||
torrent.username = history.username
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
filtered_downloads.append(torrent)
|
||||
# 按tag过滤
|
||||
if tag and filtered_downloads:
|
||||
tag_lower = tag.lower()
|
||||
filtered_downloads = [
|
||||
d for d in filtered_downloads
|
||||
if d.tags and tag_lower in d.tags.lower()
|
||||
]
|
||||
if not filtered_downloads:
|
||||
return f"未找到标签包含 '{tag}' 的下载任务"
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
@@ -159,24 +227,26 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"name": getattr(d, "name", None),
|
||||
"year": getattr(d, "year", None),
|
||||
"season_episode": getattr(d, "season_episode", None),
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"progress": self._format_progress(d.progress),
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
"upspeed": getattr(d, "upspeed", None),
|
||||
"dlspeed": getattr(d, "dlspeed", None),
|
||||
"tags": d.tags,
|
||||
"left_time": getattr(d, "left_time", None)
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
media = getattr(d, "media", None)
|
||||
if media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
"tmdbid": media.get("tmdbid"),
|
||||
"type": media_type_to_agent(media.get("type")),
|
||||
"title": media.get("title"),
|
||||
"season": media.get("season"),
|
||||
"episode": media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
|
||||
@@ -6,23 +6,21 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class QueryEpisodeScheduleInput(BaseModel):
|
||||
"""查询剧集上映时间工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the TV series")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the TV series (can be obtained from search_media tool)")
|
||||
season: int = Field(..., description="Season number to query")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
|
||||
|
||||
class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
name: str = "query_episode_schedule"
|
||||
description: str = "Query TV series episode air dates and schedule. Returns detailed information for each episode including air date, episode number, title, overview, and other metadata. Filters out episodes without air dates."
|
||||
description: str = "Query TV series episode air dates and schedule. Returns non-duplicated schedule fields, including episode list, air-date statistics, and per-episode metadata. Filters out episodes without air dates."
|
||||
args_schema: Type[BaseModel] = QueryEpisodeScheduleInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -41,12 +39,6 @@ class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, season={season}, episode_group={episode_group}")
|
||||
|
||||
try:
|
||||
# 获取媒体信息(用于获取标题和海报)
|
||||
media_chain = MediaChain()
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=MediaType.TV)
|
||||
if not mediainfo:
|
||||
return f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
|
||||
# 获取集列表
|
||||
tmdb_chain = TmdbChain()
|
||||
episodes = await tmdb_chain.async_tmdb_episodes(
|
||||
@@ -92,12 +84,7 @@ class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
episode_list.sort(key=lambda x: (x["air_date"] or "", x["episode_number"] or 0))
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"season": season,
|
||||
"episode_group": episode_group,
|
||||
"series_title": mediainfo.title if mediainfo else None,
|
||||
"series_poster": mediainfo.poster_path if mediainfo else None,
|
||||
"total_episodes": len(episodes),
|
||||
"episodes_with_air_date": len(episode_list),
|
||||
"episodes": episode_list
|
||||
|
||||
65
app/agent/tools/impl/query_installed_plugins.py
Normal file
65
app/agent/tools/impl/query_installed_plugins.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""查询已安装插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryInstalledPluginsInput(BaseModel):
|
||||
"""查询已安装插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
name: str = "query_installed_plugins"
|
||||
description: str = (
|
||||
"Query all installed plugins in MoviePilot. Returns a list of installed plugins with their ID, name, "
|
||||
"description, version, author, running state, and other information. "
|
||||
"Use this tool to discover what plugins are available before querying plugin capabilities or running plugin commands."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryInstalledPluginsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询已安装插件"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
plugin_manager = PluginManager()
|
||||
local_plugins = plugin_manager.get_local_plugins()
|
||||
# 仅返回已安装的插件
|
||||
installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
|
||||
|
||||
if not installed_plugins:
|
||||
return "当前没有已安装的插件"
|
||||
|
||||
plugins_list = []
|
||||
for plugin in installed_plugins:
|
||||
plugins_list.append(
|
||||
{
|
||||
"id": plugin.id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"plugin_desc": plugin.plugin_desc,
|
||||
"plugin_version": plugin.plugin_version,
|
||||
"plugin_author": plugin.plugin_author,
|
||||
"state": plugin.state,
|
||||
"has_page": plugin.has_page,
|
||||
}
|
||||
)
|
||||
|
||||
result_json = json.dumps(plugins_list, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询已安装插件失败: {e}", exc_info=True)
|
||||
return f"查询已安装插件时发生错误: {str(e)}"
|
||||
@@ -1,96 +1,176 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Type, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
def _sort_seasons(seasons: Optional[dict]) -> dict:
|
||||
"""按季号、集号升序整理季集信息,保证输出稳定。"""
|
||||
if not seasons:
|
||||
return {}
|
||||
|
||||
def _sort_key(value):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
return OrderedDict(
|
||||
(season, sorted(episodes, key=_sort_key))
|
||||
for season, episodes in sorted(seasons.items(), key=lambda item: _sort_key(item[0]))
|
||||
)
|
||||
|
||||
|
||||
def _filter_regular_seasons(seasons: Optional[dict]) -> OrderedDict:
|
||||
"""仅保留正片季,忽略 season 0 等特殊季。"""
|
||||
sorted_seasons = _sort_seasons(seasons)
|
||||
regular_seasons = OrderedDict()
|
||||
for season, episodes in sorted_seasons.items():
|
||||
try:
|
||||
season_number = int(season)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if season_number > 0:
|
||||
regular_seasons[season_number] = episodes
|
||||
return regular_seasons
|
||||
|
||||
|
||||
def _build_tv_server_result(existing_seasons: OrderedDict, total_seasons: OrderedDict) -> dict[str, Any]:
|
||||
"""构建单个服务器的电视剧存在性结果。"""
|
||||
seasons_result = OrderedDict()
|
||||
missing_seasons = []
|
||||
all_seasons = sorted(set(total_seasons.keys()) | set(existing_seasons.keys()))
|
||||
|
||||
for season in all_seasons:
|
||||
existing_episodes = existing_seasons.get(season, [])
|
||||
total_episodes = total_seasons.get(season)
|
||||
if total_episodes is not None:
|
||||
missing_episodes = [episode for episode in total_episodes if episode not in existing_episodes]
|
||||
total_episode_count = len(total_episodes)
|
||||
else:
|
||||
missing_episodes = None
|
||||
total_episode_count = None
|
||||
seasons_result[str(season)] = {
|
||||
"existing_episodes": existing_episodes,
|
||||
"total_episodes": total_episode_count,
|
||||
"missing_episodes": missing_episodes
|
||||
}
|
||||
if total_episodes is not None and not existing_episodes:
|
||||
missing_seasons.append(season)
|
||||
|
||||
return {
|
||||
"seasons": seasons_result,
|
||||
"missing_seasons": missing_seasons
|
||||
}
|
||||
|
||||
|
||||
class QueryLibraryExistsInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
|
||||
|
||||
class QueryLibraryExistsTool(MoviePilotTool):
|
||||
name: str = "query_library_exists"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
description: str = "Check whether media already exists in Plex, Emby, or Jellyfin by media ID. Results are grouped by media server; TV results include existing episodes, total episodes, and missing episodes/seasons. Requires tmdb_id or douban_id from search_media."
|
||||
args_schema: Type[BaseModel] = QueryLibraryExistsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
title = kwargs.get("title")
|
||||
year = kwargs.get("year")
|
||||
|
||||
parts = ["正在查询媒体库"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if year:
|
||||
parts.append(f"年份: {year}")
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
douban_id = kwargs.get("douban_id")
|
||||
media_type = kwargs.get("media_type")
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
if tmdb_id:
|
||||
message = f"正在查询媒体库: TMDB={tmdb_id}"
|
||||
elif douban_id:
|
||||
message = f"正在查询媒体库: 豆瓣={douban_id}"
|
||||
else:
|
||||
message = "正在查询媒体库"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
return message
|
||||
|
||||
async def run(self, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None,
|
||||
media_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}")
|
||||
try:
|
||||
if not title:
|
||||
return "请提供媒体标题进行查询"
|
||||
|
||||
# 创建 MediaInfo 对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
|
||||
# 转换媒体类型
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
# media_type == "all" 时不设置类型,让媒体服务器自动判断
|
||||
|
||||
# 调用媒体服务器接口实时查询
|
||||
if not tmdb_id and not douban_id:
|
||||
return "参数错误:tmdb_id 和 douban_id 至少需要提供一个,请先使用 search_media 工具获取媒体 ID。"
|
||||
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
media_chain = MediaServerChain()
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
if not existsinfo:
|
||||
mediainfo = media_chain.recognize_media(
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
mtype=media_type_enum,
|
||||
)
|
||||
if not mediainfo:
|
||||
media_id = f"TMDB={tmdb_id}" if tmdb_id else f"豆瓣={douban_id}"
|
||||
return f"未识别到媒体信息: {media_id}"
|
||||
|
||||
# 2. 遍历所有媒体服务器,分别查询存在性信息
|
||||
server_results = OrderedDict()
|
||||
media_server_helper = MediaServerHelper()
|
||||
total_seasons = _filter_regular_seasons(mediainfo.seasons)
|
||||
global_existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
for service_name in sorted(media_server_helper.get_services().keys()):
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo, server=service_name)
|
||||
if not existsinfo:
|
||||
continue
|
||||
|
||||
if existsinfo.type == MediaType.TV:
|
||||
existing_seasons = _filter_regular_seasons(existsinfo.seasons)
|
||||
server_results[service_name] = _build_tv_server_result(
|
||||
existing_seasons=existing_seasons,
|
||||
total_seasons=total_seasons
|
||||
)
|
||||
else:
|
||||
server_results[service_name] = {
|
||||
"exists": True
|
||||
}
|
||||
|
||||
if global_existsinfo:
|
||||
fallback_server_name = global_existsinfo.server or "local"
|
||||
if fallback_server_name not in server_results:
|
||||
if global_existsinfo.type == MediaType.TV:
|
||||
server_results[fallback_server_name] = _build_tv_server_result(
|
||||
existing_seasons=_filter_regular_seasons(global_existsinfo.seasons),
|
||||
total_seasons=total_seasons
|
||||
)
|
||||
else:
|
||||
server_results[fallback_server_name] = {
|
||||
"exists": True
|
||||
}
|
||||
|
||||
if not server_results:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 如果找到了,获取详细信息
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
result_items.append(item_dict)
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有详细信息,返回基本信息
|
||||
|
||||
# 3. 组装统一的存在性结果,不查询媒体服务器详情
|
||||
result_dict = {
|
||||
"type": existsinfo.type.value if existsinfo.type else None,
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": media_type_to_agent(mediainfo.type),
|
||||
"servers": server_results
|
||||
}
|
||||
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
|
||||
@@ -10,52 +10,70 @@ from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
|
||||
PAGE_SIZE = 20
|
||||
|
||||
|
||||
class QueryLibraryLatestInput(BaseModel):
|
||||
"""查询媒体服务器最近入库影片工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
server: Optional[str] = Field(None, description="Media server name (optional, if not specified queries all enabled media servers)")
|
||||
count: Optional[int] = Field(20, description="Number of items to return (default: 20)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
server: Optional[str] = Field(
|
||||
None,
|
||||
description="Media server name (optional, if not specified queries all enabled media servers)",
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1, description="Page number for pagination (default: 1, 20 items per page)"
|
||||
)
|
||||
|
||||
|
||||
class QueryLibraryLatestTool(MoviePilotTool):
|
||||
name: str = "query_library_latest"
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata."
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata. Supports pagination with 20 items per page."
|
||||
args_schema: Type[BaseModel] = QueryLibraryLatestInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
server = kwargs.get("server")
|
||||
count = kwargs.get("count", 20)
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询媒体服务器最近入库影片"]
|
||||
|
||||
|
||||
if server:
|
||||
parts.append(f"服务器: {server}")
|
||||
else:
|
||||
parts.append("所有服务器")
|
||||
|
||||
parts.append(f"数量: {count}条")
|
||||
|
||||
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(self, server: Optional[str] = None, count: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: server={server}, count={count}")
|
||||
async def run(
|
||||
self, server: Optional[str] = None, page: Optional[int] = 1, **kwargs
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
# 为了支持分页,需要获取足够多的数据再切片
|
||||
fetch_count = page * PAGE_SIZE
|
||||
logger.info(f"执行工具: {self.name}, 参数: server={server}, page={page}")
|
||||
try:
|
||||
media_chain = MediaServerChain()
|
||||
results = []
|
||||
|
||||
|
||||
# 如果没有指定服务器,获取所有启用的媒体服务器
|
||||
if not server:
|
||||
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
|
||||
enabled_servers = [ms.name for ms in mediaservers if ms.enabled]
|
||||
|
||||
|
||||
if not enabled_servers:
|
||||
return "未找到启用的媒体服务器"
|
||||
|
||||
|
||||
# 遍历所有启用的服务器
|
||||
for server_name in enabled_servers:
|
||||
latest_items = media_chain.latest(server=server_name, count=count, username=self._username)
|
||||
latest_items = media_chain.latest(
|
||||
server=server_name, count=fetch_count, username=self._username
|
||||
)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
@@ -63,24 +81,37 @@ class QueryLibraryLatestTool(MoviePilotTool):
|
||||
results.append(item_dict)
|
||||
else:
|
||||
# 查询指定服务器
|
||||
latest_items = media_chain.latest(server=server, count=count, username=self._username)
|
||||
latest_items = media_chain.latest(
|
||||
server=server, count=fetch_count, username=self._username
|
||||
)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server
|
||||
results.append(item_dict)
|
||||
|
||||
|
||||
if not results:
|
||||
server_info = f"服务器 {server}" if server else "所有服务器"
|
||||
return f"未找到 {server_info} 的最近入库影片"
|
||||
|
||||
# 限制返回数量,避免结果过多
|
||||
if len(results) > count:
|
||||
results = results[:count]
|
||||
|
||||
return json.dumps(results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
# 分页
|
||||
total_count = len(results)
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_results = results[start:end]
|
||||
|
||||
if not page_results:
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
return f"第 {page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
|
||||
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
payload_msg = f"第 {page}/{total_pages} 页,当前页 {len(page_results)} 条结果,共 {total_count} 条。"
|
||||
if page < total_pages:
|
||||
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
|
||||
|
||||
result_json = json.dumps(page_results, ensure_ascii=False, indent=2)
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体服务器最近入库影片失败: {e}", exc_info=True)
|
||||
return f"查询媒体服务器最近入库影片时发生错误: {str(e)}"
|
||||
|
||||
|
||||
126
app/agent/tools/impl/query_media_detail.py
Normal file
126
app/agent/tools/impl/query_media_detail.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""查询媒体详情工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class QueryMediaDetailInput(BaseModel):
|
||||
"""查询媒体详情工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID of the media (movie or TV series, can be obtained from search_media tool)")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID of the media (alternative to tmdb_id)")
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
|
||||
|
||||
class QueryMediaDetailTool(MoviePilotTool):
|
||||
name: str = "query_media_detail"
|
||||
description: str = "Query supplementary media details from TMDB by ID and media_type. Accepts tmdb_id or douban_id (at least one required). media_type accepts 'movie' or 'tv'. Returns non-duplicated detail fields such as status, genres, directors, actors, and season info for TV series."
|
||||
args_schema: Type[BaseModel] = QueryMediaDetailInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
douban_id = kwargs.get("douban_id")
|
||||
if tmdb_id:
|
||||
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
|
||||
return f"正在查询媒体详情: 豆瓣 ID {douban_id}"
|
||||
|
||||
async def run(self, media_type: str, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}")
|
||||
|
||||
if tmdb_id is None and douban_id is None:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "必须提供 tmdb_id 或 douban_id 之一"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, doubanid=douban_id, mtype=media_type_enum)
|
||||
|
||||
if not mediainfo:
|
||||
id_info = f"TMDB ID {tmdb_id}" if tmdb_id else f"豆瓣 ID {douban_id}"
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 {id_info} 的媒体信息"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 精简 genres - 只保留名称
|
||||
genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")]
|
||||
|
||||
# 精简 directors - 只保留姓名和职位
|
||||
directors = [
|
||||
{
|
||||
"name": d.get("name"),
|
||||
"job": d.get("job")
|
||||
}
|
||||
for d in (mediainfo.directors or [])
|
||||
if d.get("name")
|
||||
]
|
||||
|
||||
# 精简 actors - 只保留姓名和角色
|
||||
actors = [
|
||||
{
|
||||
"name": a.get("name"),
|
||||
"character": a.get("character")
|
||||
}
|
||||
for a in (mediainfo.actors or [])
|
||||
if a.get("name")
|
||||
]
|
||||
|
||||
# 构建基础媒体详情信息
|
||||
result = {
|
||||
"status": mediainfo.status,
|
||||
"genres": genres,
|
||||
"directors": directors,
|
||||
"actors": actors
|
||||
}
|
||||
|
||||
# 如果是电视剧,添加电视剧特有信息
|
||||
if mediainfo.type == MediaType.TV:
|
||||
# 精简 season_info - 只保留基础摘要
|
||||
season_info = [
|
||||
{
|
||||
"season_number": s.get("season_number"),
|
||||
"name": s.get("name"),
|
||||
"episode_count": s.get("episode_count"),
|
||||
"air_date": s.get("air_date")
|
||||
}
|
||||
for s in (mediainfo.season_info or [])
|
||||
if s.get("season_number") is not None
|
||||
]
|
||||
|
||||
result.update({
|
||||
"number_of_seasons": mediainfo.number_of_seasons,
|
||||
"number_of_episodes": mediainfo.number_of_episodes,
|
||||
"first_air_date": mediainfo.first_air_date,
|
||||
"last_air_date": mediainfo.last_air_date,
|
||||
"season_info": season_info
|
||||
})
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询媒体详情失败: {str(e)}"
|
||||
logger.error(f"查询媒体详情失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"tmdb_id": tmdb_id,
|
||||
"douban_id": douban_id
|
||||
}, ensure_ascii=False)
|
||||
118
app/agent/tools/impl/query_plugin_capabilities.py
Normal file
118
app/agent/tools/impl/query_plugin_capabilities.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""查询插件能力工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPluginCapabilitiesInput(BaseModel):
|
||||
"""查询插件能力工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional plugin ID to query capabilities for a specific plugin. "
|
||||
"If not provided, returns capabilities of all running plugins. "
|
||||
"Use query_installed_plugins tool to get the plugin IDs first.",
|
||||
)
|
||||
|
||||
|
||||
class QueryPluginCapabilitiesTool(MoviePilotTool):
|
||||
name: str = "query_plugin_capabilities"
|
||||
description: str = (
|
||||
"Query the capabilities of installed plugins, including supported commands and scheduled services. "
|
||||
"Commands are slash-commands (e.g. /xxx) that can be executed via the run_slash_command tool. "
|
||||
"Scheduled services are periodic tasks that can be triggered via the run_scheduler tool. "
|
||||
"Optionally specify a plugin_id to query a specific plugin, or omit to query all running plugins."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryPluginCapabilitiesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id")
|
||||
if plugin_id:
|
||||
return f"正在查询插件 {plugin_id} 的能力"
|
||||
return "正在查询所有插件的能力"
|
||||
|
||||
async def run(self, plugin_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
try:
|
||||
plugin_manager = PluginManager()
|
||||
result = {}
|
||||
|
||||
# 获取插件命令
|
||||
commands = plugin_manager.get_plugin_commands(pid=plugin_id)
|
||||
if commands:
|
||||
commands_list = []
|
||||
for cmd in commands:
|
||||
cmd_info = {
|
||||
"cmd": cmd.get("cmd"),
|
||||
"desc": cmd.get("desc"),
|
||||
"plugin_id": cmd.get("pid"),
|
||||
}
|
||||
# data 字段可能包含额外参数信息
|
||||
if cmd.get("data"):
|
||||
cmd_info["data"] = cmd.get("data")
|
||||
commands_list.append(cmd_info)
|
||||
result["commands"] = commands_list
|
||||
|
||||
# 获取插件动作
|
||||
actions = plugin_manager.get_plugin_actions(pid=plugin_id)
|
||||
if actions:
|
||||
actions_list = []
|
||||
for action_group in actions:
|
||||
plugin_actions = {
|
||||
"plugin_id": action_group.get("plugin_id"),
|
||||
"plugin_name": action_group.get("plugin_name"),
|
||||
"actions": [],
|
||||
}
|
||||
for action in action_group.get("actions", []):
|
||||
plugin_actions["actions"].append(
|
||||
{
|
||||
"id": action.get("id"),
|
||||
"name": action.get("name"),
|
||||
}
|
||||
)
|
||||
actions_list.append(plugin_actions)
|
||||
result["actions"] = actions_list
|
||||
|
||||
# 获取插件定时服务
|
||||
services = plugin_manager.get_plugin_services(pid=plugin_id)
|
||||
if services:
|
||||
services_list = []
|
||||
for svc in services:
|
||||
svc_info = {
|
||||
"id": svc.get("id"),
|
||||
"name": svc.get("name"),
|
||||
}
|
||||
# 包含触发器信息
|
||||
trigger = svc.get("trigger")
|
||||
if trigger:
|
||||
svc_info["trigger"] = str(trigger)
|
||||
# 包含定时器参数
|
||||
svc_kwargs = svc.get("kwargs")
|
||||
if svc_kwargs:
|
||||
svc_info["trigger_kwargs"] = {
|
||||
k: str(v) for k, v in svc_kwargs.items()
|
||||
}
|
||||
services_list.append(svc_info)
|
||||
result["services"] = services_list
|
||||
|
||||
if not result:
|
||||
if plugin_id:
|
||||
return f"插件 {plugin_id} 没有注册任何命令、动作或定时服务"
|
||||
return "当前没有运行中的插件注册了命令、动作或定时服务"
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件能力失败: {e}", exc_info=True)
|
||||
return f"查询插件能力时发生错误: {str(e)}"
|
||||
@@ -10,13 +10,13 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.context import MediaInfo
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
class QueryPopularSubscribesInput(BaseModel):
|
||||
"""查询热门订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
stype: str = Field(..., description="Media type: '电影' for films, '电视剧' for television series")
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)")
|
||||
@@ -33,13 +33,13 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
stype = kwargs.get("stype", "")
|
||||
media_type = kwargs.get("media_type", "")
|
||||
page = kwargs.get("page", 1)
|
||||
min_sub = kwargs.get("min_sub")
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = [f"正在查询热门订阅 [{stype}]"]
|
||||
parts = [f"正在查询热门订阅 [{media_type}]"]
|
||||
|
||||
if min_sub:
|
||||
parts.append(f"最少订阅: {min_sub}")
|
||||
@@ -52,7 +52,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, stype: str,
|
||||
async def run(self, media_type: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
min_sub: Optional[int] = None,
|
||||
@@ -61,7 +61,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: stype={stype}, page={page}, count={count}, min_sub={min_sub}, "
|
||||
f"执行工具: {self.name}, 参数: media_type={media_type}, page={page}, count={count}, min_sub={min_sub}, "
|
||||
f"genre_id={genre_id}, min_rating={min_rating}, max_rating={max_rating}, sort_type={sort_type}")
|
||||
|
||||
try:
|
||||
@@ -69,10 +69,13 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
subscribes = await subscribe_helper.async_get_statistic(
|
||||
stype=stype,
|
||||
stype=media_type_enum.to_agent(),
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
@@ -94,7 +97,15 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
continue
|
||||
|
||||
media = MediaInfo()
|
||||
media.type = MediaType(sub.get("type"))
|
||||
raw_type = str(sub.get("type") or "").strip().lower()
|
||||
if raw_type in ["movie", "电影"]:
|
||||
media.type = MediaType.MOVIE
|
||||
elif raw_type in ["tv", "电视剧"]:
|
||||
media.type = MediaType.TV
|
||||
else:
|
||||
# 跳过无法识别类型的数据,避免单条脏数据导致整批失败
|
||||
logger.warning(f"跳过未知媒体类型: {sub.get('type')}")
|
||||
continue
|
||||
media.tmdb_id = sub.get("tmdbid")
|
||||
# 处理标题
|
||||
title = sub.get("name")
|
||||
@@ -124,7 +135,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
for media in ret_medias:
|
||||
media_dict = media.to_dict()
|
||||
simplified = {
|
||||
"type": media_dict.get("type"),
|
||||
"type": media_type_to_agent(media_dict.get("type")),
|
||||
"title": media_dict.get("title"),
|
||||
"year": media_dict.get("year"),
|
||||
"tmdb_id": media_dict.get("tmdb_id"),
|
||||
@@ -149,4 +160,3 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询热门订阅失败: {e}", exc_info=True)
|
||||
return f"查询热门订阅时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -62,4 +62,3 @@ class QueryRuleGroupsTool(MoviePilotTool):
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -52,4 +52,3 @@ class QuerySchedulersTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询定时服务失败: {e}", exc_info=True)
|
||||
return f"查询定时服务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -14,60 +14,74 @@ from app.log import logger
|
||||
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to query user data for")
|
||||
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to query user data for (can be obtained from query_sites tool)",
|
||||
)
|
||||
workdate: Optional[str] = Field(
|
||||
None,
|
||||
description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)",
|
||||
)
|
||||
|
||||
|
||||
class QuerySiteUserdataTool(MoviePilotTool):
|
||||
name: str = "query_site_userdata"
|
||||
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySiteUserdataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
workdate = kwargs.get("workdate")
|
||||
|
||||
|
||||
message = f"正在查询站点 #{site_id} 的用户数据"
|
||||
if workdate:
|
||||
message += f" (日期: {workdate})"
|
||||
else:
|
||||
message += " (最新数据)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 获取站点用户数据
|
||||
user_data_list = await SiteUserData.async_get_by_domain(
|
||||
db,
|
||||
domain=site.domain,
|
||||
workdate=workdate
|
||||
db, domain=site.domain, workdate=workdate
|
||||
)
|
||||
|
||||
|
||||
if not user_data_list:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 格式化用户数据
|
||||
result = {
|
||||
"success": True,
|
||||
@@ -76,16 +90,26 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
"data_count": len(user_data_list),
|
||||
"user_data": []
|
||||
"user_data": [],
|
||||
}
|
||||
|
||||
|
||||
for user_data in user_data_list:
|
||||
# 格式化上传/下载量(转换为可读格式)
|
||||
upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0
|
||||
download_gb = user_data.download / (1024 ** 3) if user_data.download else 0
|
||||
seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0
|
||||
leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0
|
||||
|
||||
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
|
||||
download_gb = (
|
||||
user_data.download / (1024**3) if user_data.download else 0
|
||||
)
|
||||
seeding_size_gb = (
|
||||
user_data.seeding_size / (1024**3)
|
||||
if user_data.seeding_size
|
||||
else 0
|
||||
)
|
||||
leeching_size_gb = (
|
||||
user_data.leeching_size / (1024**3)
|
||||
if user_data.leeching_size
|
||||
else 0
|
||||
)
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
@@ -100,37 +124,46 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
"download_gb": round(download_gb, 2),
|
||||
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||
"leeching": int(user_data.leeching) if user_data.leeching else 0,
|
||||
"leeching": int(user_data.leeching)
|
||||
if user_data.leeching
|
||||
else 0,
|
||||
"seeding_size": user_data.seeding_size,
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info": user_data.seeding_info if user_data.seeding_info else [],
|
||||
"seeding_info": user_data.seeding_info
|
||||
if user_data.seeding_info
|
||||
else [],
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [],
|
||||
"message_unread_contents": user_data.message_unread_contents
|
||||
if user_data.message_unread_contents
|
||||
else [],
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time
|
||||
"updated_time": user_data.updated_time,
|
||||
}
|
||||
result["user_data"].append(user_data_dict)
|
||||
|
||||
|
||||
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||
if len(result["user_data"]) > 1:
|
||||
result["user_data"].sort(
|
||||
key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")),
|
||||
reverse=True
|
||||
key=lambda x: (
|
||||
x.get("updated_day", ""),
|
||||
x.get("updated_time", ""),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
result["message"] = (
|
||||
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
)
|
||||
result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
result["user_data"] = [result["user_data"][0]]
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询站点用户数据失败: {str(e)}"
|
||||
logger.error(f"查询站点用户数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "site_id": site_id},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -12,35 +12,45 @@ from app.log import logger
|
||||
|
||||
class QuerySitesInput(BaseModel):
|
||||
"""查询站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter sites by name (partial match, optional)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
status: Optional[str] = Field(
|
||||
"all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="Filter sites by name (partial match, optional)"
|
||||
)
|
||||
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
|
||||
parts = ["正在查询站点"]
|
||||
|
||||
|
||||
if status != "all":
|
||||
status_map = {"active": "已启用", "inactive": "已禁用"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
|
||||
async def run(
|
||||
self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
@@ -68,9 +78,10 @@ class QuerySitesTool(MoviePilotTool):
|
||||
"url": s.url,
|
||||
"pri": s.pri,
|
||||
"is_active": s.is_active,
|
||||
"cookie": s.cookie,
|
||||
"downloader": s.downloader,
|
||||
"proxy": s.proxy,
|
||||
"timeout": s.timeout
|
||||
"timeout": s.timeout,
|
||||
}
|
||||
simplified_sites.append(simplified)
|
||||
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
|
||||
@@ -79,4 +90,3 @@ class QuerySitesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点失败: {e}", exc_info=True)
|
||||
return f"查询站点时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -9,105 +9,179 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
PAGE_SIZE = 20
|
||||
|
||||
|
||||
class QuerySubscribeHistoryInput(BaseModel):
|
||||
"""查询订阅历史工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all", description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types (default: 'all')")
|
||||
name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
"all", description="Allowed values: movie, tv, all"
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="Filter by media name (partial match, optional)"
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1,
|
||||
description="Page number for pagination (default: 1, 20 items per page). Ignored when name filter is provided.",
|
||||
)
|
||||
|
||||
|
||||
class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_history"
|
||||
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Returns up to 30 records."
|
||||
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Supports pagination with 20 records per page."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询订阅历史"]
|
||||
|
||||
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
else:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}")
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
media_type: Optional[str] = "all",
|
||||
name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
if media_type not in ["all", "movie", "tv"]:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 根据类型查询
|
||||
if media_type == "all":
|
||||
# 查询所有类型,需要分别查询电影和电视剧
|
||||
movie_history = await SubscribeHistory.async_list_by_type(db, mtype="movie", page=1, count=100)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(db, mtype="tv", page=1, count=100)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
# 按日期排序
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
# 查询指定类型
|
||||
all_history = await SubscribeHistory.async_list_by_type(db, mtype=media_type, page=1, count=100)
|
||||
|
||||
# 按名称过滤
|
||||
filtered_history = []
|
||||
if name:
|
||||
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
|
||||
fetch_count = 500
|
||||
if media_type == "all":
|
||||
movie_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="movie", page=1, count=fetch_count
|
||||
)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="tv", page=1, count=fetch_count
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
all_history = list(
|
||||
await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=media_type, page=1, count=fetch_count
|
||||
)
|
||||
)
|
||||
|
||||
# 按名称过滤
|
||||
name_lower = name.lower()
|
||||
for record in all_history:
|
||||
if record.name and name_lower in record.name.lower():
|
||||
filtered_history.append(record)
|
||||
filtered_history = [
|
||||
record
|
||||
for record in all_history
|
||||
if record.name and name_lower in record.name.lower()
|
||||
]
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 名称过滤时直接返回所有匹配结果,不分页
|
||||
simplified_records = self._simplify_records(filtered_history)
|
||||
result_json = json.dumps(
|
||||
simplified_records, ensure_ascii=False, indent=2
|
||||
)
|
||||
return result_json
|
||||
else:
|
||||
filtered_history = all_history
|
||||
|
||||
# 无名称过滤时,直接利用数据库分页
|
||||
if media_type == "all":
|
||||
movie_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="movie", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="tv", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
filtered_history = all_history
|
||||
else:
|
||||
filtered_history = list(
|
||||
await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=media_type, page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
)
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 限制最多30条
|
||||
|
||||
# 分页切片
|
||||
total_count = len(filtered_history)
|
||||
limited_history = filtered_history[:30]
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in limited_history:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"season": record.season,
|
||||
"tmdbid": record.tmdbid,
|
||||
"doubanid": record.doubanid,
|
||||
"bangumiid": record.bangumiid,
|
||||
"poster": record.poster,
|
||||
"vote": record.vote,
|
||||
"total_episode": record.total_episode,
|
||||
"date": record.date,
|
||||
"username": record.username
|
||||
}
|
||||
# 添加过滤规则信息(如果有)
|
||||
if record.filter:
|
||||
simplified["filter"] = record.filter
|
||||
if record.quality:
|
||||
simplified["quality"] = record.quality
|
||||
if record.resolution:
|
||||
simplified["resolution"] = record.resolution
|
||||
simplified_records.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
|
||||
return result_json
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_records = filtered_history[start:end]
|
||||
|
||||
if not page_records:
|
||||
return f"第 {page} 页没有数据。"
|
||||
|
||||
simplified_records = self._simplify_records(page_records)
|
||||
result_json = json.dumps(
|
||||
simplified_records, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
has_more = total_count > end
|
||||
payload_msg = f"第 {page} 页,当前页 {len(simplified_records)} 条结果。"
|
||||
if has_more:
|
||||
payload_msg += (
|
||||
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
|
||||
)
|
||||
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
|
||||
return f"查询订阅历史时发生错误: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def _simplify_records(records) -> list:
|
||||
"""转换为字典格式,只保留关键信息"""
|
||||
simplified_records = []
|
||||
for record in records:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"year": record.year,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"season": record.season,
|
||||
"tmdbid": record.tmdbid,
|
||||
"doubanid": record.doubanid,
|
||||
"bangumiid": record.bangumiid,
|
||||
"poster": record.poster,
|
||||
"vote": record.vote,
|
||||
"total_episode": record.total_episode,
|
||||
"date": record.date,
|
||||
"username": record.username,
|
||||
}
|
||||
if record.filter:
|
||||
simplified["filter"] = record.filter
|
||||
if record.quality:
|
||||
simplified["quality"] = record.quality
|
||||
if record.resolution:
|
||||
simplified["resolution"] = record.resolution
|
||||
simplified_records.append(simplified)
|
||||
return simplified_records
|
||||
|
||||
@@ -110,4 +110,3 @@ class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅分享失败: {e}", exc_info=True)
|
||||
return f"查询订阅分享时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -8,82 +8,150 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.subscribe import Subscribe as SubscribeSchema
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
PAGE_SIZE = 100
|
||||
|
||||
QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
"id",
|
||||
"name",
|
||||
"year",
|
||||
"type",
|
||||
"season",
|
||||
"total_episode",
|
||||
"start_episode",
|
||||
"lack_episode",
|
||||
"filter",
|
||||
"include",
|
||||
"exclude",
|
||||
"quality",
|
||||
"resolution",
|
||||
"effect",
|
||||
"state",
|
||||
"last_update",
|
||||
"sites",
|
||||
"downloader",
|
||||
"best_version",
|
||||
"save_path",
|
||||
"custom_words",
|
||||
"media_category",
|
||||
"filter_groups",
|
||||
"episode_group",
|
||||
]
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
status: Optional[str] = Field(
|
||||
"all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions",
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
"all", description="Allowed values: movie, tv, all"
|
||||
)
|
||||
tmdb_id: Optional[int] = Field(
|
||||
None,
|
||||
description="Filter by TMDB ID to check if a specific media is already subscribed",
|
||||
)
|
||||
douban_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Filter by Douban ID to check if a specific media is already subscribed",
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1, description="Page number for pagination (default: 1, 100 items per page)"
|
||||
)
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription. Supports pagination with 100 items per page."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询订阅"]
|
||||
|
||||
|
||||
# 根据状态过滤条件生成提示
|
||||
if status != "all":
|
||||
status_map = {"R": "已启用", "P": "已禁用"}
|
||||
status_map = {"R": "已启用", "S": "已暂停"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
|
||||
# 根据媒体类型过滤条件生成提示
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
status: Optional[str] = "all",
|
||||
media_type: Optional[str] = "all",
|
||||
tmdb_id: Optional[int] = None,
|
||||
douban_id: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}, page={page}"
|
||||
)
|
||||
try:
|
||||
if media_type != "all" and not MediaType.from_agent(media_type):
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
filtered_subscribes = []
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.state != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != media_type:
|
||||
if (
|
||||
media_type != "all"
|
||||
and sub.type != MediaType.from_agent(media_type).value
|
||||
):
|
||||
continue
|
||||
if tmdb_id is not None and sub.tmdbid != tmdb_id:
|
||||
continue
|
||||
if douban_id is not None and sub.doubanid != douban_id:
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"year": s.year,
|
||||
"type": s.type,
|
||||
"season": s.season,
|
||||
"tmdbid": s.tmdbid,
|
||||
"doubanid": s.doubanid,
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
"last_update": s.last_update,
|
||||
"username": s.username
|
||||
}
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
# 分页
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_subscribes = filtered_subscribes[start:end]
|
||||
|
||||
if not page_subscribes:
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
return f"第 {page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
|
||||
|
||||
full_subscribes = [
|
||||
SubscribeSchema.model_validate(s, from_attributes=True).model_dump(
|
||||
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS), exclude_none=True
|
||||
)
|
||||
for s in page_subscribes
|
||||
]
|
||||
result_json = json.dumps(full_subscribes, ensure_ascii=False, indent=2)
|
||||
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
payload_msg = f"第 {page}/{total_pages} 页,当前页 {len(page_subscribes)} 条结果,共 {total_count} 条。"
|
||||
if page < total_pages:
|
||||
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
|
||||
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
return "未找到相关订阅"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅失败: {e}", exc_info=True)
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
|
||||
class QueryTransferHistoryInput(BaseModel):
|
||||
@@ -95,7 +96,7 @@ class QueryTransferHistoryTool(MoviePilotTool):
|
||||
"id": record.id,
|
||||
"title": record.title,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"category": record.category,
|
||||
"seasons": record.seasons,
|
||||
"episodes": record.episodes,
|
||||
|
||||
@@ -125,4 +125,3 @@ class QueryWorkflowsTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询工作流失败: {e}", exc_info=True)
|
||||
return f"查询工作流时发生错误: {str(e)}"
|
||||
|
||||
|
||||
81
app/agent/tools/impl/read_file.py
Normal file
81
app/agent/tools/impl/read_file.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""文件读取工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
# 最大读取大小 50KB
|
||||
MAX_READ_SIZE = 50 * 1024
|
||||
|
||||
|
||||
class ReadFileInput(BaseModel):
|
||||
"""Input parameters for read file tool"""
|
||||
file_path: str = Field(..., description="The absolute path of the file to read")
|
||||
start_line: Optional[int] = Field(None, description="The starting line number (1-based, inclusive). If not provided, reading starts from the beginning of the file.")
|
||||
end_line: Optional[int] = Field(None, description="The ending line number (1-based, inclusive). If not provided, reading goes until the end of the file.")
|
||||
|
||||
|
||||
class ReadFileTool(MoviePilotTool):
|
||||
name: str = "read_file"
|
||||
description: str = "Read the content of a text file. Supports reading by line range. Each read is limited to 50KB; content exceeding this limit will be truncated."
|
||||
args_schema: Type[BaseModel] = ReadFileInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在读取文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, start_line: Optional[int] = None,
|
||||
end_line: Optional[int] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}, start_line={start_line}, end_line={end_line}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
|
||||
if not await path.exists():
|
||||
return f"错误:文件 {file_path} 不存在"
|
||||
|
||||
if not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
|
||||
content = await path.read_text(encoding="utf-8")
|
||||
truncated = False
|
||||
|
||||
if start_line is not None or end_line is not None:
|
||||
lines = content.splitlines(keepends=True)
|
||||
total_lines = len(lines)
|
||||
|
||||
# 将行号转换为索引(1-based -> 0-based)
|
||||
s = (start_line - 1) if start_line and start_line >= 1 else 0
|
||||
e = end_line if end_line and end_line >= 1 else total_lines
|
||||
|
||||
# 确保范围有效
|
||||
s = max(0, min(s, total_lines))
|
||||
e = max(s, min(e, total_lines))
|
||||
|
||||
content = "".join(lines[s:e])
|
||||
|
||||
# 检查大小限制
|
||||
content_bytes = content.encode("utf-8")
|
||||
if len(content_bytes) > MAX_READ_SIZE:
|
||||
content = content_bytes[:MAX_READ_SIZE].decode("utf-8", errors="ignore")
|
||||
truncated = True
|
||||
|
||||
if truncated:
|
||||
return f"{content}\n\n[警告:文件内容已超过50KB限制,以上内容已被截断。请使用 start_line/end_line 参数分段读取。]"
|
||||
|
||||
return content
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有权限读取 {file_path}"
|
||||
except UnicodeDecodeError:
|
||||
return f"错误:{file_path} 不是文本文件,无法读取"
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
@@ -10,6 +10,7 @@ from app.chain.media import MediaChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
|
||||
class RecognizeMediaInput(BaseModel):
|
||||
@@ -98,7 +99,8 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"message": error_message
|
||||
}, ensure_ascii=False)
|
||||
|
||||
def _format_context_result(self, context: Context, source_type: str) -> str:
|
||||
@staticmethod
|
||||
def _format_context_result(context: Context, source_type: str) -> str:
|
||||
"""格式化识别结果为JSON字符串"""
|
||||
if not context:
|
||||
return json.dumps({
|
||||
@@ -124,7 +126,7 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"title": media_info.get("title"),
|
||||
"en_title": media_info.get("en_title"),
|
||||
"year": media_info.get("year"),
|
||||
"type": media_info.get("type"),
|
||||
"type": media_type_to_agent(media_info.get("type")),
|
||||
"season": media_info.get("season"),
|
||||
"tmdb_id": media_info.get("tmdb_id"),
|
||||
"imdb_id": media_info.get("imdb_id"),
|
||||
@@ -145,7 +147,7 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"name": meta_info.get("name"),
|
||||
"title": meta_info.get("title"),
|
||||
"year": meta_info.get("year"),
|
||||
"type": meta_info.get("type"),
|
||||
"type": media_type_to_agent(meta_info.get("type")),
|
||||
"begin_season": meta_info.get("begin_season"),
|
||||
"end_season": meta_info.get("end_season"),
|
||||
"begin_episode": meta_info.get("begin_episode"),
|
||||
@@ -159,4 +161,3 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@@ -11,14 +11,22 @@ from app.scheduler import Scheduler
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
"""运行定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
job_id: str = Field(
|
||||
...,
|
||||
description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)",
|
||||
)
|
||||
|
||||
|
||||
class RunSchedulerTool(MoviePilotTool):
|
||||
name: str = "run_scheduler"
|
||||
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
|
||||
args_schema: Type[BaseModel] = RunSchedulerInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据运行参数生成友好的提示消息"""
|
||||
@@ -39,15 +47,14 @@ class RunSchedulerTool(MoviePilotTool):
|
||||
job_exists = True
|
||||
job_name = s.name
|
||||
break
|
||||
|
||||
|
||||
if not job_exists:
|
||||
return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务"
|
||||
|
||||
|
||||
# 运行定时服务
|
||||
scheduler.start(job_id)
|
||||
|
||||
|
||||
return f"成功触发定时服务:{job_name} (ID: {job_id})"
|
||||
except Exception as e:
|
||||
logger.error(f"运行定时服务失败: {e}", exc_info=True)
|
||||
return f"运行定时服务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
115
app/agent/tools/impl/run_slash_command.py
Normal file
115
app/agent/tools/impl/run_slash_command.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""运行斜杠命令工具(系统命令 + 插件命令)"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, MessageChannel
|
||||
|
||||
|
||||
class RunSlashCommandInput(BaseModel):
|
||||
"""运行斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
command: str = Field(
|
||||
...,
|
||||
description="The slash command to execute, e.g. '/cookiecloud'. "
|
||||
"Must start with '/'. Can include arguments after the command, e.g. '/command arg1 arg2'. "
|
||||
"Use query_plugin_capabilities tool to discover available plugin commands, "
|
||||
"or list_slash_commands tool to discover all available commands (including system commands).",
|
||||
)
|
||||
|
||||
|
||||
class RunSlashCommandTool(MoviePilotTool):
|
||||
name: str = "run_slash_command"
|
||||
description: str = (
|
||||
"Execute a slash command (system or plugin) by sending a CommandExcute event. "
|
||||
"This tool supports ALL registered slash commands, including: "
|
||||
"1) System preset commands (e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
|
||||
"2) Plugin commands registered by installed plugins. "
|
||||
"Use the query_plugin_capabilities tool to discover plugin commands, "
|
||||
"or the list_slash_commands tool to discover all available commands. "
|
||||
"The command will be executed asynchronously. "
|
||||
"Note: This tool triggers the command execution but the actual processing happens in the background."
|
||||
)
|
||||
args_schema: Type[BaseModel] = RunSlashCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行命令: {command}"
|
||||
|
||||
async def run(self, command: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}")
|
||||
|
||||
try:
|
||||
# 确保命令以 / 开头
|
||||
if not command.startswith("/"):
|
||||
command = f"/{command}"
|
||||
|
||||
# 从全局 Command 单例中验证命令是否存在(包含系统预设命令 + 插件命令 + 其他命令)
|
||||
from app.command import Command
|
||||
|
||||
cmd_name = command.split()[0]
|
||||
command_obj = Command()
|
||||
matched_command = command_obj.get(cmd_name)
|
||||
|
||||
if not matched_command:
|
||||
# 列出所有可用命令帮助用户
|
||||
all_commands = command_obj.get_commands()
|
||||
available_cmds = [
|
||||
f"{cmd} - {info.get('description', '无描述')}"
|
||||
for cmd, info in all_commands.items()
|
||||
]
|
||||
result = {
|
||||
"success": False,
|
||||
"message": f"命令 {cmd_name} 不存在",
|
||||
}
|
||||
if available_cmds:
|
||||
result["available_commands"] = available_cmds
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
# 构建消息渠道,优先使用当前会话的渠道信息
|
||||
channel = None
|
||||
if self._channel:
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except (ValueError, KeyError):
|
||||
channel = None
|
||||
|
||||
# 发送命令执行事件,与 message.py 中的方式一致
|
||||
eventmanager.send_event(
|
||||
EventType.CommandExcute,
|
||||
{
|
||||
"cmd": command,
|
||||
"user": self._user_id,
|
||||
"channel": channel,
|
||||
"source": self._source,
|
||||
},
|
||||
)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"命令 {cmd_name} 已触发执行",
|
||||
"command": command,
|
||||
"command_desc": matched_command.get("description", ""),
|
||||
}
|
||||
# 如果是插件命令,附加插件ID
|
||||
if matched_command.get("pid"):
|
||||
result["plugin_id"] = matched_command["pid"]
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"执行命令时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -13,55 +13,61 @@ from app.log import logger
|
||||
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
workflow_identifier: str = Field(..., description="Workflow identifier: can be workflow ID (integer as string) or workflow name")
|
||||
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
workflow_id: int = Field(
|
||||
..., description="Workflow ID (can be obtained from query_workflows tool)"
|
||||
)
|
||||
from_begin: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)",
|
||||
)
|
||||
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
description: str = "Execute a specific workflow manually. Can run workflow by ID or name. Supports running from the beginning or continuing from the last executed action."
|
||||
description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据工作流参数生成友好的提示消息"""
|
||||
workflow_identifier = kwargs.get("workflow_identifier", "")
|
||||
workflow_id = kwargs.get("workflow_id")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
message = f"正在执行工作流: {workflow_identifier}"
|
||||
|
||||
message = f"正在执行工作流: {workflow_id}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
message += " (从头开始)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, workflow_identifier: str,
|
||||
from_begin: Optional[bool] = True, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_identifier={workflow_identifier}, from_begin={from_begin}")
|
||||
async def run(
|
||||
self, workflow_id: int, from_begin: Optional[bool] = True, **kwargs
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
|
||||
# 尝试解析为工作流ID
|
||||
workflow = None
|
||||
if workflow_identifier.isdigit():
|
||||
# 如果是数字,尝试作为工作流ID查询
|
||||
workflow = await workflow_oper.async_get(int(workflow_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称查询
|
||||
workflow = await workflow_oper.async_get(workflow_id)
|
||||
|
||||
if not workflow:
|
||||
workflow = await workflow_oper.async_get_by_name(workflow_identifier)
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_identifier},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
# 执行工作流
|
||||
workflow_chain = WorkflowChain()
|
||||
state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin)
|
||||
|
||||
state, errmsg = workflow_chain.process(
|
||||
workflow.id, from_begin=from_begin
|
||||
)
|
||||
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
@@ -69,4 +75,3 @@ class RunWorkflowTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流失败: {e}", exc_info=True)
|
||||
return f"执行工作流时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -16,18 +16,29 @@ from app.schemas import FileItem
|
||||
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False,
|
||||
description="Whether to overwrite existing metadata files (default: False)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
)
|
||||
storage: Optional[str] = Field(
|
||||
"local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')",
|
||||
)
|
||||
overwrite: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to overwrite existing metadata files (default: False)",
|
||||
)
|
||||
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
name: str = "scrape_metadata"
|
||||
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = ScrapeMetadataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -44,33 +55,38 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
|
||||
async def run(
|
||||
self,
|
||||
path: str,
|
||||
storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证路径
|
||||
if not path:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "刮削路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": "刮削路径不能为空"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 创建 FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage,
|
||||
path=path,
|
||||
type="file" if Path(path).suffix else "dir"
|
||||
storage=storage, path=path, type="file" if Path(path).suffix else "dir"
|
||||
)
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
if not scrape_path.exists():
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削路径不存在: {path}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"刮削路径不存在: {path}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
@@ -79,11 +95,14 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 在线程池中执行同步的刮削操作
|
||||
await global_vars.loop.run_in_executor(
|
||||
@@ -92,28 +111,31 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
overwrite=overwrite
|
||||
)
|
||||
overwrite=overwrite,
|
||||
),
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season
|
||||
}
|
||||
}, ensure_ascii=False, indent=2)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season,
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"刮削媒体元数据失败: {str(e)}"
|
||||
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "path": path},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
class SearchMediaInput(BaseModel):
|
||||
@@ -17,7 +17,7 @@ class SearchMediaInput(BaseModel):
|
||||
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
description="Allowed values: movie, tv")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows and anime (optional, only applicable for series)")
|
||||
|
||||
@@ -56,14 +56,19 @@ class SearchMediaTool(MoviePilotTool):
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
if year and result.year != year:
|
||||
continue
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if season and result.season != season:
|
||||
if media_type_enum and result.type != media_type_enum:
|
||||
continue
|
||||
if season is not None and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
@@ -78,7 +83,7 @@ class SearchMediaTool(MoviePilotTool):
|
||||
"title": r.title,
|
||||
"en_title": r.en_title,
|
||||
"year": r.year,
|
||||
"type": r.type.value if r.type else None,
|
||||
"type": media_type_to_agent(r.type),
|
||||
"season": r.season,
|
||||
"tmdb_id": r.tmdb_id,
|
||||
"imdb_id": r.imdb_id,
|
||||
@@ -91,8 +96,8 @@ class SearchMediaTool(MoviePilotTool):
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
if total_count > 100:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 100 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
|
||||
@@ -72,8 +72,8 @@ class SearchPersonTool(MoviePilotTool):
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关人物信息: {name}"
|
||||
|
||||
@@ -10,15 +10,16 @@ from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import global_vars
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
|
||||
class SearchSubscribeInput(BaseModel):
|
||||
"""搜索订阅缺失剧集工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes (can be obtained from query_subscribes tool)")
|
||||
manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply for this search (optional, use query_rule_groups tool to get available rule groups. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
description="List of filter rule group names to apply for this search (optional, can be obtained from query_rule_groups tool. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
|
||||
|
||||
class SearchSubscribeTool(MoviePilotTool):
|
||||
@@ -58,7 +59,7 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
"id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"year": subscribe.year,
|
||||
"type": subscribe.type,
|
||||
"type": media_type_to_agent(subscribe.type),
|
||||
"season": subscribe.season,
|
||||
"state": subscribe.state,
|
||||
"total_episode": subscribe.total_episode,
|
||||
|
||||
@@ -1,141 +1,109 @@
|
||||
"""搜索种子工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, SystemConfigKey
|
||||
from ._torrent_search_utils import (
|
||||
SEARCH_RESULT_CACHE_FILE,
|
||||
build_filter_options,
|
||||
)
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
"""搜索种子工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(...,
|
||||
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
area: Optional[str] = Field(None, description="Search scope: 'title' (default) or 'imdbid'")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
|
||||
filter_pattern: Optional[str] = Field(None,
|
||||
description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)")
|
||||
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
|
||||
description: str = ("Search for torrent files by media ID across configured indexer sites, cache the matched results, "
|
||||
"and return available filter options for follow-up selection. "
|
||||
"Requires tmdb_id or douban_id (can be obtained from search_media tool) for accurate matching.")
|
||||
args_schema: Type[BaseModel] = SearchTorrentsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year")
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
douban_id = kwargs.get("douban_id")
|
||||
media_type = kwargs.get("media_type")
|
||||
season = kwargs.get("season")
|
||||
filter_pattern = kwargs.get("filter_pattern")
|
||||
|
||||
message = f"正在搜索种子: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
|
||||
if tmdb_id:
|
||||
message = f"正在搜索种子: TMDB={tmdb_id}"
|
||||
elif douban_id:
|
||||
message = f"正在搜索种子: 豆瓣={douban_id}"
|
||||
else:
|
||||
message = "正在搜索种子"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
if filter_pattern:
|
||||
message += f" 过滤: {filter_pattern}"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
|
||||
async def run(self, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None,
|
||||
media_type: Optional[str] = None, area: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}")
|
||||
f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}, area={area}, sites={sites}")
|
||||
|
||||
if not tmdb_id and not douban_id:
|
||||
return "参数错误:tmdb_id 和 douban_id 至少需要提供一个,请先使用 search_media 工具获取媒体 ID。"
|
||||
|
||||
try:
|
||||
search_chain = SearchChain()
|
||||
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
|
||||
filtered_torrents = []
|
||||
# 编译正则表达式(如果提供)
|
||||
regex_pattern = None
|
||||
if filter_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(filter_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
for torrent in torrents:
|
||||
# torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性
|
||||
if year and torrent.meta_info and torrent.meta_info.year != year:
|
||||
continue
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
if not regex_pattern.search(torrent.torrent_info.title):
|
||||
continue
|
||||
filtered_torrents.append(torrent)
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
filtered_torrents = await search_chain.async_search_by_id(
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
mtype=media_type_enum,
|
||||
area=area or "title",
|
||||
sites=sites,
|
||||
cache_local=False,
|
||||
)
|
||||
|
||||
# 获取站点信息
|
||||
all_indexers = await SitesHelper().async_get_indexers()
|
||||
all_sites = [{"id": indexer.get("id"), "name": indexer.get("name")} for indexer in (all_indexers or [])]
|
||||
|
||||
if sites:
|
||||
search_site_ids = sites
|
||||
else:
|
||||
configured_sites = SystemConfigOper().get(SystemConfigKey.IndexerSites)
|
||||
search_site_ids = configured_sites if configured_sites else []
|
||||
|
||||
if filtered_torrents:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_torrents)
|
||||
limited_torrents = filtered_torrents[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_torrents = []
|
||||
for t in limited_torrents:
|
||||
simplified = {}
|
||||
# 精简 torrent_info
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
"enclosure": t.torrent_info.enclosure,
|
||||
"page_url": t.torrent_info.page_url,
|
||||
"volume_factor": t.torrent_info.volume_factor,
|
||||
"pubdate": t.torrent_info.pubdate
|
||||
}
|
||||
# 精简 media_info
|
||||
if t.media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": t.media_info.title,
|
||||
"en_title": t.media_info.en_title,
|
||||
"year": t.media_info.year,
|
||||
"type": t.media_info.type.value if t.media_info.type else None,
|
||||
"season": t.media_info.season,
|
||||
"tmdb_id": t.media_info.tmdb_id
|
||||
}
|
||||
# 精简 meta_info
|
||||
if t.meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": t.meta_info.name,
|
||||
"cn_name": t.meta_info.cn_name,
|
||||
"en_name": t.meta_info.en_name,
|
||||
"year": t.meta_info.year,
|
||||
"type": t.meta_info.type.value if t.meta_info.type else None,
|
||||
"begin_season": t.meta_info.begin_season
|
||||
}
|
||||
simplified_torrents.append(simplified)
|
||||
result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
await search_chain.async_save_cache(filtered_torrents, SEARCH_RESULT_CACHE_FILE)
|
||||
result_json = json.dumps({
|
||||
"total_count": len(filtered_torrents),
|
||||
"message": "搜索完成。请使用 get_search_results 工具获取搜索结果。",
|
||||
"all_sites": all_sites,
|
||||
"search_site_ids": search_site_ids,
|
||||
"filter_options": build_filter_options(filtered_torrents),
|
||||
}, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关种子资源: {title}"
|
||||
media_id = f"TMDB={tmdb_id}" if tmdb_id else f"豆瓣={douban_id}"
|
||||
result_json = json.dumps({
|
||||
"message": f"未找到相关种子资源: {media_id}",
|
||||
"all_sites": all_sites,
|
||||
"search_site_ids": search_site_ids,
|
||||
}, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
except Exception as e:
|
||||
error_message = f"搜索种子时发生错误: {str(e)}"
|
||||
logger.error(f"搜索种子失败: {e}", exc_info=True)
|
||||
|
||||
@@ -1,22 +1,35 @@
|
||||
"""搜索网络内容工具"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List, Dict
|
||||
|
||||
import httpx
|
||||
from ddgs import DDGS
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.http import AsyncRequestUtils
|
||||
|
||||
# 搜索超时时间(秒)
|
||||
SEARCH_TIMEOUT = 20
|
||||
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
query: str = Field(..., description="The search query string to search for on the web")
|
||||
max_results: Optional[int] = Field(5, description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: str = Field(
|
||||
..., description="The search query string to search for on the web"
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
20,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)",
|
||||
)
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
@@ -27,167 +40,200 @@ class SearchWebTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 5)
|
||||
max_results = kwargs.get("max_results", 20)
|
||||
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
|
||||
async def run(self, query: str, max_results: Optional[int] = 20, **kwargs) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
max_results: 最大返回结果数(默认5,最大10)
|
||||
|
||||
Returns:
|
||||
格式化的搜索结果JSON字符串
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 5), 10)
|
||||
|
||||
# 使用DuckDuckGo API进行搜索
|
||||
search_results = await self._search_duckduckgo_api(query, max_results)
|
||||
|
||||
if not search_results:
|
||||
max_results = min(max(1, max_results or 20), 20)
|
||||
results = []
|
||||
|
||||
# 1. 优先使用 Exa (如果配置了 API Key)
|
||||
if settings.EXA_API_KEY:
|
||||
logger.info("使用 Exa 进行搜索...")
|
||||
results = await self._search_exa(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Exa,使用 Tavily (如果配置了 API Key)
|
||||
if not results and settings.TAVILY_API_KEY:
|
||||
logger.info("使用 Tavily 进行搜索...")
|
||||
results = await self._search_tavily(query, max_results)
|
||||
|
||||
# 3. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
if not results:
|
||||
logger.info("使用 DuckDuckGo 进行搜索...")
|
||||
results = await self._search_duckduckgo(query, max_results)
|
||||
|
||||
if not results:
|
||||
return f"未找到与 '{query}' 相关的搜索结果"
|
||||
|
||||
# 裁剪结果以避免占用过多上下文
|
||||
formatted_results = self._format_and_truncate_results(search_results, max_results)
|
||||
|
||||
result_json = json.dumps(formatted_results, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
|
||||
|
||||
# 格式化并裁剪结果
|
||||
formatted_results = self._format_and_truncate_results(results, max_results)
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索网络内容失败: {str(e)}"
|
||||
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def _search_duckduckgo_api(query: str, max_results: int) -> list:
|
||||
"""
|
||||
使用DuckDuckGo API进行搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
async def _search_tavily(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Tavily API 进行搜索"""
|
||||
try:
|
||||
# DuckDuckGo Instant Answer API
|
||||
api_url = "https://api.duckduckgo.com/"
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": "1",
|
||||
"skip_disambig": "1"
|
||||
}
|
||||
|
||||
# 使用代理(如果配置了)
|
||||
http_utils = AsyncRequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
data = await http_utils.get_json(api_url, params=params)
|
||||
|
||||
results = []
|
||||
|
||||
if data:
|
||||
# 处理AbstractText(摘要)
|
||||
if data.get("AbstractText"):
|
||||
results.append({
|
||||
"title": data.get("Heading", query),
|
||||
"snippet": data.get("AbstractText", ""),
|
||||
"url": data.get("AbstractURL", ""),
|
||||
"source": "DuckDuckGo Abstract"
|
||||
})
|
||||
|
||||
# 处理RelatedTopics(相关主题)
|
||||
related_topics = data.get("RelatedTopics", [])
|
||||
for topic in related_topics[:max_results - len(results)]:
|
||||
if isinstance(topic, dict):
|
||||
text = topic.get("Text", "")
|
||||
first_url = topic.get("FirstURL", "")
|
||||
if text and first_url:
|
||||
# 提取标题(通常在" - "之前)
|
||||
title = text.split(" - ")[0] if " - " in text else text[:100]
|
||||
snippet = text
|
||||
|
||||
results.append({
|
||||
"title": title.strip(),
|
||||
"snippet": snippet,
|
||||
"url": first_url,
|
||||
"source": "DuckDuckGo Related"
|
||||
})
|
||||
|
||||
# 处理Results(搜索结果)
|
||||
api_results = data.get("Results", [])
|
||||
for result in api_results[:max_results - len(results)]:
|
||||
if isinstance(result, dict):
|
||||
title = result.get("Text", "")
|
||||
url = result.get("FirstURL", "")
|
||||
if title and url:
|
||||
results.append({
|
||||
"title": title,
|
||||
"snippet": result.get("Text", ""),
|
||||
"url": url,
|
||||
"source": "DuckDuckGo Results"
|
||||
})
|
||||
|
||||
return results[:max_results]
|
||||
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
# 从设置中随机选择一个 API Key(如果有多个)
|
||||
tavity_api_key = random.choice(settings.TAVILY_API_KEY)
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={
|
||||
"api_key": tavity_api_key,
|
||||
"query": query,
|
||||
"search_depth": "basic",
|
||||
"max_results": max_results,
|
||||
"include_answer": False,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("content", ""),
|
||||
"url": result.get("url", ""),
|
||||
"source": "Tavily",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo API搜索失败: {e}")
|
||||
logger.warning(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: list, max_results: int) -> dict:
|
||||
"""
|
||||
格式化并裁剪搜索结果以避免占用过多上下文
|
||||
|
||||
Args:
|
||||
results: 原始搜索结果列表
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
格式化后的结果字典
|
||||
"""
|
||||
formatted = {
|
||||
"total_results": len(results),
|
||||
"results": []
|
||||
}
|
||||
|
||||
# 限制结果数量
|
||||
limited_results = results[:max_results]
|
||||
|
||||
for idx, result in enumerate(limited_results, 1):
|
||||
title = result.get("title", "")[:200] # 限制标题长度
|
||||
async def _search_exa(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Exa API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={
|
||||
"x-api-key": settings.EXA_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"numResults": max_results,
|
||||
"type": "auto",
|
||||
"contents": {"highlights": {"maxCharacters": 2000}},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
highlights = result.get("highlights", [])
|
||||
snippet = (
|
||||
highlights[0] if highlights else result.get("text", "")[:500]
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": snippet,
|
||||
"url": result.get("url", ""),
|
||||
"source": "Exa",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Exa 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_proxy_url(proxy_setting) -> Optional[str]:
|
||||
"""从代理设置中提取代理URL"""
|
||||
if not proxy_setting:
|
||||
return None
|
||||
if isinstance(proxy_setting, dict):
|
||||
return proxy_setting.get("http") or proxy_setting.get("https")
|
||||
return proxy_setting
|
||||
|
||||
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 duckduckgo-search (DDGS) 进行搜索"""
|
||||
try:
|
||||
|
||||
def sync_search():
|
||||
results = []
|
||||
ddgs_kwargs = {"timeout": SEARCH_TIMEOUT}
|
||||
proxy_url = self._get_proxy_url(settings.PROXY)
|
||||
if proxy_url:
|
||||
ddgs_kwargs["proxy"] = proxy_url
|
||||
|
||||
try:
|
||||
with DDGS(**ddgs_kwargs) as ddgs:
|
||||
ddgs_gen = ddgs.text(query, max_results=max_results)
|
||||
if ddgs_gen:
|
||||
for result in ddgs_gen:
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("body", ""),
|
||||
"url": result.get("href", ""),
|
||||
"source": "DuckDuckGo",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f"DuckDuckGo search process failed: {err}")
|
||||
return results
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, sync_search)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
|
||||
"""格式化并裁剪搜索结果"""
|
||||
formatted = {"total_results": len(results), "results": []}
|
||||
|
||||
for idx, result in enumerate(results[:max_results], 1):
|
||||
title = result.get("title", "")[:200]
|
||||
snippet = result.get("snippet", "")
|
||||
url = result.get("url", "")
|
||||
source = result.get("source", "Unknown")
|
||||
|
||||
# 裁剪摘要,避免过长
|
||||
max_snippet_length = 300 # 每个摘要最多300字符
|
||||
|
||||
# 裁剪摘要
|
||||
max_snippet_length = 1000 # 增加到1000字符,提供更多上下文
|
||||
if len(snippet) > max_snippet_length:
|
||||
snippet = snippet[:max_snippet_length] + "..."
|
||||
|
||||
# 清理文本,移除多余的空白字符
|
||||
snippet = re.sub(r'\s+', ' ', snippet).strip()
|
||||
|
||||
formatted["results"].append({
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source
|
||||
})
|
||||
|
||||
# 添加提示信息
|
||||
|
||||
# 清理文本
|
||||
snippet = re.sub(r"\s+", " ", snippet).strip()
|
||||
|
||||
formatted["results"].append(
|
||||
{
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) > max_results:
|
||||
formatted["note"] = f"注意:共找到 {len(results)} 条结果,为节省上下文空间,仅显示前 {max_results} 条结果。"
|
||||
|
||||
formatted["note"] = f"仅显示前 {max_results} 条结果。"
|
||||
|
||||
return formatted
|
||||
|
||||
@@ -10,35 +10,47 @@ from app.log import logger
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
)
|
||||
message_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Title of the message, a short summary of the message content",
|
||||
)
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
message_type = kwargs.get("message_type", "info")
|
||||
|
||||
type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"}
|
||||
type_desc = type_map.get(message_type, message_type)
|
||||
|
||||
title = kwargs.get("message_type") or ""
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
return f"正在发送{type_desc}消息: {message}"
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
if title:
|
||||
return f"正在发送消息: [{title}] {message}"
|
||||
return f"正在发送消息: {message}"
|
||||
|
||||
async def run(
|
||||
self, message: str, message_type: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
title = message_type or ""
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, message={message}")
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
await self.send_tool_message(message, title=title)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
@@ -8,53 +8,31 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class TestSiteInput(BaseModel):
|
||||
"""测试站点连通性工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
site_identifier: int = Field(..., description="Site ID to test (can be obtained from query_sites tool)")
|
||||
|
||||
|
||||
class TestSiteTool(MoviePilotTool):
|
||||
name: str = "test_site"
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = TestSiteInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据测试参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
return f"正在测试站点连通性: {site_identifier}"
|
||||
|
||||
async def run(self, site_identifier: str, **kwargs) -> str:
|
||||
async def run(self, site_identifier: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
site = await site_oper.async_get(site_identifier)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
@@ -69,4 +47,3 @@ class TestSiteTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"测试站点连通性失败: {e}", exc_info=True)
|
||||
return f"测试站点连通性时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -13,23 +13,53 @@ from app.schemas import FileItem, MediaType
|
||||
|
||||
class TransferFileInput(BaseModel):
|
||||
"""整理文件或目录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
|
||||
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
|
||||
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
|
||||
media_type: Optional[str] = Field(None, description="Media type: '电影' for films, '电视剧' for television series (optional, will be auto-detected if not specified)")
|
||||
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
|
||||
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)")
|
||||
background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
)
|
||||
storage: Optional[str] = Field(
|
||||
"local",
|
||||
description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)",
|
||||
)
|
||||
target_path: Optional[str] = Field(
|
||||
None,
|
||||
description="Target path for the transferred file/directory (optional, uses default library path if not specified)",
|
||||
)
|
||||
target_storage: Optional[str] = Field(
|
||||
None,
|
||||
description="Target storage type (optional, uses default storage if not specified)",
|
||||
)
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
tmdbid: Optional[int] = Field(
|
||||
None,
|
||||
description="TMDB ID for precise media identification (optional but recommended for accuracy)",
|
||||
)
|
||||
doubanid: Optional[str] = Field(
|
||||
None, description="Douban ID for media identification (optional)"
|
||||
)
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
transfer_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)",
|
||||
)
|
||||
background: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to run transfer in background (default: False, runs synchronously)",
|
||||
)
|
||||
|
||||
|
||||
class TransferFileTool(MoviePilotTool):
|
||||
name: str = "transfer_file"
|
||||
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
|
||||
args_schema: Type[BaseModel] = TransferFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据整理参数生成友好的提示消息"""
|
||||
@@ -37,67 +67,79 @@ class TransferFileTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type")
|
||||
transfer_type = kwargs.get("transfer_type")
|
||||
background = kwargs.get("background", False)
|
||||
|
||||
|
||||
message = f"正在整理文件: {file_path}"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if transfer_type:
|
||||
transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"}
|
||||
transfer_map = {
|
||||
"move": "移动",
|
||||
"copy": "复制",
|
||||
"link": "硬链接",
|
||||
"softlink": "软链接",
|
||||
}
|
||||
message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}"
|
||||
if background:
|
||||
message += " [后台运行]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, file_path: str, storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
file_path: str,
|
||||
storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, "
|
||||
f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, "
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}")
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}"
|
||||
)
|
||||
|
||||
try:
|
||||
if not file_path:
|
||||
return "错误:必须提供文件或目录路径"
|
||||
|
||||
|
||||
# 规范化路径
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"):
|
||||
if not file_path.startswith("/") and not (
|
||||
len(file_path) > 1 and file_path[1] == ":"
|
||||
):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
file_path = str(Path(file_path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not file_path.startswith("/"):
|
||||
file_path = "/" + file_path
|
||||
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=file_path,
|
||||
type="dir" if file_path.endswith("/") else "file"
|
||||
type="dir" if file_path.endswith("/") else "file",
|
||||
)
|
||||
|
||||
|
||||
# 处理目标路径
|
||||
target_path_obj = None
|
||||
if target_path:
|
||||
target_path_obj = Path(target_path)
|
||||
|
||||
|
||||
# 处理媒体类型
|
||||
mtype = None
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
try:
|
||||
mtype = MediaType(media_type)
|
||||
except ValueError:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
|
||||
# 调用整理方法
|
||||
transfer_chain = TransferChain()
|
||||
state, errormsg = transfer_chain.manual_transfer(
|
||||
@@ -106,18 +148,20 @@ class TransferFileTool(MoviePilotTool):
|
||||
target_path=target_path_obj,
|
||||
tmdbid=tmdbid,
|
||||
doubanid=doubanid,
|
||||
mtype=mtype,
|
||||
mtype=media_type_enum,
|
||||
season=season,
|
||||
transfer_type=transfer_type,
|
||||
background=background
|
||||
background=background,
|
||||
)
|
||||
|
||||
|
||||
if not state:
|
||||
# 处理错误信息
|
||||
if isinstance(errormsg, list):
|
||||
error_text = f"整理完成,{len(errormsg)} 个文件转移失败"
|
||||
if errormsg:
|
||||
error_text += f":\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误
|
||||
error_text += f":\n" + "\n".join(
|
||||
str(e) for e in errormsg[:5]
|
||||
) # 只显示前5个错误
|
||||
if len(errormsg) > 5:
|
||||
error_text += f"\n... 还有 {len(errormsg) - 5} 个错误"
|
||||
else:
|
||||
@@ -131,4 +175,3 @@ class TransferFileTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"整理文件失败: {e}", exc_info=True)
|
||||
return f"整理文件时发生错误: {str(e)}"
|
||||
|
||||
|
||||
95
app/agent/tools/impl/update_custom_identifiers.py
Normal file
95
app/agent/tools/impl/update_custom_identifiers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""更新自定义识别词工具"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class UpdateCustomIdentifiersInput(BaseModel):
|
||||
"""更新自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
identifiers: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"The complete list of custom identifier rules to save. "
|
||||
"This REPLACES the entire existing list. "
|
||||
"Always query existing identifiers first, merge new rules, then pass the full list."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UpdateCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "update_custom_identifiers"
|
||||
description: str = (
|
||||
"Update the full list of custom identifiers (自定义识别词) used for preprocessing torrent/file names. "
|
||||
"This tool REPLACES all existing identifier rules with the provided list. "
|
||||
"IMPORTANT: Always use 'query_custom_identifiers' first to get existing rules, "
|
||||
"then merge new rules into the list before calling this tool to avoid accidentally deleting existing rules. "
|
||||
"Supported rule formats (spaces around operators are required): "
|
||||
"1) Block word: just the word/regex to remove; "
|
||||
"2) Replacement: '被替换词 => 替换词'; "
|
||||
"3) Episode offset: '前定位词 <> 后定位词 >> EP±N'; "
|
||||
"4) Combined: '被替换词 => 替换词 && 前定位词 <> 后定位词 >> EP±N'; "
|
||||
"Lines starting with '#' are comments. "
|
||||
"The replacement target supports: {[tmdbid=xxx;type=movie/tv;s=xxx;e=xxx]} for direct TMDB ID matching."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdateCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
identifiers = kwargs.get("identifiers", [])
|
||||
return f"正在更新自定义识别词(共 {len(identifiers)} 条规则)"
|
||||
|
||||
async def run(self, identifiers: List[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 规则数量: {len(identifiers) if identifiers else 0}"
|
||||
)
|
||||
try:
|
||||
if identifiers is None:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "必须提供 identifiers 参数"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 过滤空字符串
|
||||
identifiers = [i for i in identifiers if i is not None]
|
||||
|
||||
system_config_oper = SystemConfigOper()
|
||||
|
||||
# 保存
|
||||
value = identifiers if identifiers else None
|
||||
success = await system_config_oper.async_set(
|
||||
SystemConfigKey.CustomIdentifiers, value
|
||||
)
|
||||
if success:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"自定义识别词已更新,共 {len(identifiers)} 条规则",
|
||||
"count": len(identifiers),
|
||||
"identifiers": identifiers,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
else:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "保存自定义识别词失败"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"更新自定义识别词失败: {e}")
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"更新自定义识别词时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -16,37 +16,67 @@ from app.utils.string import StringUtils
|
||||
|
||||
class UpdateSiteInput(BaseModel):
|
||||
"""更新站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to update")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to update (can be obtained from query_sites tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
url: Optional[str] = Field(
|
||||
None, description="Site URL (optional, will be automatically formatted)"
|
||||
)
|
||||
pri: Optional[int] = Field(
|
||||
None,
|
||||
description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)",
|
||||
)
|
||||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||||
token: Optional[str] = Field(None, description="API token (optional)")
|
||||
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
proxy: Optional[int] = Field(
|
||||
None, description="Whether to use proxy: 0 for no, 1 for yes (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||||
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||||
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||||
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||||
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||||
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||||
timeout: Optional[int] = Field(
|
||||
None, description="Request timeout in seconds (optional, default: 15)"
|
||||
)
|
||||
limit_interval: Optional[int] = Field(
|
||||
None, description="Rate limit interval in seconds (optional)"
|
||||
)
|
||||
limit_count: Optional[int] = Field(
|
||||
None, description="Rate limit count per interval (optional)"
|
||||
)
|
||||
limit_seconds: Optional[int] = Field(
|
||||
None, description="Rate limit seconds between requests (optional)"
|
||||
)
|
||||
is_active: Optional[bool] = Field(
|
||||
None,
|
||||
description="Whether site is active: True for enabled, False for disabled (optional)",
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None, description="Downloader name for this site (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteTool(MoviePilotTool):
|
||||
name: str = "update_site"
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("url"):
|
||||
@@ -63,60 +93,63 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
fields_updated.append("启用状态")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新站点 #{site_id}"
|
||||
|
||||
async def run(self, site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
@@ -126,7 +159,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
@@ -136,7 +169,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
@@ -144,39 +177,40 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新站点
|
||||
await site.async_update(db, site_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await Site.async_get(db, site_id)
|
||||
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": updated_site.domain if updated_site else site.domain
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SiteUpdated,
|
||||
{"domain": updated_site.domain if updated_site else site.domain},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys())
|
||||
"updated_fields": list(site_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
@@ -187,17 +221,15 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout
|
||||
"timeout": updated_site.timeout,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新站点失败: {str(e)}"
|
||||
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "site_id": site_id},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -8,76 +8,73 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class UpdateSiteCookieInput(BaseModel):
|
||||
"""更新站点Cookie和UA工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_identifier: int = Field(
|
||||
...,
|
||||
description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)",
|
||||
)
|
||||
username: str = Field(..., description="Site login username")
|
||||
password: str = Field(..., description="Site login password")
|
||||
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
|
||||
two_step_code: Optional[str] = Field(
|
||||
None,
|
||||
description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)",
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteCookieTool(MoviePilotTool):
|
||||
name: str = "update_site_cookie"
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = UpdateSiteCookieInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
|
||||
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
if two_step_code:
|
||||
message += " [需要两步验证]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_identifier: str, username: str, password: str,
|
||||
two_step_code: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
|
||||
async def run(
|
||||
self,
|
||||
site_identifier: int,
|
||||
username: str,
|
||||
password: str,
|
||||
two_step_code: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}"
|
||||
)
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
|
||||
site = await site_oper.async_get(site_identifier)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
|
||||
# 更新站点Cookie和UA
|
||||
status, message = site_chain.update_cookie(
|
||||
site_info=site,
|
||||
username=username,
|
||||
password=password,
|
||||
two_step_code=two_step_code
|
||||
two_step_code=two_step_code,
|
||||
)
|
||||
|
||||
|
||||
if status:
|
||||
return f"站点【{site.name}】Cookie和UA更新成功\n{message}"
|
||||
else:
|
||||
@@ -85,4 +82,3 @@ class UpdateSiteCookieTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True)
|
||||
return f"更新站点Cookie和UA时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -15,40 +15,87 @@ from app.schemas.types import EventType
|
||||
|
||||
class UpdateSubscribeInput(BaseModel):
|
||||
"""更新订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to update (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
|
||||
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
|
||||
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
|
||||
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
total_episode: Optional[int] = Field(
|
||||
None, description="Total number of episodes (optional)"
|
||||
)
|
||||
lack_episode: Optional[int] = Field(
|
||||
None, description="Number of missing episodes (optional)"
|
||||
)
|
||||
start_episode: Optional[int] = Field(
|
||||
None, description="Starting episode number (optional)"
|
||||
)
|
||||
quality: Optional[str] = Field(
|
||||
None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')",
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')",
|
||||
)
|
||||
effect: Optional[str] = Field(
|
||||
None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')",
|
||||
)
|
||||
include: Optional[str] = Field(
|
||||
None, description="Include filter as regular expression (optional)"
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None, description="Exclude filter as regular expression (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
state: Optional[str] = Field(
|
||||
None,
|
||||
description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)",
|
||||
)
|
||||
sites: Optional[List[int]] = Field(
|
||||
None, description="List of site IDs to search from (optional)"
|
||||
)
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
|
||||
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
|
||||
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
save_path: Optional[str] = Field(
|
||||
None, description="Save path for downloaded files (optional)"
|
||||
)
|
||||
best_version: Optional[int] = Field(
|
||||
None,
|
||||
description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)",
|
||||
)
|
||||
custom_words: Optional[str] = Field(
|
||||
None, description="Custom recognition words (optional)"
|
||||
)
|
||||
media_category: Optional[str] = Field(
|
||||
None, description="Custom media category (optional)"
|
||||
)
|
||||
episode_group: Optional[str] = Field(
|
||||
None, description="Episode group ID (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSubscribeTool(MoviePilotTool):
|
||||
name: str = "update_subscribe"
|
||||
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
|
||||
args_schema: Type[BaseModel] = UpdateSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("total_episode") is not None:
|
||||
@@ -61,57 +108,62 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
fields_updated.append("分辨率过滤")
|
||||
if kwargs.get("state"):
|
||||
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
|
||||
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
|
||||
fields_updated.append(
|
||||
f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})"
|
||||
)
|
||||
if kwargs.get("sites"):
|
||||
fields_updated.append("站点")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新订阅 #{subscribe_id}"
|
||||
|
||||
async def run(self, subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取订阅
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
@@ -119,27 +171,29 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
|
||||
subscribe_dict["lack_episode"] = old_lack + (
|
||||
total_episode - (subscribe.total_episode or 0)
|
||||
)
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
@@ -153,17 +207,20 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
@@ -173,7 +230,7 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
@@ -181,35 +238,40 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys())
|
||||
"updated_fields": list(subscribe_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
@@ -223,17 +285,19 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect
|
||||
"effect": updated_subscribe.effect,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新订阅失败: {str(e)}"
|
||||
logger.error(f"更新订阅失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
54
app/agent/tools/impl/write_file.py
Normal file
54
app/agent/tools/impl/write_file.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""文件写入工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
"""Input parameters for write file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to write")
|
||||
content: str = Field(..., description="The content to write into the file")
|
||||
|
||||
|
||||
class WriteFileTool(MoviePilotTool):
|
||||
name: str = "write_file"
|
||||
description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist."
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在写入文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, content: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 路径已存在但不是一个文件"
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(content, encoding="utf-8")
|
||||
|
||||
logger.info(f"成功写入文件 {file_path}")
|
||||
return f"成功写入文件 {file_path}"
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有权限写入 {file_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"写入文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
@@ -1,8 +1,5 @@
|
||||
"""MoviePilot工具管理器
|
||||
用于HTTP API调用工具
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
@@ -10,7 +7,9 @@ from app.log import logger
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
"""
|
||||
工具定义
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
|
||||
self.name = name
|
||||
@@ -19,12 +18,14 @@ class ToolDefinition:
|
||||
|
||||
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
"""
|
||||
MoviePilot工具管理器(用于HTTP API)
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = "api_session"):
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
@@ -35,7 +36,9 @@ class MoviePilotToolsManager:
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
"""加载所有MoviePilot工具"""
|
||||
"""
|
||||
加载所有MoviePilot工具
|
||||
"""
|
||||
try:
|
||||
# 创建工具实例
|
||||
self.tools = MoviePilotToolFactory.create_tools(
|
||||
@@ -44,7 +47,7 @@ class MoviePilotToolsManager:
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None
|
||||
stream_handler=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -54,40 +57,38 @@ class MoviePilotToolsManager:
|
||||
def list_tools(self) -> List[ToolDefinition]:
|
||||
"""
|
||||
列出所有可用的工具
|
||||
|
||||
|
||||
Returns:
|
||||
工具定义列表
|
||||
"""
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
# 获取工具的输入参数模型
|
||||
args_schema = getattr(tool, 'args_schema', None)
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
if args_schema:
|
||||
# 将Pydantic模型转换为JSON Schema
|
||||
input_schema = self._convert_to_json_schema(args_schema)
|
||||
else:
|
||||
# 如果没有args_schema,使用基本信息
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
tools_list.append(ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema
|
||||
))
|
||||
tools_list.append(
|
||||
ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return tools_list
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Any]:
|
||||
"""
|
||||
获取指定工具实例
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
|
||||
Returns:
|
||||
工具实例,如果未找到返回None
|
||||
"""
|
||||
@@ -96,49 +97,178 @@ class MoviePilotToolsManager:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_field_schema(field_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
解析字段schema,兼容 Optional[T] 生成的 anyOf 结构
|
||||
"""
|
||||
if field_info.get("type"):
|
||||
return field_info
|
||||
|
||||
any_of = field_info.get("anyOf")
|
||||
if not any_of:
|
||||
return field_info
|
||||
|
||||
for type_option in any_of:
|
||||
if type_option.get("type") and type_option["type"] != "null":
|
||||
merged = dict(type_option)
|
||||
if "description" not in merged and field_info.get("description"):
|
||||
merged["description"] = field_info["description"]
|
||||
if "default" not in merged and "default" in field_info:
|
||||
merged["default"] = field_info["default"]
|
||||
return merged
|
||||
|
||||
return field_info
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scalar_value(field_type: Optional[str], value: Any, key: str) -> Any:
|
||||
"""
|
||||
根据字段类型规范化单个值
|
||||
"""
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,返回 None")
|
||||
return None
|
||||
if field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,返回 None")
|
||||
return None
|
||||
if field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return True
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _parse_array_string(value: str, key: str, item_type: str = "string") -> list:
|
||||
"""
|
||||
将逗号分隔的字符串解析为列表,并根据 item_type 转换元素类型
|
||||
"""
|
||||
trimmed = value.strip()
|
||||
if not trimmed:
|
||||
return []
|
||||
return [
|
||||
MoviePilotToolsManager._normalize_scalar_value(item_type, item.strip(), key)
|
||||
for item in trimmed.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(
|
||||
tool_instance: Any, arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据工具的参数schema规范化参数类型
|
||||
|
||||
Args:
|
||||
tool_instance: 工具实例
|
||||
arguments: 原始参数
|
||||
|
||||
Returns:
|
||||
规范化后的参数
|
||||
"""
|
||||
# 获取工具的参数schema
|
||||
args_schema = getattr(tool_instance, "args_schema", None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
properties = schema.get("properties", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
if key not in properties:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
field_info = MoviePilotToolsManager._resolve_field_schema(properties[key])
|
||||
field_type = field_info.get("type")
|
||||
|
||||
# 数组类型:将字符串解析为列表
|
||||
if field_type == "array" and isinstance(value, str):
|
||||
item_type = field_info.get("items", {}).get("type", "string")
|
||||
normalized[key] = MoviePilotToolsManager._parse_array_string(
|
||||
value, key, item_type
|
||||
)
|
||||
continue
|
||||
|
||||
# 根据类型进行转换
|
||||
normalized[key] = MoviePilotToolsManager._normalize_scalar_value(
|
||||
field_type, value, key
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
|
||||
Returns:
|
||||
工具执行结果(字符串)
|
||||
"""
|
||||
tool_instance = self.get_tool(tool_name)
|
||||
|
||||
if not tool_instance:
|
||||
error_msg = json.dumps({
|
||||
"error": f"工具 '{tool_name}' 未找到"
|
||||
}, ensure_ascii=False)
|
||||
error_msg = json.dumps(
|
||||
{"error": f"工具 '{tool_name}' 未找到"}, ensure_ascii=False
|
||||
)
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**arguments)
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
formated_result = result
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
try:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示")
|
||||
formated_result = str(result)
|
||||
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
error_msg = json.dumps(
|
||||
{"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return error_msg
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将Pydantic模型转换为JSON Schema
|
||||
|
||||
|
||||
Args:
|
||||
args_schema: Pydantic模型类
|
||||
|
||||
|
||||
Returns:
|
||||
JSON Schema字典
|
||||
"""
|
||||
@@ -151,37 +281,39 @@ class MoviePilotToolsManager:
|
||||
|
||||
if "properties" in schema:
|
||||
for field_name, field_info in schema["properties"].items():
|
||||
resolved_field_info = MoviePilotToolsManager._resolve_field_schema(
|
||||
field_info
|
||||
)
|
||||
# 转换字段类型
|
||||
field_type = field_info.get("type", "string")
|
||||
field_description = field_info.get("description", "")
|
||||
field_type = resolved_field_info.get("type", "string")
|
||||
field_description = resolved_field_info.get("description", "")
|
||||
|
||||
# 处理可选字段
|
||||
if field_name not in schema.get("required", []):
|
||||
# 可选字段
|
||||
default_value = field_info.get("default")
|
||||
default_value = resolved_field_info.get("default")
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
"description": field_description,
|
||||
}
|
||||
if default_value is not None:
|
||||
properties[field_name]["default"] = default_value
|
||||
else:
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
"description": field_description,
|
||||
}
|
||||
required.append(field_name)
|
||||
|
||||
# 处理枚举类型
|
||||
if "enum" in field_info:
|
||||
properties[field_name]["enum"] = field_info["enum"]
|
||||
if "enum" in resolved_field_info:
|
||||
properties[field_name]["enum"] = resolved_field_info["enum"]
|
||||
|
||||
# 处理数组类型
|
||||
if field_type == "array" and "items" in field_info:
|
||||
properties[field_name]["items"] = field_info["items"]
|
||||
if field_type == "array" and "items" in resolved_field_info:
|
||||
properties[field_name]["items"] = resolved_field_info["items"]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
return {"type": "object", "properties": properties, "required": required}
|
||||
|
||||
|
||||
moviepilot_tool_manager = MoviePilotToolsManager()
|
||||
|
||||
@@ -2,11 +2,12 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
api_router.include_router(user.router, prefix="/user", tags=["user"])
|
||||
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
|
||||
api_router.include_router(site.router, prefix="/site", tags=["site"])
|
||||
api_router.include_router(message.router, prefix="/message", tags=["message"])
|
||||
api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"])
|
||||
|
||||
@@ -26,11 +26,17 @@ def statistic(name: Optional[str] = None, _: schemas.TokenPayload = Depends(veri
|
||||
if media_statistics:
|
||||
# 汇总各媒体库统计信息
|
||||
ret_statistic = schemas.Statistic()
|
||||
has_episode_count = False
|
||||
for media_statistic in media_statistics:
|
||||
ret_statistic.movie_count += media_statistic.movie_count
|
||||
ret_statistic.tv_count += media_statistic.tv_count
|
||||
ret_statistic.episode_count += media_statistic.episode_count
|
||||
ret_statistic.user_count += media_statistic.user_count
|
||||
ret_statistic.movie_count += media_statistic.movie_count or 0
|
||||
ret_statistic.tv_count += media_statistic.tv_count or 0
|
||||
ret_statistic.user_count += media_statistic.user_count or 0
|
||||
if media_statistic.episode_count is not None:
|
||||
ret_statistic.episode_count += media_statistic.episode_count or 0
|
||||
has_episode_count = True
|
||||
if not has_episode_count:
|
||||
# 所有媒体服务都未提供剧集统计时,返回 None 供前端展示“未获取”。
|
||||
ret_statistic.episode_count = None
|
||||
return ret_statistic
|
||||
else:
|
||||
return schemas.Statistic()
|
||||
|
||||
@@ -6,13 +6,12 @@ from app import schemas
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import MediaInfo, Context, TorrentInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.schemas.types import ChainEventType, SystemConfigKey
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -77,13 +76,14 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid)
|
||||
mediainfo = MediaChain().select_recognize_source(
|
||||
log_name=torrent_in.title,
|
||||
log_context=torrent_in.title,
|
||||
native_fn=lambda: MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid),
|
||||
plugin_fn=lambda: MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
)
|
||||
if not mediainfo:
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
mediainfo = MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
|
||||
@@ -4,6 +4,7 @@ import jieba
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
|
||||
from app import schemas
|
||||
from app.chain.storage import StorageChain
|
||||
@@ -11,7 +12,7 @@ from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.schemas.types import EventType
|
||||
@@ -98,6 +99,8 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
state = StorageChain().delete_media_file(src_fileitem)
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=f"{src_fileitem.path} 删除失败")
|
||||
# 删除下载记录中关联的文件
|
||||
DownloadFiles.delete_by_fullpath(db, Path(src_fileitem.path).as_posix())
|
||||
# 发送事件
|
||||
eventmanager.send_event(
|
||||
EventType.DownloadFileDeleted,
|
||||
|
||||
@@ -29,7 +29,14 @@ def login_access_token(
|
||||
mfa_code=otp_password)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
# 如果是需要MFA验证,返回特殊标识
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 用户等级
|
||||
level = SitesHelper().auth_level
|
||||
@@ -50,7 +57,7 @@ def login_access_token(
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions=user_or_message.permissions or {},
|
||||
widzard=show_wizard
|
||||
wizard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,43 +1,261 @@
|
||||
"""工具API端点
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from typing import List, Any, Dict, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.agent.tools.manager import moviepilot_tool_manager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
# 导入版本号
|
||||
try:
|
||||
from version import APP_VERSION
|
||||
except ImportError:
|
||||
APP_VERSION = "unknown"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局工具管理器实例(单例模式,按用户ID缓存)
|
||||
_tools_managers: Dict[str, MoviePilotToolsManager] = {}
|
||||
# MCP 协议版本
|
||||
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
MCP_HIDDEN_TOOLS = {
|
||||
"execute_command",
|
||||
"search_web",
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"read_file",
|
||||
}
|
||||
|
||||
|
||||
def get_tools_manager(user_id: str = "mcp_user", session_id: str = "mcp_session") -> MoviePilotToolsManager:
|
||||
def list_exposed_tools():
|
||||
"""
|
||||
获取工具管理器实例(按用户ID缓存)
|
||||
获取 MCP 可见工具列表
|
||||
"""
|
||||
return [
|
||||
tool for tool in moviepilot_tool_manager.list_tools()
|
||||
if tool.name not in MCP_HIDDEN_TOOLS
|
||||
]
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
创建 JSON-RPC 成功响应
|
||||
"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": result
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[
|
||||
str, Any]:
|
||||
"""
|
||||
创建 JSON-RPC 错误响应
|
||||
"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message
|
||||
}
|
||||
}
|
||||
if data is not None:
|
||||
error["error"]["data"] = data
|
||||
return error
|
||||
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
MCP 标准 JSON-RPC 2.0 端点
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
处理所有 MCP 协议消息(初始化、工具列表、工具调用等)
|
||||
"""
|
||||
global _tools_managers
|
||||
# 使用用户ID作为缓存键
|
||||
cache_key = f"{user_id}_{session_id}"
|
||||
if cache_key not in _tools_managers:
|
||||
_tools_managers[cache_key] = MoviePilotToolsManager(
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
logger.error(f"解析请求体失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(None, -32700, "Parse error", str(e))
|
||||
)
|
||||
return _tools_managers[cache_key]
|
||||
|
||||
# 验证 JSON-RPC 格式
|
||||
if not isinstance(body, dict) or body.get("jsonrpc") != "2.0":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(body.get("id"), -32600, "Invalid Request")
|
||||
)
|
||||
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
# 如果有 id,则为请求;没有 id 则为通知
|
||||
is_notification = request_id is None
|
||||
|
||||
try:
|
||||
# 处理初始化请求
|
||||
if method == "initialize":
|
||||
result = await handle_initialize(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理已初始化通知
|
||||
elif method == "notifications/initialized":
|
||||
if is_notification:
|
||||
return Response(status_code=204)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "initialized must be a notification"}
|
||||
)
|
||||
|
||||
# 处理工具列表请求
|
||||
if method == "tools/list":
|
||||
result = await handle_tools_list()
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理工具调用请求
|
||||
elif method == "tools/call":
|
||||
result = await handle_tools_call(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理 ping 请求
|
||||
elif method == "ping":
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, {}))
|
||||
|
||||
# 未知方法
|
||||
else:
|
||||
return JSONResponse(
|
||||
content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}")
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"MCP 请求参数错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MCP 请求失败: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_jsonrpc_error(request_id, -32603, "Internal error", str(e))
|
||||
)
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
处理初始化请求
|
||||
"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}")
|
||||
|
||||
# 版本协商:选择客户端和服务器都支持的版本
|
||||
negotiated_version = MCP_PROTOCOL_VERSION
|
||||
if protocol_version in MCP_PROTOCOL_VERSIONS:
|
||||
# 客户端版本在支持列表中,使用客户端版本
|
||||
negotiated_version = protocol_version
|
||||
logger.info(f"使用客户端协议版本: {negotiated_version}")
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": False # 暂不支持工具列表变更通知
|
||||
},
|
||||
"logging": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "MoviePilot",
|
||||
"version": APP_VERSION,
|
||||
"description": "MoviePilot MCP Server - 电影自动化管理工具",
|
||||
},
|
||||
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。"
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""
|
||||
处理工具列表请求
|
||||
"""
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
for tool in tools:
|
||||
mcp_tool = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.input_schema
|
||||
}
|
||||
mcp_tools.append(mcp_tool)
|
||||
|
||||
return {
|
||||
"tools": mcp_tools
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
处理工具调用请求
|
||||
"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
try:
|
||||
if tool_name in MCP_HIDDEN_TOOLS:
|
||||
raise ValueError(f"工具 '{tool_name}' 未找到")
|
||||
|
||||
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result_text
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"工具调用失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误: {str(e)}"
|
||||
}
|
||||
],
|
||||
"isError": True
|
||||
}
|
||||
|
||||
|
||||
@router.delete("", summary="终止 MCP 会话", response_model=None)
|
||||
async def delete_mcp_session(
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
终止 MCP 会话(无状态模式下仅返回成功)
|
||||
"""
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
async def list_tools(
|
||||
@@ -49,9 +267,8 @@ async def list_tools(
|
||||
返回每个工具的名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具定义
|
||||
tools = manager.list_tools()
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
@@ -72,7 +289,7 @@ async def list_tools(
|
||||
@router.post("/tools/call", summary="调用工具", response_model=schemas.ToolCallResponse)
|
||||
async def call_tool(
|
||||
request: schemas.ToolCallRequest,
|
||||
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> Any:
|
||||
"""
|
||||
调用指定的工具
|
||||
@@ -81,11 +298,10 @@ async def call_tool(
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 使用当前用户ID创建管理器实例
|
||||
manager = get_tools_manager()
|
||||
if request.tool_name in MCP_HIDDEN_TOOLS:
|
||||
raise ValueError(f"工具 '{request.tool_name}' 未找到")
|
||||
|
||||
# 调用工具
|
||||
result_text = await manager.call_tool(request.tool_name, request.arguments)
|
||||
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
success=True,
|
||||
@@ -111,9 +327,8 @@ async def get_tool_info(
|
||||
工具的详细信息,包括名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
@@ -144,9 +359,8 @@ async def get_tool_schema(
|
||||
工具的JSON Schema定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
|
||||
@@ -11,7 +11,10 @@ from app.core.context import Context
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo, MetaInfoPath
|
||||
from app.core.security import verify_token, verify_apitoken
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_superuser
|
||||
from app.schemas import MediaType, MediaRecognizeConvertEventData
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import ChainEventType
|
||||
|
||||
router = APIRouter()
|
||||
@@ -131,6 +134,26 @@ def scrape(fileitem: schemas.FileItem,
|
||||
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
|
||||
|
||||
|
||||
@router.get("/category/config", summary="获取分类策略配置", response_model=schemas.Response)
|
||||
def get_category_config(_: User = Depends(get_current_active_user)):
|
||||
"""
|
||||
获取分类策略配置
|
||||
"""
|
||||
config = MediaChain().category_config()
|
||||
return schemas.Response(success=True, data=config.model_dump())
|
||||
|
||||
|
||||
@router.post("/category/config", summary="保存分类策略配置", response_model=schemas.Response)
|
||||
def save_category_config(config: CategoryConfig, _: User = Depends(get_current_active_superuser)):
|
||||
"""
|
||||
保存分类策略配置
|
||||
"""
|
||||
if MediaChain().save_category_config(config):
|
||||
return schemas.Response(success=True, message="保存成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message="保存失败")
|
||||
|
||||
|
||||
@router.get("/category", summary="查询自动分类配置", response_model=dict)
|
||||
async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
@@ -172,7 +195,7 @@ async def seasons(mediaid: Optional[str] = None,
|
||||
tmdbid = int(mediaid[5:])
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
|
||||
if seasons_info:
|
||||
if season:
|
||||
if season is not None:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
return seasons_info
|
||||
if title:
|
||||
@@ -184,11 +207,11 @@ async def seasons(mediaid: Optional[str] = None,
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
|
||||
if seasons_info:
|
||||
if season:
|
||||
if season is not None:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
return seasons_info
|
||||
else:
|
||||
sea = season or 1
|
||||
sea = season if season is not None else 1
|
||||
return [schemas.MediaSeason(
|
||||
season_number=sea,
|
||||
poster_path=mediainfo.poster_path,
|
||||
|
||||
@@ -21,7 +21,9 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/play/{itemid:path}", summary="在线播放")
|
||||
def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> schemas.Response:
|
||||
def play_item(
|
||||
itemid: str, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
获取媒体服务器播放页面地址
|
||||
"""
|
||||
@@ -36,25 +38,27 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
|
||||
if item:
|
||||
play_url = media_chain.get_play_url(server=name, item_id=itemid)
|
||||
if play_url:
|
||||
return schemas.Response(success=True, data={
|
||||
"url": play_url
|
||||
})
|
||||
return schemas.Response(success=True, data={"url": play_url})
|
||||
return schemas.Response(success=False, message="未找到播放地址")
|
||||
|
||||
|
||||
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
|
||||
async def exists_local(title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response
|
||||
)
|
||||
async def exists_local(
|
||||
title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
判断本地是否存在
|
||||
"""
|
||||
meta = MetaInfo(title)
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
# 返回对象
|
||||
ret_info = {}
|
||||
@@ -63,36 +67,42 @@ async def exists_local(title: Optional[str] = None,
|
||||
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
|
||||
)
|
||||
if exist:
|
||||
ret_info = {
|
||||
"id": exist.item_id
|
||||
}
|
||||
return schemas.Response(success=True if exist else False, data={
|
||||
"item": ret_info
|
||||
})
|
||||
ret_info = {"id": exist.item_id}
|
||||
return schemas.Response(success=True if exist else False, data={"item": ret_info})
|
||||
|
||||
|
||||
@router.post("/exists_remote", summary="查询已存在的剧集信息(媒体服务器)", response_model=Dict[int, list])
|
||||
def exists(media_in: schemas.MediaInfo,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.post(
|
||||
"/exists_remote",
|
||||
summary="查询已存在的剧集信息(媒体服务器)",
|
||||
response_model=Dict[int, list],
|
||||
)
|
||||
def exists(
|
||||
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
根据媒体信息查询媒体库已存在的剧集信息
|
||||
"""
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(
|
||||
mediainfo=mediainfo
|
||||
)
|
||||
if not existsinfo:
|
||||
return []
|
||||
if media_in.season:
|
||||
return {
|
||||
media_in.season: existsinfo.seasons.get(media_in.season) or []
|
||||
}
|
||||
return {}
|
||||
if media_in.season is not None:
|
||||
return {media_in.season: existsinfo.seasons.get(media_in.season) or []}
|
||||
return existsinfo.seasons
|
||||
|
||||
|
||||
@router.post("/notexists", summary="查询媒体库缺失信息(媒体服务器)", response_model=List[schemas.NotExistMediaInfo])
|
||||
def not_exists(media_in: schemas.MediaInfo,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.post(
|
||||
"/notexists",
|
||||
summary="查询媒体库缺失信息(媒体服务器)",
|
||||
response_model=List[schemas.NotExistMediaInfo],
|
||||
)
|
||||
def not_exists(
|
||||
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
根据媒体信息查询缺失电影/剧集
|
||||
"""
|
||||
@@ -101,7 +111,7 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
mtype = MediaType(media_in.type) if media_in.type else None
|
||||
if mtype:
|
||||
meta.type = mtype
|
||||
if media_in.season:
|
||||
if media_in.season is not None:
|
||||
meta.begin_season = media_in.season
|
||||
meta.type = MediaType.TV
|
||||
if media_in.year:
|
||||
@@ -109,7 +119,9 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(
|
||||
meta=meta, mediainfo=mediainfo
|
||||
)
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if mediainfo.type == MediaType.MOVIE:
|
||||
# 电影已存在时返回空列表,不存在时返回空对像列表
|
||||
@@ -120,31 +132,61 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem])
|
||||
def latest(server: str, count: Optional[int] = 20,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem]
|
||||
)
|
||||
def latest(
|
||||
server: str,
|
||||
count: Optional[int] = 20,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器最新入库条目
|
||||
"""
|
||||
return MediaServerChain().latest(server=server, count=count, username=userinfo.username) or []
|
||||
return (
|
||||
MediaServerChain().latest(
|
||||
server=server, count=count, username=userinfo.username
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem])
|
||||
def playing(server: str, count: Optional[int] = 12,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem]
|
||||
)
|
||||
def playing(
|
||||
server: str,
|
||||
count: Optional[int] = 12,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器正在播放条目
|
||||
"""
|
||||
return MediaServerChain().playing(server=server, count=count, username=userinfo.username) or []
|
||||
return (
|
||||
MediaServerChain().playing(
|
||||
server=server, count=count, username=userinfo.username
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary])
|
||||
def library(server: str, hidden: Optional[bool] = False,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary]
|
||||
)
|
||||
def library(
|
||||
server: str,
|
||||
hidden: Optional[bool] = False,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器媒体库列表
|
||||
"""
|
||||
return MediaServerChain().librarys(server=server, username=userinfo.username, hidden=hidden) or []
|
||||
return (
|
||||
MediaServerChain().librarys(
|
||||
server=server, username=userinfo.username, hidden=hidden
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
|
||||
@@ -154,5 +196,9 @@ async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
mediaservers: List[dict] = SystemConfigOper().get(SystemConfigKey.MediaServers)
|
||||
if mediaservers:
|
||||
return [{"name": d.get("name"), "type": d.get("type")} for d in mediaservers if d.get("enabled")]
|
||||
return [
|
||||
{"name": d.get("name"), "type": d.get("type")}
|
||||
for d in mediaservers
|
||||
if d.get("enabled")
|
||||
]
|
||||
return []
|
||||
|
||||
@@ -86,7 +86,10 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int],
|
||||
if not client_configs:
|
||||
return "未找到对应的消息配置"
|
||||
client_config = next((config for config in client_configs if
|
||||
config.type == "wechat" and config.enabled and (not source or config.name == source)), None)
|
||||
config.type == "wechat"
|
||||
and config.enabled
|
||||
and config.config.get("WECHAT_MODE", "app") != "bot"
|
||||
and (not source or config.name == source)), None)
|
||||
if not client_config:
|
||||
return "未找到对应的消息配置"
|
||||
try:
|
||||
|
||||
498
app/api/endpoints/mfa.py
Normal file
498
app/api/endpoints/mfa.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
MFA (Multi-Factor Authentication) API 端点
|
||||
包含 OTP 和 PassKey 相关功能
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db import get_async_db
|
||||
from app.db.models.passkey import PassKey
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_user_async
|
||||
from app.helper.passkey import PassKeyHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
def _build_credential_list(passkeys: list[PassKey]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
构建凭证列表
|
||||
|
||||
:param passkeys: PassKey 列表
|
||||
:return: 凭证字典列表
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
|
||||
def _extract_and_standardize_credential_id(credential: dict) -> str:
|
||||
"""
|
||||
从凭证中提取并标准化 credential_id
|
||||
|
||||
:param credential: 凭证字典
|
||||
:return: 标准化后的 credential_id
|
||||
:raises ValueError: 如果凭证无效
|
||||
"""
|
||||
credential_id_raw = credential.get('id') or credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise ValueError("无效的凭证")
|
||||
return PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
|
||||
def _verify_passkey_and_update(
|
||||
credential: dict,
|
||||
challenge: str,
|
||||
passkey: PassKey
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
验证 PassKey 并更新使用时间和签名计数
|
||||
|
||||
:param credential: 凭证字典
|
||||
:param challenge: 挑战值
|
||||
:param passkey: PassKey 对象
|
||||
:return: (验证是否成功, 新的签名计数)
|
||||
"""
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_challenge=challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if success:
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
return success, new_sign_count
|
||||
|
||||
|
||||
async def _check_user_has_passkey(db: AsyncSession, user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否有 PassKey
|
||||
|
||||
:param db: 数据库会话
|
||||
:param user_id: 用户 ID
|
||||
:return: 是否有 PassKey
|
||||
"""
|
||||
return bool(await PassKey.async_get_by_user_id(db=db, user_id=user_id))
|
||||
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
"""OTP验证请求"""
|
||||
uri: str
|
||||
otpPassword: str
|
||||
|
||||
class OtpDisableRequest(schemas.BaseModel):
|
||||
"""OTP禁用请求"""
|
||||
password: str
|
||||
|
||||
class PassKeyDeleteRequest(schemas.BaseModel):
|
||||
"""PassKey删除请求"""
|
||||
passkey_id: int
|
||||
password: str
|
||||
|
||||
# ==================== 通用 MFA 接口 ====================
|
||||
|
||||
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
|
||||
async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
"""
|
||||
检查指定用户是否启用了任何双重验证方式(OTP 或 PassKey)
|
||||
"""
|
||||
user: User = await User.async_get_by_name(db, username)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
|
||||
# 检查是否启用了OTP
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = await _check_user_has_passkey(db, user.id)
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
|
||||
|
||||
# ==================== OTP 相关接口 ====================
|
||||
|
||||
@router.post('/otp/generate', summary='生成 OTP 验证 URI', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""生成 OTP 密钥及对应的 URI"""
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
|
||||
async def otp_verify(
|
||||
data: OtpVerifyRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
|
||||
if not OtpUtils.is_legal(data.uri, data.otpPassword):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
data: OtpDisableRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,默认不允许关闭 OTP,除非配置允许
|
||||
has_passkey = await _check_user_has_passkey(db, current_user.id)
|
||||
if has_passkey and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
# ==================== PassKey 相关接口 ====================
|
||||
|
||||
class PassKeyRegistrationStart(schemas.BaseModel):
|
||||
"""PassKey注册开始请求"""
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyRegistrationFinish(schemas.BaseModel):
|
||||
"""PassKey注册完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
name: str = "通行密钥"
|
||||
|
||||
|
||||
class PassKeyAuthenticationStart(schemas.BaseModel):
|
||||
"""PassKey认证开始请求"""
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class PassKeyAuthenticationFinish(schemas.BaseModel):
|
||||
"""PassKey认证完成请求"""
|
||||
credential: dict
|
||||
challenge: str
|
||||
|
||||
|
||||
@router.post("/passkey/register/start", summary="开始注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_start(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""开始注册 PassKey - 生成注册选项"""
|
||||
try:
|
||||
# 安全检查:默认需要先启用 OTP,除非配置允许在未启用 OTP 时注册
|
||||
if not current_user.is_otp and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥"
|
||||
)
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = _build_credential_list(existing_passkeys) if existing_passkeys else None
|
||||
|
||||
# 生成注册选项
|
||||
options_json, challenge = PassKeyHelper.generate_registration_options(
|
||||
user_id=current_user.id,
|
||||
username=current_user.name,
|
||||
display_name=current_user.settings.get('nickname') if current_user.settings else None,
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey注册选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成注册选项失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/register/finish", summary="完成注册 PassKey", response_model=schemas.Response)
|
||||
def passkey_register_finish(
|
||||
passkey_req: PassKeyRegistrationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""完成注册 PassKey - 验证并保存凭证"""
|
||||
try:
|
||||
# 验证注册响应
|
||||
credential_id, public_key, sign_count, aaguid = PassKeyHelper.verify_registration_response(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge
|
||||
)
|
||||
|
||||
# 提取transports
|
||||
transports = None
|
||||
if 'response' in passkey_req.credential and 'transports' in passkey_req.credential['response']:
|
||||
transports = ','.join(passkey_req.credential['response']['transports'])
|
||||
|
||||
# 保存到数据库
|
||||
passkey = PassKey(
|
||||
user_id=current_user.id,
|
||||
credential_id=credential_id,
|
||||
public_key=public_key,
|
||||
sign_count=sign_count,
|
||||
name=passkey_req.name or "通行密钥",
|
||||
aaguid=aaguid,
|
||||
transports=transports
|
||||
)
|
||||
passkey.create()
|
||||
|
||||
logger.info(f"用户 {current_user.name} 成功注册PassKey: {passkey_req.name}")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥注册成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"注册PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"注册失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/start", summary="开始 PassKey 认证", response_model=schemas.Response)
|
||||
def passkey_authenticate_start(
|
||||
passkey_req: PassKeyAuthenticationStart = Body(...)
|
||||
) -> Any:
|
||||
"""开始 PassKey 认证 - 生成认证选项"""
|
||||
try:
|
||||
existing_credentials = None
|
||||
|
||||
# 如果指定了用户名,只允许该用户的PassKey
|
||||
if passkey_req.username:
|
||||
user = User.get_by_name(db=None, name=passkey_req.username)
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if user else None
|
||||
|
||||
if not user or not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
existing_credentials = _build_credential_list(existing_passkeys)
|
||||
|
||||
# 生成认证选项
|
||||
options_json, challenge = PassKeyHelper.generate_authentication_options(
|
||||
existing_credentials=existing_credentials
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
'options': options_json,
|
||||
'challenge': challenge
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成PassKey认证选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/authenticate/finish", summary="完成 PassKey 认证", response_model=schemas.Token)
|
||||
def passkey_authenticate_finish(
|
||||
passkey_req: PassKeyAuthenticationFinish
|
||||
) -> Any:
|
||||
"""完成 PassKey 认证 - 验证凭证并返回 token"""
|
||||
try:
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey认证失败,提供的凭证无效: {e}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 查找PassKey并获取用户
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id) if passkey else None
|
||||
if not passkey or not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
# 生成token
|
||||
level = SitesHelper().auth_level
|
||||
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
|
||||
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
level=level
|
||||
),
|
||||
token_type="bearer",
|
||||
super_user=user.is_superuser,
|
||||
user_id=user.id,
|
||||
user_name=user.name,
|
||||
avatar=user.avatar,
|
||||
level=level,
|
||||
permissions=user.permissions or {},
|
||||
wizard=show_wizard
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey认证失败: {e}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
|
||||
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
|
||||
def passkey_list(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""获取当前用户的所有 PassKey"""
|
||||
try:
|
||||
passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
|
||||
key_list = [
|
||||
{
|
||||
'id': pk.id,
|
||||
'name': pk.name,
|
||||
'created_at': pk.created_at.isoformat() if pk.created_at else None,
|
||||
'last_used_at': pk.last_used_at.isoformat() if pk.last_used_at else None,
|
||||
'aaguid': pk.aaguid,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=key_list
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取PassKey列表失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"获取列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
|
||||
async def passkey_delete(
|
||||
data: PassKeyDeleteRequest,
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""删除指定的 PassKey"""
|
||||
try:
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
|
||||
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥已删除"
|
||||
)
|
||||
else:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或无权删除"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"删除PassKey失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"删除失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/passkey/verify", summary="PassKey 二次验证", response_model=schemas.Response)
|
||||
def passkey_verify_mfa(
|
||||
passkey_req: PassKeyAuthenticationFinish,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> Any:
|
||||
"""使用 PassKey 进行二次验证(MFA)"""
|
||||
try:
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey二次验证失败,提供的凭证无效: {e}")
|
||||
return schemas.Response(success=False, message="验证失败")
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey or passkey.user_id != current_user.id:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥不存在或不属于当前用户"
|
||||
)
|
||||
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="通行密钥验证失败"
|
||||
)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="二次验证成功"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey二次验证失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="验证失败"
|
||||
)
|
||||
@@ -260,6 +260,14 @@ async def remotes(token: str) -> Any:
|
||||
return PluginManager().get_plugin_remotes()
|
||||
|
||||
|
||||
@router.get("/sidebar_nav", summary="获取插件侧栏导航项", response_model=List[schemas.PluginSidebarNavItem])
|
||||
def plugin_sidebar_nav(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
聚合已启用 Vue 插件声明的侧栏入口(get_sidebar_nav),供前端主界面侧栏展示。
|
||||
"""
|
||||
return PluginManager().get_plugin_sidebar_nav()
|
||||
|
||||
|
||||
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
|
||||
def plugin_form(plugin_id: str,
|
||||
_: User = Depends(get_current_active_superuser)) -> dict:
|
||||
@@ -360,7 +368,18 @@ async def plugin_static_file(plugin_id: str, filepath: str):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
plugin_base_dir = AsyncPath(settings.ROOT_PATH) / "app" / "plugins" / plugin_id.lower()
|
||||
plugin_file_path = plugin_base_dir / filepath
|
||||
plugin_file_path = plugin_base_dir / filepath.lstrip('/')
|
||||
|
||||
try:
|
||||
resolved_base = await plugin_base_dir.resolve()
|
||||
resolved_file = await plugin_file_path.resolve()
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid path")
|
||||
|
||||
if not resolved_file.is_relative_to(resolved_base):
|
||||
logger.warning(f"Static File API: Path traversal attempt detected: {plugin_id}/{filepath}")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
if not await plugin_file_path.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{plugin_file_path} 不存在")
|
||||
if not await plugin_file_path.is_file():
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from typing import List, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.ai_recommend import AIRecommendChain
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.log import logger
|
||||
from app.schemas import MediaRecognizeConvertEventData
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
|
||||
@@ -36,6 +38,9 @@ async def search_by_id(mediaid: str,
|
||||
"""
|
||||
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
|
||||
"""
|
||||
# 取消正在运行的AI推荐(会清除数据库缓存)
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
if mtype:
|
||||
media_type = MediaType(mtype)
|
||||
else:
|
||||
@@ -159,6 +164,9 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
"""
|
||||
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
|
||||
"""
|
||||
# 取消正在运行的AI推荐并清除数据库缓存
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=[int(site) for site in sites.split(",") if site] if sites else None,
|
||||
@@ -167,3 +175,87 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
if not torrents:
|
||||
return schemas.Response(success=False, message="未搜索到任何资源")
|
||||
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
|
||||
|
||||
|
||||
@router.post("/recommend", summary="AI推荐资源", response_model=schemas.Response)
|
||||
async def recommend_search_results(
|
||||
filtered_indices: Optional[List[int]] = Body(None, embed=True, description="筛选后的索引列表"),
|
||||
check_only: bool = Body(False, embed=True, description="仅检查状态,不启动新任务"),
|
||||
force: bool = Body(False, embed=True, description="强制重新推荐,清除旧结果"),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
AI推荐资源 - 轮询接口
|
||||
前端轮询此接口,发送筛选后的索引(如果有筛选)
|
||||
后端根据请求变化自动取消旧任务并启动新任务
|
||||
|
||||
参数:
|
||||
- filtered_indices: 筛选后的索引列表(可选,为空或不提供时使用所有结果)
|
||||
- check_only: 仅检查状态(首次打开页面时使用,避免触发不必要的重新推理)
|
||||
- force: 强制重新推荐(清除旧结果并重新启动)
|
||||
|
||||
返回数据结构:
|
||||
{
|
||||
"success": bool,
|
||||
"message": string, // 错误信息(仅在错误时存在)
|
||||
"data": {
|
||||
"status": string, // 状态: disabled | idle | running | completed | error
|
||||
"results": array // 推荐结果(仅status=completed时存在)
|
||||
}
|
||||
}
|
||||
"""
|
||||
# 从缓存获取上次搜索结果
|
||||
results = await SearchChain().async_last_search_results() or []
|
||||
if not results:
|
||||
return schemas.Response(success=False, message="没有可用的搜索结果", data={
|
||||
"status": "error"
|
||||
})
|
||||
|
||||
recommend_chain = AIRecommendChain()
|
||||
|
||||
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
|
||||
if force:
|
||||
# 检查功能是否启用
|
||||
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "disabled"
|
||||
})
|
||||
logger.info("收到新推荐请求,清除旧结果并启动新任务")
|
||||
recommend_chain.cancel_ai_recommend()
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 直接返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测)
|
||||
if check_only:
|
||||
# 返回当前运行状态,不做任何任务启动或取消操作
|
||||
current_status = recommend_chain.get_current_status_only()
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if current_status.get("status") == "error":
|
||||
error_msg = current_status.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=current_status)
|
||||
return schemas.Response(success=True, data=current_status)
|
||||
|
||||
# 获取当前状态(会检测请求是否变化)
|
||||
status_data = recommend_chain.get_status(filtered_indices, len(results))
|
||||
|
||||
# 如果功能未启用,直接返回禁用状态
|
||||
if status_data.get("status") == "disabled":
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
# 如果是空闲状态,启动新任务
|
||||
if status_data["status"] == "idle":
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 立即返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if status_data.get("status") == "error":
|
||||
error_msg = status_data.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=status_data)
|
||||
|
||||
# 返回当前状态
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
@@ -92,10 +92,14 @@ async def update_site(
|
||||
# 校正地址格式
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
|
||||
site_in.url = f"{_scheme}://{_netloc}/"
|
||||
site_in.domain = StringUtils.get_url_domain(site_in.url)
|
||||
await site.async_update(db, site_in.model_dump())
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": site_in.domain
|
||||
"site_id": site_in.id,
|
||||
"domain": site_in.domain,
|
||||
"name": site_in.name,
|
||||
"site_url": site_in.url
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@@ -31,6 +31,17 @@ def qrcode(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return schemas.Response(success=False, message=errmsg)
|
||||
|
||||
|
||||
@router.get("/auth_url/{name}", summary="获取 OAuth2 授权 URL", response_model=schemas.Response)
|
||||
def auth_url(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取 OAuth2 授权 URL
|
||||
"""
|
||||
auth_data, errmsg = StorageChain().generate_auth_url(name)
|
||||
if auth_data:
|
||||
return schemas.Response(success=True, data=auth_data)
|
||||
return schemas.Response(success=False, message=errmsg)
|
||||
|
||||
|
||||
@router.get("/check/{name}", summary="二维码登录确认", response_model=schemas.Response)
|
||||
def check(name: str, ck: Optional[str] = None, t: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@@ -83,7 +94,7 @@ def list_files(fileitem: schemas.FileItem,
|
||||
if sort == "name":
|
||||
file_list.sort(key=lambda x: StringUtils.natural_sort_key(x.name or ""))
|
||||
else:
|
||||
file_list.sort(key=lambda x: x.modify_time or datetime.min, reverse=True)
|
||||
file_list.sort(key=lambda x: x.modify_time or -math.inf, reverse=True)
|
||||
return file_list
|
||||
|
||||
|
||||
@@ -167,7 +178,7 @@ def rename(fileitem: schemas.FileItem,
|
||||
# 重命名目录内文件
|
||||
if recursive:
|
||||
transferchain = TransferChain()
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
# 递归修改目录内文件(智能识别命名)
|
||||
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
|
||||
if sub_files:
|
||||
|
||||
@@ -199,7 +199,7 @@ async def subscribe_mediaid(
|
||||
# 使用名称检查订阅
|
||||
if title_check and title:
|
||||
meta = MetaInfo(title)
|
||||
if season:
|
||||
if season is not None:
|
||||
meta.begin_season = season
|
||||
result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)
|
||||
|
||||
|
||||
@@ -23,8 +23,11 @@ from app.core.module import ModuleManager
|
||||
from app.core.security import verify_apitoken, verify_resource_token, verify_token
|
||||
from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.db.user_oper import (
|
||||
get_current_active_superuser,
|
||||
get_current_active_superuser_async,
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
@@ -47,12 +50,13 @@ router = APIRouter()
|
||||
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None,
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
"""
|
||||
@@ -83,72 +87,117 @@ async def fetch_image(
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
图片代理,可选是否使用代理服务器,支持 HTTP 缓存
|
||||
"""
|
||||
# 媒体服务器添加图片代理支持
|
||||
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
|
||||
config and config.config and config.config.get("host")]
|
||||
hosts = [
|
||||
config.config.get("host")
|
||||
for config in MediaServerHelper().get_configs().values()
|
||||
if config and config.config and config.config.get("host")
|
||||
]
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
|
||||
cookies = (
|
||||
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
|
||||
if use_cookies
|
||||
else None
|
||||
)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
return await fetch_image(
|
||||
url=imgurl,
|
||||
proxy=proxy,
|
||||
use_cache=cache,
|
||||
cookies=cookies,
|
||||
if_none_match=if_none_match,
|
||||
allowed_domains=allowed_domains,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cache/image", summary="图片缓存")
|
||||
async def cache_img(
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
return await fetch_image(
|
||||
url=url, use_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match
|
||||
)
|
||||
|
||||
|
||||
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
|
||||
def get_global_setting(token: str):
|
||||
"""
|
||||
查询非敏感系统设置(默认鉴权)
|
||||
仅包含登录前UI初始化必需的字段
|
||||
"""
|
||||
if token != "moviepilot":
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# FIXME: 新增敏感配置项时要在此处添加排除项
|
||||
# 白名单模式,仅包含登录前UI初始化必需的字段
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
|
||||
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
|
||||
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
|
||||
include={
|
||||
"TMDB_IMAGE_DOMAIN",
|
||||
"GLOBAL_IMAGE_CACHE",
|
||||
"ADVANCED_MODE",
|
||||
}
|
||||
)
|
||||
# 追加版本信息(用于版本检查)
|
||||
info.update(
|
||||
{
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION,
|
||||
}
|
||||
)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/global/user", summary="查询用户相关系统设置", response_model=schemas.Response
|
||||
)
|
||||
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询用户相关系统设置(登录后获取)
|
||||
包含业务功能相关的配置和用户权限信息
|
||||
"""
|
||||
# 业务功能相关的配置字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP",
|
||||
}
|
||||
)
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
info["AI_RECOMMEND_ENABLED"] = False
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
info.update(
|
||||
{
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
}
|
||||
)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
|
||||
@@ -156,22 +205,22 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统环境变量,包括当前版本号(仅管理员)
|
||||
"""
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
|
||||
info = settings.model_dump(exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"})
|
||||
info.update(
|
||||
{
|
||||
"VERSION": APP_VERSION,
|
||||
"AUTH_VERSION": SitesHelper().auth_version,
|
||||
"INDEXER_VERSION": SitesHelper().indexer_version,
|
||||
"FRONTEND_VERSION": SystemChain().get_frontend_version(),
|
||||
}
|
||||
)
|
||||
info.update({
|
||||
"VERSION": APP_VERSION,
|
||||
"AUTH_VERSION": SitesHelper().auth_version,
|
||||
"INDEXER_VERSION": SitesHelper().indexer_version,
|
||||
"FRONTEND_VERSION": SystemChain().get_frontend_version()
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
|
||||
async def set_env_setting(env: dict,
|
||||
_: User = Depends(get_current_active_superuser_async)):
|
||||
async def set_env_setting(
|
||||
env: dict, _: User = Depends(get_current_active_superuser_async)
|
||||
):
|
||||
"""
|
||||
更新系统环境变量(仅管理员)
|
||||
"""
|
||||
@@ -184,30 +233,31 @@ async def set_env_setting(env: dict,
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"{', '.join([v[1] for v in failed_updates.values()])}",
|
||||
data={
|
||||
"success_updates": success_updates,
|
||||
"failed_updates": failed_updates
|
||||
}
|
||||
data={"success_updates": success_updates, "failed_updates": failed_updates},
|
||||
)
|
||||
|
||||
if success_updates:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=success_updates.keys(),
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(
|
||||
key=success_updates.keys(), change_type="update"
|
||||
),
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="所有配置项更新成功",
|
||||
data={
|
||||
"success_updates": success_updates
|
||||
}
|
||||
data={"success_updates": success_updates},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/progress/{process_type}", summary="实时进度")
|
||||
async def get_progress(request: Request, process_type: str, _: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_progress(
|
||||
request: Request,
|
||||
process_type: str,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取处理进度,返回格式为SSE
|
||||
"""
|
||||
@@ -228,8 +278,7 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
|
||||
|
||||
|
||||
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
|
||||
async def get_setting(key: str,
|
||||
_: User = Depends(get_current_active_user_async)):
|
||||
async def get_setting(key: str, _: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统设置(仅管理员)
|
||||
"""
|
||||
@@ -237,16 +286,14 @@ async def get_setting(key: str,
|
||||
value = getattr(settings, key)
|
||||
else:
|
||||
value = SystemConfigOper().get(key)
|
||||
return schemas.Response(success=True, data={
|
||||
"value": value
|
||||
})
|
||||
return schemas.Response(success=True, data={"value": value})
|
||||
|
||||
|
||||
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
||||
async def set_setting(
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
更新系统设置(仅管理员)
|
||||
@@ -255,11 +302,10 @@ async def set_setting(
|
||||
success, message = settings.update_setting(key=key, value=value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(key=key, value=value, change_type="update"),
|
||||
)
|
||||
elif success is None:
|
||||
success = True
|
||||
return schemas.Response(success=success, message=message)
|
||||
@@ -270,31 +316,40 @@ async def set_setting(
|
||||
success = await SystemConfigOper().async_set(key, value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(key=key, value=value, change_type="update"),
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
else:
|
||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(provider: str, api_key: str, base_url: Optional[str] = None, _: User = Depends(get_current_active_user_async)):
|
||||
async def get_llm_models(
|
||||
provider: str,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
"""
|
||||
try:
|
||||
models = LLMHelper().get_models(provider, api_key, base_url)
|
||||
models = await asyncio.to_thread(
|
||||
LLMHelper().get_models, provider, api_key, base_url
|
||||
)
|
||||
return schemas.Response(success=True, data=models)
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(request: Request, role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_message(
|
||||
request: Request,
|
||||
role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统消息,返回格式为SSE
|
||||
"""
|
||||
@@ -315,8 +370,12 @@ async def get_message(request: Request, role: Optional[str] = "system",
|
||||
|
||||
|
||||
@router.get("/logging", summary="实时日志")
|
||||
async def get_logging(request: Request, length: Optional[int] = 50, logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_logging(
|
||||
request: Request,
|
||||
length: Optional[int] = 50,
|
||||
logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统日志
|
||||
length = -1 时, 返回text/plain
|
||||
@@ -325,7 +384,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
base_path = AsyncPath(settings.LOG_PATH)
|
||||
log_path = base_path / logfile
|
||||
|
||||
if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}):
|
||||
if not await SecurityUtils.async_is_safe_path(
|
||||
base_path=base_path, user_path=log_path, allowed_suffixes={".log"}
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
if not await log_path.exists() or not await log_path.is_file():
|
||||
@@ -340,7 +401,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
file_size = file_stat.st_size
|
||||
|
||||
# 读取历史日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 优化大文件读取策略
|
||||
if file_size > 100 * 1024:
|
||||
# 只读取最后100KB的内容
|
||||
@@ -349,9 +412,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
await f.seek(position)
|
||||
content = await f.read()
|
||||
# 找到第一个完整的行
|
||||
first_newline = content.find('\n')
|
||||
first_newline = content.find("\n")
|
||||
if first_newline != -1:
|
||||
content = content[first_newline + 1:]
|
||||
content = content[first_newline + 1 :]
|
||||
else:
|
||||
# 小文件直接读取全部内容
|
||||
content = await f.read()
|
||||
@@ -359,7 +422,7 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
# 按行分割并添加到队列,只保留非空行
|
||||
lines = [line.strip() for line in content.splitlines() if line.strip()]
|
||||
# 只取最后N行
|
||||
for line in lines[-max(length, 50):]:
|
||||
for line in lines[-max(length, 50) :]:
|
||||
lines_queue.append(line)
|
||||
|
||||
# 输出历史日志
|
||||
@@ -367,7 +430,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
yield f"data: {line}\n\n"
|
||||
|
||||
# 实时监听新日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 移动文件指针到文件末尾,继续监听新增内容
|
||||
await f.seek(0, 2)
|
||||
# 记录初始文件大小
|
||||
@@ -404,7 +469,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
return Response(content="日志文件不存在!", media_type="text/plain")
|
||||
try:
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as file:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as file:
|
||||
text = await file.read()
|
||||
# 倒序输出
|
||||
text = "\n".join(text.split("\n")[::-1])
|
||||
@@ -416,13 +483,16 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
return StreamingResponse(log_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/versions", summary="查询Github所有Release版本", response_model=schemas.Response
|
||||
)
|
||||
async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询Github所有Release版本
|
||||
"""
|
||||
version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
|
||||
f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
version_res = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY, headers=settings.GITHUB_HEADERS
|
||||
).get_res(f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
if version_res:
|
||||
ver_json = version_res.json()
|
||||
if ver_json:
|
||||
@@ -431,10 +501,12 @@ async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
|
||||
|
||||
@router.get("/ruletest", summary="过滤规则测试", response_model=schemas.Response)
|
||||
def ruletest(title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)):
|
||||
def ruletest(
|
||||
title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
过滤规则测试,规则类型 1-订阅,2-洗版,3-搜索
|
||||
"""
|
||||
@@ -445,7 +517,9 @@ def ruletest(title: str,
|
||||
# 查询规则组详情
|
||||
rulegroup = RuleHelper().get_rule_group(rulegroup_name)
|
||||
if not rulegroup:
|
||||
return schemas.Response(success=False, message=f"过滤规则组 {rulegroup_name} 不存在!")
|
||||
return schemas.Response(
|
||||
success=False, message=f"过滤规则组 {rulegroup_name} 不存在!"
|
||||
)
|
||||
|
||||
# 根据标题查询媒体信息
|
||||
media_info = SearchChain().recognize_media(MetaInfo(title=title, subtitle=subtitle))
|
||||
@@ -453,21 +527,22 @@ def ruletest(title: str,
|
||||
return schemas.Response(success=False, message="未识别到媒体信息!")
|
||||
|
||||
# 过滤
|
||||
result = SearchChain().filter_torrents(rule_groups=[rulegroup.name],
|
||||
torrent_list=[torrent], mediainfo=media_info)
|
||||
result = SearchChain().filter_torrents(
|
||||
rule_groups=[rulegroup.name], torrent_list=[torrent], mediainfo=media_info
|
||||
)
|
||||
if not result:
|
||||
return schemas.Response(success=False, message="不符合过滤规则!")
|
||||
return schemas.Response(success=True, data={
|
||||
"priority": 100 - result[0].pri_order + 1
|
||||
})
|
||||
return schemas.Response(
|
||||
success=True, data={"priority": 100 - result[0].pri_order + 1}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nettest", summary="测试网络连通性")
|
||||
async def nettest(
|
||||
url: str,
|
||||
proxy: bool,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
url: str,
|
||||
proxy: bool,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
测试网络连通性
|
||||
@@ -539,21 +614,26 @@ async def nettest(
|
||||
return schemas.Response(success=False, message=message, data={"time": time})
|
||||
|
||||
|
||||
@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response
|
||||
)
|
||||
def modulelist(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询已加载的模块ID列表
|
||||
"""
|
||||
modules = [{
|
||||
"id": k,
|
||||
"name": v.get_name(),
|
||||
} for k, v in ModuleManager().get_modules().items()]
|
||||
return schemas.Response(success=True, data={
|
||||
"modules": modules
|
||||
})
|
||||
modules = [
|
||||
{
|
||||
"id": k,
|
||||
"name": v.get_name(),
|
||||
}
|
||||
for k, v in ModuleManager().get_modules().items()
|
||||
]
|
||||
return schemas.Response(success=True, data={"modules": modules})
|
||||
|
||||
|
||||
@router.get("/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response
|
||||
)
|
||||
def moduletest(moduleid: str, _: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
模块可用性测试接口
|
||||
@@ -577,25 +657,31 @@ def restart_system(_: User = Depends(get_current_active_superuser)):
|
||||
|
||||
|
||||
@router.get("/runscheduler", summary="运行服务", response_model=schemas.Response)
|
||||
def run_scheduler(jobid: str,
|
||||
_: User = Depends(get_current_active_superuser)):
|
||||
def run_scheduler(jobid: str, _: User = Depends(get_current_active_superuser)):
|
||||
"""
|
||||
执行命令(仅管理员)
|
||||
"""
|
||||
if not jobid:
|
||||
return schemas.Response(success=False, message="命令不能为空!")
|
||||
Scheduler().start(jobid)
|
||||
if jobid in {"recommend_refresh", "cookiecloud"}:
|
||||
Scheduler().start(jobid, manual=True)
|
||||
else:
|
||||
Scheduler().start(jobid)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response)
|
||||
def run_scheduler2(jobid: str,
|
||||
_: Annotated[str, Depends(verify_apitoken)]):
|
||||
@router.get(
|
||||
"/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response
|
||||
)
|
||||
def run_scheduler2(jobid: str, _: Annotated[str, Depends(verify_apitoken)]):
|
||||
"""
|
||||
执行命令(API_TOKEN认证)
|
||||
"""
|
||||
if not jobid:
|
||||
return schemas.Response(success=False, message="命令不能为空!")
|
||||
|
||||
Scheduler().start(jobid)
|
||||
if jobid in {"recommend_refresh", "cookiecloud"}:
|
||||
Scheduler().start(jobid, manual=True)
|
||||
else:
|
||||
Scheduler().start(jobid)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -93,6 +93,8 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
:param _: Token校验
|
||||
"""
|
||||
force = False
|
||||
downloader = None
|
||||
download_hash = None
|
||||
target_path = Path(transer_item.target_path) if transer_item.target_path else None
|
||||
if transer_item.logid:
|
||||
# 查询历史记录
|
||||
@@ -101,6 +103,8 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
return schemas.Response(success=False, message=f"整理记录不存在,ID:{transer_item.logid}")
|
||||
# 强制转移
|
||||
force = True
|
||||
downloader = history.downloader
|
||||
download_hash = history.download_hash
|
||||
if history.status and ("move" in history.mode):
|
||||
# 重新整理成功的转移,则使用成功的 dest 做 in_path
|
||||
src_fileitem = FileItem(**history.dest_fileitem)
|
||||
@@ -121,6 +125,7 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
transer_item.tmdbid = int(history.tmdbid) if history.tmdbid else transer_item.tmdbid
|
||||
transer_item.doubanid = str(history.doubanid) if history.doubanid else transer_item.doubanid
|
||||
transer_item.season = int(str(history.seasons).replace("S", "")) if history.seasons else transer_item.season
|
||||
transer_item.episode_group = history.episode_group or transer_item.episode_group
|
||||
if history.episodes:
|
||||
if "-" in str(history.episodes):
|
||||
# E01-E03多集合并
|
||||
@@ -138,8 +143,14 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
else:
|
||||
return schemas.Response(success=False, message=f"缺少参数")
|
||||
|
||||
# 类型
|
||||
mtype = MediaType(transer_item.type_name) if transer_item.type_name else None
|
||||
# 类型(“自动/auto/none”按未指定处理)
|
||||
mtype = None
|
||||
type_name = str(transer_item.type_name).strip() if transer_item.type_name else ""
|
||||
if type_name and type_name.lower() not in {"自动", "auto", "none"}:
|
||||
try:
|
||||
mtype = MediaType(type_name)
|
||||
except ValueError:
|
||||
return schemas.Response(success=False, message=f"不支持的媒体类型:{type_name}")
|
||||
# 自定义格式
|
||||
epformat = None
|
||||
if transer_item.episode_offset or transer_item.episode_part \
|
||||
@@ -167,7 +178,9 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
library_type_folder=transer_item.library_type_folder,
|
||||
library_category_folder=transer_item.library_category_folder,
|
||||
force=force,
|
||||
background=background
|
||||
background=background,
|
||||
downloader=downloader,
|
||||
download_hash=download_hash
|
||||
)
|
||||
# 失败
|
||||
if not state:
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.db.models.user import User
|
||||
from app.db.user_oper import get_current_active_superuser_async, \
|
||||
get_current_active_user_async, get_current_active_user
|
||||
from app.db.userconfig_oper import UserConfigOper
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -111,45 +110,6 @@ async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db),
|
||||
return schemas.Response(success=True, message=file.filename)
|
||||
|
||||
|
||||
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||
async def otp_judge(
|
||||
data: dict,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
|
||||
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
|
||||
user: User = await User.async_get_by_name(db, userid)
|
||||
if not user:
|
||||
return schemas.Response(success=False)
|
||||
return schemas.Response(success=user.is_otp)
|
||||
|
||||
|
||||
@router.get("/config/{key}", summary="查询用户配置", response_model=schemas.Response)
|
||||
def get_config(key: str,
|
||||
current_user: User = Depends(get_current_active_user)):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Callable, Any, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from anyio import Path as AsyncPath
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -128,9 +128,12 @@ async def get_cookie(
|
||||
@cookie_router.post("/get/{uuid}")
|
||||
async def post_cookie(
|
||||
uuid: Annotated[str, Path(min_length=5, pattern="^[a-zA-Z0-9]+$")],
|
||||
request: schemas.CookiePassword):
|
||||
request: Optional[schemas.CookiePassword] = Body(None)):
|
||||
"""
|
||||
POST 下载加密数据
|
||||
"""
|
||||
data = await load_encrypt_data(uuid)
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
if request is not None:
|
||||
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
|
||||
else:
|
||||
return data
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
318
app/chain/ai_recommend.py
Normal file
318
app/chain/ai_recommend.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import re
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class AIRecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
AI推荐处理链,单例运行
|
||||
用于基于搜索结果的AI智能推荐
|
||||
"""
|
||||
|
||||
# 缓存文件名
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
# AI推荐状态
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_request_hash: Optional[str] = None # 当前请求的哈希值
|
||||
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存(索引列表)
|
||||
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
|
||||
|
||||
@staticmethod
|
||||
def _calculate_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算请求的哈希值,用于判断请求是否变化
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
def _build_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典
|
||||
:return: 状态字典
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if self._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
# 尝试从数据库加载缓存
|
||||
if self._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
self._ai_recommend_result = cached_indices
|
||||
|
||||
# 只要有结果,始终返回completed状态和数据
|
||||
if self._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": self._ai_recommend_result}
|
||||
|
||||
if self._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": self._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前状态(不校验hash,用于check_only模式)
|
||||
"""
|
||||
return self._build_status()
|
||||
|
||||
def get_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态并检查请求是否变化(用于首次请求或force模式)
|
||||
如果请求变化(筛选条件变化),返回idle状态
|
||||
"""
|
||||
# 计算当前请求的hash
|
||||
request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 检查请求是否变化
|
||||
is_same_request = request_hash == self._current_request_hash
|
||||
|
||||
# 如果请求变化了(筛选条件改变),返回idle状态
|
||||
if not is_same_request:
|
||||
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
|
||||
|
||||
# 请求未变化,返回当前实际状态
|
||||
return self._build_status()
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
|
||||
"""
|
||||
AI推荐
|
||||
:param items: 候选资源列表(JSON字符串格式)
|
||||
:param preference: 用户偏好(可选)
|
||||
:return: AI返回的推荐结果
|
||||
"""
|
||||
# 设置运行状态
|
||||
self._ai_recommend_running = True
|
||||
try:
|
||||
# 导入LLMHelper
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
# 获取LLM实例
|
||||
llm = LLMHelper.get_llm()
|
||||
|
||||
# 构建提示词
|
||||
user_preference = (
|
||||
preference
|
||||
or settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
|
||||
# 添加指令
|
||||
instruction = """
|
||||
Task: Select the best matching items from the list based on user preferences.
|
||||
|
||||
Each item contains:
|
||||
- index: Item number
|
||||
- title: Full torrent title
|
||||
- size: File size
|
||||
- seeders: Number of seeders
|
||||
|
||||
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
|
||||
"""
|
||||
message = (
|
||||
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
|
||||
+ "\n".join(items)
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
response = await llm.ainvoke(message)
|
||||
return response.content
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"AI推荐配置错误: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
finally:
|
||||
# 清除运行状态
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
def is_ai_recommend_running(self) -> bool:
|
||||
"""
|
||||
检查AI推荐是否正在运行
|
||||
"""
|
||||
return self._ai_recommend_running
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消正在运行的AI推荐任务
|
||||
"""
|
||||
if self._ai_recommend_task and not self._ai_recommend_task.done():
|
||||
self._ai_recommend_task.cancel()
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
self._current_request_hash = None
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务
|
||||
:param filtered_indices: 筛选后的索引列表
|
||||
:param search_results_count: 搜索结果总数
|
||||
:param results: 搜索结果列表
|
||||
"""
|
||||
# 防护检查:确保AI推荐功能已启用
|
||||
if not self.is_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
# 计算新请求的哈希值
|
||||
new_request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 如果请求变化了,取消旧任务
|
||||
if new_request_hash != self._current_request_hash:
|
||||
self.cancel_ai_recommend()
|
||||
|
||||
# 更新请求哈希值
|
||||
self._current_request_hash = new_request_hash
|
||||
|
||||
# 重置状态
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
|
||||
# 启动新任务
|
||||
async def run_recommend():
|
||||
# 获取当前任务对象,用于在finally中比对
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
self._ai_recommend_running = True
|
||||
|
||||
# 准备数据
|
||||
items = []
|
||||
valid_indices = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
|
||||
if filtered_indices is not None and len(filtered_indices) > 0:
|
||||
results_to_process = [
|
||||
results[i]
|
||||
for i in filtered_indices
|
||||
if 0 <= i < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for i, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(i)
|
||||
|
||||
item_info = {
|
||||
"index": i,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
if not items:
|
||||
self._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
# 调用AI推荐
|
||||
ai_response = await self.async_ai_recommend(items)
|
||||
|
||||
# 解析AI返回的索引
|
||||
try:
|
||||
# 使用正则提取JSON数组(非贪婪模式,避免匹配多个数组)
|
||||
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(ai_response)
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
# 映射回原始索引
|
||||
if filtered_indices:
|
||||
original_indices = [
|
||||
filtered_indices[valid_indices[i]]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= filtered_indices[valid_indices[i]] < len(results)
|
||||
]
|
||||
else:
|
||||
original_indices = [
|
||||
valid_indices[i]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= valid_indices[i] < len(results)
|
||||
]
|
||||
|
||||
# 只返回索引列表,不返回完整数据
|
||||
self._ai_recommend_result = original_indices
|
||||
|
||||
# 保存到数据库
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
|
||||
)
|
||||
self._ai_recommend_error = str(e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"AI推荐任务失败: {e}")
|
||||
self._ai_recommend_error = str(e)
|
||||
finally:
|
||||
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
|
||||
# 如果任务被取消并启动了新任务,self._ai_recommend_task 已经指向新任务,不应重置
|
||||
if self._ai_recommend_task == current_task:
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
# 创建并启动任务
|
||||
self._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
@@ -152,7 +152,8 @@ class DownloadChain(ChainBase):
|
||||
save_path: Optional[str] = None,
|
||||
userid: Union[str, int] = None,
|
||||
username: Optional[str] = None,
|
||||
label: Optional[str] = None) -> Optional[str]:
|
||||
label: Optional[str] = None,
|
||||
return_detail: bool = False) -> Union[Optional[str], Tuple[Optional[str], Optional[str]]]:
|
||||
"""
|
||||
下载及发送通知
|
||||
:param context: 资源上下文
|
||||
@@ -166,6 +167,8 @@ class DownloadChain(ChainBase):
|
||||
:param userid: 用户ID
|
||||
:param username: 调用下载的用户名/插件名
|
||||
:param label: 自定义标签
|
||||
:param return_detail: 是否返回详细结果;False 时返回下载任务 hash 或 None,True 时返回 (hash, error_msg)
|
||||
:return: return_detail=False 时返回下载任务 hash 或 None;return_detail=True 时返回 (hash, error_msg)
|
||||
"""
|
||||
_torrent = context.torrent_info
|
||||
_media = context.media_info
|
||||
@@ -195,7 +198,7 @@ class DownloadChain(ChainBase):
|
||||
logger.debug(
|
||||
f"Resource download canceled by event: {event_data.source},"
|
||||
f"Reason: {event_data.reason}")
|
||||
return None
|
||||
return (None, "下载被事件取消") if return_detail else None
|
||||
# 如果事件修改了下载路径,使用新路径
|
||||
if event_data.options and event_data.options.get("save_path"):
|
||||
save_path = event_data.options.get("save_path")
|
||||
@@ -227,7 +230,7 @@ class DownloadChain(ChainBase):
|
||||
torrent_content = cache_backend.get(torrent_file.as_posix(), region="torrents")
|
||||
|
||||
if not torrent_content:
|
||||
return None
|
||||
return (None, "下载种子内容为空") if return_detail else None
|
||||
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
@@ -259,7 +262,7 @@ class DownloadChain(ChainBase):
|
||||
logger.error(f"未找到下载目录:{_media.type.value} {_media.title_year}")
|
||||
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
|
||||
title="下载失败", role="system")
|
||||
return None
|
||||
return (None, "未找到下载目录") if return_detail else None
|
||||
fileURI = FileURI(storage=storage, path=download_dir.as_posix())
|
||||
download_dir = Path(fileURI.uri)
|
||||
|
||||
@@ -327,9 +330,10 @@ class DownloadChain(ChainBase):
|
||||
if not file_meta.begin_episode \
|
||||
or file_meta.begin_episode not in episodes:
|
||||
continue
|
||||
# 只处理视频格式
|
||||
# 只处理音视频、字幕格式
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if not Path(file).suffix \
|
||||
or Path(file).suffix.lower() not in settings.RMT_MEDIAEXT:
|
||||
or Path(file).suffix.lower() not in media_exts:
|
||||
continue
|
||||
files_to_add.append({
|
||||
"download_hash": _hash,
|
||||
@@ -387,6 +391,8 @@ class DownloadChain(ChainBase):
|
||||
f"错误信息:{error_msg}",
|
||||
image=_media.get_message_image(),
|
||||
userid=userid))
|
||||
if return_detail:
|
||||
return _hash, error_msg
|
||||
return _hash
|
||||
|
||||
def batch_download(self,
|
||||
|
||||
1705
app/chain/media.py
1705
app/chain/media.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,7 @@ from app.chain import ChainBase
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import cached, fresh
|
||||
from app.core.config import settings, global_vars
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
@@ -27,9 +27,11 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
# 推荐缓存区域
|
||||
recommend_cache_region = "recommend"
|
||||
|
||||
def refresh_recommend(self):
|
||||
def refresh_recommend(self, manual: bool = False):
|
||||
"""
|
||||
刷新推荐
|
||||
|
||||
:param manual: 手动触发
|
||||
"""
|
||||
logger.debug("Starting to refresh Recommend data.")
|
||||
|
||||
@@ -62,7 +64,9 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
if method in methods_finished:
|
||||
continue
|
||||
logger.debug(f"Fetch {method.__name__} data for page {page}.")
|
||||
data = method(page=page)
|
||||
# 手动触发的刷新,总是需要获取最新数据
|
||||
with fresh(manual):
|
||||
data = method(page=page)
|
||||
if not data:
|
||||
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
|
||||
methods_finished.add(method)
|
||||
@@ -90,7 +94,6 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
poster_path = data.get("poster_path")
|
||||
if poster_path:
|
||||
poster_url = poster_path.replace("original", "w500")
|
||||
logger.debug(f"Caching poster image: {poster_url}")
|
||||
self.__fetch_and_save_image(poster_url)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -29,6 +29,7 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
|
||||
__result_temp_file = "__search_result__"
|
||||
__ai_result_temp_file = "__ai_search_result__"
|
||||
|
||||
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
@@ -48,7 +49,7 @@ class SearchChain(ChainBase):
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
return []
|
||||
no_exists = None
|
||||
if season:
|
||||
if season is not None:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
@@ -98,6 +99,18 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_ai_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次AI推荐结果
|
||||
"""
|
||||
return await self.async_load_cache(self.__ai_result_temp_file)
|
||||
|
||||
async def async_save_ai_results(self, results: List[Context]):
|
||||
"""
|
||||
异步保存AI推荐结果
|
||||
"""
|
||||
await self.async_save_cache(results, self.__ai_result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
|
||||
@@ -116,7 +129,7 @@ class SearchChain(ChainBase):
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
return []
|
||||
no_exists = None
|
||||
if season:
|
||||
if season is not None:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
@@ -168,7 +181,7 @@ class SearchChain(ChainBase):
|
||||
# 过滤剧集
|
||||
season_episodes = {sea: info.episodes
|
||||
for sea, info in no_exists[mediakey].items()}
|
||||
elif mediainfo.season:
|
||||
elif mediainfo.season is not None:
|
||||
# 豆瓣只搜索当前季
|
||||
season_episodes = {mediainfo.season: []}
|
||||
else:
|
||||
@@ -267,7 +280,7 @@ class SearchChain(ChainBase):
|
||||
logger.info(f"种子名称应用识别词后发生改变:{torrent.title} => {torrent_meta.org_string}")
|
||||
# 季集数过滤
|
||||
if season_episodes \
|
||||
and not torrenthelper.match_season_episodes(torrent=torrent,
|
||||
and not TorrentHelper.match_season_episodes(torrent=torrent,
|
||||
meta=torrent_meta,
|
||||
season_episodes=season_episodes):
|
||||
continue
|
||||
|
||||
@@ -44,6 +44,7 @@ class SiteChain(ChainBase):
|
||||
"star-space.net": self.__indexphp_test,
|
||||
"yemapt.org": self.__yema_test,
|
||||
"hddolby.com": self.__hddolby_test,
|
||||
"rousi.pro": self.__rousi_test,
|
||||
}
|
||||
|
||||
def refresh_userdata(self, site: dict = None) -> Optional[SiteUserData]:
|
||||
@@ -249,6 +250,32 @@ class SiteChain(ChainBase):
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __rousi_test(site: Site) -> Tuple[bool, str]:
|
||||
"""
|
||||
判断站点是否已经登陆:rousi
|
||||
"""
|
||||
url = f"https://{StringUtils.get_url_domain(site.url)}/api/v1/profile"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {site.apikey}",
|
||||
}
|
||||
res = RequestUtils(
|
||||
headers=headers,
|
||||
proxies=settings.PROXY if site.proxy else None,
|
||||
timeout=site.timeout or 15
|
||||
).get_res(url=url)
|
||||
if res is None:
|
||||
return False, "无法打开网站!"
|
||||
if res.status_code == 200:
|
||||
user_info = res.json()
|
||||
if user_info and user_info.get("code") == 0:
|
||||
return True, "连接成功"
|
||||
return False, "APIKEY已过期"
|
||||
else:
|
||||
return False, f"错误:{res.status_code} {res.reason}"
|
||||
|
||||
@staticmethod
|
||||
def __parse_favicon(url: str, cookie: str, ua: str) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
@@ -462,20 +489,18 @@ class SiteChain(ChainBase):
|
||||
logger.warn(f"站点 {domain} 索引器不存在!")
|
||||
return
|
||||
# 查询站点图标
|
||||
site_icon = siteoper.get_icon_by_domain(domain)
|
||||
if not site_icon or not site_icon.base64:
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
|
||||
@eventmanager.register(EventType.SiteUpdated)
|
||||
def clear_site_data(self, event: Event):
|
||||
|
||||
@@ -31,6 +31,12 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("generate_qrcode", storage=storage)
|
||||
|
||||
def generate_auth_url(self, storage: str) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
"""
|
||||
return self.run_module("generate_auth_url", storage=storage)
|
||||
|
||||
def check_login(self, storage: str, **kwargs) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
登录确认
|
||||
@@ -133,30 +139,41 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("support_transtype", storage=storage)
|
||||
|
||||
def is_bluray_folder(self, fileitem: Optional[schemas.FileItem]) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "BDMV"):
|
||||
return True
|
||||
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "CERTIFICATE"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def contains_bluray_subdirectories(fileitems: Optional[List[schemas.FileItem]]) -> bool:
|
||||
"""
|
||||
判断是否包含蓝光必备的文件夹
|
||||
"""
|
||||
required_files = {"BDMV", "CERTIFICATE"}
|
||||
return any(
|
||||
item.type == "dir" and item.name in required_files
|
||||
for item in fileitems or []
|
||||
)
|
||||
|
||||
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
|
||||
"""
|
||||
删除媒体文件,以及不含媒体文件的目录
|
||||
"""
|
||||
|
||||
def __is_bluray_dir(_fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
_dir_files = self.list_files(fileitem=_fileitem, recursion=False)
|
||||
if _dir_files:
|
||||
for _f in _dir_files:
|
||||
if _f.type == "dir" and _f.name in ["BDMV", "CERTIFICATE"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
|
||||
if len(fileitem_path.parts) <= 2:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 根目录或一级目录不允许删除")
|
||||
return False
|
||||
if fileitem.type == "dir":
|
||||
# 本身是目录
|
||||
if __is_bluray_dir(fileitem):
|
||||
if self.is_bluray_folder(fileitem):
|
||||
logger.warn(f"正在删除蓝光原盘目录:【{fileitem.storage}】{fileitem.path}")
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
|
||||
@@ -42,7 +42,7 @@ class SubscribeChain(ChainBase):
|
||||
_LOCK_TIMOUT = 3600 * 2
|
||||
|
||||
@staticmethod
|
||||
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
"""
|
||||
广播事件解析媒体信息
|
||||
"""
|
||||
@@ -144,7 +144,7 @@ class SubscribeChain(ChainBase):
|
||||
metainfo.year = year
|
||||
if mtype:
|
||||
metainfo.type = mtype
|
||||
if season:
|
||||
if season is not None:
|
||||
metainfo.type = MediaType.TV
|
||||
metainfo.begin_season = season
|
||||
# 识别媒体信息
|
||||
@@ -158,7 +158,7 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
else:
|
||||
# 使用TMDBID识别
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
|
||||
@@ -169,12 +169,12 @@ class SubscribeChain(ChainBase):
|
||||
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
|
||||
elif mediaid:
|
||||
# 未知前缀,广播事件解析媒体信息
|
||||
mediainfo = self.__get_event_meida(mediaid, metainfo)
|
||||
mediainfo = self.__get_event_media(mediaid, metainfo)
|
||||
if mediainfo:
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
|
||||
# 使用名称识别兜底
|
||||
@@ -188,7 +188,7 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 总集数
|
||||
if mediainfo.type == MediaType.TV:
|
||||
if not season:
|
||||
if season is None:
|
||||
season = 1
|
||||
# 总集数
|
||||
if not kwargs.get('total_episode'):
|
||||
@@ -292,7 +292,7 @@ class SubscribeChain(ChainBase):
|
||||
"description": mediainfo.overview
|
||||
})
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
return sid, err_msg
|
||||
|
||||
async def async_add(self, title: str, year: str,
|
||||
mtype: MediaType = None,
|
||||
@@ -321,7 +321,7 @@ class SubscribeChain(ChainBase):
|
||||
metainfo.year = year
|
||||
if mtype:
|
||||
metainfo.type = mtype
|
||||
if season:
|
||||
if season is not None:
|
||||
metainfo.type = MediaType.TV
|
||||
metainfo.begin_season = season
|
||||
# 识别媒体信息
|
||||
@@ -351,7 +351,7 @@ class SubscribeChain(ChainBase):
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
|
||||
# 使用名称识别兜底
|
||||
@@ -365,7 +365,7 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 总集数
|
||||
if mediainfo.type == MediaType.TV:
|
||||
if not season:
|
||||
if season is None:
|
||||
season = 1
|
||||
# 总集数
|
||||
if not kwargs.get('total_episode'):
|
||||
@@ -469,7 +469,7 @@ class SubscribeChain(ChainBase):
|
||||
"description": mediainfo.overview
|
||||
})
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
return sid, err_msg
|
||||
|
||||
@staticmethod
|
||||
def exists(mediainfo: MediaInfo, meta: MetaBase = None):
|
||||
@@ -530,7 +530,7 @@ class SubscribeChain(ChainBase):
|
||||
# 生成元数据
|
||||
meta = MetaInfo(subscribe.name)
|
||||
meta.year = subscribe.year
|
||||
meta.begin_season = subscribe.season or None
|
||||
meta.begin_season = subscribe.season if subscribe.season is not None else None
|
||||
try:
|
||||
meta.type = MediaType(subscribe.type)
|
||||
except ValueError:
|
||||
@@ -593,11 +593,17 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 洗版
|
||||
if subscribe.best_version:
|
||||
# 洗版时,非整季不要
|
||||
if torrent_mediainfo.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.info(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
continue
|
||||
# 洗版时,不符合订阅集数的不要
|
||||
if (
|
||||
torrent_mediainfo.type == MediaType.TV
|
||||
and not self._is_episode_range_covered(
|
||||
meta=torrent_meta, subscribe=subscribe
|
||||
)
|
||||
):
|
||||
logger.info(
|
||||
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
|
||||
)
|
||||
continue
|
||||
# 洗版时,优先级小于等于已下载优先级的不要
|
||||
if subscribe.current_priority \
|
||||
and torrent_info.pri_order <= subscribe.current_priority:
|
||||
@@ -985,11 +991,18 @@ class SubscribeChain(ChainBase):
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 洗版时,非整季不要
|
||||
if meta.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.debug(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
continue
|
||||
# 洗版时,不符合订阅集数的不要
|
||||
if (
|
||||
meta.type == MediaType.TV
|
||||
and not self._is_episode_range_covered(
|
||||
meta=torrent_meta,
|
||||
subscribe=subscribe,
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
|
||||
)
|
||||
continue
|
||||
|
||||
# 匹配订阅附加参数
|
||||
if not torrenthelper.filter_torrent(torrent_info=torrent_info,
|
||||
@@ -1119,6 +1132,19 @@ class SubscribeChain(ChainBase):
|
||||
})
|
||||
logger.info(f'{subscribe.name} 订阅元数据更新完成')
|
||||
|
||||
def get_subscribe_by_source(self, source: str) -> Optional[Subscribe]:
|
||||
"""
|
||||
从来源获取订阅
|
||||
"""
|
||||
source_keyword = self.parse_subscribe_source_keyword(source)
|
||||
if not source_keyword:
|
||||
return None
|
||||
# 只保留需要的字段动态获取订阅
|
||||
valid_fields = {k: v for k, v in source_keyword.items()
|
||||
if k in ["type", "season", "tmdbid", "doubanid", "bangumiid"]}
|
||||
# 暂时不考虑订阅历史, 若有必要再添加
|
||||
return SubscribeOper().get_by(**valid_fields)
|
||||
|
||||
@staticmethod
|
||||
def follow():
|
||||
"""
|
||||
@@ -1635,7 +1661,7 @@ class SubscribeChain(ChainBase):
|
||||
info = schemas.SubscribeEpisodeInfo()
|
||||
info.title = episode.name
|
||||
info.description = episode.overview
|
||||
info.backdrop = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/w500${episode.still_path}"
|
||||
info.backdrop = settings.TMDB_IMAGE_URL(episode.still_path, "w500")
|
||||
episodes[episode.episode_number] = info
|
||||
elif subscribe.type == MediaType.TV.value:
|
||||
# 根据开始结束集计算集信息
|
||||
@@ -1655,7 +1681,7 @@ class SubscribeChain(ChainBase):
|
||||
if download_his:
|
||||
for his in download_his:
|
||||
# 查询下载文件
|
||||
files = downloadhis.get_files_by_hash(his.download_hash)
|
||||
files = downloadhis.get_files_by_hash(his.download_hash, state=1)
|
||||
if files:
|
||||
for file in files:
|
||||
# 识别文件名
|
||||
@@ -1808,6 +1834,23 @@ class SubscribeChain(ChainBase):
|
||||
# 返回结果,表示媒体未完全下载或存在
|
||||
return False, no_exists
|
||||
|
||||
@staticmethod
|
||||
def _is_episode_range_covered(meta: MetaBase, subscribe: Subscribe) -> bool:
|
||||
"""
|
||||
判断种子是否包含指定订阅的剧集范围
|
||||
"""
|
||||
episodes = meta.episode_list
|
||||
if not episodes:
|
||||
# 没有剧集信息,表示该种子为合集
|
||||
return True
|
||||
|
||||
min_ep = min(episodes)
|
||||
max_ep = max(episodes)
|
||||
start_ep = subscribe.start_episode or 1
|
||||
end_ep = subscribe.total_episode
|
||||
|
||||
return min_ep <= start_ep and max_ep >= end_ep
|
||||
|
||||
@staticmethod
|
||||
def get_states_for_search(state: str) -> str:
|
||||
"""
|
||||
@@ -1828,8 +1871,9 @@ class SubscribeChain(ChainBase):
|
||||
def get_subscribe_source_keyword(subscribe: Subscribe) -> str:
|
||||
"""
|
||||
构造用于订阅来源的关键字字符串
|
||||
|
||||
:param subscribe: Subscribe 对象
|
||||
:return: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
:return str: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
"""
|
||||
source_keyword = {
|
||||
'id': subscribe.id,
|
||||
@@ -1844,3 +1888,24 @@ class SubscribeChain(ChainBase):
|
||||
'bangumiid': subscribe.bangumiid
|
||||
}
|
||||
return f"Subscribe|{json.dumps(source_keyword, ensure_ascii=False)}"
|
||||
|
||||
@staticmethod
|
||||
def parse_subscribe_source_keyword(source_keyword_str: str) -> Optional[dict]:
|
||||
"""
|
||||
解析订阅来源关键字字符串
|
||||
|
||||
:param source_keyword_str: 订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
:return Dict: 如果解析失败则返回None
|
||||
"""
|
||||
if not source_keyword_str or not source_keyword_str.startswith("Subscribe|"):
|
||||
return None
|
||||
|
||||
try:
|
||||
# 分割字符串获取JSON部分
|
||||
json_part = source_keyword_str.split("|", 1)[1]
|
||||
# 解析JSON字符串
|
||||
source_keyword = json.loads(json_part)
|
||||
return source_keyword
|
||||
except (IndexError, json.JSONDecodeError, TypeError) as e:
|
||||
logger.error(f"解析订阅来源关键字失败: {e}")
|
||||
return None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user