extend setup wizard for database and agent

This commit is contained in:
jxxghp
2026-04-16 17:10:25 +08:00
parent 60996be71b
commit 5995b3f3e8
2 changed files with 204 additions and 1 deletions

View File

@@ -40,6 +40,20 @@ DEFAULT_NODE_VERSION = "20.12.1"
FRONTEND_LATEST_API = "https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases/latest"
FRONTEND_TAG_API = "https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases/tags/{tag}"
RESOURCES_MAIN_ZIP = "https://github.com/jxxghp/MoviePilot-Resources/archive/refs/heads/main.zip"
LLM_PROVIDER_DEFAULTS = {
"deepseek": {
"model": "deepseek-chat",
"base_url": "https://api.deepseek.com",
},
"openai": {
"model": "gpt-4o-mini",
"base_url": "https://api.openai.com/v1",
},
"google": {
"model": "gemini-2.5-flash",
"base_url": "",
},
}
RUNTIME_PACKAGE = {
"name": "moviepilot-frontend-runtime",
"private": True,
@@ -218,6 +232,14 @@ def _load_env_lines() -> list[str]:
return ENV_FILE.read_text(encoding="utf-8").splitlines(keepends=True)
def _serialize_env_value(value: Any) -> str:
if isinstance(value, Path):
value = str(value)
if value is None:
return '""'
return json.dumps(value, ensure_ascii=False)
def read_env_value(key: str) -> Optional[str]:
for line in _load_env_lines():
stripped = line.strip()
@@ -232,7 +254,7 @@ def read_env_value(key: str) -> Optional[str]:
def write_env_value(key: str, value: str) -> None:
ensure_local_dirs()
lines = _load_env_lines()
new_line = f"{key}={json.dumps(str(value), ensure_ascii=False)}\n"
new_line = f"{key}={_serialize_env_value(value)}\n"
for index, line in enumerate(lines):
stripped = line.strip()
@@ -250,6 +272,11 @@ def write_env_value(key: str, value: str) -> None:
ENV_FILE.write_text("".join(lines), encoding="utf-8")
def write_env_values(values: dict[str, Any]) -> None:
for key, value in values.items():
write_env_value(key, value)
def ensure_api_token(force_token: bool = False, token: Optional[str] = None) -> str:
ensure_local_dirs()
current_token = read_env_value("API_TOKEN") or ""
@@ -546,6 +573,30 @@ def _normalize_choice(value: str) -> str:
return value.strip().lower().replace("_", "").replace("-", "")
def _env_default(key: str, default: str = "") -> str:
value = read_env_value(key)
if value is None or value == "":
return default
return value
def _env_bool(key: str, default: bool) -> bool:
value = read_env_value(key)
if value is None or value == "":
return default
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
def _env_int(key: str, default: int) -> int:
value = read_env_value(key)
if value is None or value == "":
return default
try:
return int(value)
except (TypeError, ValueError):
return default
def _prompt_text(
label: str,
*,
@@ -568,6 +619,29 @@ def _prompt_text(
print("请输入有效内容,或使用回车接受默认值。")
def _prompt_secret_text(
label: str,
*,
current_value: Optional[str] = None,
allow_empty: bool = False,
required: bool = False,
) -> str:
while True:
suffix = " [留空保持现有值]" if current_value not in (None, "") else ""
prompt = f"{label}{suffix}: "
value = getpass.getpass(prompt).strip()
if value:
return value
if current_value is not None and current_value != "":
return current_value
if allow_empty and not required:
return ""
if not required:
return ""
print("请输入有效内容。")
def _prompt_yes_no(label: str, default: bool = True) -> bool:
suffix = "Y/n" if default else "y/N"
while True:
@@ -652,6 +726,55 @@ def _collect_directory_config() -> dict[str, Any]:
}
def _collect_database_config() -> dict[str, Any]:
print_step("数据库配置")
current_db_type = _env_default("DB_TYPE", "sqlite").lower()
if current_db_type not in {"sqlite", "postgresql"}:
current_db_type = "sqlite"
db_type = _prompt_choice(
"选择数据库类型",
{
"sqlite": "SQLite",
"postgresql": "PostgreSQL",
},
default=current_db_type,
)
config: dict[str, Any] = {
"DB_TYPE": db_type,
}
if db_type == "sqlite":
return config
config.update(
{
"DB_POSTGRESQL_HOST": _prompt_text(
"PostgreSQL 主机地址",
default=_env_default("DB_POSTGRESQL_HOST", "localhost"),
),
"DB_POSTGRESQL_PORT": _prompt_text(
"PostgreSQL 端口",
default=str(_env_int("DB_POSTGRESQL_PORT", 5432)),
),
"DB_POSTGRESQL_DATABASE": _prompt_text(
"PostgreSQL 数据库名(需已创建)",
default=_env_default("DB_POSTGRESQL_DATABASE", "moviepilot"),
),
"DB_POSTGRESQL_USERNAME": _prompt_text(
"PostgreSQL 用户名",
default=_env_default("DB_POSTGRESQL_USERNAME", "moviepilot"),
),
"DB_POSTGRESQL_PASSWORD": _prompt_secret_text(
"PostgreSQL 密码",
current_value=read_env_value("DB_POSTGRESQL_PASSWORD"),
allow_empty=True,
),
}
)
return config
def _collect_downloader_config() -> Optional[dict[str, Any]]:
print_step("下载器配置")
downloader_type = _prompt_choice(
@@ -808,6 +931,73 @@ def _collect_notification_config() -> Optional[dict[str, Any]]:
}
def _collect_agent_config() -> dict[str, Any]:
print_step("AI Agent 配置")
enabled = _prompt_yes_no(
"是否启用 AI 智能体",
default=_env_bool("AI_AGENT_ENABLE", False),
)
if not enabled:
return {
"AI_AGENT_ENABLE": False,
"AI_AGENT_GLOBAL": False,
}
current_provider = _env_default("LLM_PROVIDER", "deepseek").lower()
if current_provider not in LLM_PROVIDER_DEFAULTS:
current_provider = "deepseek"
provider = _prompt_choice(
"选择 LLM 提供商",
{
"deepseek": "DeepSeek",
"openai": "OpenAI",
"google": "Google",
},
default=current_provider,
)
defaults = LLM_PROVIDER_DEFAULTS[provider]
current_model = _env_default("LLM_MODEL", defaults["model"])
current_base_url = _env_default("LLM_BASE_URL", defaults["base_url"])
config: dict[str, Any] = {
"AI_AGENT_ENABLE": True,
"AI_AGENT_GLOBAL": _prompt_yes_no(
"是否启用全局 AI 智能体",
default=_env_bool("AI_AGENT_GLOBAL", False),
),
"LLM_PROVIDER": provider,
"LLM_MODEL": _prompt_text(
"LLM 模型名称",
default=current_model,
),
"LLM_API_KEY": _prompt_secret_text(
"LLM API Key",
current_value=read_env_value("LLM_API_KEY"),
required=True,
),
"LLM_SUPPORT_IMAGE_INPUT": _prompt_yes_no(
"是否启用图片输入支持",
default=_env_bool("LLM_SUPPORT_IMAGE_INPUT", True),
),
}
if provider == "google":
config["LLM_BASE_URL"] = _prompt_text(
"自定义 Google API Base URL可选",
default=current_base_url,
allow_empty=True,
)
else:
config["LLM_BASE_URL"] = _prompt_text(
"LLM Base URL",
default=current_base_url,
allow_empty=True,
)
return config
def run_setup_wizard(force_token: bool) -> dict[str, Any]:
if not _is_interactive():
raise RuntimeError("交互式向导需要在终端中运行,请直接执行 moviepilot setup --wizard 或 moviepilot init --wizard")
@@ -841,6 +1031,10 @@ def run_setup_wizard(force_token: bool) -> dict[str, Any]:
return {
"api_token": api_token,
"env_settings": {
**_collect_database_config(),
**_collect_agent_config(),
},
"directories": [_collect_directory_config()],
"downloader": _collect_downloader_config(),
"mediaserver": _collect_media_server_config(),
@@ -1006,6 +1200,10 @@ def init_local(
else:
ensure_api_token(force_token=force_token)
if wizard_payload and wizard_payload.get("env_settings"):
write_env_values(wizard_payload["env_settings"])
print_step(f"已写入环境配置到 {ENV_FILE}")
if skip_resources:
if resources_ready:
print_step("资源文件已完成同步")