mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 16:53:03 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2c9950bb1 | ||
|
|
ffbe348d66 | ||
|
|
6d7b0733af | ||
|
|
49a51cca25 | ||
|
|
06197144c0 | ||
|
|
62541ffe43 | ||
|
|
c762628217 |
@@ -281,7 +281,10 @@ class PromptManager:
|
||||
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
|
||||
else:
|
||||
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
|
||||
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
db_info = (
|
||||
f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@"
|
||||
f"{settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
)
|
||||
|
||||
info_lines = [
|
||||
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
|
||||
@@ -5,7 +5,8 @@ import os
|
||||
import signal
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, TextIO, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -15,7 +16,7 @@ from app.log import logger
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 60
|
||||
MAX_TIMEOUT_SECONDS = 300
|
||||
MAX_OUTPUT_CHARS = 6000
|
||||
MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
@@ -25,40 +26,93 @@ _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
@dataclass
|
||||
class _CommandOutput:
|
||||
"""保存受限命令输出,避免大输出一次性进入内存。"""
|
||||
"""保存前 10KB 预览,并在超限时将完整输出写入临时文件。"""
|
||||
|
||||
limit: int
|
||||
stdout_chunks: list[str] = field(default_factory=list)
|
||||
stderr_chunks: list[str] = field(default_factory=list)
|
||||
captured_chars: int = 0
|
||||
truncated: bool = False
|
||||
preview_limit_bytes: int
|
||||
preview_entries: list[tuple[str, str]] = field(default_factory=list)
|
||||
captured_bytes: int = 0
|
||||
preview_truncated: bool = False
|
||||
temp_file_path: Optional[str] = None
|
||||
temp_file_handle: Optional[TextIO] = None
|
||||
last_written_stream: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
|
||||
if byte_limit <= 0:
|
||||
return ""
|
||||
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
|
||||
|
||||
def _write_chunk(self, stream_name: str, text: str) -> None:
|
||||
if not self.temp_file_handle or not text:
|
||||
return
|
||||
|
||||
if self.last_written_stream != stream_name:
|
||||
if self.temp_file_handle.tell() > 0:
|
||||
self.temp_file_handle.write("\n")
|
||||
title = "标准输出" if stream_name == "stdout" else "错误输出"
|
||||
self.temp_file_handle.write(f"[{title}]\n")
|
||||
self.last_written_stream = stream_name
|
||||
|
||||
self.temp_file_handle.write(text)
|
||||
|
||||
def _ensure_temp_file(self) -> None:
|
||||
if self.temp_file_handle:
|
||||
return
|
||||
|
||||
temp_file = NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
suffix=".log",
|
||||
prefix="moviepilot-command-",
|
||||
delete=False,
|
||||
)
|
||||
self.temp_file_path = temp_file.name
|
||||
self.temp_file_handle = temp_file
|
||||
for stream_name, chunk in self.preview_entries:
|
||||
self._write_chunk(stream_name, chunk)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self.temp_file_handle:
|
||||
return
|
||||
self.temp_file_handle.flush()
|
||||
self.temp_file_handle.close()
|
||||
self.temp_file_handle = None
|
||||
|
||||
def append(self, stream_name: str, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
|
||||
remaining = self.limit - self.captured_chars
|
||||
if remaining <= 0:
|
||||
self.truncated = True
|
||||
if self.temp_file_handle:
|
||||
self._write_chunk(stream_name, text)
|
||||
return
|
||||
|
||||
captured = text[:remaining]
|
||||
if stream_name == "stdout":
|
||||
self.stdout_chunks.append(captured)
|
||||
else:
|
||||
self.stderr_chunks.append(captured)
|
||||
chunk_bytes = len(text.encode("utf-8"))
|
||||
remaining = self.preview_limit_bytes - self.captured_bytes
|
||||
if chunk_bytes <= remaining:
|
||||
self.preview_entries.append((stream_name, text))
|
||||
self.captured_bytes += chunk_bytes
|
||||
return
|
||||
|
||||
self.captured_chars += len(captured)
|
||||
if len(text) > remaining:
|
||||
self.truncated = True
|
||||
self.preview_truncated = True
|
||||
self._ensure_temp_file()
|
||||
self._write_chunk(stream_name, text)
|
||||
|
||||
preview = self._clip_text_to_bytes(text, remaining)
|
||||
if preview:
|
||||
self.preview_entries.append((stream_name, preview))
|
||||
self.captured_bytes += len(preview.encode("utf-8"))
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
return "".join(self.stdout_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stdout"
|
||||
).strip()
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
return "".join(self.stderr_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stderr"
|
||||
).strip()
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
@@ -78,7 +132,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
description: str = (
|
||||
"Safely execute shell commands on the server. Useful for system "
|
||||
"maintenance, checking status, or running custom scripts. Includes "
|
||||
"timeout, concurrency, and hard output limits."
|
||||
"timeout, concurrency, and output preview limits."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
@@ -107,7 +161,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _subprocess_kwargs() -> dict:
|
||||
"""为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。"""
|
||||
"""为子进程创建独立进程组,便于超时场景清理整棵子进程。"""
|
||||
kwargs = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
@@ -124,23 +178,14 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
output: _CommandOutput,
|
||||
limit_reached: asyncio.Event,
|
||||
) -> None:
|
||||
"""按块读取输出,达到上限后通知主流程终止命令。"""
|
||||
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
|
||||
while True:
|
||||
chunk = await stream.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
continue
|
||||
|
||||
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
# 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
|
||||
@@ -205,27 +250,33 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
output: _CommandOutput,
|
||||
timeout: int,
|
||||
timed_out: bool,
|
||||
output_limited: bool,
|
||||
timeout_note: Optional[str],
|
||||
) -> str:
|
||||
if timed_out:
|
||||
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
|
||||
elif output_limited:
|
||||
result = (
|
||||
f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符,"
|
||||
f"已截断并终止进程,退出码: {exit_code})"
|
||||
)
|
||||
else:
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
|
||||
if timeout_note:
|
||||
result += f"\n\n提示:\n{timeout_note}"
|
||||
if output.temp_file_path:
|
||||
file_note = (
|
||||
"截至命令终止前的完整输出"
|
||||
if timed_out
|
||||
else "完整输出"
|
||||
)
|
||||
result += (
|
||||
"\n\n提示:\n"
|
||||
f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
|
||||
f"{file_note}已写入临时文件: {output.temp_file_path}\n"
|
||||
"如需完整内容,请继续读取该文件。"
|
||||
)
|
||||
if output.stdout:
|
||||
result += f"\n\n标准输出:\n{output.stdout}"
|
||||
if output.stderr:
|
||||
result += f"\n\n错误输出:\n{output.stderr}"
|
||||
if output.truncated:
|
||||
result += "\n\n...(输出内容过长,已截断)"
|
||||
if output.preview_truncated:
|
||||
result += "\n\n...(仅展示前 10KB 内容)"
|
||||
if not output.stdout and not output.stderr:
|
||||
result += "\n\n(无输出内容)"
|
||||
return result
|
||||
@@ -252,51 +303,40 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
async with _command_semaphore:
|
||||
# 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。
|
||||
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, **self._subprocess_kwargs()
|
||||
)
|
||||
output = _CommandOutput(limit=MAX_OUTPUT_CHARS)
|
||||
limit_reached = asyncio.Event()
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
limit_task = asyncio.create_task(limit_reached.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stdout, "stdout", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stdout, "stdout", output)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stderr, "stderr", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stderr, "stderr", output)
|
||||
),
|
||||
]
|
||||
|
||||
timed_out = False
|
||||
output_limited = False
|
||||
done, _ = await asyncio.wait(
|
||||
{wait_task, limit_task},
|
||||
timeout=normalized_timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if wait_task not in done:
|
||||
if limit_task in done:
|
||||
output_limited = True
|
||||
else:
|
||||
timed_out = True
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
|
||||
limit_task.cancel()
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
|
||||
return self._format_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
output_limited=output_limited,
|
||||
timeout_note=timeout_note,
|
||||
)
|
||||
|
||||
|
||||
@@ -352,6 +352,16 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return current_fileitem, None # 返回一个表示失败的FileItem和None
|
||||
target_dir_path = Path(target_dir_item.path)
|
||||
# 图片通常是放在当前目录 (current_fileitem) 下
|
||||
# Jellyfin/Kodi 等在季目录内使用通用图片名,而不是 season01-poster.jpg
|
||||
elif item_type == ScrapingTarget.SEASON:
|
||||
season_image_name_map = {
|
||||
ScrapingMetadata.POSTER: "poster",
|
||||
ScrapingMetadata.BANNER: "banner",
|
||||
ScrapingMetadata.THUMB: "thumb",
|
||||
}
|
||||
if season_image_name := season_image_name_map.get(metadata_type):
|
||||
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
|
||||
final_filename = f"{season_image_name}{hint_ext}"
|
||||
# 如果是 EPISODE 类型的图片(如thumb),通常也是放在文件同级目录,文件名与视频文件一致
|
||||
elif (
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
|
||||
@@ -20,7 +20,7 @@ from app.core.event import eventmanager
|
||||
from app.core.meta import MetaBase
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
@@ -1686,7 +1686,102 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_shared_download_roots(file_path: Path) -> set[str]:
|
||||
"""
|
||||
获取当前文件所在的共享下载根目录边界。
|
||||
|
||||
父目录兜底回查只应在种子自身目录内进行,不能越过共享下载根目录,
|
||||
否则历史中的单文件/无子目录任务会污染同级其它文件的识别结果。
|
||||
"""
|
||||
shared_roots: set[str] = set()
|
||||
media_type_dirs = {mtype.value for mtype in MediaType}
|
||||
|
||||
for dir_info in DirectoryHelper().get_download_dirs():
|
||||
if not dir_info.download_path:
|
||||
continue
|
||||
|
||||
download_root = Path(dir_info.download_path)
|
||||
if not file_path.is_relative_to(download_root):
|
||||
continue
|
||||
|
||||
shared_roots.add(download_root.as_posix())
|
||||
relative_parts = file_path.relative_to(download_root).parts
|
||||
current_root = download_root
|
||||
part_index = 0
|
||||
|
||||
if (
|
||||
not dir_info.media_type
|
||||
and dir_info.download_type_folder
|
||||
and len(relative_parts) > part_index
|
||||
and relative_parts[part_index] in media_type_dirs
|
||||
):
|
||||
current_root = current_root / relative_parts[part_index]
|
||||
shared_roots.add(current_root.as_posix())
|
||||
part_index += 1
|
||||
|
||||
if (
|
||||
not dir_info.media_category
|
||||
and dir_info.download_category_folder
|
||||
and len(relative_parts) > part_index
|
||||
):
|
||||
current_root = current_root / relative_parts[part_index]
|
||||
shared_roots.add(current_root.as_posix())
|
||||
|
||||
return shared_roots
|
||||
|
||||
@staticmethod
|
||||
def _match_download_file(
|
||||
download_file: DownloadFiles,
|
||||
file_path: Path,
|
||||
save_path: Path,
|
||||
) -> bool:
|
||||
"""
|
||||
判断下载文件记录是否明确对应当前文件。
|
||||
"""
|
||||
if download_file.fullpath == file_path.as_posix():
|
||||
return True
|
||||
|
||||
filepath = download_file.filepath
|
||||
if not filepath:
|
||||
return False
|
||||
|
||||
try:
|
||||
return (save_path / Path(filepath)).as_posix() == file_path.as_posix()
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
def _resolve_history_from_download_files(
|
||||
self,
|
||||
downloadhis: DownloadHistoryOper,
|
||||
download_files: List[DownloadFiles],
|
||||
file_path: Optional[Path] = None,
|
||||
save_path: Optional[Path] = None,
|
||||
) -> Optional[DownloadHistory]:
|
||||
"""
|
||||
从下载文件记录中解析唯一的下载历史。
|
||||
"""
|
||||
if file_path and save_path:
|
||||
download_files = [
|
||||
download_file
|
||||
for download_file in download_files
|
||||
if self._match_download_file(
|
||||
download_file=download_file,
|
||||
file_path=file_path,
|
||||
save_path=save_path,
|
||||
)
|
||||
]
|
||||
|
||||
download_hashes = {
|
||||
download_file.download_hash
|
||||
for download_file in download_files
|
||||
if download_file.download_hash
|
||||
}
|
||||
if len(download_hashes) == 1:
|
||||
return downloadhis.get_by_hash(next(iter(download_hashes)))
|
||||
return None
|
||||
|
||||
def _resolve_download_history(
|
||||
self,
|
||||
downloadhis: DownloadHistoryOper,
|
||||
file_path: Path,
|
||||
bluray_dir: bool = False,
|
||||
@@ -1707,20 +1802,35 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 多文件种子里的字幕/附加文件可能没有稳定的 fullpath 记录,
|
||||
# 退回到父目录和 savepath 继续查找,尽量补齐同一种子的关联信息。
|
||||
shared_download_roots = self._get_shared_download_roots(file_path)
|
||||
|
||||
for parent_path in file_path.parents:
|
||||
parent_posix = parent_path.as_posix()
|
||||
download_files = downloadhis.get_files_by_savepath(parent_posix) or []
|
||||
|
||||
if parent_posix in shared_download_roots:
|
||||
# 共享下载根目录只能接受有明确文件记录的匹配,
|
||||
# 避免单文件/磁力任务把整个根目录污染成同一媒体。
|
||||
history = self._resolve_history_from_download_files(
|
||||
downloadhis=downloadhis,
|
||||
download_files=download_files,
|
||||
file_path=file_path,
|
||||
save_path=parent_path,
|
||||
)
|
||||
if history:
|
||||
return history
|
||||
break
|
||||
|
||||
download_history = downloadhis.get_by_path(parent_posix)
|
||||
if download_history:
|
||||
return download_history
|
||||
|
||||
download_files = downloadhis.get_files_by_savepath(parent_posix) or []
|
||||
download_hashes = {
|
||||
download_file.download_hash
|
||||
for download_file in download_files
|
||||
if download_file.download_hash
|
||||
}
|
||||
if len(download_hashes) == 1:
|
||||
return downloadhis.get_by_hash(next(iter(download_hashes)))
|
||||
history = self._resolve_history_from_download_files(
|
||||
downloadhis=downloadhis,
|
||||
download_files=download_files,
|
||||
)
|
||||
if history:
|
||||
return history
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
from asyncio import AbstractEventLoop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import quote, urlencode, urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
@@ -126,8 +126,8 @@ class ConfigModel(BaseModel):
|
||||
DB_SQLITE_MAX_OVERFLOW: int = 50
|
||||
# PostgreSQL 主机地址
|
||||
DB_POSTGRESQL_HOST: str = "localhost"
|
||||
# PostgreSQL 端口
|
||||
DB_POSTGRESQL_PORT: int = 5432
|
||||
# PostgreSQL 端口;使用 Unix Socket 时可留空
|
||||
DB_POSTGRESQL_PORT: str = "5432"
|
||||
# PostgreSQL 数据库名
|
||||
DB_POSTGRESQL_DATABASE: str = "moviepilot"
|
||||
# PostgreSQL 用户名
|
||||
@@ -142,7 +142,7 @@ class ConfigModel(BaseModel):
|
||||
# ==================== 缓存配置 ====================
|
||||
# 缓存类型,支持 cachetools 和 redis,默认使用 cachetools
|
||||
CACHE_BACKEND_TYPE: str = "cachetools"
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要,支持 Redis Unix Socket URL
|
||||
CACHE_BACKEND_URL: Optional[str] = "redis://localhost:6379"
|
||||
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
|
||||
CACHE_REDIS_MAXMEMORY: Optional[str] = None
|
||||
@@ -921,6 +921,39 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
}
|
||||
return None
|
||||
|
||||
@property
|
||||
def DB_POSTGRESQL_SOCKET_MODE(self) -> bool:
|
||||
host = (self.DB_POSTGRESQL_HOST or "").strip()
|
||||
return host.startswith("/")
|
||||
|
||||
@property
|
||||
def DB_POSTGRESQL_TARGET(self) -> str:
|
||||
if self.DB_POSTGRESQL_SOCKET_MODE:
|
||||
target = f"socket {self.DB_POSTGRESQL_HOST}"
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
target = f"{target} (port {self.DB_POSTGRESQL_PORT})"
|
||||
return target
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
return f"{self.DB_POSTGRESQL_HOST}:{self.DB_POSTGRESQL_PORT}"
|
||||
return self.DB_POSTGRESQL_HOST
|
||||
|
||||
def DB_POSTGRESQL_URL(self, driver: Optional[str] = None) -> str:
|
||||
scheme = "postgresql" if not driver else f"postgresql+{driver}"
|
||||
username = quote(str(self.DB_POSTGRESQL_USERNAME), safe="")
|
||||
database = quote(str(self.DB_POSTGRESQL_DATABASE), safe="")
|
||||
auth = username
|
||||
if self.DB_POSTGRESQL_PASSWORD:
|
||||
auth = f"{auth}:{quote(str(self.DB_POSTGRESQL_PASSWORD), safe='')}"
|
||||
|
||||
if self.DB_POSTGRESQL_SOCKET_MODE:
|
||||
query = {"host": self.DB_POSTGRESQL_HOST}
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
query["port"] = self.DB_POSTGRESQL_PORT
|
||||
return f"{scheme}://{auth}@/{database}?{urlencode(query)}"
|
||||
|
||||
port = f":{self.DB_POSTGRESQL_PORT}" if self.DB_POSTGRESQL_PORT else ""
|
||||
return f"{scheme}://{auth}@{self.DB_POSTGRESQL_HOST}{port}/{database}"
|
||||
|
||||
@property
|
||||
def PROXY_SERVER(self):
|
||||
if self.PROXY_HOST:
|
||||
|
||||
@@ -116,11 +116,7 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
"""
|
||||
获取PostgreSQL数据库引擎
|
||||
"""
|
||||
# 构建PostgreSQL连接URL
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
db_url = settings.DB_POSTGRESQL_URL()
|
||||
|
||||
# PostgreSQL连接参数
|
||||
_connect_args = {}
|
||||
@@ -150,12 +146,11 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(**_db_kwargs)
|
||||
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return engine
|
||||
else:
|
||||
# 构建异步PostgreSQL连接URL
|
||||
async_db_url = f"postgresql+asyncpg://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
async_db_url = settings.DB_POSTGRESQL_URL("asyncpg")
|
||||
|
||||
# 数据库参数,只能使用 NullPool
|
||||
_db_kwargs = {
|
||||
@@ -168,7 +163,7 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
}
|
||||
# 创建异步数据库引擎
|
||||
async_engine = create_async_engine(**_db_kwargs)
|
||||
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
@@ -28,10 +28,7 @@ def update_db():
|
||||
|
||||
# 根据数据库类型设置不同的URL
|
||||
if settings.DB_TYPE.lower() == "postgresql":
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
db_url = settings.DB_POSTGRESQL_URL()
|
||||
else:
|
||||
db_location = settings.CONFIG_PATH / 'user.db'
|
||||
db_url = f"sqlite:///{db_location}"
|
||||
|
||||
@@ -176,86 +176,101 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
if item:
|
||||
return [item]
|
||||
return []
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/list"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
"password": password,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"refresh": refresh,
|
||||
},
|
||||
)
|
||||
"""
|
||||
{
|
||||
"path": "/t",
|
||||
"password": "",
|
||||
"page": 1,
|
||||
"per_page": 0,
|
||||
"refresh": false
|
||||
}
|
||||
======================================
|
||||
{
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": "Alist V3.md",
|
||||
"size": 1592,
|
||||
"is_dir": false,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"created": "2024-05-17T13:47:47.5725906+08:00",
|
||||
"sign": "",
|
||||
"thumb": "",
|
||||
"type": 4,
|
||||
"hashinfo": "null",
|
||||
"hash_info": null
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"readme": "",
|
||||
"header": "",
|
||||
"write": true,
|
||||
"provider": "Local"
|
||||
items = []
|
||||
current_page = page
|
||||
while True:
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/list"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
"password": password,
|
||||
"page": current_page,
|
||||
"per_page": per_page,
|
||||
"refresh": refresh,
|
||||
},
|
||||
)
|
||||
"""
|
||||
{
|
||||
"path": "/t",
|
||||
"password": "",
|
||||
"page": 1,
|
||||
"per_page": 0,
|
||||
"refresh": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
======================================
|
||||
{
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": "Alist V3.md",
|
||||
"size": 1592,
|
||||
"is_dir": false,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"created": "2024-05-17T13:47:47.5725906+08:00",
|
||||
"sign": "",
|
||||
"thumb": "",
|
||||
"type": 4,
|
||||
"hashinfo": "null",
|
||||
"hash_info": null
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"readme": "",
|
||||
"header": "",
|
||||
"write": true,
|
||||
"provider": "Local"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
if resp is None:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务"
|
||||
)
|
||||
return []
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return []
|
||||
if resp is None:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务"
|
||||
)
|
||||
return []
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return []
|
||||
|
||||
result = resp.json()
|
||||
result = resp.json()
|
||||
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
|
||||
)
|
||||
return []
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
|
||||
)
|
||||
return []
|
||||
|
||||
return [
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir" if item["is_dir"] else "file",
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix()
|
||||
+ ("/" if item["is_dir"] else ""),
|
||||
name=item["name"],
|
||||
basename=Path(item["name"]).stem,
|
||||
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
|
||||
size=item["size"] if not item["is_dir"] else None,
|
||||
modify_time=self.__parse_timestamp(item["modified"]),
|
||||
thumbnail=item["thumb"],
|
||||
page_content = result["data"].get("content") or []
|
||||
items.extend(
|
||||
[
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir" if item["is_dir"] else "file",
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix()
|
||||
+ ("/" if item["is_dir"] else ""),
|
||||
name=item["name"],
|
||||
basename=Path(item["name"]).stem,
|
||||
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
|
||||
size=item["size"] if not item["is_dir"] else None,
|
||||
modify_time=self.__parse_timestamp(item["modified"]),
|
||||
thumbnail=item["thumb"],
|
||||
)
|
||||
for item in page_content
|
||||
]
|
||||
)
|
||||
for item in result["data"]["content"] or []
|
||||
]
|
||||
|
||||
if per_page > 0:
|
||||
return items
|
||||
|
||||
total = result["data"].get("total") or 0
|
||||
if not page_content or len(items) >= total:
|
||||
return items
|
||||
|
||||
current_page += 1
|
||||
|
||||
def create_folder(
|
||||
self, fileitem: schemas.FileItem, name: str
|
||||
|
||||
@@ -148,7 +148,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
# 如果要选择文件则先暂停
|
||||
is_paused = True if episodes else False
|
||||
# 添加任务
|
||||
state = server.add_torrent(
|
||||
state, added_torrent_ids = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
@@ -188,7 +188,11 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
return None, None, None, f"添加种子任务失败:{content}"
|
||||
else:
|
||||
# 获取种子Hash
|
||||
torrent_hash = server.get_torrent_id_by_tag(tags=tag)
|
||||
torrent_hash = next(iter(added_torrent_ids), None)
|
||||
if torrent_hash:
|
||||
server.delete_torrents_tag(torrent_hash, tag)
|
||||
else:
|
||||
torrent_hash = server.get_torrent_id_by_tag(tags=tag)
|
||||
if not torrent_hash:
|
||||
return None, None, None, f"下载任务添加成功,但获取Qbittorrent任务信息失败:{content}"
|
||||
else:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from http.cookies import SimpleCookie
|
||||
from typing import Any, Optional, Union, Tuple, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import qbittorrentapi
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from qbittorrentapi import TorrentDictionary, TorrentFilesList
|
||||
from qbittorrentapi.client import Client
|
||||
from qbittorrentapi.transfer import TransferInfoDictionary
|
||||
@@ -17,6 +20,7 @@ class Qbittorrent:
|
||||
"""
|
||||
def __init__(self, host: Optional[str] = None, port: int = None,
|
||||
username: Optional[str] = None, password: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
category: Optional[bool] = False, sequentail: Optional[bool] = False,
|
||||
force_resume: Optional[bool] = False, first_last_piece=False,
|
||||
**kwargs):
|
||||
@@ -33,12 +37,122 @@ class Qbittorrent:
|
||||
return
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._apikey = str(apikey or "").strip() or None
|
||||
self._category = category
|
||||
self._sequentail = sequentail
|
||||
self._force_resume = force_resume
|
||||
self._first_last_piece = first_last_piece
|
||||
self.qbc = self.__login_qbittorrent()
|
||||
|
||||
@staticmethod
|
||||
def __get_mapping_value(data: Any, key: str) -> Any:
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, dict):
|
||||
return data.get(key)
|
||||
getter = getattr(data, "get", None)
|
||||
if callable(getter):
|
||||
try:
|
||||
return getter(key)
|
||||
except Exception:
|
||||
pass
|
||||
return getattr(data, key, None)
|
||||
|
||||
def __normalize_cookie(self, cookie: Any) -> dict:
|
||||
result = {}
|
||||
for key in ("domain", "path", "name", "value", "expirationDate"):
|
||||
value = self.__get_mapping_value(cookie, key)
|
||||
if value not in (None, ""):
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def __cookie_key(cookie: dict) -> Optional[tuple]:
|
||||
name = cookie.get("name")
|
||||
domain = cookie.get("domain")
|
||||
path = cookie.get("path") or "/"
|
||||
if not name or not domain:
|
||||
return None
|
||||
return domain, path, name
|
||||
|
||||
@staticmethod
|
||||
def __build_site_cookies(url: str, cookie_header: str) -> List[dict]:
|
||||
domain = urlparse(url).hostname
|
||||
if not domain:
|
||||
return []
|
||||
|
||||
raw_cookies = SimpleCookie()
|
||||
raw_cookies.load(cookie_header)
|
||||
return [
|
||||
{
|
||||
"domain": domain,
|
||||
"path": "/",
|
||||
"name": morsel.key,
|
||||
"value": morsel.value,
|
||||
}
|
||||
for morsel in raw_cookies.values()
|
||||
]
|
||||
|
||||
def __parse_add_torrent_response(self, response: Any) -> Tuple[bool, List[str]]:
|
||||
if not response:
|
||||
return False, []
|
||||
if isinstance(response, str):
|
||||
return "Ok" in response, []
|
||||
|
||||
success_count = self.__get_mapping_value(response, "success_count") or 0
|
||||
pending_count = self.__get_mapping_value(response, "pending_count") or 0
|
||||
added_torrent_ids = self.__get_mapping_value(response, "added_torrent_ids") or []
|
||||
if not isinstance(added_torrent_ids, list):
|
||||
added_torrent_ids = list(added_torrent_ids)
|
||||
added_torrent_ids = [str(torrent_id) for torrent_id in added_torrent_ids if torrent_id]
|
||||
if added_torrent_ids:
|
||||
return True, added_torrent_ids
|
||||
if success_count or pending_count:
|
||||
return True, []
|
||||
return "Ok" in str(response), []
|
||||
|
||||
def __use_api_key_auth(self) -> bool:
|
||||
return bool(self._apikey)
|
||||
|
||||
def __supports_cookie_api(self) -> bool:
|
||||
if not self.qbc:
|
||||
return False
|
||||
try:
|
||||
web_api_version = self.qbc.app_web_api_version()
|
||||
return Version(str(web_api_version)) >= Version("2.11.3")
|
||||
except (InvalidVersion, TypeError, ValueError):
|
||||
return False
|
||||
except Exception as err:
|
||||
logger.warn(f"获取 qbittorrent Web API 版本失败,跳过 Cookie API 兼容:{err}")
|
||||
return False
|
||||
|
||||
def __sync_download_cookies(self, url: str, cookie_header: str) -> bool:
|
||||
if not self.qbc or not url or not cookie_header or not self.__supports_cookie_api():
|
||||
return False
|
||||
|
||||
try:
|
||||
site_cookies = self.__build_site_cookies(url=url, cookie_header=cookie_header)
|
||||
if not site_cookies:
|
||||
return False
|
||||
|
||||
merged_cookies = {}
|
||||
for cookie in self.qbc.app_cookies() or []:
|
||||
normalized = self.__normalize_cookie(cookie)
|
||||
cookie_key = self.__cookie_key(normalized)
|
||||
if cookie_key:
|
||||
merged_cookies[cookie_key] = normalized
|
||||
|
||||
for cookie in site_cookies:
|
||||
cookie_key = self.__cookie_key(cookie)
|
||||
if cookie_key:
|
||||
merged_cookies[cookie_key] = cookie
|
||||
|
||||
self.qbc.app_set_cookies(cookies=list(merged_cookies.values()))
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"同步下载Cookie出错:{str(err)}")
|
||||
return False
|
||||
|
||||
def is_inactive(self) -> bool:
|
||||
"""
|
||||
判断是否需要重连
|
||||
@@ -67,14 +181,20 @@ class Qbittorrent:
|
||||
port=self._port,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
EXTRA_HEADERS={"Authorization": f"Bearer {self._apikey}"}
|
||||
if self.__use_api_key_auth() else None,
|
||||
VERIFY_WEBUI_CERTIFICATE=False,
|
||||
REQUESTS_ARGS={'timeout': (15, 60)})
|
||||
try:
|
||||
qbt.auth_log_in()
|
||||
except (qbittorrentapi.LoginFailed, qbittorrentapi.Forbidden403Error) as e:
|
||||
logger.error(f"qbittorrent 登录失败:{str(e).strip() or '请检查用户名和密码是否正确'}")
|
||||
return None
|
||||
if self.__use_api_key_auth():
|
||||
qbt.app_version()
|
||||
else:
|
||||
qbt.auth_log_in()
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ in {"LoginFailed", "Forbidden403Error", "Unauthorized401Error"}:
|
||||
error_hint = "请检查 API Key 是否正确" if self.__use_api_key_auth() else "请检查用户名和密码是否正确"
|
||||
logger.error(f"qbittorrent 登录失败:{str(e).strip() or error_hint}")
|
||||
return None
|
||||
stack_trace = "".join(traceback.format_exception(None, e, e.__traceback__))[:2000]
|
||||
logger.error(f"qbittorrent 登录失败:{str(e)}\n{stack_trace}")
|
||||
return None
|
||||
@@ -241,7 +361,7 @@ class Qbittorrent:
|
||||
category: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> bool:
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
添加种子
|
||||
:param content: 种子urls或文件内容
|
||||
@@ -251,10 +371,10 @@ class Qbittorrent:
|
||||
:param download_dir: 下载路径
|
||||
:param cookie: 站点Cookie用于辅助下载种子
|
||||
:param kwargs: 可选参数,如 ignore_category_check 以及 QB相关参数
|
||||
:return: bool
|
||||
:return: 添加是否成功, 新版API返回的种子ID列表
|
||||
"""
|
||||
if not self.qbc or not content:
|
||||
return False
|
||||
return False, []
|
||||
|
||||
# 下载内容
|
||||
if isinstance(content, str):
|
||||
@@ -287,6 +407,11 @@ class Qbittorrent:
|
||||
is_auto = False
|
||||
category = None
|
||||
try:
|
||||
cookie_to_use = cookie
|
||||
if urls and cookie and not StringUtils.is_magnet_link(urls):
|
||||
if self.__sync_download_cookies(url=urls, cookie_header=cookie):
|
||||
cookie_to_use = None
|
||||
|
||||
# 添加下载
|
||||
qbc_ret = self.qbc.torrents_add(urls=urls,
|
||||
torrent_files=torrent_files,
|
||||
@@ -296,13 +421,13 @@ class Qbittorrent:
|
||||
use_auto_torrent_management=is_auto,
|
||||
is_sequential_download=self._sequentail,
|
||||
is_first_last_piece_priority=self._first_last_piece,
|
||||
cookie=cookie,
|
||||
cookie=cookie_to_use,
|
||||
category=category,
|
||||
**kwargs)
|
||||
return True if qbc_ret and str(qbc_ret).find("Ok") != -1 else False
|
||||
return self.__parse_add_torrent_response(qbc_ret)
|
||||
except Exception as err:
|
||||
logger.error(f"添加种子出错:{str(err)}")
|
||||
return False
|
||||
return False, []
|
||||
|
||||
def start_torrents(self, ids: Union[str, list]) -> bool:
|
||||
"""
|
||||
|
||||
@@ -24,7 +24,7 @@ DB_TYPE=postgresql
|
||||
# PostgreSQL 主机地址
|
||||
DB_POSTGRESQL_HOST=localhost
|
||||
|
||||
# PostgreSQL 端口
|
||||
# PostgreSQL 端口;使用 Unix Socket 时可留空
|
||||
DB_POSTGRESQL_PORT=5432
|
||||
|
||||
# PostgreSQL 数据库名
|
||||
@@ -43,6 +43,21 @@ DB_POSTGRESQL_POOL_SIZE=20
|
||||
DB_POSTGRESQL_MAX_OVERFLOW=30
|
||||
```
|
||||
|
||||
### 3. Unix Socket 连接
|
||||
|
||||
如果 PostgreSQL 通过 Unix Socket 暴露,可以把 `DB_POSTGRESQL_HOST` 设置为套接字目录。
|
||||
|
||||
```bash
|
||||
DB_TYPE=postgresql
|
||||
DB_POSTGRESQL_HOST=/var/run/postgresql
|
||||
DB_POSTGRESQL_PORT=
|
||||
DB_POSTGRESQL_DATABASE=moviepilot
|
||||
DB_POSTGRESQL_USERNAME=moviepilot
|
||||
DB_POSTGRESQL_PASSWORD=moviepilot
|
||||
```
|
||||
|
||||
如需显式指定 socket 端口,也可以保留 `DB_POSTGRESQL_PORT`,程序会生成带 `host=/path/to/socket` 查询参数的 PostgreSQL URL。
|
||||
|
||||
## Docker 部署
|
||||
|
||||
### 使用外部 PostgreSQL
|
||||
@@ -60,6 +75,13 @@ DB_POSTGRESQL_USERNAME=your-username
|
||||
DB_POSTGRESQL_PASSWORD=your-password
|
||||
```
|
||||
|
||||
使用 Redis Unix Socket 时,可直接设置 `CACHE_BACKEND_URL`,例如:
|
||||
|
||||
```bash
|
||||
CACHE_BACKEND_TYPE=redis
|
||||
CACHE_BACKEND_URL=unix:///var/run/redis/redis.sock?db=0
|
||||
```
|
||||
|
||||
## 数据迁移
|
||||
|
||||
### 从 SQLite 迁移到 PostgreSQL
|
||||
|
||||
@@ -28,7 +28,7 @@ APScheduler~=3.11.0
|
||||
cryptography~=45.0.4
|
||||
pytz~=2025.2
|
||||
pycryptodome~=3.23.0
|
||||
qbittorrent-api==2025.5.0
|
||||
qbittorrent-api==2026.5.1
|
||||
plexapi~=4.17.0
|
||||
transmission-rpc~=4.3.0
|
||||
Jinja2~=3.1.6
|
||||
@@ -89,3 +89,4 @@ openai~=2.33.0
|
||||
google-genai~=1.74.0
|
||||
ddgs~=9.10.0
|
||||
websocket-client~=1.8.0
|
||||
pytest~=8.4.0
|
||||
|
||||
@@ -1312,8 +1312,9 @@ def _collect_downloader_config() -> Optional[dict[str, Any]]:
|
||||
config_name = _prompt_text("下载器名称", default=downloader_type)
|
||||
if downloader_type == "qbittorrent":
|
||||
host = _prompt_text("qBittorrent 地址", default="http://127.0.0.1:8080")
|
||||
username = _prompt_text("qBittorrent 用户名", default="admin")
|
||||
password = _prompt_text("qBittorrent 密码", secret=True)
|
||||
apikey = _prompt_text("qBittorrent API Key(可选,5.2+ 推荐)", allow_empty=True, default="")
|
||||
username = _prompt_text("qBittorrent 用户名", default="admin") if not apikey else ""
|
||||
password = _prompt_text("qBittorrent 密码", secret=True, allow_empty=bool(apikey)) if not apikey else ""
|
||||
category = _prompt_yes_no("是否启用 qBittorrent 分类", default=False)
|
||||
return {
|
||||
"name": config_name,
|
||||
@@ -1322,6 +1323,7 @@ def _collect_downloader_config() -> Optional[dict[str, Any]]:
|
||||
"enabled": True,
|
||||
"config": {
|
||||
"host": host,
|
||||
"apikey": apikey,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"category": category,
|
||||
|
||||
205
tests/test_alist_storage.py
Normal file
205
tests/test_alist_storage.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _load_alist_module():
|
||||
module_name = "_test_alist_module"
|
||||
app_module = types.ModuleType("app")
|
||||
schemas_module = types.ModuleType("app.schemas")
|
||||
cache_module = types.ModuleType("app.core.cache")
|
||||
config_module = types.ModuleType("app.core.config")
|
||||
log_module = types.ModuleType("app.log")
|
||||
storages_module = types.ModuleType("app.modules.filemanager.storages")
|
||||
exception_module = types.ModuleType("app.schemas.exception")
|
||||
types_module = types.ModuleType("app.schemas.types")
|
||||
http_module = types.ModuleType("app.utils.http")
|
||||
singleton_module = types.ModuleType("app.utils.singleton")
|
||||
url_module = types.ModuleType("app.utils.url")
|
||||
|
||||
class _FileItem:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
class _StorageSchemaValue:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
class _Logger:
|
||||
def debug(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warn(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warning(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def error(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def critical(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def info(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
class _StorageBase:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_conf(self):
|
||||
return {}
|
||||
|
||||
class _OperationInterrupted(Exception):
|
||||
pass
|
||||
|
||||
class _RequestUtils:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class _UrlUtils:
|
||||
@staticmethod
|
||||
def standardize_base_url(url):
|
||||
return url.rstrip("/") if url else ""
|
||||
|
||||
@staticmethod
|
||||
def adapt_request_url(base, path):
|
||||
return f"{base() if callable(base) else base}{path}"
|
||||
|
||||
@staticmethod
|
||||
def quote(path):
|
||||
return path
|
||||
|
||||
def _cached(*_args, **_kwargs):
|
||||
def decorator(func):
|
||||
func.cache_clear = lambda: None
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
schemas_module.FileItem = _FileItem
|
||||
schemas_module.StorageUsage = object
|
||||
cache_module.cached = _cached
|
||||
config_module.settings = types.SimpleNamespace(
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME=True,
|
||||
TEMP_PATH=Path("/tmp"),
|
||||
)
|
||||
config_module.global_vars = types.SimpleNamespace(
|
||||
is_transfer_stopped=lambda *_args, **_kwargs: False
|
||||
)
|
||||
log_module.logger = _Logger()
|
||||
storages_module.StorageBase = _StorageBase
|
||||
storages_module.transfer_process = lambda *_args, **_kwargs: (lambda *_a, **_k: None)
|
||||
exception_module.OperationInterrupted = _OperationInterrupted
|
||||
types_module.StorageSchema = types.SimpleNamespace(Alist=_StorageSchemaValue("alist"))
|
||||
http_module.RequestUtils = _RequestUtils
|
||||
singleton_module.WeakSingleton = type
|
||||
url_module.UrlUtils = _UrlUtils
|
||||
|
||||
app_module.schemas = schemas_module
|
||||
|
||||
stub_modules = {
|
||||
"app": app_module,
|
||||
"app.schemas": schemas_module,
|
||||
"app.core.cache": cache_module,
|
||||
"app.core.config": config_module,
|
||||
"app.log": log_module,
|
||||
"app.modules.filemanager.storages": storages_module,
|
||||
"app.schemas.exception": exception_module,
|
||||
"app.schemas.types": types_module,
|
||||
"app.utils.http": http_module,
|
||||
"app.utils.singleton": singleton_module,
|
||||
"app.utils.url": url_module,
|
||||
}
|
||||
for stub_module in stub_modules.values():
|
||||
stub_module._alist_test_stub = True
|
||||
|
||||
alist_path = Path(__file__).resolve().parents[1] / "app" / "modules" / "filemanager" / "storages" / "alist.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, alist_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
with patch.dict(sys.modules, stub_modules):
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
alist_module = _load_alist_module()
|
||||
Alist = alist_module.Alist
|
||||
FileItem = alist_module.schemas.FileItem
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload: dict, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class AlistStorageTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.storage = Alist()
|
||||
|
||||
@staticmethod
|
||||
def _dir_item(path: str = "/"):
|
||||
return FileItem(storage="alist", type="dir", path=path)
|
||||
|
||||
@staticmethod
|
||||
def _page_payload(start: int, count: int, total: int) -> dict:
|
||||
return {
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": f"dir-{index}",
|
||||
"size": 0,
|
||||
"is_dir": True,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"thumb": "",
|
||||
}
|
||||
for index in range(start, start + count)
|
||||
],
|
||||
"total": total,
|
||||
},
|
||||
}
|
||||
|
||||
def test_list_fetches_all_pages_when_per_page_is_default(self):
|
||||
responses = [
|
||||
_FakeResponse(self._page_payload(0, 200, 205)),
|
||||
_FakeResponse(self._page_payload(200, 5, 205)),
|
||||
]
|
||||
request_utils = MagicMock()
|
||||
request_utils.post_res.side_effect = responses
|
||||
|
||||
with patch.object(Alist, "get_conf", return_value={"url": "http://openlist.test", "token": "token"}):
|
||||
with patch.object(alist_module, "RequestUtils", return_value=request_utils):
|
||||
items = self.storage.list(self._dir_item())
|
||||
|
||||
self.assertEqual(205, len(items))
|
||||
self.assertEqual("/dir-0/", items[0].path)
|
||||
self.assertEqual("/dir-204/", items[-1].path)
|
||||
self.assertEqual(2, request_utils.post_res.call_count)
|
||||
self.assertEqual(1, request_utils.post_res.call_args_list[0].kwargs["json"]["page"])
|
||||
self.assertEqual(2, request_utils.post_res.call_args_list[1].kwargs["json"]["page"])
|
||||
|
||||
def test_list_respects_explicit_per_page_without_auto_paging(self):
|
||||
request_utils = MagicMock()
|
||||
request_utils.post_res.return_value = _FakeResponse(self._page_payload(0, 50, 205))
|
||||
|
||||
with patch.object(Alist, "get_conf", return_value={"url": "http://openlist.test", "token": "token"}):
|
||||
with patch.object(alist_module, "RequestUtils", return_value=request_utils):
|
||||
items = self.storage.list(self._dir_item(), per_page=50)
|
||||
|
||||
self.assertEqual(50, len(items))
|
||||
self.assertEqual(1, request_utils.post_res.call_count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -8,7 +9,7 @@ import unittest
|
||||
|
||||
from app.agent.tools.impl.execute_command import (
|
||||
ExecuteCommandTool,
|
||||
MAX_OUTPUT_CHARS,
|
||||
MAX_OUTPUT_PREVIEW_BYTES,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +22,11 @@ def _python_command(code: str) -> str:
|
||||
|
||||
|
||||
class TestExecuteCommandTool(unittest.TestCase):
|
||||
def _temp_file_path_from_result(self, result: str) -> str:
|
||||
match = re.search(r"临时文件: (.+)", result)
|
||||
self.assertIsNotNone(match)
|
||||
return match.group(1).strip()
|
||||
|
||||
def _run_command(self, command: str, timeout: int = 60) -> str:
|
||||
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
return asyncio.run(tool.run(command=command, timeout=timeout))
|
||||
@@ -31,9 +37,19 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
)
|
||||
|
||||
result = self._run_command(command)
|
||||
temp_file_path = self._temp_file_path_from_result(result)
|
||||
|
||||
self.assertIn("输出内容过长,已截断", result)
|
||||
self.assertLess(len(result), MAX_OUTPUT_CHARS + 500)
|
||||
self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path))
|
||||
self.assertIn("命令输出超过 10KB", result)
|
||||
self.assertIn("仅展示前 10KB 内容", result)
|
||||
self.assertIn("如需完整内容,请继续读取该文件", result)
|
||||
self.assertLess(len(result), MAX_OUTPUT_PREVIEW_BYTES + 600)
|
||||
|
||||
with open(temp_file_path, encoding="utf-8") as file_handle:
|
||||
file_content = file_handle.read()
|
||||
|
||||
self.assertIn("[标准输出]", file_content)
|
||||
self.assertGreater(len(file_content), 100000)
|
||||
|
||||
def test_timeout_returns_partial_output_promptly(self):
|
||||
command = _python_command(
|
||||
@@ -48,6 +64,24 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
self.assertIn("命令执行超时", result)
|
||||
self.assertIn("started", result)
|
||||
|
||||
def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self):
|
||||
command = _python_command(
|
||||
"import sys, time; sys.stdout.write('x' * 20000); sys.stdout.flush(); time.sleep(5)"
|
||||
)
|
||||
|
||||
result = self._run_command(command, timeout=1)
|
||||
temp_file_path = self._temp_file_path_from_result(result)
|
||||
|
||||
self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path))
|
||||
self.assertIn("命令执行超时", result)
|
||||
self.assertIn("截至命令终止前的完整输出已写入临时文件", result)
|
||||
|
||||
with open(temp_file_path, encoding="utf-8") as file_handle:
|
||||
file_content = file_handle.read()
|
||||
|
||||
self.assertIn("[标准输出]", file_content)
|
||||
self.assertGreaterEqual(file_content.count("x"), 20000)
|
||||
|
||||
def test_timeout_is_capped(self):
|
||||
command = _python_command("print('ok')")
|
||||
|
||||
|
||||
@@ -64,6 +64,28 @@ class TestMediaScrapingPaths(unittest.TestCase):
|
||||
self.assertEqual(target_item, fileitem)
|
||||
self.assertEqual(target_path, Path("/tv/Show/Season 1/season.nfo"))
|
||||
|
||||
def test_season_dir_poster_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1", name="Season 1", type="dir", storage="local")
|
||||
target_item, target_path = self.media_chain._get_target_fileitem_and_path(
|
||||
current_fileitem=fileitem,
|
||||
item_type=ScrapingTarget.SEASON,
|
||||
metadata_type=ScrapingMetadata.POSTER,
|
||||
filename_hint="season01-poster.jpg"
|
||||
)
|
||||
self.assertEqual(target_item, fileitem)
|
||||
self.assertEqual(target_path, Path("/tv/Show/Season 1/poster.jpg"))
|
||||
|
||||
def test_season_dir_specials_poster_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Specials", name="Specials", type="dir", storage="local")
|
||||
target_item, target_path = self.media_chain._get_target_fileitem_and_path(
|
||||
current_fileitem=fileitem,
|
||||
item_type=ScrapingTarget.SEASON,
|
||||
metadata_type=ScrapingMetadata.POSTER,
|
||||
filename_hint="season-specials-poster.jpg"
|
||||
)
|
||||
self.assertEqual(target_item, fileitem)
|
||||
self.assertEqual(target_path, Path("/tv/Show/Specials/poster.jpg"))
|
||||
|
||||
def test_episode_file_nfo_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1/S01E01.mp4", name="S01E01.mp4", type="file", storage="local")
|
||||
parent_item = schemas.FileItem(path="/tv/Show/Season 1", name="Season 1", type="dir", storage="local")
|
||||
@@ -171,6 +193,7 @@ class TestMediaScrapingImages(unittest.TestCase):
|
||||
calls = self.media_chain._download_and_save_image.call_args_list
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].kwargs["url"], "http://season01")
|
||||
self.assertEqual(calls[0].kwargs["path"], Path("/tv/Show/Season 1/poster.jpg"))
|
||||
|
||||
def test_scrape_episode_thumb_image_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1/S01E01.mp4", name="S01E01.mp4", type="file", storage="local")
|
||||
|
||||
101
tests/test_postgresql_socket_config.py
Normal file
101
tests/test_postgresql_socket_config.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import sys
|
||||
import unittest
|
||||
from enum import Enum
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=_DummyLogger(),
|
||||
log_settings=_DummyLogger(),
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
_stub_module("psutil")
|
||||
_schemas_module = _stub_module(
|
||||
"app.schemas", MediaType=Enum("MediaType", {"Movie": "Movie", "TV": "TV"})
|
||||
)
|
||||
_schemas_module.__getattr__ = lambda name: type(name, (), {})
|
||||
_stub_module("version", APP_VERSION="test")
|
||||
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
|
||||
class PostgreSQLSocketConfigTests(unittest.TestCase):
|
||||
def test_postgresql_tcp_url_keeps_host_and_port(self):
|
||||
settings = Settings(
|
||||
DB_POSTGRESQL_HOST="db",
|
||||
DB_POSTGRESQL_PORT="5433",
|
||||
DB_POSTGRESQL_DATABASE="moviepilot",
|
||||
DB_POSTGRESQL_USERNAME="user",
|
||||
DB_POSTGRESQL_PASSWORD="pass",
|
||||
)
|
||||
|
||||
self.assertFalse(settings.DB_POSTGRESQL_SOCKET_MODE)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL(),
|
||||
"postgresql://user:pass@db:5433/moviepilot",
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL("asyncpg"),
|
||||
"postgresql+asyncpg://user:pass@db:5433/moviepilot",
|
||||
)
|
||||
self.assertEqual(settings.DB_POSTGRESQL_TARGET, "db:5433")
|
||||
|
||||
def test_postgresql_socket_url_uses_host_query_param(self):
|
||||
settings = Settings(
|
||||
DB_POSTGRESQL_HOST="/var/run/postgresql",
|
||||
DB_POSTGRESQL_PORT="",
|
||||
DB_POSTGRESQL_DATABASE="moviepilot",
|
||||
DB_POSTGRESQL_USERNAME="user",
|
||||
DB_POSTGRESQL_PASSWORD="pass",
|
||||
)
|
||||
|
||||
self.assertTrue(settings.DB_POSTGRESQL_SOCKET_MODE)
|
||||
self.assertIsNone(settings.DB_POSTGRESQL_PORT_VALUE)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL(),
|
||||
"postgresql://user:pass@/moviepilot?host=%2Fvar%2Frun%2Fpostgresql",
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL("asyncpg"),
|
||||
"postgresql+asyncpg://user:pass@/moviepilot?host=%2Fvar%2Frun%2Fpostgresql",
|
||||
)
|
||||
self.assertEqual(settings.DB_POSTGRESQL_TARGET, "socket /var/run/postgresql")
|
||||
|
||||
def test_postgresql_socket_url_can_keep_explicit_port(self):
|
||||
settings = Settings(
|
||||
DB_POSTGRESQL_HOST="/var/run/postgresql",
|
||||
DB_POSTGRESQL_PORT="5432",
|
||||
DB_POSTGRESQL_DATABASE="moviepilot",
|
||||
DB_POSTGRESQL_USERNAME="user",
|
||||
DB_POSTGRESQL_PASSWORD="",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_URL(),
|
||||
"postgresql://user@/moviepilot?host=%2Fvar%2Frun%2Fpostgresql&port=5432",
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.DB_POSTGRESQL_TARGET,
|
||||
"socket /var/run/postgresql (port 5432)",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
352
tests/test_qbittorrent_compat.py
Normal file
352
tests/test_qbittorrent_compat.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _load_qbittorrent_modules():
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
|
||||
app_module = types.ModuleType("app")
|
||||
app_module.__path__ = []
|
||||
core_module = types.ModuleType("app.core")
|
||||
core_module.__path__ = []
|
||||
utils_module = types.ModuleType("app.utils")
|
||||
utils_module.__path__ = []
|
||||
modules_module = types.ModuleType("app.modules")
|
||||
modules_module.__path__ = []
|
||||
qbittorrent_package_module = types.ModuleType("app.modules.qbittorrent")
|
||||
qbittorrent_package_module.__path__ = []
|
||||
log_module = types.ModuleType("app.log")
|
||||
cache_module = types.ModuleType("app.core.cache")
|
||||
config_module = types.ModuleType("app.core.config")
|
||||
metainfo_module = types.ModuleType("app.core.metainfo")
|
||||
schemas_module = types.ModuleType("app.schemas")
|
||||
schema_types_module = types.ModuleType("app.schemas.types")
|
||||
string_module = types.ModuleType("app.utils.string")
|
||||
torrentool_module = types.ModuleType("torrentool")
|
||||
torrentool_module.__path__ = []
|
||||
torrentool_torrent_module = types.ModuleType("torrentool.torrent")
|
||||
qbittorrentapi_module = types.ModuleType("qbittorrentapi")
|
||||
qbittorrentapi_client_module = types.ModuleType("qbittorrentapi.client")
|
||||
qbittorrentapi_transfer_module = types.ModuleType("qbittorrentapi.transfer")
|
||||
|
||||
class _Logger:
|
||||
def info(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warn(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def warning(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def error(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
class _StringUtils:
|
||||
@staticmethod
|
||||
def get_domain_address(address, prefix=False):
|
||||
return address, 8080
|
||||
|
||||
@staticmethod
|
||||
def is_magnet_link(value):
|
||||
if isinstance(value, bytes):
|
||||
return value.startswith(b"magnet:")
|
||||
return isinstance(value, str) and value.startswith("magnet:")
|
||||
|
||||
@staticmethod
|
||||
def generate_random_str(_length):
|
||||
return "tmp-tag-01"
|
||||
|
||||
@staticmethod
|
||||
def str_filesize(value):
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def str_secends(value):
|
||||
return str(value)
|
||||
|
||||
class _FileCache:
|
||||
def get(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class _MetaInfo:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.year = None
|
||||
self.season_episode = ""
|
||||
self.episode_list = []
|
||||
|
||||
class _ModuleBase:
|
||||
pass
|
||||
|
||||
class _DownloaderBase:
|
||||
def __class_getitem__(cls, _item):
|
||||
return cls
|
||||
|
||||
class _Torrent:
|
||||
@staticmethod
|
||||
def from_string(content):
|
||||
return types.SimpleNamespace(name="test", total_size=len(content))
|
||||
|
||||
class TorrentStatus(Enum):
|
||||
TRANSFER = "transfer"
|
||||
DOWNLOADING = "downloading"
|
||||
|
||||
class ModuleType(Enum):
|
||||
Downloader = "Downloader"
|
||||
|
||||
class DownloaderType(Enum):
|
||||
Qbittorrent = "Qbittorrent"
|
||||
|
||||
log_module.logger = _Logger()
|
||||
cache_module.FileCache = _FileCache
|
||||
config_module.settings = types.SimpleNamespace(TORRENT_TAG="moviepilot-tag")
|
||||
metainfo_module.MetaInfo = _MetaInfo
|
||||
schemas_module.DownloaderInfo = object
|
||||
schemas_module.TransferTorrent = object
|
||||
schemas_module.DownloadingTorrent = object
|
||||
schema_types_module.TorrentStatus = TorrentStatus
|
||||
schema_types_module.ModuleType = ModuleType
|
||||
schema_types_module.DownloaderType = DownloaderType
|
||||
string_module.StringUtils = _StringUtils
|
||||
modules_module._ModuleBase = _ModuleBase
|
||||
modules_module._DownloaderBase = _DownloaderBase
|
||||
torrentool_torrent_module.Torrent = _Torrent
|
||||
qbittorrentapi_module.TorrentDictionary = dict
|
||||
qbittorrentapi_module.TorrentFilesList = list
|
||||
qbittorrentapi_module.LoginFailed = type("LoginFailed", (Exception,), {})
|
||||
qbittorrentapi_module.Forbidden403Error = type("Forbidden403Error", (Exception,), {})
|
||||
qbittorrentapi_module.Unauthorized401Error = type("Unauthorized401Error", (Exception,), {})
|
||||
qbittorrentapi_module.Client = object
|
||||
qbittorrentapi_client_module.Client = object
|
||||
qbittorrentapi_transfer_module.TransferInfoDictionary = dict
|
||||
|
||||
app_module.core = core_module
|
||||
app_module.log = log_module
|
||||
app_module.modules = modules_module
|
||||
app_module.schemas = schemas_module
|
||||
app_module.utils = utils_module
|
||||
core_module.cache = cache_module
|
||||
core_module.config = config_module
|
||||
core_module.metainfo = metainfo_module
|
||||
utils_module.string = string_module
|
||||
schemas_module.types = schema_types_module
|
||||
modules_module.qbittorrent = qbittorrent_package_module
|
||||
torrentool_module.torrent = torrentool_torrent_module
|
||||
|
||||
stub_modules = {
|
||||
"app": app_module,
|
||||
"app.core": core_module,
|
||||
"app.core.cache": cache_module,
|
||||
"app.core.config": config_module,
|
||||
"app.core.metainfo": metainfo_module,
|
||||
"app.log": log_module,
|
||||
"app.modules": modules_module,
|
||||
"app.modules.qbittorrent": qbittorrent_package_module,
|
||||
"app.schemas": schemas_module,
|
||||
"app.schemas.types": schema_types_module,
|
||||
"app.utils": utils_module,
|
||||
"app.utils.string": string_module,
|
||||
"qbittorrentapi": qbittorrentapi_module,
|
||||
"qbittorrentapi.client": qbittorrentapi_client_module,
|
||||
"qbittorrentapi.transfer": qbittorrentapi_transfer_module,
|
||||
"torrentool": torrentool_module,
|
||||
"torrentool.torrent": torrentool_torrent_module,
|
||||
}
|
||||
|
||||
for stub_module in stub_modules.values():
|
||||
stub_module._qbittorrent_test_stub = True
|
||||
|
||||
qbittorrent_path = repo_root / "app" / "modules" / "qbittorrent" / "qbittorrent.py"
|
||||
qbittorrent_spec = importlib.util.spec_from_file_location(
|
||||
"app.modules.qbittorrent.qbittorrent",
|
||||
qbittorrent_path,
|
||||
)
|
||||
qbittorrent_module = importlib.util.module_from_spec(qbittorrent_spec)
|
||||
assert qbittorrent_spec and qbittorrent_spec.loader
|
||||
|
||||
module_path = repo_root / "app" / "modules" / "qbittorrent" / "__init__.py"
|
||||
qbittorrent_module_spec = importlib.util.spec_from_file_location(
|
||||
"_test_qbittorrent_module",
|
||||
module_path,
|
||||
)
|
||||
module_package = importlib.util.module_from_spec(qbittorrent_module_spec)
|
||||
assert qbittorrent_module_spec and qbittorrent_module_spec.loader
|
||||
|
||||
with patch.dict(sys.modules, stub_modules):
|
||||
sys.modules[qbittorrent_spec.name] = qbittorrent_module
|
||||
qbittorrent_spec.loader.exec_module(qbittorrent_module)
|
||||
qbittorrent_package_module.qbittorrent = qbittorrent_module
|
||||
qbittorrent_module_spec.loader.exec_module(module_package)
|
||||
|
||||
return qbittorrent_module, module_package
|
||||
|
||||
|
||||
qbittorrent_module, qbittorrent_package_module = _load_qbittorrent_modules()
|
||||
Qbittorrent = qbittorrent_module.Qbittorrent
|
||||
QbittorrentModule = qbittorrent_package_module.QbittorrentModule
|
||||
|
||||
|
||||
class TestQbittorrentCompat(unittest.TestCase):
|
||||
def test_login_uses_api_key_header_without_auth_login(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.app_version.return_value = "v5.2.0"
|
||||
|
||||
with patch.object(qbittorrent_module.qbittorrentapi, "Client", return_value=fake_client) as client_cls:
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, apikey="secret-token")
|
||||
|
||||
self.assertIs(downloader.qbc, fake_client)
|
||||
fake_client.auth_log_in.assert_not_called()
|
||||
fake_client.app_version.assert_called_once_with()
|
||||
self.assertEqual(
|
||||
client_cls.call_args.kwargs["EXTRA_HEADERS"],
|
||||
{"Authorization": "Bearer secret-token"},
|
||||
)
|
||||
|
||||
def test_add_torrent_accepts_structured_success_response(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.torrents_add.return_value = {
|
||||
"success_count": 1,
|
||||
"failure_count": 0,
|
||||
"pending_count": 0,
|
||||
"added_torrent_ids": ["abc123"],
|
||||
}
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(content="https://example.com/test.torrent")
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, ["abc123"])
|
||||
|
||||
def test_add_torrent_accepts_pending_success_response_without_ids(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.torrents_add.return_value = {
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"pending_count": 1,
|
||||
"added_torrent_ids": [],
|
||||
}
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(content="https://example.com/test.torrent")
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, [])
|
||||
|
||||
def test_add_torrent_uses_cookie_api_for_qbittorrent_52(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.app_web_api_version.return_value = "2.11.3"
|
||||
fake_client.app_cookies.return_value = [
|
||||
{
|
||||
"domain": "old.example.com",
|
||||
"path": "/",
|
||||
"name": "old",
|
||||
"value": "cookie",
|
||||
}
|
||||
]
|
||||
fake_client.torrents_add.return_value = "Ok."
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(
|
||||
content="https://tracker.example.com/download?id=1",
|
||||
cookie="uid=1; passkey=abc",
|
||||
)
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, [])
|
||||
set_cookie_call = fake_client.app_set_cookies.call_args.kwargs["cookies"]
|
||||
self.assertIn(
|
||||
{
|
||||
"domain": "tracker.example.com",
|
||||
"path": "/",
|
||||
"name": "uid",
|
||||
"value": "1",
|
||||
},
|
||||
set_cookie_call,
|
||||
)
|
||||
self.assertIn(
|
||||
{
|
||||
"domain": "tracker.example.com",
|
||||
"path": "/",
|
||||
"name": "passkey",
|
||||
"value": "abc",
|
||||
},
|
||||
set_cookie_call,
|
||||
)
|
||||
self.assertIsNone(fake_client.torrents_add.call_args.kwargs["cookie"])
|
||||
|
||||
def test_add_torrent_keeps_legacy_cookie_param_for_old_webapi(self):
|
||||
fake_client = MagicMock()
|
||||
fake_client.app_web_api_version.return_value = "2.11.2"
|
||||
fake_client.torrents_add.return_value = "Ok."
|
||||
|
||||
with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client):
|
||||
downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin")
|
||||
|
||||
success, added_torrent_ids = downloader.add_torrent(
|
||||
content="https://tracker.example.com/download?id=1",
|
||||
cookie="uid=1",
|
||||
)
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(added_torrent_ids, [])
|
||||
fake_client.app_set_cookies.assert_not_called()
|
||||
self.assertEqual(fake_client.torrents_add.call_args.kwargs["cookie"], "uid=1")
|
||||
|
||||
|
||||
class TestQbittorrentModuleCompat(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _build_module(server):
|
||||
module = QbittorrentModule.__new__(QbittorrentModule)
|
||||
module.get_instance = MagicMock(return_value=server)
|
||||
module.normalize_path = MagicMock(side_effect=lambda path, _downloader: path)
|
||||
module.get_default_config_name = MagicMock(return_value="default-qb")
|
||||
return module
|
||||
|
||||
def test_download_prefers_added_torrent_ids_before_tag_lookup(self):
|
||||
fake_server = MagicMock()
|
||||
fake_server.add_torrent.return_value = (True, ["abc123"])
|
||||
fake_server.get_content_layout.return_value = "Original"
|
||||
fake_server.is_force_resume.return_value = False
|
||||
|
||||
module = self._build_module(fake_server)
|
||||
result = module.download(
|
||||
content="magnet:?xt=urn:btih:123",
|
||||
download_dir=Path("/downloads"),
|
||||
cookie="",
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
self.assertEqual(result, ("qb", "abc123", "Original", "添加下载成功"))
|
||||
fake_server.delete_torrents_tag.assert_called_once_with("abc123", "tmp-tag-01")
|
||||
fake_server.get_torrent_id_by_tag.assert_not_called()
|
||||
self.assertEqual(
|
||||
fake_server.add_torrent.call_args.kwargs["tag"],
|
||||
["tmp-tag-01", "moviepilot-tag"],
|
||||
)
|
||||
|
||||
def test_download_falls_back_to_tag_lookup_when_added_ids_missing(self):
|
||||
fake_server = MagicMock()
|
||||
fake_server.add_torrent.return_value = (True, [])
|
||||
fake_server.get_content_layout.return_value = "Original"
|
||||
fake_server.get_torrent_id_by_tag.return_value = "def456"
|
||||
fake_server.is_force_resume.return_value = False
|
||||
|
||||
module = self._build_module(fake_server)
|
||||
result = module.download(
|
||||
content="magnet:?xt=urn:btih:456",
|
||||
download_dir=Path("/downloads"),
|
||||
cookie="",
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
self.assertEqual(result, ("qb", "def456", "Original", "添加下载成功"))
|
||||
fake_server.delete_torrents_tag.assert_not_called()
|
||||
fake_server.get_torrent_id_by_tag.assert_called_once_with(tags="tmp-tag-01")
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.chain.transfer import TransferChain
|
||||
|
||||
@@ -32,6 +33,9 @@ class FakeDownloadHistoryOper:
|
||||
|
||||
|
||||
class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.chain = object.__new__(TransferChain)
|
||||
|
||||
def test_resolve_download_history_falls_back_to_parent_download_path(self):
|
||||
expected = SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
oper = FakeDownloadHistoryOper(
|
||||
@@ -39,7 +43,7 @@ class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
histories_by_path={"/downloads/season-pack": expected},
|
||||
)
|
||||
|
||||
history = TransferChain._resolve_download_history(
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/season-pack/Test.Show.S01E01.mkv"),
|
||||
)
|
||||
@@ -58,7 +62,7 @@ class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
history = TransferChain._resolve_download_history(
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/season-pack/subs/Test.Show.S01E01.zh.ass"),
|
||||
)
|
||||
@@ -79,13 +83,127 @@ class TransferDownloadHistoryLookupTest(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
history = TransferChain._resolve_download_history(
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/shared/Test.Show.S01E01.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
def test_resolve_download_history_stops_at_shared_download_root_path(self):
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_path={
|
||||
"/downloads": SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=False,
|
||||
media_category=None,
|
||||
download_category_folder=False,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
def test_resolve_download_history_stops_at_shared_download_root_savepath(self):
|
||||
expected = SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_hash={"hash1": expected},
|
||||
files_by_savepath={
|
||||
"/downloads": [
|
||||
SimpleNamespace(download_hash="hash1", filepath="Other.Show.mkv"),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=False,
|
||||
media_category=None,
|
||||
download_category_folder=False,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
def test_resolve_download_history_accepts_shared_root_savepath_for_exact_file(self):
|
||||
expected = SimpleNamespace(download_hash="hash1", downloader="qb")
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_hash={"hash1": expected},
|
||||
files_by_savepath={
|
||||
"/downloads": [
|
||||
SimpleNamespace(download_hash="hash1", filepath="Ghost.Concert.mkv"),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=False,
|
||||
media_category=None,
|
||||
download_category_folder=False,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIs(history, expected)
|
||||
|
||||
def test_resolve_download_history_stops_at_type_category_download_root(self):
|
||||
oper = FakeDownloadHistoryOper(
|
||||
histories_by_path={
|
||||
"/downloads/电视剧/动漫": SimpleNamespace(
|
||||
download_hash="hash1", downloader="qb"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.transfer.DirectoryHelper.get_download_dirs",
|
||||
return_value=[
|
||||
SimpleNamespace(
|
||||
download_path="/downloads",
|
||||
media_type=None,
|
||||
download_type_folder=True,
|
||||
media_category=None,
|
||||
download_category_folder=True,
|
||||
)
|
||||
],
|
||||
):
|
||||
history = self.chain._resolve_download_history(
|
||||
downloadhis=oper,
|
||||
file_path=Path("/downloads/电视剧/动漫/Ghost.Concert.mkv"),
|
||||
)
|
||||
|
||||
self.assertIsNone(history)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user