mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-11 09:59:51 +08:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d2c2a01eb | ||
|
|
226f9c9318 | ||
|
|
b77b5a21c5 | ||
|
|
82b637532e | ||
|
|
c2c9950bb1 | ||
|
|
ffbe348d66 | ||
|
|
6d7b0733af | ||
|
|
49a51cca25 | ||
|
|
06197144c0 | ||
|
|
62541ffe43 | ||
|
|
c762628217 | ||
|
|
caf615f3bd | ||
|
|
27436757a0 | ||
|
|
924d54dfd3 | ||
|
|
39f9550f86 | ||
|
|
367ecafbbb | ||
|
|
10467244e0 | ||
|
|
cb6dcc6a2e | ||
|
|
43c421b0bb | ||
|
|
45d0891502 | ||
|
|
76c5f54465 | ||
|
|
bcf8116172 | ||
|
|
1f889596b7 | ||
|
|
04443fcfba |
@@ -768,12 +768,25 @@ class LLMHelper:
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_google_models(api_key or "")
|
||||
]
|
||||
model_list_base_url = base_url
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
model_list_base_url = (
|
||||
LLMProviderManager().resolve_model_list_base_url(
|
||||
provider_id=provider,
|
||||
base_url=base_url,
|
||||
)
|
||||
or base_url
|
||||
)
|
||||
except Exception:
|
||||
model_list_base_url = base_url
|
||||
return [
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_openai_compatible_models(
|
||||
provider,
|
||||
api_key or "",
|
||||
base_url,
|
||||
model_list_base_url,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -44,6 +44,16 @@ class ProviderAuthMethod:
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderUrlPreset:
|
||||
"""前端展示用的 Base URL 预设。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
model_list_base_url: Optional[str] = None
|
||||
models_dev_provider_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
"""描述一个可接入的 LLM provider。"""
|
||||
@@ -53,6 +63,7 @@ class ProviderSpec:
|
||||
runtime: str
|
||||
models_dev_provider_id: Optional[str] = None
|
||||
default_base_url: Optional[str] = None
|
||||
base_url_presets: Tuple[ProviderUrlPreset, ...] = ()
|
||||
base_url_editable: bool = False
|
||||
requires_base_url: bool = False
|
||||
supports_api_key: bool = True
|
||||
@@ -138,7 +149,158 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
label="设备码授权",
|
||||
description="适合无回调环境,复制设备码到浏览器完成登录。",
|
||||
)
|
||||
return (
|
||||
url_preset = ProviderUrlPreset
|
||||
def openai_provider(
|
||||
provider_id: str,
|
||||
name: str,
|
||||
default_base_url: str,
|
||||
sort_order: int,
|
||||
*,
|
||||
models_dev_provider_id: Optional[str] = None,
|
||||
base_url_presets: Tuple[ProviderUrlPreset, ...] = (),
|
||||
api_key_hint: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
model_list_strategy: str = "openai_compatible",
|
||||
api_key_label: str = "API Key",
|
||||
) -> ProviderSpec:
|
||||
return ProviderSpec(
|
||||
id=provider_id,
|
||||
name=name,
|
||||
runtime="openai_compatible",
|
||||
models_dev_provider_id=models_dev_provider_id or provider_id,
|
||||
default_base_url=default_base_url,
|
||||
base_url_presets=base_url_presets,
|
||||
api_key_label=api_key_label,
|
||||
api_key_hint=api_key_hint or f"填写 {name} API Key。",
|
||||
model_list_strategy=model_list_strategy,
|
||||
description=description or f"{name} OpenAI-compatible 端点。",
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
def catalog_openai_provider(
|
||||
provider_id: str,
|
||||
name: str,
|
||||
default_base_url: str,
|
||||
sort_order: int,
|
||||
*,
|
||||
models_dev_provider_id: Optional[str] = None,
|
||||
base_url_presets: Tuple[ProviderUrlPreset, ...] = (),
|
||||
api_key_hint: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
api_key_label: str = "API Key",
|
||||
) -> ProviderSpec:
|
||||
return openai_provider(
|
||||
provider_id=provider_id,
|
||||
name=name,
|
||||
default_base_url=default_base_url,
|
||||
sort_order=sort_order,
|
||||
models_dev_provider_id=models_dev_provider_id,
|
||||
base_url_presets=base_url_presets,
|
||||
api_key_hint=api_key_hint,
|
||||
description=description,
|
||||
model_list_strategy="models_dev_only",
|
||||
api_key_label=api_key_label,
|
||||
)
|
||||
|
||||
def anthropic_provider(
|
||||
provider_id: str,
|
||||
name: str,
|
||||
default_base_url: str,
|
||||
sort_order: int,
|
||||
*,
|
||||
models_dev_provider_id: Optional[str] = None,
|
||||
base_url_presets: Tuple[ProviderUrlPreset, ...] = (),
|
||||
api_key_hint: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> ProviderSpec:
|
||||
return ProviderSpec(
|
||||
id=provider_id,
|
||||
name=name,
|
||||
runtime="anthropic_compatible",
|
||||
models_dev_provider_id=models_dev_provider_id or provider_id,
|
||||
default_base_url=default_base_url,
|
||||
base_url_presets=base_url_presets,
|
||||
api_key_hint=api_key_hint or f"填写 {name} API Key。",
|
||||
model_list_strategy="anthropic_compatible",
|
||||
description=description or f"{name} Anthropic-compatible 端点。",
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
catalog_openai_providers = (
|
||||
("302ai", "302.AI", "https://api.302.ai/v1"),
|
||||
("abacus", "Abacus", "https://routellm.abacus.ai/v1"),
|
||||
("abliteration-ai", "abliteration.ai", "https://api.abliteration.ai/v1"),
|
||||
("baseten", "Baseten", "https://inference.baseten.co/v1"),
|
||||
("berget", "Berget.AI", "https://api.berget.ai/v1"),
|
||||
("chutes", "Chutes", "https://llm.chutes.ai/v1"),
|
||||
("clarifai", "Clarifai", "https://api.clarifai.com/v2/ext/openai/v1"),
|
||||
("cloudferro-sherlock", "CloudFerro Sherlock", "https://api-sherlock.cloudferro.com/openai/v1/"),
|
||||
("cloudflare-workers-ai", "Cloudflare Workers AI", "https://api.cloudflare.com/client/v4/accounts/${CLOUDFLARE_ACCOUNT_ID}/ai/v1"),
|
||||
("cortecs", "Cortecs", "https://api.cortecs.ai/v1"),
|
||||
("digitalocean", "DigitalOcean", "https://inference.do-ai.run/v1"),
|
||||
("dinference", "DInference", "https://api.dinference.com/v1"),
|
||||
("drun", "D.Run (China)", "https://chat.d.run/v1"),
|
||||
("evroc", "evroc", "https://models.think.evroc.com/v1"),
|
||||
("fastrouter", "FastRouter", "https://go.fastrouter.ai/api/v1"),
|
||||
("fireworks-ai", "Fireworks AI", "https://api.fireworks.ai/inference/v1/"),
|
||||
("firmware", "Firmware", "https://app.frogbot.ai/api/v1"),
|
||||
("friendli", "Friendli", "https://api.friendli.ai/serverless/v1"),
|
||||
("helicone", "Helicone", "https://ai-gateway.helicone.ai/v1"),
|
||||
("hpc-ai", "HPC-AI", "https://api.hpc-ai.com/inference/v1"),
|
||||
("huggingface", "Hugging Face", "https://router.huggingface.co/v1"),
|
||||
("iflowcn", "iFlow", "https://apis.iflow.cn/v1"),
|
||||
("inception", "Inception", "https://api.inceptionlabs.ai/v1/"),
|
||||
("inference", "Inference", "https://inference.net/v1"),
|
||||
("io-net", "IO.NET", "https://api.intelligence.io.solutions/api/v1"),
|
||||
("jiekou", "Jiekou.AI", "https://api.jiekou.ai/openai"),
|
||||
("kilo", "Kilo Gateway", "https://api.kilo.ai/api/gateway"),
|
||||
("kuae-cloud-coding-plan", "KUAE Cloud Coding Plan", "https://coding-plan-endpoint.kuaecloud.net/v1"),
|
||||
("llama", "Llama", "https://api.llama.com/compat/v1/"),
|
||||
("llmgateway", "LLM Gateway", "https://api.llmgateway.io/v1"),
|
||||
("lucidquery", "LucidQuery AI", "https://lucidquery.com/api/v1"),
|
||||
("meganova", "Meganova", "https://api.meganova.ai/v1"),
|
||||
("mixlayer", "Mixlayer", "https://models.mixlayer.ai/v1"),
|
||||
("moark", "Moark", "https://moark.com/v1"),
|
||||
("modelscope", "ModelScope", "https://api-inference.modelscope.cn/v1"),
|
||||
("morph", "Morph", "https://api.morphllm.com/v1"),
|
||||
("nano-gpt", "NanoGPT", "https://nano-gpt.com/api/v1"),
|
||||
("nebius", "Nebius Token Factory", "https://api.tokenfactory.nebius.com/v1"),
|
||||
("neuralwatt", "Neuralwatt", "https://api.neuralwatt.com/v1"),
|
||||
("nova", "Nova", "https://api.nova.amazon.com/v1"),
|
||||
("novita-ai", "NovitaAI", "https://api.novita.ai/openai"),
|
||||
("ovhcloud", "OVHcloud AI Endpoints", "https://oai.endpoints.kepler.ai.cloud.ovh.net/v1"),
|
||||
("perplexity-agent", "Perplexity Agent", "https://api.perplexity.ai/v1"),
|
||||
("poe", "Poe", "https://api.poe.com/v1"),
|
||||
("privatemode-ai", "Privatemode AI", "http://localhost:8080/v1"),
|
||||
("qihang-ai", "QiHang", "https://api.qhaigc.net/v1"),
|
||||
("qiniu-ai", "Qiniu", "https://api.qnaigc.com/v1"),
|
||||
("regolo-ai", "Regolo AI", "https://api.regolo.ai/v1"),
|
||||
("requesty", "Requesty", "https://router.requesty.ai/v1"),
|
||||
("scaleway", "Scaleway", "https://api.scaleway.ai/v1"),
|
||||
("stackit", "STACKIT", "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1"),
|
||||
("stepfun", "StepFun", "https://api.stepfun.com/v1"),
|
||||
("submodel", "submodel", "https://llm.submodel.ai/v1"),
|
||||
("synthetic", "Synthetic", "https://api.synthetic.new/openai/v1"),
|
||||
("the-grid-ai", "The Grid AI", "https://api.thegrid.ai/v1"),
|
||||
("upstage", "Upstage", "https://api.upstage.ai/v1/solar"),
|
||||
("vivgrid", "Vivgrid", "https://api.vivgrid.com/v1"),
|
||||
("vultr", "Vultr", "https://api.vultrinference.com/v1"),
|
||||
("wafer.ai", "Wafer", "https://pass.wafer.ai/v1"),
|
||||
("wandb", "Weights & Biases", "https://api.inference.wandb.ai/v1"),
|
||||
("zenmux", "ZenMux", "https://zenmux.ai/api/v1"),
|
||||
)
|
||||
catalog_openai_overrides = {
|
||||
"cloudflare-workers-ai": {
|
||||
"api_key_hint": "填写 Cloudflare API Token,并将 Base URL 中的 ${CLOUDFLARE_ACCOUNT_ID} 替换为真实账户 ID。",
|
||||
"description": "Cloudflare Workers AI OpenAI-compatible 端点,需要替换账户 ID。",
|
||||
},
|
||||
"privatemode-ai": {
|
||||
"api_key_hint": "如未启用鉴权,可填写任意占位值。",
|
||||
"description": "Privatemode AI 本地 OpenAI-compatible 端点。",
|
||||
},
|
||||
}
|
||||
|
||||
providers = [
|
||||
ProviderSpec(
|
||||
id="chatgpt",
|
||||
name="ChatGPT",
|
||||
@@ -162,6 +324,14 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
description="Gemini / Google AI Studio。",
|
||||
sort_order=20,
|
||||
),
|
||||
anthropic_provider(
|
||||
provider_id="anthropic",
|
||||
name="Anthropic",
|
||||
default_base_url="https://api.anthropic.com/v1",
|
||||
sort_order=25,
|
||||
api_key_hint="填写 Anthropic API Key。",
|
||||
description="Anthropic Claude 官方端点。",
|
||||
),
|
||||
ProviderSpec(
|
||||
id="deepseek",
|
||||
name="DeepSeek",
|
||||
@@ -172,6 +342,14 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
description="DeepSeek 官方平台。",
|
||||
sort_order=30,
|
||||
),
|
||||
catalog_openai_provider(
|
||||
provider_id="groq",
|
||||
name="Groq",
|
||||
default_base_url="https://api.groq.com/openai/v1",
|
||||
sort_order=35,
|
||||
api_key_hint="填写 Groq API Key。",
|
||||
description="Groq 官方 OpenAI-compatible 端点。",
|
||||
),
|
||||
ProviderSpec(
|
||||
id="openrouter",
|
||||
name="OpenRouter",
|
||||
@@ -182,6 +360,14 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
description="OpenRouter 聚合模型平台。",
|
||||
sort_order=40,
|
||||
),
|
||||
catalog_openai_provider(
|
||||
provider_id="xai",
|
||||
name="xAI",
|
||||
default_base_url="https://api.x.ai/v1",
|
||||
sort_order=45,
|
||||
api_key_hint="填写 xAI API Key。",
|
||||
description="xAI 官方 OpenAI-compatible 端点。",
|
||||
),
|
||||
ProviderSpec(
|
||||
id="github-copilot",
|
||||
name="GitHub Copilot",
|
||||
@@ -201,25 +387,140 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
description="通过 GitHub Copilot 订阅接入。",
|
||||
sort_order=50,
|
||||
),
|
||||
ProviderSpec(
|
||||
id="siliconflow",
|
||||
name="硅基流动",
|
||||
runtime="openai_compatible",
|
||||
models_dev_provider_id="siliconflow",
|
||||
default_base_url="https://api.siliconflow.cn/v1",
|
||||
api_key_hint="填写硅基流动 API Key。",
|
||||
description="SiliconFlow 官方兼容端点。",
|
||||
sort_order=60,
|
||||
catalog_openai_provider(
|
||||
provider_id="github-models",
|
||||
name="GitHub Models",
|
||||
default_base_url="https://models.github.ai/inference",
|
||||
sort_order=55,
|
||||
api_key_label="GitHub Token",
|
||||
api_key_hint="填写具有 GitHub Models 访问权限的 GitHub Token。",
|
||||
description="GitHub Models 推理端点。",
|
||||
),
|
||||
ProviderSpec(
|
||||
id="alibaba",
|
||||
openai_provider(
|
||||
provider_id="siliconflow",
|
||||
name="硅基流动",
|
||||
default_base_url="https://api.siliconflow.cn/v1",
|
||||
sort_order=60,
|
||||
models_dev_provider_id="siliconflow-cn",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="中国大陆",
|
||||
value="https://api.siliconflow.cn/v1",
|
||||
models_dev_provider_id="siliconflow-cn",
|
||||
),
|
||||
url_preset(
|
||||
label="Global",
|
||||
value="https://api.siliconflow.com/v1",
|
||||
models_dev_provider_id="siliconflow",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写硅基流动 API Key,可在中国大陆与 Global 端点间切换。",
|
||||
description="SiliconFlow 官方兼容端点。",
|
||||
),
|
||||
catalog_openai_provider(
|
||||
provider_id="moonshot",
|
||||
name="Moonshot AI",
|
||||
default_base_url="https://api.moonshot.cn/v1",
|
||||
sort_order=62,
|
||||
models_dev_provider_id="moonshotai-cn",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="中国站",
|
||||
value="https://api.moonshot.cn/v1",
|
||||
models_dev_provider_id="moonshotai-cn",
|
||||
),
|
||||
url_preset(
|
||||
label="国际站",
|
||||
value="https://api.moonshot.ai/v1",
|
||||
models_dev_provider_id="moonshotai",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 Moonshot / Kimi API Key,可在中国站与国际站端点间切换。",
|
||||
description="Moonshot / Kimi 官方兼容端点。",
|
||||
),
|
||||
anthropic_provider(
|
||||
provider_id="kimi-coding",
|
||||
name="Kimi for Coding",
|
||||
default_base_url="https://api.kimi.com/coding/v1",
|
||||
sort_order=63,
|
||||
models_dev_provider_id="kimi-for-coding",
|
||||
api_key_hint="填写 Moonshot / Kimi API Key。",
|
||||
description="Moonshot Kimi Coding Anthropic-compatible 端点。",
|
||||
),
|
||||
openai_provider(
|
||||
provider_id="zhipu",
|
||||
name="智谱 GLM",
|
||||
default_base_url="https://open.bigmodel.cn/api/paas/v4",
|
||||
sort_order=65,
|
||||
models_dev_provider_id="zhipuai",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="Token Plan / 通用 API",
|
||||
value="https://open.bigmodel.cn/api/paas/v4",
|
||||
models_dev_provider_id="zhipuai",
|
||||
),
|
||||
url_preset(
|
||||
label="Coding Plan",
|
||||
value="https://open.bigmodel.cn/api/coding/paas/v4",
|
||||
model_list_base_url="https://open.bigmodel.cn/api/paas/v4",
|
||||
models_dev_provider_id="zhipuai-coding-plan",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写智谱开放平台 API Key,可在 Token Plan / 通用 API 与 Coding Plan 端点间切换。",
|
||||
description="智谱开放平台国内站,支持通用 API 与 GLM Coding Plan 端点。",
|
||||
),
|
||||
catalog_openai_provider(
|
||||
provider_id="zai",
|
||||
name="Z.AI",
|
||||
default_base_url="https://api.z.ai/api/paas/v4",
|
||||
sort_order=66,
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="Token Plan / 通用 API",
|
||||
value="https://api.z.ai/api/paas/v4",
|
||||
models_dev_provider_id="zai",
|
||||
),
|
||||
url_preset(
|
||||
label="Coding Plan",
|
||||
value="https://api.z.ai/api/coding/paas/v4",
|
||||
models_dev_provider_id="zai-coding-plan",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 Z.AI API Key,可在通用 API 与 Coding Plan 端点间切换。",
|
||||
description="Z.AI 官方端点。",
|
||||
),
|
||||
openai_provider(
|
||||
provider_id="alibaba",
|
||||
name="阿里云百炼",
|
||||
runtime="openai_compatible",
|
||||
models_dev_provider_id="alibaba",
|
||||
default_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
api_key_hint="填写 DashScope / Alibaba API Key。",
|
||||
description="阿里云百炼兼容端点。",
|
||||
sort_order=70,
|
||||
models_dev_provider_id="alibaba-cn",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="中国内地 / 通用",
|
||||
value="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
models_dev_provider_id="alibaba-cn",
|
||||
),
|
||||
url_preset(
|
||||
label="国际站 / 通用",
|
||||
value="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
models_dev_provider_id="alibaba",
|
||||
),
|
||||
url_preset(
|
||||
label="中国内地 / Coding Plan",
|
||||
value="https://coding.dashscope.aliyuncs.com/v1",
|
||||
model_list_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
models_dev_provider_id="alibaba-coding-plan-cn",
|
||||
),
|
||||
url_preset(
|
||||
label="国际站 / Coding Plan",
|
||||
value="https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
model_list_base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
models_dev_provider_id="alibaba-coding-plan",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 DashScope / Alibaba API Key,可在中国内地、国际站与 Coding Plan 端点间切换。",
|
||||
description="阿里云百炼兼容端点。",
|
||||
),
|
||||
ProviderSpec(
|
||||
id="volcengine",
|
||||
@@ -236,7 +537,19 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
runtime="openai_compatible",
|
||||
models_dev_provider_id="tencent-tokenhub",
|
||||
default_base_url="https://tokenhub.tencentmaas.com/v1",
|
||||
api_key_hint="填写 Tencent API Key。",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="TokenHub",
|
||||
value="https://tokenhub.tencentmaas.com/v1",
|
||||
models_dev_provider_id="tencent-tokenhub",
|
||||
),
|
||||
url_preset(
|
||||
label="Coding Plan",
|
||||
value="https://api.lkeap.cloud.tencent.com/coding/v3",
|
||||
models_dev_provider_id="tencent-coding-plan",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 Tencent API Key,可在 TokenHub 与 Coding Plan 端点间切换。",
|
||||
model_list_strategy="models_dev_only",
|
||||
description="腾讯兼容端点。",
|
||||
sort_order=90,
|
||||
@@ -261,27 +574,125 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
description="Nvidia 集成推理平台。",
|
||||
sort_order=110,
|
||||
),
|
||||
ProviderSpec(
|
||||
id="minimax",
|
||||
catalog_openai_provider(
|
||||
provider_id="opencode",
|
||||
name="OpenCode",
|
||||
default_base_url="https://opencode.ai/zen/v1",
|
||||
sort_order=115,
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="Zen",
|
||||
value="https://opencode.ai/zen/v1",
|
||||
models_dev_provider_id="opencode",
|
||||
),
|
||||
url_preset(
|
||||
label="Go",
|
||||
value="https://opencode.ai/zen/go/v1",
|
||||
models_dev_provider_id="opencode-go",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 OpenCode API Key,可在 Zen 与 Go 端点间切换。",
|
||||
description="OpenCode Zen / Go 端点。",
|
||||
),
|
||||
anthropic_provider(
|
||||
provider_id="minimax",
|
||||
name="MiniMax",
|
||||
runtime="anthropic_compatible",
|
||||
models_dev_provider_id="minimax",
|
||||
default_base_url="https://api.minimaxi.com/anthropic/v1",
|
||||
api_key_hint="填写 MiniMax API Key。",
|
||||
model_list_strategy="anthropic_compatible",
|
||||
description="MiniMax Anthropic-compatible 端点。",
|
||||
sort_order=120,
|
||||
models_dev_provider_id="minimax-cn",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="中国内地 / 通用",
|
||||
value="https://api.minimaxi.com/anthropic/v1",
|
||||
models_dev_provider_id="minimax-cn",
|
||||
),
|
||||
url_preset(
|
||||
label="国际站 / 通用",
|
||||
value="https://api.minimax.io/anthropic/v1",
|
||||
models_dev_provider_id="minimax",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 MiniMax API Key,可在中国内地与国际站通用端点间切换。",
|
||||
description="MiniMax Anthropic-compatible 通用端点。",
|
||||
),
|
||||
ProviderSpec(
|
||||
id="xiaomi",
|
||||
anthropic_provider(
|
||||
provider_id="minimax-coding",
|
||||
name="MiniMax Coding Plan",
|
||||
default_base_url="https://api.minimaxi.com/anthropic/v1",
|
||||
sort_order=121,
|
||||
models_dev_provider_id="minimax-cn-coding-plan",
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="中国内地 / Coding Plan",
|
||||
value="https://api.minimaxi.com/anthropic/v1",
|
||||
models_dev_provider_id="minimax-cn-coding-plan",
|
||||
),
|
||||
url_preset(
|
||||
label="国际站 / Coding Plan",
|
||||
value="https://api.minimax.io/anthropic/v1",
|
||||
models_dev_provider_id="minimax-coding-plan",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 MiniMax API Key,可在中国内地与国际站 Coding Plan 目录间切换。",
|
||||
description="MiniMax Coding Plan Anthropic-compatible 端点。",
|
||||
),
|
||||
catalog_openai_provider(
|
||||
provider_id="xiaomi",
|
||||
name="Xiaomi",
|
||||
runtime="openai_compatible",
|
||||
models_dev_provider_id="xiaomi",
|
||||
default_base_url="https://api.xiaomimimo.com/v1",
|
||||
api_key_hint="填写 Xiaomi API Key。",
|
||||
description="小米 Mimo 兼容端点。",
|
||||
sort_order=130,
|
||||
base_url_presets=(
|
||||
url_preset(
|
||||
label="标准端点",
|
||||
value="https://api.xiaomimimo.com/v1",
|
||||
models_dev_provider_id="xiaomi",
|
||||
),
|
||||
url_preset(
|
||||
label="Token Plan / 中国",
|
||||
value="https://token-plan-cn.xiaomimimo.com/v1",
|
||||
models_dev_provider_id="xiaomi-token-plan-cn",
|
||||
),
|
||||
url_preset(
|
||||
label="Token Plan / 新加坡",
|
||||
value="https://token-plan-sgp.xiaomimimo.com/v1",
|
||||
models_dev_provider_id="xiaomi-token-plan-sgp",
|
||||
),
|
||||
url_preset(
|
||||
label="Token Plan / 欧洲",
|
||||
value="https://token-plan-ams.xiaomimimo.com/v1",
|
||||
models_dev_provider_id="xiaomi-token-plan-ams",
|
||||
),
|
||||
),
|
||||
api_key_hint="填写 Xiaomi API Key,可在标准端点与各区域 Token Plan 端点间切换。",
|
||||
description="小米 Mimo 兼容端点。",
|
||||
),
|
||||
catalog_openai_provider(
|
||||
provider_id="lmstudio",
|
||||
name="LM Studio",
|
||||
default_base_url="http://127.0.0.1:1234/v1",
|
||||
sort_order=135,
|
||||
api_key_hint="如未启用鉴权,可填写任意占位值。",
|
||||
description="LM Studio 本地 OpenAI-compatible 端点。",
|
||||
),
|
||||
]
|
||||
|
||||
for sort_order, (provider_id, name, base_url) in enumerate(
|
||||
catalog_openai_providers,
|
||||
start=200,
|
||||
):
|
||||
overrides = catalog_openai_overrides.get(provider_id, {})
|
||||
providers.append(
|
||||
catalog_openai_provider(
|
||||
provider_id=provider_id,
|
||||
name=name,
|
||||
default_base_url=base_url,
|
||||
sort_order=sort_order,
|
||||
api_key_hint=overrides.get("api_key_hint"),
|
||||
description=overrides.get("description"),
|
||||
)
|
||||
)
|
||||
|
||||
providers.append(
|
||||
ProviderSpec(
|
||||
id="openai",
|
||||
name="OpenAI Compatible",
|
||||
@@ -292,9 +703,10 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
supports_api_key=True,
|
||||
api_key_hint="通用 OpenAI-compatible 兜底入口,需要手动填写 Base URL。",
|
||||
description="通用 OpenAI-compatible 模型服务。",
|
||||
sort_order=200,
|
||||
),
|
||||
sort_order=1000,
|
||||
)
|
||||
)
|
||||
return tuple(providers)
|
||||
|
||||
def list_providers(self) -> list[dict[str, Any]]:
|
||||
"""返回前端可渲染的 provider 目录。"""
|
||||
@@ -305,7 +717,14 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
"id": spec.id,
|
||||
"name": spec.name,
|
||||
"runtime": spec.runtime,
|
||||
"default_base_url": spec.default_base_url or "",
|
||||
"default_base_url": self._default_base_url_for_provider(spec) or "",
|
||||
"base_url_presets": [
|
||||
{
|
||||
"label": preset.label,
|
||||
"value": self._sanitize_base_url(preset.value) or "",
|
||||
}
|
||||
for preset in spec.base_url_presets
|
||||
],
|
||||
"base_url_editable": spec.base_url_editable,
|
||||
"requires_base_url": spec.requires_base_url,
|
||||
"supports_api_key": spec.supports_api_key,
|
||||
@@ -344,6 +763,65 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return None
|
||||
return value.rstrip("/")
|
||||
|
||||
@classmethod
|
||||
def _default_base_url_for_provider(cls, spec: ProviderSpec) -> Optional[str]:
|
||||
default_base_url = cls._sanitize_base_url(spec.default_base_url)
|
||||
if default_base_url:
|
||||
return default_base_url
|
||||
if not spec.base_url_presets:
|
||||
return None
|
||||
return cls._sanitize_base_url(spec.base_url_presets[0].value)
|
||||
|
||||
@classmethod
|
||||
def _resolve_provider_model_list_base_url(
|
||||
cls, spec: ProviderSpec, base_url: Optional[str]
|
||||
) -> Optional[str]:
|
||||
normalized_base_url = cls._sanitize_base_url(base_url)
|
||||
if normalized_base_url:
|
||||
for preset in spec.base_url_presets:
|
||||
preset_value = cls._sanitize_base_url(preset.value)
|
||||
if normalized_base_url != preset_value:
|
||||
continue
|
||||
return cls._sanitize_base_url(preset.model_list_base_url) or preset_value
|
||||
return normalized_base_url
|
||||
|
||||
default_base_url = cls._default_base_url_for_provider(spec)
|
||||
if default_base_url:
|
||||
for preset in spec.base_url_presets:
|
||||
preset_value = cls._sanitize_base_url(preset.value)
|
||||
if preset_value != default_base_url:
|
||||
continue
|
||||
return cls._sanitize_base_url(preset.model_list_base_url) or preset_value
|
||||
return default_base_url
|
||||
|
||||
@classmethod
|
||||
def _resolve_provider_models_dev_provider_id(
|
||||
cls, spec: ProviderSpec, base_url: Optional[str]
|
||||
) -> Optional[str]:
|
||||
normalized_base_url = cls._sanitize_base_url(base_url)
|
||||
if normalized_base_url:
|
||||
for preset in spec.base_url_presets:
|
||||
preset_value = cls._sanitize_base_url(preset.value)
|
||||
if normalized_base_url != preset_value:
|
||||
continue
|
||||
return preset.models_dev_provider_id or spec.models_dev_provider_id
|
||||
return spec.models_dev_provider_id
|
||||
|
||||
default_base_url = cls._default_base_url_for_provider(spec)
|
||||
if default_base_url:
|
||||
for preset in spec.base_url_presets:
|
||||
preset_value = cls._sanitize_base_url(preset.value)
|
||||
if preset_value != default_base_url:
|
||||
continue
|
||||
return preset.models_dev_provider_id or spec.models_dev_provider_id
|
||||
return spec.models_dev_provider_id
|
||||
|
||||
def resolve_model_list_base_url(
|
||||
self, provider_id: str, base_url: Optional[str]
|
||||
) -> Optional[str]:
|
||||
spec = self.get_provider(provider_id)
|
||||
return self._resolve_provider_model_list_base_url(spec, base_url)
|
||||
|
||||
@staticmethod
|
||||
def _httpx_proxy_key() -> str:
|
||||
"""兼容不同 httpx 版本的 proxy 参数名。"""
|
||||
@@ -492,16 +970,22 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return cached
|
||||
raise LLMProviderError(f"获取 models.dev 数据失败: {err}") from err
|
||||
|
||||
async def _models_dev_provider_payload(self, provider_id: str) -> dict[str, Any]:
|
||||
async def _models_dev_provider_payload(
|
||||
self, provider_id: str, base_url: Optional[str] = None
|
||||
) -> dict[str, Any]:
|
||||
spec = self.get_provider(provider_id)
|
||||
if not spec.models_dev_provider_id:
|
||||
models_dev_provider_id = self._resolve_provider_models_dev_provider_id(
|
||||
spec,
|
||||
base_url,
|
||||
)
|
||||
if not models_dev_provider_id:
|
||||
return {}
|
||||
return (await self.get_models_dev_data()).get(spec.models_dev_provider_id, {}) or {}
|
||||
return (await self.get_models_dev_data()).get(models_dev_provider_id, {}) or {}
|
||||
|
||||
async def _models_dev_model(
|
||||
self, provider_id: str, model_id: str
|
||||
self, provider_id: str, model_id: str, base_url: Optional[str] = None
|
||||
) -> dict[str, Any] | None:
|
||||
payload = await self._models_dev_provider_payload(provider_id)
|
||||
payload = await self._models_dev_provider_payload(provider_id, base_url=base_url)
|
||||
models = payload.get("models") if isinstance(payload, dict) else None
|
||||
if not isinstance(models, dict):
|
||||
return None
|
||||
@@ -649,7 +1133,11 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
results = []
|
||||
response = await client.models.list()
|
||||
for model in response.data:
|
||||
metadata = await self._models_dev_model(provider_id, model.id) or {}
|
||||
metadata = await self._models_dev_model(
|
||||
provider_id,
|
||||
model.id,
|
||||
base_url=base_url,
|
||||
) or {}
|
||||
results.append(
|
||||
self._normalize_model_record(
|
||||
model_id=model.id,
|
||||
@@ -664,13 +1152,14 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
self,
|
||||
provider_id: str,
|
||||
transport: str = "openai",
|
||||
base_url: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
某些 provider 没有统一稳定的 models.list 行为,
|
||||
因此优先读取 models.dev 目录;若未来 provider 暴露标准 models 接口,
|
||||
再平滑补充实时刷新即可。
|
||||
"""
|
||||
payload = await self._models_dev_provider_payload(provider_id)
|
||||
payload = await self._models_dev_provider_payload(provider_id, base_url=base_url)
|
||||
models = payload.get("models") if isinstance(payload, dict) else None
|
||||
if not isinstance(models, dict):
|
||||
raise LLMProviderError(f"{provider_id} 暂无可用模型目录")
|
||||
@@ -825,10 +1314,11 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""返回标准化后的模型目录。"""
|
||||
spec = self.get_provider(provider_id)
|
||||
if force_refresh and spec.models_dev_provider_id:
|
||||
if self._resolve_provider_models_dev_provider_id(spec, base_url):
|
||||
# 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表”
|
||||
# 在使用目录型 provider 时也能拿到最新参数。
|
||||
await self.get_models_dev_data(force_refresh=True)
|
||||
if force_refresh:
|
||||
await self.get_models_dev_data(force_refresh=True)
|
||||
runtime = await self.resolve_runtime(
|
||||
provider_id,
|
||||
model=None,
|
||||
@@ -848,7 +1338,10 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return await self._list_models_from_openai_compatible(
|
||||
provider_id="chatgpt",
|
||||
api_key=runtime["api_key"],
|
||||
base_url=runtime["base_url"],
|
||||
base_url=self._resolve_provider_model_list_base_url(
|
||||
spec,
|
||||
runtime["base_url"],
|
||||
),
|
||||
default_headers=runtime.get("default_headers"),
|
||||
)
|
||||
|
||||
@@ -856,28 +1349,40 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return await self._list_models_from_models_dev_only(
|
||||
provider_id=provider_id,
|
||||
transport="anthropic",
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
if spec.model_list_strategy == "models_dev_only":
|
||||
return await self._list_models_from_models_dev_only(
|
||||
provider_id=provider_id,
|
||||
transport="openai",
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
# openai-compatible / deepseek 默认走官方 models 端点。
|
||||
return await self._list_models_from_openai_compatible(
|
||||
provider_id=provider_id,
|
||||
api_key=runtime["api_key"],
|
||||
base_url=runtime["base_url"],
|
||||
base_url=self._resolve_provider_model_list_base_url(
|
||||
spec,
|
||||
runtime["base_url"],
|
||||
),
|
||||
default_headers=runtime.get("default_headers"),
|
||||
)
|
||||
|
||||
async def resolve_model_metadata(
|
||||
self, provider_id: str, model_id: Optional[str]
|
||||
self,
|
||||
provider_id: str,
|
||||
model_id: Optional[str],
|
||||
base_url: Optional[str] = None,
|
||||
) -> dict[str, Any] | None:
|
||||
if not model_id:
|
||||
return None
|
||||
metadata = await self._models_dev_model(provider_id, model_id)
|
||||
metadata = await self._models_dev_model(
|
||||
provider_id,
|
||||
model_id,
|
||||
base_url=base_url,
|
||||
)
|
||||
if metadata:
|
||||
return metadata
|
||||
if provider_id == "chatgpt":
|
||||
@@ -1366,7 +1871,11 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
"runtime": spec.runtime,
|
||||
"model_id": model,
|
||||
"model_record": model_record,
|
||||
"model_metadata": await self.resolve_model_metadata(provider_id, model),
|
||||
"model_metadata": await self.resolve_model_metadata(
|
||||
provider_id,
|
||||
model,
|
||||
base_url=base_url,
|
||||
),
|
||||
"default_headers": None,
|
||||
"use_responses_api": None,
|
||||
"auth_mode": "api_key",
|
||||
@@ -1401,7 +1910,8 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
{
|
||||
"runtime": "openai_compatible",
|
||||
"api_key": normalized_api_key,
|
||||
"base_url": normalized_base_url or spec.default_base_url,
|
||||
"base_url": normalized_base_url
|
||||
or self._default_base_url_for_provider(spec),
|
||||
"auth_mode": "api_key",
|
||||
}
|
||||
)
|
||||
@@ -1448,7 +1958,9 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return result
|
||||
|
||||
if spec.runtime == "anthropic_compatible":
|
||||
effective_base_url = normalized_base_url or spec.default_base_url
|
||||
effective_base_url = normalized_base_url or self._default_base_url_for_provider(
|
||||
spec
|
||||
)
|
||||
if not normalized_api_key:
|
||||
raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key")
|
||||
if not effective_base_url:
|
||||
@@ -1464,7 +1976,7 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
)
|
||||
return result
|
||||
|
||||
effective_base_url = normalized_base_url or spec.default_base_url
|
||||
effective_base_url = normalized_base_url or self._default_base_url_for_provider(spec)
|
||||
if spec.requires_base_url and not effective_base_url:
|
||||
raise LLMProviderAuthError(f"{spec.name} 需要填写 Base URL")
|
||||
if not normalized_api_key:
|
||||
|
||||
@@ -8,48 +8,64 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel
|
||||
Identity and Goal:
|
||||
- You are an AI media assistant powered by MoviePilot.
|
||||
- Your primary goal is to fully resolve the user's MoviePilot-related media tasks with the available tools whenever the request is actionable.
|
||||
- Focus on MoviePilot's home media domain: search, recognition, subscriptions, downloads, library organization, file transfer, and system status.
|
||||
- Focus on MoviePilot's core home media domain: sites, search, recognition, downloads, subscriptions, library organization, file transfer, and system status.
|
||||
- Treat sites as a first-class system capability, not background detail. In MoviePilot, sites are the upstream source for search, account status, authentication, and many download or subscription decisions.
|
||||
- Understand the platform's core workflow as: site availability and configuration -> media search -> media recognition/metadata confirmation -> manual download or subscription -> transfer and library organization -> status/history confirmation.
|
||||
- Treat manual download and subscription automation as two execution modes of the same core pipeline. One is user-triggered immediate acquisition; the other is persistent site-driven monitoring and acquisition.
|
||||
- Stay within the MoviePilot product domain unless the user explicitly asks for adjacent help that can be handled with your existing tools.
|
||||
|
||||
Behavior Model:
|
||||
- Prioritize task progress over conversation.
|
||||
- Check current state before making changes, then do the smallest correct action.
|
||||
- When a task depends on tracker or indexer availability, inspect site state first or as early as possible.
|
||||
- Do not stop for approval on read-only operations. Only confirm before destructive or high-impact actions such as starting downloads, deleting subscriptions, or removing history.
|
||||
- When a request can be completed by tools, prefer doing the work over explaining what you might do.
|
||||
- After an action, perform the minimum validation needed to confirm the result actually landed.
|
||||
- Keep the user anchored to the operational step that matters now: site, search, recognition, download, subscription, or transfer.
|
||||
- If the user explicitly asks to change the speaking style or persona, use the dedicated persona tools instead of editing runtime files manually.
|
||||
- If the user explicitly asks to rewrite or create a persona definition, prefer `update_persona_definition` rather than generic file-editing tools.
|
||||
- Do not let user memory or persona style override this core identity, safety boundaries, or built-in background task rules.
|
||||
- You are not a general-purpose coding assistant in normal media conversations. Only cross into implementation details when the user explicitly asks about MoviePilot internals or debugging.
|
||||
|
||||
Core Capabilities:
|
||||
1. Media Search and Recognition - Identify movies, TV shows, and anime; recognize media from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management - Create rules for automated downloading and monitor trending content.
|
||||
3. Download Control - Search torrents across trackers and filter by quality, codec, and release group.
|
||||
4. System Status and Organization - Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
7. Persona Management - If the user explicitly asks to change the speaking style or persona, prefer `query_personas` and `switch_persona`; if the user asks to rewrite or create a persona definition, prefer `update_persona_definition` instead of editing runtime files manually.
|
||||
1. Site Operations - Query configured sites, understand site priority and availability, inspect account data, test connectivity, and update site authentication when the user explicitly requests site maintenance.
|
||||
2. Media Search and Recognition - Identify movies, TV shows, and anime; search media databases; recognize media from fuzzy filenames, torrent titles, or incomplete names.
|
||||
3. Torrent Search and Selection - Search torrents across configured sites and filter by quality, resolution, codec, effect, release group, and other result traits.
|
||||
4. Download Control - Add, inspect, modify, or remove download tasks and connect site results to downloader execution.
|
||||
5. Subscription Management - Create and manage subscriptions that continuously search configured sites and automatically download matching releases.
|
||||
6. Transfer and Library Organization - Transfer files into the library, trigger recognition-aware organization, and confirm post-download file landing or cleanup state.
|
||||
7. System Status and History - Monitor downloader state, site state, transfer history, subscription history, and related system health signals.
|
||||
8. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
9. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
10. Persona Management - If the user explicitly asks to change the speaking style or persona, prefer `query_personas` and `switch_persona`; if the user asks to rewrite or create a persona definition, prefer `update_persona_definition` instead of editing runtime files manually.
|
||||
|
||||
Core Workflow:
|
||||
1. Media Discovery: Identify exact media metadata such as TMDB ID and Season or Episode using search tools when needed.
|
||||
2. Context Checking: Verify whether the media already exists in the library, has already been subscribed, or has relevant history that affects the next step.
|
||||
3. Action Execution: Perform the requested task with concise user-facing output unless the operation is destructive or blocked.
|
||||
4. Final Confirmation: State the outcome briefly, including the key media facts or blocker.
|
||||
1. Site and Context Check: Determine whether site status, site scope, library state, existing subscriptions, or prior download/transfer history can affect the task.
|
||||
2. Media Identity Resolution: Confirm exact media identity such as TMDB ID, title, year, type, season, or episode using `search_media`, `query_media_detail`, or `recognize_media` as needed.
|
||||
3. Resource Discovery: Use the appropriate search path for the task. For manual acquisition, search site resources and inspect result quality. For automation, prepare subscription conditions that will search sites continuously.
|
||||
4. Action Execution: Perform the requested task, typically one of: test/query site, search torrents, add download, add or modify subscription, or transfer and organize files.
|
||||
5. Final Confirmation: State the outcome briefly, including the key media facts, chosen site or resource scope when relevant, and the next blocker if the task could not be completed.
|
||||
|
||||
Tool Calling Strategy:
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- Prefer site-aware tool paths when the task is about torrents, subscriptions, or download failures. `query_sites`, `test_site`, and `query_site_userdata` are part of the main operating flow, not edge-case tools.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- For fuzzy torrent names, filenames, or manually provided paths, prefer `recognize_media` before asking the user for a cleaner title.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when automated paths are exhausted.
|
||||
- If torrent search yields no useful result, check site scope, site health, and recognition quality before concluding that the resource is unavailable.
|
||||
- Reuse the latest torrent search cache for `get_search_results` and `add_download` instead of re-running the same search unnecessarily.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
|
||||
Media Management Rules:
|
||||
1. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
|
||||
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
|
||||
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
|
||||
5. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
1. Site Awareness: When search, download, or subscription behavior depends on sites, prefer checking enabled sites, selected site IDs, priority, or site health before changing user expectations.
|
||||
2. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
3. Search vs Recognition: `search_media` is for database lookup, `recognize_media` is for parsing titles or paths, and `search_torrents` is for site resource lookup. Do not confuse these roles.
|
||||
4. Subscription Logic: Check for the best matching quality profile, filter groups, and site scope based on user history or defaults.
|
||||
5. Library Awareness: Check if content already exists in the library to avoid duplicates before downloading, subscribing, or transferring.
|
||||
6. Transfer Awareness: If the user asks about downloaded files landing in the library, include transfer or organization state in the reasoning, not just download completion.
|
||||
7. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative or the next best operational step.
|
||||
8. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
</agent_core>
|
||||
|
||||
<communication_runtime>
|
||||
|
||||
@@ -281,7 +281,10 @@ class PromptManager:
|
||||
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
|
||||
else:
|
||||
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
|
||||
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
db_info = (
|
||||
f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@"
|
||||
f"{settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
)
|
||||
|
||||
info_lines = [
|
||||
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.utils.crypto import HashUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
SEARCH_RESULT_CACHE_FILE = "__search_result__"
|
||||
TORRENT_RESULT_LIMIT = 50
|
||||
TORRENT_RESULT_LIMIT = 200
|
||||
|
||||
|
||||
def build_torrent_ref(context: Optional[Context]) -> str:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type, List
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.db.user_oper import UserOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, MessageChannel
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
@@ -101,6 +102,36 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def _resolve_subscribe_username(self) -> Optional[str]:
|
||||
"""优先映射为系统用户名,未绑定时回退当前渠道用户名。"""
|
||||
resolved_username = self._username
|
||||
if not self._channel or not self._user_id:
|
||||
return resolved_username
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return resolved_username
|
||||
|
||||
binding_keys = {
|
||||
MessageChannel.Telegram: ("telegram_userid",),
|
||||
MessageChannel.Discord: ("discord_userid",),
|
||||
MessageChannel.Wechat: ("wechat_userid",),
|
||||
MessageChannel.Slack: ("slack_userid",),
|
||||
MessageChannel.VoceChat: ("vocechat_userid",),
|
||||
MessageChannel.SynologyChat: ("synologychat_userid",),
|
||||
MessageChannel.QQ: ("qq_userid", "qq_openid"),
|
||||
}.get(channel)
|
||||
if not binding_keys:
|
||||
return resolved_username
|
||||
|
||||
mapped_username = await self.run_blocking(
|
||||
"db",
|
||||
UserOper().get_name,
|
||||
**{key: self._user_id for key in binding_keys},
|
||||
)
|
||||
return mapped_username or resolved_username
|
||||
|
||||
async def run(
|
||||
self,
|
||||
title: str,
|
||||
@@ -137,6 +168,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
if media_type_enum == MediaType.TV
|
||||
else None
|
||||
)
|
||||
subscribe_username = await self._resolve_subscribe_username()
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
@@ -162,7 +194,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
username=subscribe_username,
|
||||
**subscribe_kwargs,
|
||||
)
|
||||
if sid:
|
||||
|
||||
@@ -5,7 +5,8 @@ import os
|
||||
import signal
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, TextIO, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -15,7 +16,7 @@ from app.log import logger
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 60
|
||||
MAX_TIMEOUT_SECONDS = 300
|
||||
MAX_OUTPUT_CHARS = 6000
|
||||
MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
@@ -25,40 +26,93 @@ _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
@dataclass
|
||||
class _CommandOutput:
|
||||
"""保存受限命令输出,避免大输出一次性进入内存。"""
|
||||
"""保存前 10KB 预览,并在超限时将完整输出写入临时文件。"""
|
||||
|
||||
limit: int
|
||||
stdout_chunks: list[str] = field(default_factory=list)
|
||||
stderr_chunks: list[str] = field(default_factory=list)
|
||||
captured_chars: int = 0
|
||||
truncated: bool = False
|
||||
preview_limit_bytes: int
|
||||
preview_entries: list[tuple[str, str]] = field(default_factory=list)
|
||||
captured_bytes: int = 0
|
||||
preview_truncated: bool = False
|
||||
temp_file_path: Optional[str] = None
|
||||
temp_file_handle: Optional[TextIO] = None
|
||||
last_written_stream: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
|
||||
if byte_limit <= 0:
|
||||
return ""
|
||||
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
|
||||
|
||||
def _write_chunk(self, stream_name: str, text: str) -> None:
|
||||
if not self.temp_file_handle or not text:
|
||||
return
|
||||
|
||||
if self.last_written_stream != stream_name:
|
||||
if self.temp_file_handle.tell() > 0:
|
||||
self.temp_file_handle.write("\n")
|
||||
title = "标准输出" if stream_name == "stdout" else "错误输出"
|
||||
self.temp_file_handle.write(f"[{title}]\n")
|
||||
self.last_written_stream = stream_name
|
||||
|
||||
self.temp_file_handle.write(text)
|
||||
|
||||
def _ensure_temp_file(self) -> None:
|
||||
if self.temp_file_handle:
|
||||
return
|
||||
|
||||
temp_file = NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
suffix=".log",
|
||||
prefix="moviepilot-command-",
|
||||
delete=False,
|
||||
)
|
||||
self.temp_file_path = temp_file.name
|
||||
self.temp_file_handle = temp_file
|
||||
for stream_name, chunk in self.preview_entries:
|
||||
self._write_chunk(stream_name, chunk)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self.temp_file_handle:
|
||||
return
|
||||
self.temp_file_handle.flush()
|
||||
self.temp_file_handle.close()
|
||||
self.temp_file_handle = None
|
||||
|
||||
def append(self, stream_name: str, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
|
||||
remaining = self.limit - self.captured_chars
|
||||
if remaining <= 0:
|
||||
self.truncated = True
|
||||
if self.temp_file_handle:
|
||||
self._write_chunk(stream_name, text)
|
||||
return
|
||||
|
||||
captured = text[:remaining]
|
||||
if stream_name == "stdout":
|
||||
self.stdout_chunks.append(captured)
|
||||
else:
|
||||
self.stderr_chunks.append(captured)
|
||||
chunk_bytes = len(text.encode("utf-8"))
|
||||
remaining = self.preview_limit_bytes - self.captured_bytes
|
||||
if chunk_bytes <= remaining:
|
||||
self.preview_entries.append((stream_name, text))
|
||||
self.captured_bytes += chunk_bytes
|
||||
return
|
||||
|
||||
self.captured_chars += len(captured)
|
||||
if len(text) > remaining:
|
||||
self.truncated = True
|
||||
self.preview_truncated = True
|
||||
self._ensure_temp_file()
|
||||
self._write_chunk(stream_name, text)
|
||||
|
||||
preview = self._clip_text_to_bytes(text, remaining)
|
||||
if preview:
|
||||
self.preview_entries.append((stream_name, preview))
|
||||
self.captured_bytes += len(preview.encode("utf-8"))
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
return "".join(self.stdout_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stdout"
|
||||
).strip()
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
return "".join(self.stderr_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stderr"
|
||||
).strip()
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
@@ -78,7 +132,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
description: str = (
|
||||
"Safely execute shell commands on the server. Useful for system "
|
||||
"maintenance, checking status, or running custom scripts. Includes "
|
||||
"timeout, concurrency, and hard output limits."
|
||||
"timeout, concurrency, and output preview limits."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
@@ -107,7 +161,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _subprocess_kwargs() -> dict:
|
||||
"""为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。"""
|
||||
"""为子进程创建独立进程组,便于超时场景清理整棵子进程。"""
|
||||
kwargs = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
@@ -124,23 +178,14 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
output: _CommandOutput,
|
||||
limit_reached: asyncio.Event,
|
||||
) -> None:
|
||||
"""按块读取输出,达到上限后通知主流程终止命令。"""
|
||||
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
|
||||
while True:
|
||||
chunk = await stream.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
continue
|
||||
|
||||
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
# 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
|
||||
@@ -205,27 +250,33 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
output: _CommandOutput,
|
||||
timeout: int,
|
||||
timed_out: bool,
|
||||
output_limited: bool,
|
||||
timeout_note: Optional[str],
|
||||
) -> str:
|
||||
if timed_out:
|
||||
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
|
||||
elif output_limited:
|
||||
result = (
|
||||
f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符,"
|
||||
f"已截断并终止进程,退出码: {exit_code})"
|
||||
)
|
||||
else:
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
|
||||
if timeout_note:
|
||||
result += f"\n\n提示:\n{timeout_note}"
|
||||
if output.temp_file_path:
|
||||
file_note = (
|
||||
"截至命令终止前的完整输出"
|
||||
if timed_out
|
||||
else "完整输出"
|
||||
)
|
||||
result += (
|
||||
"\n\n提示:\n"
|
||||
f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
|
||||
f"{file_note}已写入临时文件: {output.temp_file_path}\n"
|
||||
"如需完整内容,请继续读取该文件。"
|
||||
)
|
||||
if output.stdout:
|
||||
result += f"\n\n标准输出:\n{output.stdout}"
|
||||
if output.stderr:
|
||||
result += f"\n\n错误输出:\n{output.stderr}"
|
||||
if output.truncated:
|
||||
result += "\n\n...(输出内容过长,已截断)"
|
||||
if output.preview_truncated:
|
||||
result += "\n\n...(仅展示前 10KB 内容)"
|
||||
if not output.stdout and not output.stderr:
|
||||
result += "\n\n(无输出内容)"
|
||||
return result
|
||||
@@ -252,51 +303,40 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
async with _command_semaphore:
|
||||
# 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。
|
||||
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, **self._subprocess_kwargs()
|
||||
)
|
||||
output = _CommandOutput(limit=MAX_OUTPUT_CHARS)
|
||||
limit_reached = asyncio.Event()
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
limit_task = asyncio.create_task(limit_reached.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stdout, "stdout", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stdout, "stdout", output)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stderr, "stderr", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stderr, "stderr", output)
|
||||
),
|
||||
]
|
||||
|
||||
timed_out = False
|
||||
output_limited = False
|
||||
done, _ = await asyncio.wait(
|
||||
{wait_task, limit_task},
|
||||
timeout=normalized_timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if wait_task not in done:
|
||||
if limit_task in done:
|
||||
output_limited = True
|
||||
else:
|
||||
timed_out = True
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
|
||||
limit_task.cancel()
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
|
||||
return self._format_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
output_limited=output_limited,
|
||||
timeout_note=timeout_note,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.agent.llm import (
|
||||
LLMTestTimeout,
|
||||
render_auth_result_html,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import (
|
||||
get_current_active_superuser_async,
|
||||
|
||||
@@ -950,6 +950,30 @@ def restart_system(_: User = Depends(get_current_active_superuser)):
|
||||
global_vars.stop_system()
|
||||
# 执行重启
|
||||
ret, msg = SystemHelper.restart()
|
||||
if not ret:
|
||||
global_vars.resume_system()
|
||||
return schemas.Response(success=ret, message=msg)
|
||||
|
||||
|
||||
@router.post("/upgrade", summary="升级并重启系统", response_model=schemas.Response)
|
||||
def upgrade_system(
|
||||
mode: Annotated[str | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
):
|
||||
"""
|
||||
触发系统升级并重启(仅管理员)
|
||||
|
||||
- 当前已开启自动升级时:直接重启,由启动流程完成升级。
|
||||
- 当前未开启自动升级时:写入一次性升级标记,本次重启后仅执行一次升级。
|
||||
"""
|
||||
if not SystemHelper.can_restart():
|
||||
return schemas.Response(success=False, message="当前运行环境不支持升级操作!")
|
||||
|
||||
# 标识停止事件
|
||||
global_vars.stop_system()
|
||||
ret, msg = SystemHelper.upgrade(mode=mode or "release")
|
||||
if not ret:
|
||||
global_vars.resume_system()
|
||||
return schemas.Response(success=ret, message=msg)
|
||||
|
||||
|
||||
|
||||
@@ -352,6 +352,16 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return current_fileitem, None # 返回一个表示失败的FileItem和None
|
||||
target_dir_path = Path(target_dir_item.path)
|
||||
# 图片通常是放在当前目录 (current_fileitem) 下
|
||||
# Jellyfin/Kodi 等在季目录内使用通用图片名,而不是 season01-poster.jpg
|
||||
elif item_type == ScrapingTarget.SEASON:
|
||||
season_image_name_map = {
|
||||
ScrapingMetadata.POSTER: "poster",
|
||||
ScrapingMetadata.BANNER: "banner",
|
||||
ScrapingMetadata.THUMB: "thumb",
|
||||
}
|
||||
if season_image_name := season_image_name_map.get(metadata_type):
|
||||
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
|
||||
final_filename = f"{season_image_name}{hint_ext}"
|
||||
# 如果是 EPISODE 类型的图片(如thumb),通常也是放在文件同级目录,文件名与视频文件一致
|
||||
elif (
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
|
||||
@@ -592,6 +592,66 @@ class SearchChain(ChainBase):
|
||||
torrent_list=torrent_list,
|
||||
mediainfo=mediainfo) or []
|
||||
|
||||
def __do_site_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
"""
|
||||
执行单个站点的过滤流程
|
||||
"""
|
||||
if not torrent_list:
|
||||
return []
|
||||
|
||||
filtered_torrents = torrent_list
|
||||
if filter_params:
|
||||
torrenthelper = TorrentHelper()
|
||||
filtered_torrents = [
|
||||
torrent for torrent in filtered_torrents
|
||||
if torrenthelper.filter_torrent(torrent, filter_params)
|
||||
]
|
||||
|
||||
if rule_groups and filtered_torrents:
|
||||
filtered_torrents = __do_filter(filtered_torrents)
|
||||
|
||||
return filtered_torrents
|
||||
|
||||
def __do_parallel_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
"""
|
||||
按站点并发执行过滤,保持站点内顺序不变
|
||||
"""
|
||||
if not torrent_list or (not filter_params and not rule_groups):
|
||||
return torrent_list
|
||||
|
||||
site_torrents: Dict[Tuple[Optional[int], Optional[str]], List[TorrentInfo]] = {}
|
||||
for torrent in torrent_list:
|
||||
site_key = (torrent.site, torrent.site_name)
|
||||
if site_key not in site_torrents:
|
||||
site_torrents[site_key] = []
|
||||
site_torrents[site_key].append(torrent)
|
||||
|
||||
if len(site_torrents) <= 1:
|
||||
return __do_site_filter(torrent_list)
|
||||
|
||||
finished_count = 0
|
||||
filtered_by_site: Dict[Tuple[Optional[int], Optional[str]], List[TorrentInfo]] = {}
|
||||
max_workers = min(len(site_torrents), settings.CONF.threadpool or len(site_torrents))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
all_tasks = {
|
||||
executor.submit(__do_site_filter, site_torrent_list): site_key
|
||||
for site_key, site_torrent_list in site_torrents.items()
|
||||
}
|
||||
for future in as_completed(all_tasks):
|
||||
finished_count += 1
|
||||
filtered_by_site[all_tasks[future]] = future.result() or []
|
||||
progress.update(
|
||||
value=finished_count / len(site_torrents) * 50,
|
||||
text=f'正在过滤,已完成 {finished_count} / {len(site_torrents)} 个站点 ...'
|
||||
)
|
||||
|
||||
filtered_ids = {
|
||||
id(torrent)
|
||||
for filtered_torrents in filtered_by_site.values()
|
||||
for torrent in filtered_torrents
|
||||
}
|
||||
return [torrent for torrent in torrent_list if id(torrent) in filtered_ids]
|
||||
|
||||
if not torrents:
|
||||
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
|
||||
return []
|
||||
@@ -605,14 +665,14 @@ class SearchChain(ChainBase):
|
||||
# 匹配订阅附加参数
|
||||
if filter_params:
|
||||
logger.info(f'开始附加参数过滤,附加参数:{filter_params} ...')
|
||||
torrents = [torrent for torrent in torrents if TorrentHelper().filter_torrent(torrent, filter_params)]
|
||||
# 开始过滤规则过滤
|
||||
if rule_groups is None:
|
||||
# 取搜索过滤规则
|
||||
rule_groups: List[str] = SystemConfigOper().get(SystemConfigKey.SearchFilterRuleGroups)
|
||||
if rule_groups:
|
||||
logger.info(f'开始过滤规则/剧集过滤,使用规则组:{rule_groups} ...')
|
||||
torrents = __do_filter(torrents)
|
||||
torrents = __do_parallel_filter(torrents)
|
||||
if rule_groups:
|
||||
if not torrents:
|
||||
logger.warn(f'{keyword or mediainfo.title} 没有符合过滤规则的资源')
|
||||
return []
|
||||
|
||||
@@ -2348,7 +2348,7 @@ class SubscribeChain(ChainBase):
|
||||
)
|
||||
if subscribe.type == MediaType.TV.value:
|
||||
season_number = file_meta.begin_season
|
||||
if season_number and season_number != subscribe.season:
|
||||
if season_number is not None and season_number != subscribe.season:
|
||||
continue
|
||||
episode_number = file_meta.begin_episode
|
||||
if episode_number and episodes.get(episode_number):
|
||||
@@ -2389,7 +2389,7 @@ class SubscribeChain(ChainBase):
|
||||
)
|
||||
if subscribe.type == MediaType.TV.value:
|
||||
season_number = file_meta.begin_season
|
||||
if season_number and season_number != subscribe.season:
|
||||
if season_number is not None and season_number != subscribe.season:
|
||||
continue
|
||||
episode_number = file_meta.begin_episode
|
||||
if episode_number and episodes.get(episode_number):
|
||||
|
||||
@@ -20,7 +20,7 @@ from app.core.event import eventmanager
|
||||
from app.core.meta import MetaBase
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
@@ -1686,7 +1686,102 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_shared_download_roots(file_path: Path) -> set[str]:
|
||||
"""
|
||||
获取当前文件所在的共享下载根目录边界。
|
||||
|
||||
父目录兜底回查只应在种子自身目录内进行,不能越过共享下载根目录,
|
||||
否则历史中的单文件/无子目录任务会污染同级其它文件的识别结果。
|
||||
"""
|
||||
shared_roots: set[str] = set()
|
||||
media_type_dirs = {mtype.value for mtype in MediaType}
|
||||
|
||||
for dir_info in DirectoryHelper().get_download_dirs():
|
||||
if not dir_info.download_path:
|
||||
continue
|
||||
|
||||
download_root = Path(dir_info.download_path)
|
||||
if not file_path.is_relative_to(download_root):
|
||||
continue
|
||||
|
||||
shared_roots.add(download_root.as_posix())
|
||||
relative_parts = file_path.relative_to(download_root).parts
|
||||
current_root = download_root
|
||||
part_index = 0
|
||||
|
||||
if (
|
||||
not dir_info.media_type
|
||||
and dir_info.download_type_folder
|
||||
and len(relative_parts) > part_index
|
||||
and relative_parts[part_index] in media_type_dirs
|
||||
):
|
||||
current_root = current_root / relative_parts[part_index]
|
||||
shared_roots.add(current_root.as_posix())
|
||||
part_index += 1
|
||||
|
||||
if (
|
||||
not dir_info.media_category
|
||||
and dir_info.download_category_folder
|
||||
and len(relative_parts) > part_index
|
||||
):
|
||||
current_root = current_root / relative_parts[part_index]
|
||||
shared_roots.add(current_root.as_posix())
|
||||
|
||||
return shared_roots
|
||||
|
||||
@staticmethod
|
||||
def _match_download_file(
|
||||
download_file: DownloadFiles,
|
||||
file_path: Path,
|
||||
save_path: Path,
|
||||
) -> bool:
|
||||
"""
|
||||
判断下载文件记录是否明确对应当前文件。
|
||||
"""
|
||||
if download_file.fullpath == file_path.as_posix():
|
||||
return True
|
||||
|
||||
filepath = download_file.filepath
|
||||
if not filepath:
|
||||
return False
|
||||
|
||||
try:
|
||||
return (save_path / Path(filepath)).as_posix() == file_path.as_posix()
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
def _resolve_history_from_download_files(
|
||||
self,
|
||||
downloadhis: DownloadHistoryOper,
|
||||
download_files: List[DownloadFiles],
|
||||
file_path: Optional[Path] = None,
|
||||
save_path: Optional[Path] = None,
|
||||
) -> Optional[DownloadHistory]:
|
||||
"""
|
||||
从下载文件记录中解析唯一的下载历史。
|
||||
"""
|
||||
if file_path and save_path:
|
||||
download_files = [
|
||||
download_file
|
||||
for download_file in download_files
|
||||
if self._match_download_file(
|
||||
download_file=download_file,
|
||||
file_path=file_path,
|
||||
save_path=save_path,
|
||||
)
|
||||
]
|
||||
|
||||
download_hashes = {
|
||||
download_file.download_hash
|
||||
for download_file in download_files
|
||||
if download_file.download_hash
|
||||
}
|
||||
if len(download_hashes) == 1:
|
||||
return downloadhis.get_by_hash(next(iter(download_hashes)))
|
||||
return None
|
||||
|
||||
def _resolve_download_history(
|
||||
self,
|
||||
downloadhis: DownloadHistoryOper,
|
||||
file_path: Path,
|
||||
bluray_dir: bool = False,
|
||||
@@ -1707,20 +1802,35 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 多文件种子里的字幕/附加文件可能没有稳定的 fullpath 记录,
|
||||
# 退回到父目录和 savepath 继续查找,尽量补齐同一种子的关联信息。
|
||||
shared_download_roots = self._get_shared_download_roots(file_path)
|
||||
|
||||
for parent_path in file_path.parents:
|
||||
parent_posix = parent_path.as_posix()
|
||||
download_files = downloadhis.get_files_by_savepath(parent_posix) or []
|
||||
|
||||
if parent_posix in shared_download_roots:
|
||||
# 共享下载根目录只能接受有明确文件记录的匹配,
|
||||
# 避免单文件/磁力任务把整个根目录污染成同一媒体。
|
||||
history = self._resolve_history_from_download_files(
|
||||
downloadhis=downloadhis,
|
||||
download_files=download_files,
|
||||
file_path=file_path,
|
||||
save_path=parent_path,
|
||||
)
|
||||
if history:
|
||||
return history
|
||||
break
|
||||
|
||||
download_history = downloadhis.get_by_path(parent_posix)
|
||||
if download_history:
|
||||
return download_history
|
||||
|
||||
download_files = downloadhis.get_files_by_savepath(parent_posix) or []
|
||||
download_hashes = {
|
||||
download_file.download_hash
|
||||
for download_file in download_files
|
||||
if download_file.download_hash
|
||||
}
|
||||
if len(download_hashes) == 1:
|
||||
return downloadhis.get_by_hash(next(iter(download_hashes)))
|
||||
history = self._resolve_history_from_download_files(
|
||||
downloadhis=downloadhis,
|
||||
download_files=download_files,
|
||||
)
|
||||
if history:
|
||||
return history
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import click
|
||||
import psutil
|
||||
|
||||
from app.core.config import Settings, settings
|
||||
from app.helper.system import SystemHelper
|
||||
from version import APP_VERSION
|
||||
|
||||
BACKEND_RUNTIME_FILE = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
@@ -272,7 +273,10 @@ def _git_current_branch() -> Optional[str]:
|
||||
|
||||
|
||||
def _auto_update_mode() -> str:
|
||||
return str(getattr(settings, "MOVIEPILOT_AUTO_UPDATE", "") or "").strip().lower()
|
||||
one_shot_mode = SystemHelper.consume_one_shot_update_mode()
|
||||
if one_shot_mode:
|
||||
return one_shot_mode
|
||||
return SystemHelper.get_auto_update_mode()
|
||||
|
||||
|
||||
def _resolve_auto_update_targets(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
from asyncio import AbstractEventLoop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import quote, urlencode, urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
@@ -126,8 +126,8 @@ class ConfigModel(BaseModel):
|
||||
DB_SQLITE_MAX_OVERFLOW: int = 50
|
||||
# PostgreSQL 主机地址
|
||||
DB_POSTGRESQL_HOST: str = "localhost"
|
||||
# PostgreSQL 端口
|
||||
DB_POSTGRESQL_PORT: int = 5432
|
||||
# PostgreSQL 端口;使用 Unix Socket 时可留空
|
||||
DB_POSTGRESQL_PORT: str = "5432"
|
||||
# PostgreSQL 数据库名
|
||||
DB_POSTGRESQL_DATABASE: str = "moviepilot"
|
||||
# PostgreSQL 用户名
|
||||
@@ -142,7 +142,7 @@ class ConfigModel(BaseModel):
|
||||
# ==================== 缓存配置 ====================
|
||||
# 缓存类型,支持 cachetools 和 redis,默认使用 cachetools
|
||||
CACHE_BACKEND_TYPE: str = "cachetools"
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要,支持 Redis Unix Socket URL
|
||||
CACHE_BACKEND_URL: Optional[str] = "redis://localhost:6379"
|
||||
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
|
||||
CACHE_REDIS_MAXMEMORY: Optional[str] = None
|
||||
@@ -921,6 +921,39 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
}
|
||||
return None
|
||||
|
||||
@property
|
||||
def DB_POSTGRESQL_SOCKET_MODE(self) -> bool:
|
||||
host = (self.DB_POSTGRESQL_HOST or "").strip()
|
||||
return host.startswith("/")
|
||||
|
||||
@property
|
||||
def DB_POSTGRESQL_TARGET(self) -> str:
|
||||
if self.DB_POSTGRESQL_SOCKET_MODE:
|
||||
target = f"socket {self.DB_POSTGRESQL_HOST}"
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
target = f"{target} (port {self.DB_POSTGRESQL_PORT})"
|
||||
return target
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
return f"{self.DB_POSTGRESQL_HOST}:{self.DB_POSTGRESQL_PORT}"
|
||||
return self.DB_POSTGRESQL_HOST
|
||||
|
||||
def DB_POSTGRESQL_URL(self, driver: Optional[str] = None) -> str:
|
||||
scheme = "postgresql" if not driver else f"postgresql+{driver}"
|
||||
username = quote(str(self.DB_POSTGRESQL_USERNAME), safe="")
|
||||
database = quote(str(self.DB_POSTGRESQL_DATABASE), safe="")
|
||||
auth = username
|
||||
if self.DB_POSTGRESQL_PASSWORD:
|
||||
auth = f"{auth}:{quote(str(self.DB_POSTGRESQL_PASSWORD), safe='')}"
|
||||
|
||||
if self.DB_POSTGRESQL_SOCKET_MODE:
|
||||
query = {"host": self.DB_POSTGRESQL_HOST}
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
query["port"] = self.DB_POSTGRESQL_PORT
|
||||
return f"{scheme}://{auth}@/{database}?{urlencode(query)}"
|
||||
|
||||
port = f":{self.DB_POSTGRESQL_PORT}" if self.DB_POSTGRESQL_PORT else ""
|
||||
return f"{scheme}://{auth}@{self.DB_POSTGRESQL_HOST}{port}/{database}"
|
||||
|
||||
@property
|
||||
def PROXY_SERVER(self):
|
||||
if self.PROXY_HOST:
|
||||
@@ -1066,6 +1099,12 @@ class GlobalVar(object):
|
||||
"""
|
||||
self.STOP_EVENT.set()
|
||||
|
||||
def resume_system(self):
|
||||
"""
|
||||
恢复系统运行标记。
|
||||
"""
|
||||
self.STOP_EVENT.clear()
|
||||
|
||||
@property
|
||||
def is_system_stopped(self):
|
||||
"""
|
||||
|
||||
@@ -116,11 +116,7 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
"""
|
||||
获取PostgreSQL数据库引擎
|
||||
"""
|
||||
# 构建PostgreSQL连接URL
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
db_url = settings.DB_POSTGRESQL_URL()
|
||||
|
||||
# PostgreSQL连接参数
|
||||
_connect_args = {}
|
||||
@@ -150,12 +146,11 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(**_db_kwargs)
|
||||
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return engine
|
||||
else:
|
||||
# 构建异步PostgreSQL连接URL
|
||||
async_db_url = f"postgresql+asyncpg://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
async_db_url = settings.DB_POSTGRESQL_URL("asyncpg")
|
||||
|
||||
# 数据库参数,只能使用 NullPool
|
||||
_db_kwargs = {
|
||||
@@ -168,7 +163,7 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
}
|
||||
# 创建异步数据库引擎
|
||||
async_engine = create_async_engine(**_db_kwargs)
|
||||
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
@@ -28,10 +28,7 @@ def update_db():
|
||||
|
||||
# 根据数据库类型设置不同的URL
|
||||
if settings.DB_TYPE.lower() == "postgresql":
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
db_url = settings.DB_POSTGRESQL_URL()
|
||||
else:
|
||||
db_location = settings.CONFIG_PATH / 'user.db'
|
||||
db_url = f"sqlite:///{db_location}"
|
||||
|
||||
@@ -21,6 +21,7 @@ class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
系统工具类,提供系统相关的操作和判断
|
||||
"""
|
||||
AUTO_UPDATE_ENABLED_VALUES = {"release", "dev"}
|
||||
CONFIG_WATCH = {
|
||||
"DEBUG",
|
||||
"LOG_LEVEL",
|
||||
@@ -33,6 +34,7 @@ class SystemHelper(ConfigReloadMixin):
|
||||
__system_flag_file = "/var/log/nginx/__moviepilot__"
|
||||
__local_backend_runtime_file = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
__local_restart_log_file = settings.LOG_PATH / "moviepilot.restart.stdout.log"
|
||||
__one_shot_update_flag_file = settings.TEMP_PATH / "moviepilot.pending_update"
|
||||
|
||||
def on_config_changed(self):
|
||||
logger.update_loggers()
|
||||
@@ -85,6 +87,96 @@ class SystemHelper(ConfigReloadMixin):
|
||||
except (psutil.Error, TypeError, ValueError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def normalize_auto_update_mode(mode: Optional[str]) -> str:
|
||||
"""
|
||||
统一自动升级模式值,兼容历史 true 表示 release。
|
||||
"""
|
||||
normalized = str(mode or "").strip().lower()
|
||||
return "release" if normalized == "true" else normalized
|
||||
|
||||
@staticmethod
|
||||
def get_auto_update_mode() -> str:
|
||||
"""
|
||||
获取当前配置中的自动升级模式。
|
||||
"""
|
||||
return SystemHelper.normalize_auto_update_mode(
|
||||
settings.MOVIEPILOT_AUTO_UPDATE
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_auto_update_enabled(mode: Optional[str] = None) -> bool:
|
||||
"""
|
||||
判断给定模式或当前配置是否启用了启动时自动升级。
|
||||
"""
|
||||
effective_mode = (
|
||||
SystemHelper.get_auto_update_mode()
|
||||
if mode is None
|
||||
else SystemHelper.normalize_auto_update_mode(mode)
|
||||
)
|
||||
return effective_mode in SystemHelper.AUTO_UPDATE_ENABLED_VALUES
|
||||
|
||||
@staticmethod
|
||||
def queue_one_shot_update(mode: str = "release") -> Tuple[bool, str]:
|
||||
"""
|
||||
写入一次性升级标记,供重启后的启动流程消费。
|
||||
"""
|
||||
effective_mode = SystemHelper.normalize_auto_update_mode(mode)
|
||||
if effective_mode not in SystemHelper.AUTO_UPDATE_ENABLED_VALUES:
|
||||
return False, "升级模式仅支持 release 或 dev"
|
||||
|
||||
try:
|
||||
SystemHelper.__one_shot_update_flag_file.parent.mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
SystemHelper.__one_shot_update_flag_file.write_text(
|
||||
effective_mode, encoding="utf-8"
|
||||
)
|
||||
logger.info(f"已写入一次性升级标记,模式: {effective_mode}")
|
||||
return True, ""
|
||||
except OSError as err:
|
||||
logger.error(f"写入一次性升级标记失败: {err}")
|
||||
return False, f"写入一次性升级标记失败:{err}"
|
||||
|
||||
@staticmethod
|
||||
def consume_one_shot_update_mode() -> Optional[str]:
|
||||
"""
|
||||
读取并清除一次性升级标记,避免后续启动重复执行。
|
||||
"""
|
||||
path = SystemHelper.__one_shot_update_flag_file
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
raw_mode = path.read_text(encoding="utf-8")
|
||||
except OSError as err:
|
||||
logger.warning(f"读取一次性升级标记失败: {err}")
|
||||
raw_mode = ""
|
||||
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError as err:
|
||||
logger.warning(f"删除一次性升级标记失败: {err}")
|
||||
|
||||
effective_mode = SystemHelper.normalize_auto_update_mode(raw_mode)
|
||||
if effective_mode not in SystemHelper.AUTO_UPDATE_ENABLED_VALUES:
|
||||
if raw_mode:
|
||||
logger.warning(f"忽略无效的一次性升级模式: {raw_mode}")
|
||||
return None
|
||||
|
||||
logger.info(f"检测到一次性升级标记,模式: {effective_mode}")
|
||||
return effective_mode
|
||||
|
||||
@staticmethod
|
||||
def clear_one_shot_update_flag() -> None:
|
||||
"""
|
||||
删除一次性升级标记。
|
||||
"""
|
||||
try:
|
||||
SystemHelper.__one_shot_update_flag_file.unlink(missing_ok=True)
|
||||
except OSError as err:
|
||||
logger.warning(f"删除一次性升级标记失败: {err}")
|
||||
|
||||
@staticmethod
|
||||
def _spawn_local_restart_helper() -> None:
|
||||
helper_code = (
|
||||
@@ -178,6 +270,8 @@ class SystemHelper(ConfigReloadMixin):
|
||||
return False, "当前实例不是由 moviepilot CLI 启动,无法执行内建重启!"
|
||||
try:
|
||||
SystemHelper._spawn_local_restart_helper()
|
||||
# 复用与 Docker 相同的优雅退出路径,确保当前后端进程真正结束。
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
return True, ""
|
||||
except Exception as err:
|
||||
logger.error(f"本地 CLI 重启失败: {str(err)}")
|
||||
@@ -204,6 +298,34 @@ class SystemHelper(ConfigReloadMixin):
|
||||
logger.warning("降级为Docker API重启...")
|
||||
return SystemHelper._docker_api_restart()
|
||||
|
||||
@staticmethod
|
||||
def upgrade(mode: str = "release") -> Tuple[bool, str]:
|
||||
"""
|
||||
触发升级并重启。
|
||||
|
||||
- 已开启自动升级时,直接重启,沿用当前配置。
|
||||
- 未开启自动升级时,写入一次性升级标记,供下次启动时执行升级。
|
||||
"""
|
||||
current_mode = SystemHelper.get_auto_update_mode()
|
||||
if SystemHelper.is_auto_update_enabled(current_mode):
|
||||
ret, msg = SystemHelper.restart()
|
||||
if not ret:
|
||||
return ret, msg
|
||||
if current_mode == "dev":
|
||||
return True, "已检测到自动升级模式 dev,正在重启并执行升级"
|
||||
return True, "已检测到自动升级已开启,正在重启并执行升级"
|
||||
|
||||
queued, message = SystemHelper.queue_one_shot_update(mode)
|
||||
if not queued:
|
||||
return False, message
|
||||
|
||||
ret, msg = SystemHelper.restart()
|
||||
if not ret:
|
||||
SystemHelper.clear_one_shot_update_flag()
|
||||
return ret, msg
|
||||
effective_mode = SystemHelper.normalize_auto_update_mode(mode)
|
||||
return True, f"已安排一次性 {effective_mode} 升级并重启"
|
||||
|
||||
@staticmethod
|
||||
def _start_graceful_shutdown_monitor():
|
||||
"""
|
||||
@@ -212,8 +334,8 @@ class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
|
||||
def monitor_thread():
|
||||
time.sleep(30) # 等待30秒
|
||||
logger.warning("优雅退出超时30秒,使用Docker API强制重启...")
|
||||
time.sleep(180) # 等待180秒
|
||||
logger.warning("优雅退出超时180秒,使用Docker API强制重启...")
|
||||
try:
|
||||
SystemHelper._docker_api_restart()
|
||||
except Exception as e:
|
||||
|
||||
@@ -176,86 +176,101 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
if item:
|
||||
return [item]
|
||||
return []
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/list"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
"password": password,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"refresh": refresh,
|
||||
},
|
||||
)
|
||||
"""
|
||||
{
|
||||
"path": "/t",
|
||||
"password": "",
|
||||
"page": 1,
|
||||
"per_page": 0,
|
||||
"refresh": false
|
||||
}
|
||||
======================================
|
||||
{
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": "Alist V3.md",
|
||||
"size": 1592,
|
||||
"is_dir": false,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"created": "2024-05-17T13:47:47.5725906+08:00",
|
||||
"sign": "",
|
||||
"thumb": "",
|
||||
"type": 4,
|
||||
"hashinfo": "null",
|
||||
"hash_info": null
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"readme": "",
|
||||
"header": "",
|
||||
"write": true,
|
||||
"provider": "Local"
|
||||
items = []
|
||||
current_page = page
|
||||
while True:
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/list"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
"password": password,
|
||||
"page": current_page,
|
||||
"per_page": per_page,
|
||||
"refresh": refresh,
|
||||
},
|
||||
)
|
||||
"""
|
||||
{
|
||||
"path": "/t",
|
||||
"password": "",
|
||||
"page": 1,
|
||||
"per_page": 0,
|
||||
"refresh": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
======================================
|
||||
{
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": "Alist V3.md",
|
||||
"size": 1592,
|
||||
"is_dir": false,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"created": "2024-05-17T13:47:47.5725906+08:00",
|
||||
"sign": "",
|
||||
"thumb": "",
|
||||
"type": 4,
|
||||
"hashinfo": "null",
|
||||
"hash_info": null
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"readme": "",
|
||||
"header": "",
|
||||
"write": true,
|
||||
"provider": "Local"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
if resp is None:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务"
|
||||
)
|
||||
return []
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return []
|
||||
if resp is None:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务"
|
||||
)
|
||||
return []
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return []
|
||||
|
||||
result = resp.json()
|
||||
result = resp.json()
|
||||
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
|
||||
)
|
||||
return []
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
|
||||
)
|
||||
return []
|
||||
|
||||
return [
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir" if item["is_dir"] else "file",
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix()
|
||||
+ ("/" if item["is_dir"] else ""),
|
||||
name=item["name"],
|
||||
basename=Path(item["name"]).stem,
|
||||
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
|
||||
size=item["size"] if not item["is_dir"] else None,
|
||||
modify_time=self.__parse_timestamp(item["modified"]),
|
||||
thumbnail=item["thumb"],
|
||||
page_content = result["data"].get("content") or []
|
||||
items.extend(
|
||||
[
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir" if item["is_dir"] else "file",
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix()
|
||||
+ ("/" if item["is_dir"] else ""),
|
||||
name=item["name"],
|
||||
basename=Path(item["name"]).stem,
|
||||
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
|
||||
size=item["size"] if not item["is_dir"] else None,
|
||||
modify_time=self.__parse_timestamp(item["modified"]),
|
||||
thumbnail=item["thumb"],
|
||||
)
|
||||
for item in page_content
|
||||
]
|
||||
)
|
||||
for item in result["data"]["content"] or []
|
||||
]
|
||||
|
||||
if per_page > 0:
|
||||
return items
|
||||
|
||||
total = result["data"].get("total") or 0
|
||||
if not page_content or len(items) >= total:
|
||||
return items
|
||||
|
||||
current_page += 1
|
||||
|
||||
def create_folder(
|
||||
self, fileitem: schemas.FileItem, name: str
|
||||
|
||||
@@ -15,10 +15,6 @@ from app.utils.string import StringUtils
|
||||
|
||||
class FilterModule(_ModuleBase):
|
||||
CONFIG_WATCH = {SystemConfigKey.CustomFilterRules.value}
|
||||
# 规则解析器
|
||||
parser: RuleParser = None
|
||||
# 媒体信息
|
||||
media: MediaInfo = None
|
||||
|
||||
# 保留一份只读内置规则定义,方便查询工具准确区分“内置规则”和“自定义规则”。
|
||||
builtin_rule_set: Dict[str, dict] = deepcopy(BUILTIN_RULE_SET)
|
||||
@@ -30,7 +26,6 @@ class FilterModule(_ModuleBase):
|
||||
self.rulehelper = RuleHelper()
|
||||
|
||||
def init_module(self) -> None:
|
||||
self.parser = RuleParser()
|
||||
# 每次重载都先恢复为纯内置规则,避免旧的自定义规则残留在内存里。
|
||||
self.rule_set = deepcopy(self.builtin_rule_set)
|
||||
self.__init_custom_rules()
|
||||
@@ -90,7 +85,7 @@ class FilterModule(_ModuleBase):
|
||||
"""
|
||||
if not rule_groups:
|
||||
return torrent_list
|
||||
self.media = mediainfo
|
||||
parser = RuleParser()
|
||||
# 查询规则表详情
|
||||
groups = self.rulehelper.get_rule_group_by_media(media=mediainfo, group_names=rule_groups)
|
||||
if groups:
|
||||
@@ -99,12 +94,16 @@ class FilterModule(_ModuleBase):
|
||||
torrent_list = self.__filter_torrents(
|
||||
rule_string=group.rule_string,
|
||||
rule_name=group.name,
|
||||
torrent_list=torrent_list
|
||||
torrent_list=torrent_list,
|
||||
mediainfo=mediainfo,
|
||||
parser=parser,
|
||||
)
|
||||
return torrent_list
|
||||
|
||||
def __filter_torrents(self, rule_string: str, rule_name: str,
|
||||
torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
torrent_list: List[TorrentInfo],
|
||||
mediainfo: MediaInfo,
|
||||
parser: RuleParser) -> List[TorrentInfo]:
|
||||
"""
|
||||
过滤种子
|
||||
"""
|
||||
@@ -112,7 +111,7 @@ class FilterModule(_ModuleBase):
|
||||
ret_torrents = []
|
||||
for torrent in torrent_list:
|
||||
# 能命中优先级的才返回
|
||||
if not self.__get_order(torrent, rule_string):
|
||||
if not self.__get_order(torrent, rule_string, mediainfo, parser):
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} {torrent.description or ''} "
|
||||
f"不匹配 {rule_name} 过滤规则")
|
||||
continue
|
||||
@@ -120,7 +119,8 @@ class FilterModule(_ModuleBase):
|
||||
|
||||
return ret_torrents
|
||||
|
||||
def __get_order(self, torrent: TorrentInfo, rule_str: str) -> Optional[TorrentInfo]:
|
||||
def __get_order(self, torrent: TorrentInfo, rule_str: str,
|
||||
mediainfo: MediaInfo, parser: RuleParser) -> Optional[TorrentInfo]:
|
||||
"""
|
||||
获取种子匹配的规则优先级,值越大越优先,未匹配时返回None
|
||||
"""
|
||||
@@ -133,8 +133,8 @@ class FilterModule(_ModuleBase):
|
||||
|
||||
for rule_group in rule_groups:
|
||||
# 解析规则组
|
||||
parsed_group = self.parser.parse(rule_group.strip())
|
||||
if self.__match_group(torrent, parsed_group.as_list()[0]):
|
||||
parsed_group = parser.parse(rule_group.strip())
|
||||
if self.__match_group(torrent, parsed_group.as_list()[0], mediainfo):
|
||||
# 出现匹配时中断
|
||||
matched = True
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} 优先级为 {100 - res_order + 1}")
|
||||
@@ -145,27 +145,31 @@ class FilterModule(_ModuleBase):
|
||||
|
||||
return None if not matched else torrent
|
||||
|
||||
def __match_group(self, torrent: TorrentInfo, rule_group: Union[list, str]) -> Optional[bool]:
|
||||
def __match_group(self, torrent: TorrentInfo, rule_group: Union[list, str],
|
||||
mediainfo: MediaInfo) -> Optional[bool]:
|
||||
"""
|
||||
判断种子是否匹配规则组
|
||||
"""
|
||||
if not isinstance(rule_group, list):
|
||||
# 不是列表,说明是规则名称
|
||||
return self.__match_rule(torrent, rule_group)
|
||||
return self.__match_rule(torrent, rule_group, mediainfo)
|
||||
elif isinstance(rule_group, list) and len(rule_group) == 1:
|
||||
# 只有一个规则项
|
||||
return self.__match_group(torrent, rule_group[0])
|
||||
return self.__match_group(torrent, rule_group[0], mediainfo)
|
||||
elif rule_group[0] == "not":
|
||||
# 非操作
|
||||
return not self.__match_group(torrent, rule_group[1:])
|
||||
return not self.__match_group(torrent, rule_group[1:], mediainfo)
|
||||
elif rule_group[1] == "and":
|
||||
# 与操作
|
||||
return self.__match_group(torrent, rule_group[0]) and self.__match_group(torrent, rule_group[2:])
|
||||
return self.__match_group(torrent, rule_group[0], mediainfo) \
|
||||
and self.__match_group(torrent, rule_group[2:], mediainfo)
|
||||
elif rule_group[1] == "or":
|
||||
# 或操作
|
||||
return self.__match_group(torrent, rule_group[0]) or self.__match_group(torrent, rule_group[2:])
|
||||
return self.__match_group(torrent, rule_group[0], mediainfo) \
|
||||
or self.__match_group(torrent, rule_group[2:], mediainfo)
|
||||
|
||||
def __match_rule(self, torrent: TorrentInfo, rule_name: str) -> bool:
|
||||
def __match_rule(self, torrent: TorrentInfo, rule_name: str,
|
||||
mediainfo: MediaInfo) -> bool:
|
||||
"""
|
||||
判断种子是否匹配规则项
|
||||
"""
|
||||
@@ -176,7 +180,7 @@ class FilterModule(_ModuleBase):
|
||||
# TMDB规则
|
||||
tmdb = self.rule_set[rule_name].get("tmdb")
|
||||
# 符合TMDB规则的直接返回True,即不过滤
|
||||
if tmdb and self.__match_tmdb(tmdb):
|
||||
if tmdb and self.__match_tmdb(tmdb, mediainfo):
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} 符合 {rule_name} 的TMDB规则,匹配成功")
|
||||
return True
|
||||
# 匹配项:标题、副标题、标签
|
||||
@@ -259,18 +263,19 @@ class FilterModule(_ModuleBase):
|
||||
|
||||
return True
|
||||
|
||||
def __match_tmdb(self, tmdb: dict) -> bool:
|
||||
@staticmethod
|
||||
def __match_tmdb(tmdb: dict, mediainfo: MediaInfo) -> bool:
|
||||
"""
|
||||
判断种子是否匹配TMDB规则
|
||||
"""
|
||||
|
||||
def __get_media_value(key: str):
|
||||
try:
|
||||
return getattr(self.media, key)
|
||||
return getattr(mediainfo, key)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if not self.media:
|
||||
if not mediainfo:
|
||||
return False
|
||||
|
||||
for attr, value in tmdb.items():
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import posixpath
|
||||
from datetime import datetime
|
||||
from typing import List, Union, Optional, Dict, Generator, Tuple, Any
|
||||
|
||||
@@ -123,7 +124,12 @@ class Jellyfin:
|
||||
user = self.get_user(username)
|
||||
else:
|
||||
user = self.user
|
||||
url = f"{self._host}Users/{user}/Views"
|
||||
if not user:
|
||||
return []
|
||||
# 使用标准库路径拼接结合统一 URL 规整,避免 host 尾部斜杠缺失导致的寻址偏移。
|
||||
url = UrlUtils.combine_url(self._host, posixpath.join("Users", str(user), "Views"))
|
||||
if not url:
|
||||
return []
|
||||
params = {"api_key": self._apikey}
|
||||
try:
|
||||
res = RequestUtils().get_res(url, params)
|
||||
@@ -213,10 +219,37 @@ class Jellyfin:
|
||||
for user in users:
|
||||
if user.get("Name") == user_name:
|
||||
return user.get("Id")
|
||||
# 查询管理员
|
||||
if user_name == settings.SUPERUSER:
|
||||
logger.warning(
|
||||
"MoviePilot 当前配置的超级管理员用户名为 {},请确保Jellyfin中存在同名管理员账号,否则可能无法正常使用部分功能!".format(settings.SUPERUSER)
|
||||
)
|
||||
# 查询管理员,优先选择同时具备全库访问能力的账号,再回退到普通管理员。
|
||||
# 获取总媒体库数量
|
||||
total_library_count = len(self.get_jellyfin_folders())
|
||||
best_admin_id = None
|
||||
best_admin_name = None
|
||||
best_admin_library_count = -1
|
||||
for user in users:
|
||||
if user.get("Policy", {}).get("IsAdministrator"):
|
||||
policy = user.get("Policy") or {}
|
||||
if not policy.get("IsAdministrator"):
|
||||
continue
|
||||
if policy.get("EnableAllFolders"):
|
||||
return user.get("Id")
|
||||
else:
|
||||
enabled_folders = policy.get('EnabledFolders') or []
|
||||
current_count = len(enabled_folders)
|
||||
# 更新最佳管理员
|
||||
if best_admin_id is None or current_count > best_admin_library_count:
|
||||
best_admin_id = user.get("Id")
|
||||
best_admin_name = user.get("Name")
|
||||
best_admin_library_count = current_count
|
||||
if best_admin_id is None:
|
||||
logger.warning("未找到可用的管理员账号,无法获取管理员用户,请检查Jellyfin用户及权限配置!")
|
||||
return None
|
||||
logger.warning(
|
||||
f"未找到具备全库访问权限的管理员账号,回退使用仅可访问{best_admin_library_count}/{total_library_count}个媒体库的管理员账号{best_admin_name}!"
|
||||
)
|
||||
return best_admin_id
|
||||
else:
|
||||
logger.error(f"Users 未获取到返回数据")
|
||||
except Exception as e:
|
||||
|
||||
@@ -148,7 +148,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
# 如果要选择文件则先暂停
|
||||
is_paused = True if episodes else False
|
||||
# 添加任务
|
||||
state = server.add_torrent(
|
||||
state, added_torrent_ids = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
@@ -188,7 +188,11 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
return None, None, None, f"添加种子任务失败:{content}"
|
||||
else:
|
||||
# 获取种子Hash
|
||||
torrent_hash = server.get_torrent_id_by_tag(tags=tag)
|
||||
torrent_hash = next(iter(added_torrent_ids), None)
|
||||
if torrent_hash:
|
||||
server.delete_torrents_tag(torrent_hash, tag)
|
||||
else:
|
||||
torrent_hash = server.get_torrent_id_by_tag(tags=tag)
|
||||
if not torrent_hash:
|
||||
return None, None, None, f"下载任务添加成功,但获取Qbittorrent任务信息失败:{content}"
|
||||
else:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from http.cookies import SimpleCookie
|
||||
from typing import Any, Optional, Union, Tuple, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import qbittorrentapi
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from qbittorrentapi import TorrentDictionary, TorrentFilesList
|
||||
from qbittorrentapi.client import Client
|
||||
from qbittorrentapi.transfer import TransferInfoDictionary
|
||||
@@ -17,6 +20,7 @@ class Qbittorrent:
|
||||
"""
|
||||
def __init__(self, host: Optional[str] = None, port: int = None,
|
||||
username: Optional[str] = None, password: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
category: Optional[bool] = False, sequentail: Optional[bool] = False,
|
||||
force_resume: Optional[bool] = False, first_last_piece=False,
|
||||
**kwargs):
|
||||
@@ -33,12 +37,122 @@ class Qbittorrent:
|
||||
return
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._apikey = str(apikey or "").strip() or None
|
||||
self._category = category
|
||||
self._sequentail = sequentail
|
||||
self._force_resume = force_resume
|
||||
self._first_last_piece = first_last_piece
|
||||
self.qbc = self.__login_qbittorrent()
|
||||
|
||||
@staticmethod
|
||||
def __get_mapping_value(data: Any, key: str) -> Any:
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, dict):
|
||||
return data.get(key)
|
||||
getter = getattr(data, "get", None)
|
||||
if callable(getter):
|
||||
try:
|
||||
return getter(key)
|
||||
except Exception:
|
||||
pass
|
||||
return getattr(data, key, None)
|
||||
|
||||
def __normalize_cookie(self, cookie: Any) -> dict:
|
||||
result = {}
|
||||
for key in ("domain", "path", "name", "value", "expirationDate"):
|
||||
value = self.__get_mapping_value(cookie, key)
|
||||
if value not in (None, ""):
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def __cookie_key(cookie: dict) -> Optional[tuple]:
|
||||
name = cookie.get("name")
|
||||
domain = cookie.get("domain")
|
||||
path = cookie.get("path") or "/"
|
||||
if not name or not domain:
|
||||
return None
|
||||
return domain, path, name
|
||||
|
||||
@staticmethod
|
||||
def __build_site_cookies(url: str, cookie_header: str) -> List[dict]:
|
||||
domain = urlparse(url).hostname
|
||||
if not domain:
|
||||
return []
|
||||
|
||||
raw_cookies = SimpleCookie()
|
||||
raw_cookies.load(cookie_header)
|
||||
return [
|
||||
{
|
||||
"domain": domain,
|
||||
"path": "/",
|
||||
"name": morsel.key,
|
||||
"value": morsel.value,
|
||||
}
|
||||
for morsel in raw_cookies.values()
|
||||
]
|
||||
|
||||
def __parse_add_torrent_response(self, response: Any) -> Tuple[bool, List[str]]:
|
||||
if not response:
|
||||
return False, []
|
||||
if isinstance(response, str):
|
||||
return "Ok" in response, []
|
||||
|
||||
success_count = self.__get_mapping_value(response, "success_count") or 0
|
||||
pending_count = self.__get_mapping_value(response, "pending_count") or 0
|
||||
added_torrent_ids = self.__get_mapping_value(response, "added_torrent_ids") or []
|
||||
if not isinstance(added_torrent_ids, list):
|
||||
added_torrent_ids = list(added_torrent_ids)
|
||||
added_torrent_ids = [str(torrent_id) for torrent_id in added_torrent_ids if torrent_id]
|
||||
if added_torrent_ids:
|
||||
return True, added_torrent_ids
|
||||
if success_count or pending_count:
|
||||
return True, []
|
||||
return "Ok" in str(response), []
|
||||
|
||||
def __use_api_key_auth(self) -> bool:
|
||||
return bool(self._apikey)
|
||||
|
||||
def __supports_cookie_api(self) -> bool:
|
||||
if not self.qbc:
|
||||
return False
|
||||
try:
|
||||
web_api_version = self.qbc.app_web_api_version()
|
||||
return Version(str(web_api_version)) >= Version("2.11.3")
|
||||
except (InvalidVersion, TypeError, ValueError):
|
||||
return False
|
||||
except Exception as err:
|
||||
logger.warn(f"获取 qbittorrent Web API 版本失败,跳过 Cookie API 兼容:{err}")
|
||||
return False
|
||||
|
||||
def __sync_download_cookies(self, url: str, cookie_header: str) -> bool:
|
||||
if not self.qbc or not url or not cookie_header or not self.__supports_cookie_api():
|
||||
return False
|
||||
|
||||
try:
|
||||
site_cookies = self.__build_site_cookies(url=url, cookie_header=cookie_header)
|
||||
if not site_cookies:
|
||||
return False
|
||||
|
||||
merged_cookies = {}
|
||||
for cookie in self.qbc.app_cookies() or []:
|
||||
normalized = self.__normalize_cookie(cookie)
|
||||
cookie_key = self.__cookie_key(normalized)
|
||||
if cookie_key:
|
||||
merged_cookies[cookie_key] = normalized
|
||||
|
||||
for cookie in site_cookies:
|
||||
cookie_key = self.__cookie_key(cookie)
|
||||
if cookie_key:
|
||||
merged_cookies[cookie_key] = cookie
|
||||
|
||||
self.qbc.app_set_cookies(cookies=list(merged_cookies.values()))
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"同步下载Cookie出错:{str(err)}")
|
||||
return False
|
||||
|
||||
def is_inactive(self) -> bool:
|
||||
"""
|
||||
判断是否需要重连
|
||||
@@ -67,14 +181,20 @@ class Qbittorrent:
|
||||
port=self._port,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
EXTRA_HEADERS={"Authorization": f"Bearer {self._apikey}"}
|
||||
if self.__use_api_key_auth() else None,
|
||||
VERIFY_WEBUI_CERTIFICATE=False,
|
||||
REQUESTS_ARGS={'timeout': (15, 60)})
|
||||
try:
|
||||
qbt.auth_log_in()
|
||||
except (qbittorrentapi.LoginFailed, qbittorrentapi.Forbidden403Error) as e:
|
||||
logger.error(f"qbittorrent 登录失败:{str(e).strip() or '请检查用户名和密码是否正确'}")
|
||||
return None
|
||||
if self.__use_api_key_auth():
|
||||
qbt.app_version()
|
||||
else:
|
||||
qbt.auth_log_in()
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ in {"LoginFailed", "Forbidden403Error", "Unauthorized401Error"}:
|
||||
error_hint = "请检查 API Key 是否正确" if self.__use_api_key_auth() else "请检查用户名和密码是否正确"
|
||||
logger.error(f"qbittorrent 登录失败:{str(e).strip() or error_hint}")
|
||||
return None
|
||||
stack_trace = "".join(traceback.format_exception(None, e, e.__traceback__))[:2000]
|
||||
logger.error(f"qbittorrent 登录失败:{str(e)}\n{stack_trace}")
|
||||
return None
|
||||
@@ -241,7 +361,7 @@ class Qbittorrent:
|
||||
category: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> bool:
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
添加种子
|
||||
:param content: 种子urls或文件内容
|
||||
@@ -251,10 +371,10 @@ class Qbittorrent:
|
||||
:param download_dir: 下载路径
|
||||
:param cookie: 站点Cookie用于辅助下载种子
|
||||
:param kwargs: 可选参数,如 ignore_category_check 以及 QB相关参数
|
||||
:return: bool
|
||||
:return: 添加是否成功, 新版API返回的种子ID列表
|
||||
"""
|
||||
if not self.qbc or not content:
|
||||
return False
|
||||
return False, []
|
||||
|
||||
# 下载内容
|
||||
if isinstance(content, str):
|
||||
@@ -287,6 +407,11 @@ class Qbittorrent:
|
||||
is_auto = False
|
||||
category = None
|
||||
try:
|
||||
cookie_to_use = cookie
|
||||
if urls and cookie and not StringUtils.is_magnet_link(urls):
|
||||
if self.__sync_download_cookies(url=urls, cookie_header=cookie):
|
||||
cookie_to_use = None
|
||||
|
||||
# 添加下载
|
||||
qbc_ret = self.qbc.torrents_add(urls=urls,
|
||||
torrent_files=torrent_files,
|
||||
@@ -296,13 +421,13 @@ class Qbittorrent:
|
||||
use_auto_torrent_management=is_auto,
|
||||
is_sequential_download=self._sequentail,
|
||||
is_first_last_piece_priority=self._first_last_piece,
|
||||
cookie=cookie,
|
||||
cookie=cookie_to_use,
|
||||
category=category,
|
||||
**kwargs)
|
||||
return True if qbc_ret and str(qbc_ret).find("Ok") != -1 else False
|
||||
return self.__parse_add_torrent_response(qbc_ret)
|
||||
except Exception as err:
|
||||
logger.error(f"添加种子出错:{str(err)}")
|
||||
return False
|
||||
return False, []
|
||||
|
||||
def start_torrents(self, ids: Union[str, list]) -> bool:
|
||||
"""
|
||||
|
||||
@@ -219,6 +219,25 @@ function graceful_exit() {
|
||||
# 使用env配置
|
||||
load_config_from_app_env
|
||||
|
||||
# 一次性升级标记仅影响本次启动,避免把临时升级模式带入运行中的 Python 进程
|
||||
ONE_SHOT_UPDATE_FLAG="${CONFIG_DIR}/temp/moviepilot.pending_update"
|
||||
ONE_SHOT_UPDATE_APPLIED="false"
|
||||
MOVIEPILOT_AUTO_UPDATE_ORIGINAL="${MOVIEPILOT_AUTO_UPDATE}"
|
||||
if [ -f "${ONE_SHOT_UPDATE_FLAG}" ]; then
|
||||
ONE_SHOT_UPDATE_MODE="$(tr -d '\r\n' < "${ONE_SHOT_UPDATE_FLAG}" | tr '[:upper:]' '[:lower:]')"
|
||||
rm -f "${ONE_SHOT_UPDATE_FLAG}"
|
||||
if [ "${ONE_SHOT_UPDATE_MODE}" = "true" ]; then
|
||||
ONE_SHOT_UPDATE_MODE="release"
|
||||
fi
|
||||
if [ "${ONE_SHOT_UPDATE_MODE}" = "release" ] || [ "${ONE_SHOT_UPDATE_MODE}" = "dev" ]; then
|
||||
INFO "检测到一次性升级标记,本次启动将执行 ${ONE_SHOT_UPDATE_MODE} 升级..."
|
||||
export MOVIEPILOT_AUTO_UPDATE="${ONE_SHOT_UPDATE_MODE}"
|
||||
ONE_SHOT_UPDATE_APPLIED="true"
|
||||
elif [ -n "${ONE_SHOT_UPDATE_MODE}" ]; then
|
||||
WARN "检测到无效的一次性升级模式:${ONE_SHOT_UPDATE_MODE},已忽略"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 生成HTTPS配置块
|
||||
if [ "${ENABLE_SSL}" = "true" ]; then
|
||||
export HTTPS_SERVER_CONF=$(cat <<EOF
|
||||
@@ -256,6 +275,9 @@ envsubst '${NGINX_PORT}${PORT}${NGINX_CLIENT_MAX_BODY_SIZE}${ENABLE_SSL}${HTTPS_
|
||||
# 自动更新
|
||||
cd /
|
||||
source /usr/local/bin/mp_update.sh
|
||||
if [ "${ONE_SHOT_UPDATE_APPLIED}" = "true" ]; then
|
||||
export MOVIEPILOT_AUTO_UPDATE="${MOVIEPILOT_AUTO_UPDATE_ORIGINAL}"
|
||||
fi
|
||||
cd /app || exit
|
||||
|
||||
# 更改 moviepilot userid 和 groupid
|
||||
|
||||
@@ -24,7 +24,7 @@ DB_TYPE=postgresql
|
||||
# PostgreSQL 主机地址
|
||||
DB_POSTGRESQL_HOST=localhost
|
||||
|
||||
# PostgreSQL 端口
|
||||
# PostgreSQL 端口;使用 Unix Socket 时可留空
|
||||
DB_POSTGRESQL_PORT=5432
|
||||
|
||||
# PostgreSQL 数据库名
|
||||
@@ -43,6 +43,21 @@ DB_POSTGRESQL_POOL_SIZE=20
|
||||
DB_POSTGRESQL_MAX_OVERFLOW=30
|
||||
```
|
||||
|
||||
### 3. Unix Socket 连接
|
||||
|
||||
如果 PostgreSQL 通过 Unix Socket 暴露,可以把 `DB_POSTGRESQL_HOST` 设置为套接字目录。
|
||||
|
||||
```bash
|
||||
DB_TYPE=postgresql
|
||||
DB_POSTGRESQL_HOST=/var/run/postgresql
|
||||
DB_POSTGRESQL_PORT=
|
||||
DB_POSTGRESQL_DATABASE=moviepilot
|
||||
DB_POSTGRESQL_USERNAME=moviepilot
|
||||
DB_POSTGRESQL_PASSWORD=moviepilot
|
||||
```
|
||||
|
||||
如需显式指定 socket 端口,也可以保留 `DB_POSTGRESQL_PORT`,程序会生成带 `host=/path/to/socket` 查询参数的 PostgreSQL URL。
|
||||
|
||||
## Docker 部署
|
||||
|
||||
### 使用外部 PostgreSQL
|
||||
@@ -60,6 +75,13 @@ DB_POSTGRESQL_USERNAME=your-username
|
||||
DB_POSTGRESQL_PASSWORD=your-password
|
||||
```
|
||||
|
||||
使用 Redis Unix Socket 时,可直接设置 `CACHE_BACKEND_URL`,例如:
|
||||
|
||||
```bash
|
||||
CACHE_BACKEND_TYPE=redis
|
||||
CACHE_BACKEND_URL=unix:///var/run/redis/redis.sock?db=0
|
||||
```
|
||||
|
||||
## 数据迁移
|
||||
|
||||
### 从 SQLite 迁移到 PostgreSQL
|
||||
|
||||
@@ -28,7 +28,7 @@ APScheduler~=3.11.0
|
||||
cryptography~=45.0.4
|
||||
pytz~=2025.2
|
||||
pycryptodome~=3.23.0
|
||||
qbittorrent-api==2025.5.0
|
||||
qbittorrent-api==2026.5.1
|
||||
plexapi~=4.17.0
|
||||
transmission-rpc~=4.3.0
|
||||
Jinja2~=3.1.6
|
||||
@@ -89,3 +89,4 @@ openai~=2.33.0
|
||||
google-genai~=1.74.0
|
||||
ddgs~=9.10.0
|
||||
websocket-client~=1.8.0
|
||||
pytest~=8.4.0
|
||||
|
||||
@@ -1312,8 +1312,9 @@ def _collect_downloader_config() -> Optional[dict[str, Any]]:
|
||||
config_name = _prompt_text("下载器名称", default=downloader_type)
|
||||
if downloader_type == "qbittorrent":
|
||||
host = _prompt_text("qBittorrent 地址", default="http://127.0.0.1:8080")
|
||||
username = _prompt_text("qBittorrent 用户名", default="admin")
|
||||
password = _prompt_text("qBittorrent 密码", secret=True)
|
||||
apikey = _prompt_text("qBittorrent API Key(可选,5.2+ 推荐)", allow_empty=True, default="")
|
||||
username = _prompt_text("qBittorrent 用户名", default="admin") if not apikey else ""
|
||||
password = _prompt_text("qBittorrent 密码", secret=True, allow_empty=bool(apikey)) if not apikey else ""
|
||||
category = _prompt_yes_no("是否启用 qBittorrent 分类", default=False)
|
||||
return {
|
||||
"name": config_name,
|
||||
@@ -1322,6 +1323,7 @@ def _collect_downloader_config() -> Optional[dict[str, Any]]:
|
||||
"enabled": True,
|
||||
"config": {
|
||||
"host": host,
|
||||
"apikey": apikey,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"category": category,
|
||||
|
||||
@@ -310,7 +310,7 @@ All endpoints are under the base URL `{MP_HOST}`. Path parameters are shown as `
|
||||
| POST | `/api/v1/workflow/fork` | Fork shared workflow. Body: WorkflowShare JSON |
|
||||
| GET | `/api/v1/workflow/shares` | List shared workflows. Params: `name`, `page`, `count` |
|
||||
|
||||
### System (20 endpoints)
|
||||
### System (21 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
@@ -321,6 +321,7 @@ All endpoints are under the base URL `{MP_HOST}`. Path parameters are shown as `
|
||||
| GET | `/api/v1/system/global` | Non-sensitive settings. Params: `token` (required) |
|
||||
| GET | `/api/v1/system/global/user` | User-related settings |
|
||||
| GET | `/api/v1/system/restart` | Restart system |
|
||||
| POST | `/api/v1/system/upgrade` | Upgrade and restart system. Body: `"release"` or `"dev"` |
|
||||
| GET | `/api/v1/system/runscheduler` | Run scheduled service. Params: `jobid` (required) |
|
||||
| GET | `/api/v1/system/runscheduler2` | Run scheduler (API_TOKEN, use `--token-param`). Params: `jobid` |
|
||||
| GET | `/api/v1/system/modulelist` | List loaded modules |
|
||||
|
||||
@@ -1,144 +1,74 @@
|
||||
---
|
||||
name: moviepilot-update
|
||||
version: 1
|
||||
description: Use this skill when you need to restart or upgrade MoviePilot. This skill covers system restart, version check, and manual upgrade procedures.
|
||||
version: 2
|
||||
description: Use this skill when you need to check MoviePilot versions, restart MoviePilot, or trigger a MoviePilot upgrade. Prefer the built-in system APIs instead of docker commands or manual file replacement. If auto-update on restart is already enabled, just restart. If it is disabled, call the upgrade API so MoviePilot performs a one-shot upgrade and restart.
|
||||
---
|
||||
|
||||
# MoviePilot System Update & Restart
|
||||
# MoviePilot Update
|
||||
|
||||
> All script paths are relative to this skill file.
|
||||
|
||||
This skill provides capabilities to restart MoviePilot service, check for updates, and perform manual upgrades.
|
||||
Use this skill for MoviePilot restart and upgrade operations.
|
||||
|
||||
## Restart MoviePilot
|
||||
## Setup
|
||||
|
||||
### Method 1: Using REST API (Recommended)
|
||||
This skill reuses the `moviepilot-api` client configuration.
|
||||
|
||||
Call the restart endpoint with admin authentication:
|
||||
Configure host and API key once:
|
||||
|
||||
```bash
|
||||
# Using moviepilot-api skill
|
||||
python scripts/mp-api.py GET /api/v1/system/restart
|
||||
python ../moviepilot-api/scripts/mp-api.py configure --host http://localhost:3000 --apikey <API_TOKEN>
|
||||
```
|
||||
|
||||
Or with curl:
|
||||
```bash
|
||||
curl -X GET "http://localhost:3000/api/v1/system/restart" \
|
||||
-H "X-API-KEY: <YOUR_API_TOKEN>"
|
||||
```
|
||||
## Preferred Commands
|
||||
|
||||
**Note:** This API will restart the Docker container internally. The service will be briefly unavailable during restart.
|
||||
|
||||
### Method 2: Using execute_command tool
|
||||
|
||||
If you have admin privileges, you can execute the restart command directly:
|
||||
### Check versions
|
||||
|
||||
```bash
|
||||
docker restart moviepilot
|
||||
python scripts/mp-update.py versions
|
||||
```
|
||||
|
||||
## Check for Updates
|
||||
This calls `GET /api/v1/system/versions`.
|
||||
|
||||
### Method 1: Using REST API
|
||||
### Restart MoviePilot
|
||||
|
||||
```bash
|
||||
python scripts/mp-api.py GET /api/v1/system/versions
|
||||
python scripts/mp-update.py restart
|
||||
```
|
||||
|
||||
This returns all available GitHub releases.
|
||||
This calls `GET /api/v1/system/restart`.
|
||||
|
||||
### Method 2: Check current version
|
||||
### Upgrade and restart MoviePilot
|
||||
|
||||
Release mode:
|
||||
|
||||
```bash
|
||||
# Check current version
|
||||
cat /app/version.py
|
||||
python scripts/mp-update.py upgrade
|
||||
```
|
||||
|
||||
## Upgrade MoviePilot
|
||||
Dev mode:
|
||||
|
||||
### Option 1: Automatic Update (Recommended)
|
||||
```bash
|
||||
python scripts/mp-update.py upgrade dev
|
||||
```
|
||||
|
||||
Set the environment variable `MOVIEPILOT_AUTO_UPDATE` and restart:
|
||||
This calls `POST /api/v1/system/upgrade`.
|
||||
|
||||
1. **For Docker Compose users:**
|
||||
```bash
|
||||
# Edit docker-compose.yml, add environment variable:
|
||||
environment:
|
||||
- MOVIEPILOT_AUTO_UPDATE=release # or "dev" for dev版本
|
||||
|
||||
# Then restart
|
||||
docker-compose down && docker-compose up -d
|
||||
```
|
||||
Behavior:
|
||||
|
||||
2. **For Docker run users:**
|
||||
```bash
|
||||
docker stop moviepilot
|
||||
docker rm moviepilot
|
||||
docker run -d ... -e MOVIEPILOT_AUTO_UPDATE=release jxxghp/moviepilot
|
||||
```
|
||||
- If `MOVIEPILOT_AUTO_UPDATE` is already enabled (`release` or `dev`), MoviePilot only triggers a restart and lets the normal startup flow perform the upgrade.
|
||||
- If `MOVIEPILOT_AUTO_UPDATE` is disabled, MoviePilot writes a one-shot upgrade flag, restarts itself, performs that single upgrade during startup, and then continues running without changing the persisted auto-update setting.
|
||||
|
||||
The update script (`/usr/local/bin/mp_update.sh` or `/app/docker/update.sh`) will automatically:
|
||||
- Check GitHub for latest release
|
||||
- Download new backend code
|
||||
- Update dependencies if changed
|
||||
- Download new frontend
|
||||
- Update site resources
|
||||
- Restart the service
|
||||
## Direct API Examples
|
||||
|
||||
### Option 2: Manual Upgrade
|
||||
```bash
|
||||
python ../moviepilot-api/scripts/mp-api.py GET /api/v1/system/restart
|
||||
python ../moviepilot-api/scripts/mp-api.py POST /api/v1/system/upgrade --json '"release"'
|
||||
python ../moviepilot-api/scripts/mp-api.py POST /api/v1/system/upgrade --json '"dev"'
|
||||
```
|
||||
|
||||
If you need to manually download and apply updates:
|
||||
## Notes
|
||||
|
||||
1. **Get latest release version:**
|
||||
```bash
|
||||
curl -s https://api.github.com/repos/jxxghp/MoviePilot/releases | grep '"tag_name"' | grep "v2" | head -1
|
||||
```
|
||||
|
||||
2. **Download and extract backend:**
|
||||
```bash
|
||||
# Replace v2.x.x with actual version
|
||||
curl -L -o /tmp/backend.zip https://github.com/jxxghp/MoviePilot/archive/refs/tags/v2.x.x.zip
|
||||
unzip -d /tmp/backend /tmp/backend.zip
|
||||
```
|
||||
|
||||
3. **Backup and replace:**
|
||||
```bash
|
||||
# Backup current installation
|
||||
cp -r /app /app_backup
|
||||
|
||||
# Replace files (exclude config and plugins)
|
||||
cp -r /tmp/backend/MoviePilot-*/* /app/
|
||||
```
|
||||
|
||||
4. **Restart MoviePilot:**
|
||||
```bash
|
||||
# Use API or docker restart
|
||||
python scripts/mp-api.py GET /api/v1/system/restart
|
||||
```
|
||||
|
||||
### Important Notes
|
||||
|
||||
- **Backup first:** Before upgrading, backup your configuration and database
|
||||
- **Dependencies:** Check if requirements.in has changes; if so, update virtual environment
|
||||
- **Plugins:** The update script automatically backs up and restores plugins
|
||||
- **Non-Docker:** For non-Docker installations, use `git pull` or `pip install -U moviepilot`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Restart fails | Check if Docker daemon is running; verify container has restart policy |
|
||||
| Update fails | Check network connectivity to GitHub; ensure sufficient disk space |
|
||||
| Version unchanged | Verify `MOVIEPILOT_AUTO_UPDATE` environment variable is set correctly |
|
||||
| Dependency errors | May need to rebuild virtual environment: `pip-compile requirements.in && pip install -r requirements.txt` |
|
||||
|
||||
## Environment Variables for Auto-Update
|
||||
|
||||
| Variable | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `MOVIEPILOT_AUTO_UPDATE` | `release` | Auto-update to latest stable release |
|
||||
| `MOVIEPILOT_AUTO_UPDATE` | `dev` | Auto-update to latest dev version |
|
||||
| `MOVIEPILOT_AUTO_UPDATE` | `false` | Disable auto-update (default) |
|
||||
| `GITHUB_TOKEN` | (token) | GitHub token for higher rate limits |
|
||||
| `GITHUB_PROXY` | (url) | GitHub proxy URL for China users |
|
||||
| `PROXY_HOST` | (url) | Global proxy host |
|
||||
- These operations require administrator authentication.
|
||||
- Restart or upgrade will interrupt the current agent session. Do not rely on post-restart follow-up steps in the same run.
|
||||
- Prefer the API flow above. Only fall back to manual container commands when the API is unavailable.
|
||||
|
||||
62
skills/moviepilot-update/scripts/mp-update.py
Normal file
62
skills/moviepilot-update/scripts/mp-update.py
Normal file
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
API_SCRIPT = SCRIPT_DIR.parents[1] / "moviepilot-api" / "scripts" / "mp-api.py"
|
||||
|
||||
|
||||
def run_api_call(args: list[str]) -> int:
|
||||
command = [sys.executable, str(API_SCRIPT), *args]
|
||||
return_code = __import__("subprocess").run(command, check=False).returncode
|
||||
return return_code
|
||||
|
||||
|
||||
def print_usage() -> None:
|
||||
print(
|
||||
"Usage:\n"
|
||||
f" python {Path(sys.argv[0]).name} versions\n"
|
||||
f" python {Path(sys.argv[0]).name} restart\n"
|
||||
f" python {Path(sys.argv[0]).name} upgrade [release|dev]"
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
argv = sys.argv[1:]
|
||||
if not argv or argv[0] in {"-h", "--help", "help"}:
|
||||
print_usage()
|
||||
return 0
|
||||
|
||||
command = argv[0].lower()
|
||||
if command == "versions":
|
||||
return run_api_call(["GET", "/api/v1/system/versions"])
|
||||
|
||||
if command == "restart":
|
||||
return run_api_call(["GET", "/api/v1/system/restart"])
|
||||
|
||||
if command == "upgrade":
|
||||
mode = (argv[1] if len(argv) > 1 else "release").strip().lower()
|
||||
if mode == "true":
|
||||
mode = "release"
|
||||
if mode not in {"release", "dev"}:
|
||||
print("Error: mode must be release or dev", file=sys.stderr)
|
||||
return 1
|
||||
return run_api_call([
|
||||
"POST",
|
||||
"/api/v1/system/upgrade",
|
||||
"--json",
|
||||
json.dumps(mode, ensure_ascii=False),
|
||||
])
|
||||
|
||||
print(f"Error: unknown command: {command}", file=sys.stderr)
|
||||
print_usage()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -3,15 +3,24 @@ import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentAddSubscribeTool(unittest.TestCase):
|
||||
def test_tv_subscription_without_season_reports_default_first_season(self):
|
||||
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="tg_display_name",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
|
||||
new=AsyncMock(return_value=(1, "")),
|
||||
) as async_add, patch(
|
||||
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
|
||||
return_value="moviepilot-user",
|
||||
):
|
||||
result = asyncio.run(
|
||||
tool.run(
|
||||
@@ -21,9 +30,36 @@ class TestAgentAddSubscribeTool(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(async_add.await_args.kwargs["username"], "moviepilot-user")
|
||||
self.assertIn("第1季", result)
|
||||
self.assertIn("默认按第一季订阅", result)
|
||||
|
||||
def test_subscription_falls_back_to_channel_username_when_no_binding_exists(self):
|
||||
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="tg_display_name",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
|
||||
new=AsyncMock(return_value=(1, "")),
|
||||
) as async_add, patch(
|
||||
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
|
||||
return_value=None,
|
||||
):
|
||||
result = asyncio.run(
|
||||
tool.run(
|
||||
title="The Matrix",
|
||||
year="1999",
|
||||
media_type="movie",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(async_add.await_args.kwargs["username"], "tg_display_name")
|
||||
self.assertIn("成功添加订阅:The Matrix (1999)", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
205
tests/test_alist_storage.py
Normal file
205
tests/test_alist_storage.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _load_alist_module():
|
||||
module_name = "_test_alist_module"
|
||||
app_module = types.ModuleType("app")
|
||||
schemas_module = types.ModuleType("app.schemas")
|
||||
cache_module = types.ModuleType("app.core.cache")
|
||||
config_module = types.ModuleType("app.core.config")
|
||||
log_module = types.ModuleType("app.log")
|
||||
storages_module = types.ModuleType("app.modules.filemanager.storages")
|
||||
exception_module = types.ModuleType("app.schemas.exception")
|
||||
types_module = types.ModuleType("app.schemas.types")
|
||||
http_module = types.ModuleType("app.utils.http")
|
||||
singleton_module = types.ModuleType("app.utils.singleton")
|
||||
url_module = types.ModuleType("app.utils.url")
|
||||
|
||||
class _FileItem:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
class _StorageSchemaValue:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
class _Logger:
|
||||
def debug(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warn(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warning(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def error(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def critical(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def info(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
class _StorageBase:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_conf(self):
|
||||
return {}
|
||||
|
||||
class _OperationInterrupted(Exception):
|
||||
pass
|
||||
|
||||
class _RequestUtils:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class _UrlUtils:
|
||||
@staticmethod
|
||||
def standardize_base_url(url):
|
||||
return url.rstrip("/") if url else ""
|
||||
|
||||
@staticmethod
|
||||
def adapt_request_url(base, path):
|
||||
return f"{base() if callable(base) else base}{path}"
|
||||
|
||||
@staticmethod
|
||||
def quote(path):
|
||||
return path
|
||||
|
||||
def _cached(*_args, **_kwargs):
|
||||
def decorator(func):
|
||||
func.cache_clear = lambda: None
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
schemas_module.FileItem = _FileItem
|
||||
schemas_module.StorageUsage = object
|
||||
cache_module.cached = _cached
|
||||
config_module.settings = types.SimpleNamespace(
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME=True,
|
||||
TEMP_PATH=Path("/tmp"),
|
||||
)
|
||||
config_module.global_vars = types.SimpleNamespace(
|
||||
is_transfer_stopped=lambda *_args, **_kwargs: False
|
||||
)
|
||||
log_module.logger = _Logger()
|
||||
storages_module.StorageBase = _StorageBase
|
||||
storages_module.transfer_process = lambda *_args, **_kwargs: (lambda *_a, **_k: None)
|
||||
exception_module.OperationInterrupted = _OperationInterrupted
|
||||
types_module.StorageSchema = types.SimpleNamespace(Alist=_StorageSchemaValue("alist"))
|
||||
http_module.RequestUtils = _RequestUtils
|
||||
singleton_module.WeakSingleton = type
|
||||
url_module.UrlUtils = _UrlUtils
|
||||
|
||||
app_module.schemas = schemas_module
|
||||
|
||||
stub_modules = {
|
||||
"app": app_module,
|
||||
"app.schemas": schemas_module,
|
||||
"app.core.cache": cache_module,
|
||||
"app.core.config": config_module,
|
||||
"app.log": log_module,
|
||||
"app.modules.filemanager.storages": storages_module,
|
||||
"app.schemas.exception": exception_module,
|
||||
"app.schemas.types": types_module,
|
||||
"app.utils.http": http_module,
|
||||
"app.utils.singleton": singleton_module,
|
||||
"app.utils.url": url_module,
|
||||
}
|
||||
for stub_module in stub_modules.values():
|
||||
stub_module._alist_test_stub = True
|
||||
|
||||
alist_path = Path(__file__).resolve().parents[1] / "app" / "modules" / "filemanager" / "storages" / "alist.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, alist_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
with patch.dict(sys.modules, stub_modules):
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
alist_module = _load_alist_module()
|
||||
Alist = alist_module.Alist
|
||||
FileItem = alist_module.schemas.FileItem
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload: dict, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class AlistStorageTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.storage = Alist()
|
||||
|
||||
@staticmethod
|
||||
def _dir_item(path: str = "/"):
|
||||
return FileItem(storage="alist", type="dir", path=path)
|
||||
|
||||
@staticmethod
|
||||
def _page_payload(start: int, count: int, total: int) -> dict:
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": f"dir-{index}",
|
||||
"size": 0,
|
||||
"is_dir": True,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"thumb": "",
|
||||
}
|
||||
for index in range(start, start + count)
|
||||
],
|
||||
"total": total,
|
||||
},
|
||||
}
|
||||
|
||||
def test_list_fetches_all_pages_when_per_page_is_default(self):
|
||||
responses = [
|
||||
_FakeResponse(self._page_payload(0, 200, 205)),
|
||||
_FakeResponse(self._page_payload(200, 5, 205)),
|
||||
]
|
||||
request_utils = MagicMock()
|
||||
request_utils.post_res.side_effect = responses
|
||||
|
||||
with patch.object(Alist, "get_conf", return_value={"url": "http://openlist.test", "token": "token"}):
|
||||
with patch.object(alist_module, "RequestUtils", return_value=request_utils):
|
||||
items = self.storage.list(self._dir_item())
|
||||
|
||||
self.assertEqual(205, len(items))
|
||||
self.assertEqual("/dir-0/", items[0].path)
|
||||
self.assertEqual("/dir-204/", items[-1].path)
|
||||
self.assertEqual(2, request_utils.post_res.call_count)
|
||||
self.assertEqual(1, request_utils.post_res.call_args_list[0].kwargs["json"]["page"])
|
||||
self.assertEqual(2, request_utils.post_res.call_args_list[1].kwargs["json"]["page"])
|
||||
|
||||
def test_list_respects_explicit_per_page_without_auto_paging(self):
|
||||
request_utils = MagicMock()
|
||||
request_utils.post_res.return_value = _FakeResponse(self._page_payload(0, 50, 205))
|
||||
|
||||
with patch.object(Alist, "get_conf", return_value={"url": "http://openlist.test", "token": "token"}):
|
||||
with patch.object(alist_module, "RequestUtils", return_value=request_utils):
|
||||
items = self.storage.list(self._dir_item(), per_page=50)
|
||||
|
||||
self.assertEqual(50, len(items))
|
||||
self.assertEqual(1, request_utils.post_res.call_count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -8,7 +9,7 @@ import unittest
|
||||
|
||||
from app.agent.tools.impl.execute_command import (
|
||||
ExecuteCommandTool,
|
||||
MAX_OUTPUT_CHARS,
|
||||
MAX_OUTPUT_PREVIEW_BYTES,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +22,11 @@ def _python_command(code: str) -> str:
|
||||
|
||||
|
||||
class TestExecuteCommandTool(unittest.TestCase):
|
||||
def _temp_file_path_from_result(self, result: str) -> str:
|
||||
match = re.search(r"临时文件: (.+)", result)
|
||||
self.assertIsNotNone(match)
|
||||
return match.group(1).strip()
|
||||
|
||||
def _run_command(self, command: str, timeout: int = 60) -> str:
|
||||
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
return asyncio.run(tool.run(command=command, timeout=timeout))
|
||||
@@ -31,9 +37,19 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
)
|
||||
|
||||
result = self._run_command(command)
|
||||
temp_file_path = self._temp_file_path_from_result(result)
|
||||
|
||||
self.assertIn("输出内容过长,已截断", result)
|
||||
self.assertLess(len(result), MAX_OUTPUT_CHARS + 500)
|
||||
self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path))
|
||||
self.assertIn("命令输出超过 10KB", result)
|
||||
self.assertIn("仅展示前 10KB 内容", result)
|
||||
self.assertIn("如需完整内容,请继续读取该文件", result)
|
||||
self.assertLess(len(result), MAX_OUTPUT_PREVIEW_BYTES + 600)
|
||||
|
||||
with open(temp_file_path, encoding="utf-8") as file_handle:
|
||||
file_content = file_handle.read()
|
||||
|
||||
self.assertIn("[标准输出]", file_content)
|
||||
self.assertGreater(len(file_content), 100000)
|
||||
|
||||
def test_timeout_returns_partial_output_promptly(self):
|
||||
command = _python_command(
|
||||
@@ -48,6 +64,24 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
self.assertIn("命令执行超时", result)
|
||||
self.assertIn("started", result)
|
||||
|
||||
def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self):
|
||||
command = _python_command(
|
||||
"import sys, time; sys.stdout.write('x' * 20000); sys.stdout.flush(); time.sleep(5)"
|
||||
)
|
||||
|
||||
result = self._run_command(command, timeout=1)
|
||||
temp_file_path = self._temp_file_path_from_result(result)
|
||||
|
||||
self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path))
|
||||
self.assertIn("命令执行超时", result)
|
||||
self.assertIn("截至命令终止前的完整输出已写入临时文件", result)
|
||||
|
||||
with open(temp_file_path, encoding="utf-8") as file_handle:
|
||||
file_content = file_handle.read()
|
||||
|
||||
self.assertIn("[标准输出]", file_content)
|
||||
self.assertGreaterEqual(file_content.count("x"), 20000)
|
||||
|
||||
def test_timeout_is_capped(self):
|
||||
command = _python_command("print('ok')")
|
||||
|
||||
|
||||
238
tests/test_jellyfin.py
Normal file
238
tests/test_jellyfin.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import call, patch
|
||||
|
||||
|
||||
def _load_jellyfin_module():
|
||||
module_name = "_test_jellyfin_module"
|
||||
app_module = types.ModuleType("app")
|
||||
core_module = types.ModuleType("app.core")
|
||||
utils_module = types.ModuleType("app.utils")
|
||||
log_module = types.ModuleType("app.log")
|
||||
config_module = types.ModuleType("app.core.config")
|
||||
schemas_module = types.ModuleType("app.schemas")
|
||||
http_module = types.ModuleType("app.utils.http")
|
||||
url_module = types.ModuleType("app.utils.url")
|
||||
|
||||
class _Logger:
|
||||
def info(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warning(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def error(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def debug(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
class _RequestUtils:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get_res(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
class _UrlUtils:
|
||||
@staticmethod
|
||||
def standardize_base_url(host):
|
||||
if not host:
|
||||
return host
|
||||
if not host.endswith("/"):
|
||||
host += "/"
|
||||
if not host.startswith("http://") and not host.startswith("https://"):
|
||||
host = "http://" + host
|
||||
return host
|
||||
|
||||
@staticmethod
|
||||
def combine_url(host, path=None, query=None):
|
||||
from urllib.parse import urljoin
|
||||
|
||||
if path is None:
|
||||
path = "/"
|
||||
host = _UrlUtils.standardize_base_url(host)
|
||||
return urljoin(host, path)
|
||||
|
||||
log_module.logger = _Logger()
|
||||
config_module.settings = types.SimpleNamespace(SUPERUSER="admin", USER_AGENT="MoviePilot")
|
||||
schemas_module.MediaType = types.SimpleNamespace(MOVIE=types.SimpleNamespace(value="movie"))
|
||||
schemas_module.MediaServerItem = object
|
||||
schemas_module.MediaServerLibrary = object
|
||||
schemas_module.Statistic = object
|
||||
schemas_module.WebhookEventInfo = object
|
||||
schemas_module.MediaServerItemUserState = object
|
||||
schemas_module.MediaServerPlayItem = object
|
||||
http_module.RequestUtils = _RequestUtils
|
||||
url_module.UrlUtils = _UrlUtils
|
||||
|
||||
app_module.schemas = schemas_module
|
||||
app_module.log = log_module
|
||||
app_module.core = core_module
|
||||
app_module.utils = utils_module
|
||||
core_module.config = config_module
|
||||
utils_module.http = http_module
|
||||
utils_module.url = url_module
|
||||
|
||||
stub_modules = {
|
||||
"app": app_module,
|
||||
"app.log": log_module,
|
||||
"app.core": core_module,
|
||||
"app.core.config": config_module,
|
||||
"app.schemas": schemas_module,
|
||||
"app.utils": utils_module,
|
||||
"app.utils.http": http_module,
|
||||
"app.utils.url": url_module,
|
||||
}
|
||||
for stub_module in stub_modules.values():
|
||||
stub_module._jellyfin_test_stub = True
|
||||
|
||||
jellyfin_path = Path(__file__).resolve().parents[1] / "app" / "modules" / "jellyfin" / "jellyfin.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, jellyfin_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
with patch.dict(sys.modules, stub_modules):
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
jellyfin_module = _load_jellyfin_module()
|
||||
Jellyfin = jellyfin_module.Jellyfin
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload: dict):
|
||||
self._payload = payload
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class JellyfinUserResolutionTest(unittest.TestCase):
|
||||
def test_loader_does_not_leave_stub_modules_in_sys_modules(self):
|
||||
self.assertNotIn("_test_jellyfin_module", sys.modules)
|
||||
self.assertFalse(getattr(sys.modules.get("app.log"), "_jellyfin_test_stub", False))
|
||||
self.assertFalse(getattr(sys.modules.get("app.core.config"), "_jellyfin_test_stub", False))
|
||||
self.assertFalse(getattr(sys.modules.get("app.utils.http"), "_jellyfin_test_stub", False))
|
||||
|
||||
def _build_client(self) -> Jellyfin:
|
||||
client = Jellyfin.__new__(Jellyfin)
|
||||
client._host = "http://jellyfin.local:8096"
|
||||
client._apikey = "api-key"
|
||||
client._playhost = None
|
||||
client._sync_libraries = []
|
||||
client.user = "fallback-user"
|
||||
return client
|
||||
|
||||
def test_get_user_prefers_exact_username_without_warning(self):
|
||||
client = self._build_client()
|
||||
payload = [
|
||||
{"Id": "admin-id", "Name": "admin", "Policy": {"IsAdministrator": True}},
|
||||
{"Id": "alice-id", "Name": "alice", "Policy": {"IsAdministrator": False}},
|
||||
]
|
||||
|
||||
with patch.object(jellyfin_module, "RequestUtils") as request_utils_cls, patch.object(
|
||||
jellyfin_module.logger, "warning"
|
||||
) as warning_mock:
|
||||
request_utils_cls.return_value.get_res.return_value = _FakeResponse(payload)
|
||||
|
||||
user_id = client.get_user("alice")
|
||||
|
||||
self.assertEqual(user_id, "alice-id")
|
||||
warning_mock.assert_not_called()
|
||||
|
||||
def test_get_user_prefers_enable_all_folders_admin(self):
|
||||
client = self._build_client()
|
||||
payload = [
|
||||
{
|
||||
"Id": "visible-admin-id",
|
||||
"Name": "visible",
|
||||
"Policy": {"IsAdministrator": True, "EnabledFolders": ["lib-1", "lib-2", "lib-3"]},
|
||||
},
|
||||
{
|
||||
"Id": "full-admin-id",
|
||||
"Name": "full",
|
||||
"Policy": {"IsAdministrator": True, "EnableAllFolders": True},
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(jellyfin_module, "RequestUtils") as request_utils_cls:
|
||||
request_utils_cls.return_value.get_res.return_value = _FakeResponse(payload)
|
||||
|
||||
user_id = client.get_user()
|
||||
|
||||
self.assertEqual(user_id, "full-admin-id")
|
||||
|
||||
def test_get_user_warns_and_prefers_larger_visible_scope_admin(self):
|
||||
client = self._build_client()
|
||||
payload = [
|
||||
{
|
||||
"Id": "small-admin-id",
|
||||
"Name": "small",
|
||||
"Policy": {"IsAdministrator": True, "EnabledFolders": ["lib-1"]},
|
||||
},
|
||||
{
|
||||
"Id": "large-admin-id",
|
||||
"Name": "large",
|
||||
"Policy": {"IsAdministrator": True, "EnabledFolders": ["lib-1", "lib-2", "lib-3"]},
|
||||
},
|
||||
{"Id": "user-id", "Name": "normal", "Policy": {"IsAdministrator": False}},
|
||||
]
|
||||
|
||||
with patch.object(jellyfin_module, "RequestUtils") as request_utils_cls, patch.object(
|
||||
jellyfin_module.logger, "warning"
|
||||
) as warning_mock:
|
||||
request_utils_cls.return_value.get_res.return_value = _FakeResponse(payload)
|
||||
|
||||
user_id = client.get_user("admin")
|
||||
|
||||
self.assertEqual(user_id, "large-admin-id")
|
||||
self.assertGreaterEqual(warning_mock.call_count, 2)
|
||||
|
||||
warning_messages = [
|
||||
call.args[0] for call in warning_mock.call_args_list if call.args and isinstance(call.args[0], str)
|
||||
]
|
||||
self.assertTrue(any("超级管理员" in message for message in warning_messages))
|
||||
self.assertTrue(
|
||||
any(
|
||||
("部分" in message)
|
||||
or ("可见" in message)
|
||||
or ("访问范围" in message)
|
||||
or ("EnabledFolders" in message)
|
||||
for message in warning_messages
|
||||
)
|
||||
)
|
||||
self.assertTrue(any(("回退" in message) or ("fallback" in message.lower()) for message in warning_messages))
|
||||
|
||||
def test_get_jellyfin_librarys_returns_empty_when_user_missing(self):
|
||||
client = self._build_client()
|
||||
client.user = None
|
||||
|
||||
with patch.object(jellyfin_module, "RequestUtils") as request_utils_cls:
|
||||
libraries = client._Jellyfin__get_jellyfin_librarys()
|
||||
|
||||
self.assertEqual(libraries, [])
|
||||
request_utils_cls.assert_not_called()
|
||||
|
||||
def test_get_jellyfin_librarys_uses_normalized_views_url(self):
|
||||
client = self._build_client()
|
||||
client._host = "http://jellyfin.local:8096"
|
||||
client.user = "user-id"
|
||||
|
||||
with patch.object(jellyfin_module, "RequestUtils") as request_utils_cls:
|
||||
request_utils_cls.return_value.get_res.return_value = _FakeResponse({"Items": []})
|
||||
|
||||
libraries = client._Jellyfin__get_jellyfin_librarys()
|
||||
|
||||
self.assertEqual(libraries, [])
|
||||
request_utils_cls.return_value.get_res.assert_called_once_with(
|
||||
"http://jellyfin.local:8096/Users/user-id/Views",
|
||||
{"api_key": "api-key"},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -64,6 +64,28 @@ class TestMediaScrapingPaths(unittest.TestCase):
|
||||
self.assertEqual(target_item, fileitem)
|
||||
self.assertEqual(target_path, Path("/tv/Show/Season 1/season.nfo"))
|
||||
|
||||
def test_season_dir_poster_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1", name="Season 1", type="dir", storage="local")
|
||||
target_item, target_path = self.media_chain._get_target_fileitem_and_path(
|
||||
current_fileitem=fileitem,
|
||||
item_type=ScrapingTarget.SEASON,
|
||||
metadata_type=ScrapingMetadata.POSTER,
|
||||
filename_hint="season01-poster.jpg"
|
||||
)
|
||||
self.assertEqual(target_item, fileitem)
|
||||
self.assertEqual(target_path, Path("/tv/Show/Season 1/poster.jpg"))
|
||||
|
||||
def test_season_dir_specials_poster_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Specials", name="Specials", type="dir", storage="local")
|
||||
target_item, target_path = self.media_chain._get_target_fileitem_and_path(
|
||||
current_fileitem=fileitem,
|
||||
item_type=ScrapingTarget.SEASON,
|
||||
metadata_type=ScrapingMetadata.POSTER,
|
||||
filename_hint="season-specials-poster.jpg"
|
||||
)
|
||||
self.assertEqual(target_item, fileitem)
|
||||
self.assertEqual(target_path, Path("/tv/Show/Specials/poster.jpg"))
|
||||
|
||||
def test_episode_file_nfo_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1/S01E01.mp4", name="S01E01.mp4", type="file", storage="local")
|
||||
parent_item = schemas.FileItem(path="/tv/Show/Season 1", name="Season 1", type="dir", storage="local")
|
||||
@@ -171,6 +193,7 @@ class TestMediaScrapingImages(unittest.TestCase):
|
||||
calls = self.media_chain._download_and_save_image.call_args_list
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].kwargs["url"], "http://season01")
|
||||
self.assertEqual(calls[0].kwargs["path"], Path("/tv/Show/Season 1/poster.jpg"))
|
||||
|
||||
def test_scrape_episode_thumb_image_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1/S01E01.mp4", name="S01E01.mp4", type="file", storage="local")
|
||||
|
||||
101
tests/test_postgresql_socket_config.py
Normal file
101
tests/test_postgresql_socket_config.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import sys
|
||||
import unittest
|
||||
from enum import Enum
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=_DummyLogger(),
|
||||
log_settings=_DummyLogger(),
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
_stub_module("psutil")
|
||||
_schemas_module = _stub_module(
|
||||
"app.schemas", MediaType=Enum("MediaType", {"Movie": "Movie", "TV": "TV"})
|
||||
)
|
||||
_schemas_module.__getattr__ = lambda name: type(name, (), {})
|
||||
_stub_module("version", APP_VERSION="test")
|
||||
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
|
||||
class PostgreSQLSocketConfigTests(unittest.TestCase):
|
||||
def test_postgresql_tcp_url_keeps_host_and_port(self):
|
||||
settings = Settings(
|
||||
DB_POSTGRESQL_HOST="db",
|
||||
DB_POSTGRESQL_PORT="5433",
|
||||
DB_POSTGRESQL_DATABASE="moviepilot",
|
||||
DB_POSTGRESQL_USERNAME="user",
|
||||
DB_POSTGRESQL_PASSWORD="pass",
|
||||
)
|
||||
|
||||
self.assertFalse(settings.DB_POSTGRESQL_SOCKET_MODE)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL(),
|
||||
"postgresql://user:pass@db:5433/moviepilot",
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL("asyncpg"),
|
||||
"postgresql+asyncpg://user:pass@db:5433/moviepilot",
|
||||
)
|
||||
self.assertEqual(settings.DB_POSTGRESQL_TARGET, "db:5433")
|
||||
|
||||
def test_postgresql_socket_url_uses_host_query_param(self):
|
||||
settings = Settings(
|
||||
DB_POSTGRESQL_HOST="/var/run/postgresql",
|
||||
DB_POSTGRESQL_PORT="",
|
||||
DB_POSTGRESQL_DATABASE="moviepilot",
|
||||
DB_POSTGRESQL_USERNAME="user",
|
||||
DB_POSTGRESQL_PASSWORD="pass",
|
||||
)
|
||||
|
||||
self.assertTrue(settings.DB_POSTGRESQL_SOCKET_MODE)
|
||||
self.assertIsNone(settings.DB_POSTGRESQL_PORT_VALUE)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL(),
|
||||
"postgresql://user:pass@/moviepilot?host=%2Fvar%2Frun%2Fpostgresql",
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL("asyncpg"),
|
||||
"postgresql+asyncpg://user:pass@/moviepilot?host=%2Fvar%2Frun%2Fpostgresql",
|
||||
)
|
||||
self.assertEqual(settings.DB_POSTGRESQL_TARGET, "socket /var/run/postgresql")
|
||||
|
||||
def test_postgresql_socket_url_can_keep_explicit_port(self):
|
||||
settings = Settings(
|
||||
DB_POSTGRESQL_HOST="/var/run/postgresql",
|
||||
DB_POSTGRESQL_PORT="5432",
|
||||
DB_POSTGRESQL_DATABASE="moviepilot",
|
||||
DB_POSTGRESQL_USERNAME="user",
|
||||
DB_POSTGRESQL_PASSWORD="",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL(),
|
||||
"postgresql://user@/moviepilot?host=%2Fvar%2Frun%2Fpostgresql&port=5432",
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_TARGET,
|
||||
"socket /var/run/postgresql (port 5432)",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
352
tests/test_qbittorrent_compat.py
Normal file
352
tests/test_qbittorrent_compat.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _load_qbittorrent_modules():
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
|
||||
app_module = types.ModuleType("app")
|
||||
app_module.__path__ = []
|
||||
core_module = types.ModuleType("app.core")
|
||||
core_module.__path__ = []
|
||||
utils_module = types.ModuleType("app.utils")
|
||||
utils_module.__path__ = []
|
||||
modules_module = types.ModuleType("app.modules")
|
||||
modules_module.__path__ = []
|
||||
qbittorrent_package_module = types.ModuleType("app.modules.qbittorrent")
|
||||
qbittorrent_package_module.__path__ = []
|
||||
log_module = types.ModuleType("app.log")
|
||||
cache_module = types.ModuleType("app.core.cache")
|
||||
config_module = types.ModuleType("app.core.config")
|
||||
metainfo_module = types.ModuleType("app.core.metainfo")
|
||||
schemas_module = types.ModuleType("app.schemas")
|
||||
schema_types_module = types.ModuleType("app.schemas.types")
|
||||
string_module = types.ModuleType("app.utils.string")
|
||||
torrentool_module = types.ModuleType("torrentool")
|
||||
torrentool_module.__path__ = []
|
||||
torrentool_torrent_module = types.ModuleType("torrentool.torrent")
|
||||
qbittorrentapi_module = types.ModuleType("qbittorrentapi")
|
||||
qbittorrentapi_client_module = types.ModuleType("qbittorrentapi.client")
|
||||
qbittorrentapi_transfer_module = types.ModuleType("qbittorrentapi.transfer")
|
||||
|
||||
class _Logger:
|
||||
def info(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warn(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warning(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def error(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
class _StringUtils:
|
||||
@staticmethod
|
||||
def get_domain_address(address, prefix=False):
|
||||
return address, 8080
|
||||
|
||||
@staticmethod
|
||||
def is_magnet_link(value):
|
||||
if isinstance(value, bytes):
|
||||
return value.startswith(b"magnet:")
|
||||
return isinstance(value, str) and value.startswith("magnet:")
|
||||
|
||||
@staticmethod
|
||||
def generate_random_str(_length):
|
||||
return "tmp-tag-01"
|
||||
|
||||
@staticmethod
|
||||
def str_filesize(value):
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def str_secends(value):
|
||||
return str(value)
|
||||
|
||||
class _FileCache:
|
||||
def get(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class _MetaInfo:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.year = None
|
||||
self.season_episode = ""
|
||||
self.episode_list = []
|
||||
|
||||
class _ModuleBase:
|
||||
pass
|
||||
|
||||
class _DownloaderBase:
|
||||
def __class_getitem__(cls, _item):
|
||||
return cls
|
||||
|
||||
class _Torrent:
|
||||
@staticmethod
|
||||
def from_string(content):
|
||||
return types.SimpleNamespace(name="test", total_size=len(content))
|
||||
|
||||
class TorrentStatus(Enum):
|
||||
TRANSFER = "transfer"
|
||||
DOWNLOADING = "downloading"
|
||||
|
||||
class ModuleType(Enum):
|
||||
Downloader = "Downloader"
|
||||
|
||||
class DownloaderType(Enum):
|
||||
Qbittorrent = "Qbittorrent"
|
||||
|
||||
log_module.logger = _Logger()
|
||||
cache_module.FileCache = _FileCache
|
||||
config_module.settings = types.SimpleNamespace(TORRENT_TAG="moviepilot-tag")
|
||||
metainfo_module.MetaInfo = _MetaInfo
|
||||
schemas_module.DownloaderInfo = object
|
||||
schemas_module.TransferTorrent = object
|
||||
schemas_module.DownloadingTorrent = object
|
||||
schema_types_module.TorrentStatus = TorrentStatus
|
||||
schema_types_module.ModuleType = ModuleType
|
||||
schema_types_module.DownloaderType = DownloaderType
|
||||
string_module.StringUtils = _StringUtils
|
||||
modules_module._ModuleBase = _ModuleBase
|
||||
modules_module._DownloaderBase = _DownloaderBase
|
||||
torrentool_torrent_module.Torrent = _Torrent
|
||||
qbittorrentapi_module.TorrentDictionary = dict
|
||||
qbittorrentapi_module.TorrentFilesList = list
|
||||
qbittorrentapi_module.LoginFailed = type("LoginFailed", (Exception,), {})
|
||||
qbittorrentapi_module.Forbidden403Error = type("Forbidden403Error", (Exception,), {})
|
||||
qbittorrentapi_module.Unauthorized401Error = type("Unauthorized401Error", (Exception,), {})
|
||||
qbittorrentapi_module.Client = object
|
||||
qbittorrentapi_client_module.Client = object
|
||||
qbittorrentapi_transfer_module.TransferInfoDictionary = dict
|
||||
|
||||
app_module.core = core_module
|
||||
app_module.log = log_module
|
||||
app_module.modules = modules_module
|
||||
app_module.schemas = schemas_module
|
||||
app_module.utils = utils_module
|
||||
core_module.cache = cache_module
|
||||
core_module.config = config_module
|
||||
core_module.metainfo = metainfo_module
|
||||
utils_module.string = string_module
|
||||
schemas_module.types = schema_types_module
|
||||
modules_module.qbittorrent = qbittorrent_package_module
|
||||
torrentool_module.torrent = torrentool_torrent_module
|
||||
|
||||
stub_modules = {
|
||||
"app": app_module,
|
||||
"app.core": core_module,
|
||||
"app.core.cache": cache_module,
|
||||
"app.core.config": config_module,
|
||||
"app.core.metainfo": metainfo_module,
|
||||
"app.log": log_module,
|
||||
"app.modules": modules_module,
|
||||
"app.modules.qbittorrent": qbittorrent_package_module,
|
||||
"app.schemas": schemas_module,
|
||||
"app.schemas.types": schema_types_module,
|
||||
"app.utils": utils_module,
|
||||
"app.utils.string": string_module,
|
||||
"qbittorrentapi": qbittorrentapi_module,
|
||||
"qbittorrentapi.client": qbittorrentapi_client_module,
|
||||
"qbittorrentapi.transfer": qbittorrentapi_transfer_module,
|
||||
"torrentool": torrentool_module,
|
||||
"torrentool.torrent": torrentool_torrent_module,
|
||||
}
|
||||
|
||||
for stub_module in stub_modules.values():
|
||||
stub_module._qbittorrent_test_stub = True
|
||||
|
||||
qbittorrent_path = repo_root / "app" / "modules" / "qbittorrent" / "qbittorrent.py"
|
||||
qbittorrent_spec = importlib.util.spec_from_file_location(
|
||||
"app.modules.qbittorrent.qbittorrent",
|
||||
qbittorrent_path,
|
||||
)
|
||||
qbittorrent_module = importlib.util.module_from_spec(qbittorrent_spec)
|
||||
assert qbittorrent_spec and qbittorrent_spec.loader
|
||||
|
||||
module_path = repo_root / "app" / "modules" / "qbittorrent" / "__init__.py"
|
||||
qbittorrent_module_spec = importlib.util.spec_from_file_location(
|
||||
"_test_qbittorrent_module",
|
||||
module_path,
|
||||
)
|
||||
module_package = importlib.util.module_from_spec(qbittorrent_module_spec)
|
||||
assert qbittorrent_module_spec and qbittorrent_module_spec.loader
|
||||
|
||||
with patch.dict(sys.modules, stub_modules):
|
||||
sys.modules[qbittorrent_spec.name] = qbittorrent_module
|
||||
qbittorrent_spec.loader.exec_module(qbittorrent_module)
|
||||
qbittorrent_package_module.qbittorrent = qbittorrent_module
|
||||
qbittorrent_module_spec.loader.exec_module(module_package)
|
||||
|
||||
return qbittorrent_module, module_package
|
||||
|
||||
|
||||
qbittorrent_module, qbittorrent_package_module = _load_qbittorrent_modules()
|
||||
Qbittorrent = qbittorrent_module.Qbittorrent
|
||||
QbittorrentModule = qbittorrent_package_module.QbittorrentModule
|
||||
|
||||
|
||||
class TestQbittorrentCompat(unittest.TestCase):
|
||||
def test_login_uses_api_key_header_without_auth_login(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.app_version.return_value = "v5.2.0"
|
||||
|
||||
with patch.object(qbittorrent_module.qbittorrentapi, "Client", return_value=fake_client) as client_cls:
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, apikey="secret-token")
|
||||
|
||||
self.assertIs(downloader.qbc, fake_client)
|
||||
fake_client.auth_log_in.assert_not_called()
|
||||
fake_client.app_version.assert_called_once_with()
|
||||
self.assertEqual(
|
||||
client_cls.call_args.kwargs["EXTRA_HEADERS"],
|
||||
{"Authorization": "Bearer secret-token"},
|
||||
)
|
||||
|
||||
def test_add_torrent_accepts_structured_success_response(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.torrents_add.return_value = {
|
||||
"success_count": 1,
|
||||
"failure_count": 0,
|
||||
"pending_count": 0,
|
||||
"added_torrent_ids": ["abc123"],
|
||||
}
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(content="https://example.com/test.torrent")
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, ["abc123"])
|
||||
|
||||
def test_add_torrent_accepts_pending_success_response_without_ids(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.torrents_add.return_value = {
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"pending_count": 1,
|
||||
"added_torrent_ids": [],
|
||||
}
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(content="https://example.com/test.torrent")
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, [])
|
||||
|
||||
def test_add_torrent_uses_cookie_api_for_qbittorrent_52(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.app_web_api_version.return_value = "2.11.3"
|
||||
fake_client.app_cookies.return_value = [
|
||||
{
|
||||
"domain": "old.example.com",
|
||||
"path": "/",
|
||||
"name": "old",
|
||||
"value": "cookie",
|
||||
}
|
||||
]
|
||||
fake_client.torrents_add.return_value = "Ok."
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(
|
||||
content="https://tracker.example.com/download?id=1",
|
||||
cookie="uid=1; passkey=abc",
|
||||
)
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, [])
|
||||
set_cookie_call = fake_client.app_set_cookies.call_args.kwargs["cookies"]
|
||||
self.assertIn(
|
||||
{
|
||||
"domain": "tracker.example.com",
|
||||
"path": "/",
|
||||
"name": "uid",
|
||||
"value": "1",
|
||||
},
|
||||
set_cookie_call,
|
||||
)
|
||||
self.assertIn(
|
||||
{
|
||||
"domain": "tracker.example.com",
|
||||
"path": "/",
|
||||
"name": "passkey",
|
||||
"value": "abc",
|
||||
},
|
||||
set_cookie_call,
|
||||
)
|
||||
self.assertIsNone(fake_client.torrents_add.call_args.kwargs["cookie"])
|
||||
|
||||
def test_add_torrent_keeps_legacy_cookie_param_for_old_webapi(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.app_web_api_version.return_value = "2.11.2"
|
||||
fake_client.torrents_add.return_value = "Ok."
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(
|
||||
content="https://tracker.example.com/download?id=1",
|
||||
cookie="uid=1",
|
||||
)
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, [])
|
||||
fake_client.app_set_cookies.assert_not_called()
|
||||
self.assertEqual(fake_client.torrents_add.call_args.kwargs["cookie"], "uid=1")
|
||||
|
||||
|
||||
class TestQbittorrentModuleCompat(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _build_module(server):
|
||||
module = QbittorrentModule.__new__(QbittorrentModule)
|
||||
module.get_instance = MagicMock(return_value=server)
|
||||
module.normalize_path = MagicMock(side_effect=lambda path, _downloader: path)
|
||||
module.get_default_config_name = MagicMock(return_value="default-qb")
|
||||
return module
|
||||
|
||||
def test_download_prefers_added_torrent_ids_before_tag_lookup(self):
|
||||
fake_server = MagicMock()
|
||||
fake_server.add_torrent.return_value = (True, ["abc123"])
|
||||
fake_server.get_content_layout.return_value = "Original"
|
||||
fake_server.is_force_resume.return_value = False
|
||||
|
||||
module = self._build_module(fake_server)
|
||||
result = module.download(
|
||||
content="magnet:?xt=urn:btih:123",
|
||||
download_dir=Path("/downloads"),
|
||||
cookie="",
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
self.assertEqual(result, ("qb", "abc123", "Original", "添加下载成功"))
|
||||
fake_server.delete_torrents_tag.assert_called_once_with("abc123", "tmp-tag-01")
|
||||
fake_server.get_torrent_id_by_tag.assert_not_called()
|
||||
self.assertEqual(
|
||||
fake_server.add_torrent.call_args.kwargs["tag"],
|
||||
["tmp-tag-01", "moviepilot-tag"],
|
||||
)
|
||||
|
||||
def test_download_falls_back_to_tag_lookup_when_added_ids_missing(self):
|
||||
fake_server = MagicMock()
|
||||
fake_server.add_torrent.return_value = (True, [])
|
||||
fake_server.get_content_layout.return_value = "Original"
|
||||
fake_server.get_torrent_id_by_tag.return_value = "def456"
|
||||
fake_server.is_force_resume.return_value = False
|
||||
|
||||
module = self._build_module(fake_server)
|
||||
result = module.download(
|
||||
content="magnet:?xt=urn:btih:456",
|
||||
download_dir=Path("/downloads"),
|
||||
cookie="",
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
self.assertEqual(result, ("qb", "def456", "Original", "添加下载成功"))
|
||||
fake_server.delete_torrents_tag.assert_not_called()
|
||||
fake_server.get_torrent_id_by_tag.assert_called_once_with(tags="tmp-tag-01")
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.chain.transfer import TransferChain
|
||||
|
||||
@@ -32,6 +33,9 @@ class FakeDownloadHistoryOper:
|
||||
|
||||
|
||||
class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.chain = object.__new__(TransferChain)
|
||||
|
||||
def test_resolve_download_history_falls_back_to_parent_download_path(self):
|
||||
expected = SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
oper = FakeDownloadHistoryOper(
|
||||
@@ -39,7 +43,7 @@ class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
histories_by_path={"/downloads/season-pack": expected},
|
||||
)
|
||||
|
||||
history = TransferChain._resolve_download_history(
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/season-pack/Test.Show.S01E01.mkv"),
|
||||
)
|
||||
@@ -58,7 +62,7 @@ class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
history = TransferChain._resolve_download_history(
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/season-pack/subs/Test.Show.S01E01.zh.ass"),
|
||||
)
|
||||
@@ -79,13 +83,127 @@ class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
history = TransferChain._resolve_download_history(
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/shared/Test.Show.S01E01.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
def test_resolve_download_history_stops_at_shared_download_root_path(self):
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_path={
|
||||
"/downloads": SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=False,
|
||||
media_category=None,
|
||||
download_category_folder=False,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
def test_resolve_download_history_stops_at_shared_download_root_savepath(self):
|
||||
expected = SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_hash={"hash1": expected},
|
||||
files_by_savepath={
|
||||
"/downloads": [
|
||||
SimpleNamespace(download_hash="hash1", filepath="Other.Show.mkv"),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=False,
|
||||
media_category=None,
|
||||
download_category_folder=False,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
def test_resolve_download_history_accepts_shared_root_savepath_for_exact_file(self):
|
||||
expected = SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_hash={"hash1": expected},
|
||||
files_by_savepath={
|
||||
"/downloads": [
|
||||
SimpleNamespace(download_hash="hash1", filepath="Ghost.Concert.mkv"),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=False,
|
||||
media_category=None,
|
||||
download_category_folder=False,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIs(history, expected)
|
||||
|
||||
def test_resolve_download_history_stops_at_type_category_download_root(self):
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_path={
|
||||
"/downloads/电视剧/动漫": SimpleNamespace(
|
||||
download_hash="hash1", downloader="qb"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=True,
|
||||
media_category=None,
|
||||
download_category_folder=True,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/电视剧/动漫/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.10.9'
|
||||
FRONTEND_VERSION = 'v2.10.9'
|
||||
APP_VERSION = 'v2.10.11'
|
||||
FRONTEND_VERSION = 'v2.10.11'
|
||||
|
||||
Reference in New Issue
Block a user