From 51be3ab4ca5a0e94b136f8fd2fb6bde438a90eaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=B4=E5=A4=A9?= Date: Tue, 26 May 2026 05:39:19 +0800 Subject: [PATCH] fix(hermes): validate mcp sampling config --- scripts/dev-api.js | 51 ++++++++ src-tauri/src/commands/hermes.rs | 155 +++++++++++++++++++++++- tests/hermes-mcp-servers-config.test.js | 41 +++++++ 3 files changed, 246 insertions(+), 1 deletion(-) diff --git a/scripts/dev-api.js b/scripts/dev-api.js index 749cfba..cb66ca5 100644 --- a/scripts/dev-api.js +++ b/scripts/dev-api.js @@ -4513,6 +4513,56 @@ function normalizeHermesMcpTimeout(entry, field, key) { entry[field] = parseHermesInteger(entry[field], key, 120, 1, 86400, true) } +function normalizeHermesMcpSampling(value, key) { + if (!value || typeof value !== 'object' || Array.isArray(value)) { + throw new Error(`${key} 必须是 JSON 对象`) + } + const sampling = mergeConfigsPreservingFields(value, {}) + if (Object.hasOwn(sampling, 'enabled')) { + if (typeof sampling.enabled !== 'boolean') throw new Error(`${key}.enabled 必须是布尔值`) + } + if (Object.hasOwn(sampling, 'model')) { + if (sampling.model == null || sampling.model === '') { + delete sampling.model + } else if (typeof sampling.model !== 'string') { + throw new Error(`${key}.model 必须是字符串`) + } else { + const model = sampling.model.trim() + if (model) sampling.model = model + else delete sampling.model + } + } + if (Object.hasOwn(sampling, 'max_tokens_cap')) { + sampling.max_tokens_cap = parseHermesInteger(sampling.max_tokens_cap, `${key}.max_tokens_cap`, 4096, 1, 1000000, true) + } + if (Object.hasOwn(sampling, 'timeout')) { + sampling.timeout = parseHermesInteger(sampling.timeout, `${key}.timeout`, 30, 1, 86400, true) + } + if (Object.hasOwn(sampling, 'max_rpm')) { + sampling.max_rpm = parseHermesInteger(sampling.max_rpm, `${key}.max_rpm`, 10, 1, 100000, true) + } + if (Object.hasOwn(sampling, 'allowed_models')) { + sampling.allowed_models = normalizeHermesStringArray(sampling.allowed_models, `${key}.allowed_models`) + } + if (Object.hasOwn(sampling, 'max_tool_rounds')) { + sampling.max_tool_rounds = parseHermesInteger(sampling.max_tool_rounds, `${key}.max_tool_rounds`, 5, 0, 1000, true) + } + if (Object.hasOwn(sampling, 'log_level')) { + if (sampling.log_level == null || sampling.log_level === '') { + delete sampling.log_level + } else if (typeof sampling.log_level !== 'string') { + throw new Error(`${key}.log_level 必须是字符串`) + } else { + const level = sampling.log_level.trim().toLowerCase() + if (!['debug', 'info', 'warning', 'error'].includes(level)) { + throw new Error(`${key}.log_level 必须是 debug、info、warning 或 error`) + } + sampling.log_level = level + } + } + return sampling +} + function validateHermesMcpServers(value) { if (!value || typeof value !== 'object' || Array.isArray(value)) { throw new Error('mcp_servers 必须是 JSON 对象') @@ -4540,6 +4590,7 @@ function validateHermesMcpServers(value) { if (Object.hasOwn(entry, 'headers')) entry.headers = normalizeHermesStringMap(entry.headers, `mcp_servers.${name}.headers`) normalizeHermesMcpTimeout(entry, 'timeout', `mcp_servers.${name}.timeout`) normalizeHermesMcpTimeout(entry, 'connect_timeout', `mcp_servers.${name}.connect_timeout`) + if (Object.hasOwn(entry, 'sampling')) entry.sampling = normalizeHermesMcpSampling(entry.sampling, `mcp_servers.${name}.sampling`) normalized[name] = entry } return normalized diff --git a/src-tauri/src/commands/hermes.rs b/src-tauri/src/commands/hermes.rs index ba28e5b..775db4d 100644 --- a/src-tauri/src/commands/hermes.rs +++ b/src-tauri/src/commands/hermes.rs @@ -4860,6 +4860,84 @@ fn normalize_hermes_mcp_timeout( Ok(()) } +fn normalize_hermes_mcp_sampling(value: &Value, key: &str) -> Result { + let Some(config) = value.as_object() else { + return Err(format!("{key} 必须是 JSON 对象")); + }; + let mut sampling = config.clone(); + + if let Some(enabled) = sampling.get("enabled") { + if !enabled.is_boolean() { + return Err(format!("{key}.enabled 必须是布尔值")); + } + } + + if sampling.contains_key("model") { + let empty = sampling.get("model").is_some_and(|value| { + value.is_null() || value.as_str().is_some_and(|text| text.trim().is_empty()) + }); + if empty { + sampling.remove("model"); + } else { + let Some(model) = sampling.get("model").and_then(|value| value.as_str()) else { + return Err(format!("{key}.model 必须是字符串")); + }; + sampling.insert("model".to_string(), Value::String(model.trim().to_string())); + } + } + + for (field, fallback, min, max) in [ + ("max_tokens_cap", 4096, 1, 1_000_000), + ("timeout", 30, 1, 86400), + ("max_rpm", 10, 1, 100000), + ("max_tool_rounds", 5, 0, 1000), + ] { + if let Some(raw) = sampling.get(field).cloned() { + let parsed = if let Some(value) = raw.as_i64() { + Some(value) + } else if let Some(value) = raw.as_u64() { + i64::try_from(value).ok() + } else if let Some(value) = raw.as_str() { + value.trim().parse::().ok() + } else { + None + }; + let parsed = parsed.ok_or_else(|| format!("{key}.{field} 必须是整数"))?; + let parsed = + validate_hermes_i64(Some(parsed), &format!("{key}.{field}"), fallback, min, max)?; + sampling.insert(field.to_string(), Value::Number(parsed.into())); + } + } + + if let Some(allowed_models) = sampling.get("allowed_models") { + let allowed_models = + normalize_hermes_json_string_array(allowed_models, &format!("{key}.allowed_models"))?; + sampling.insert("allowed_models".to_string(), Value::Array(allowed_models)); + } + + if sampling.contains_key("log_level") { + let empty = sampling.get("log_level").is_some_and(|value| { + value.is_null() || value.as_str().is_some_and(|text| text.trim().is_empty()) + }); + if empty { + sampling.remove("log_level"); + } else { + let Some(level) = sampling.get("log_level").and_then(|value| value.as_str()) else { + return Err(format!("{key}.log_level 必须是字符串")); + }; + let level = level.trim().to_ascii_lowercase(); + if !matches!(level.as_str(), "debug" | "info" | "warning" | "error") { + return Err(format!( + "{key}.log_level 必须是 debug、info、warning 或 error" + )); + } + sampling.insert("log_level".to_string(), Value::String(level)); + } + } + + Ok(Value::Object(sampling)) +} + fn validate_hermes_mcp_servers(value: &Value) -> Result, String> { let Some(map) = value.as_object() else { return Err("mcp_servers 必须是 JSON 对象".to_string()); @@ -4932,6 +5010,11 @@ fn validate_hermes_mcp_servers(value: &Value) -> Result mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', timeout: 0 } }) }), /mcp_servers\.time\.timeout/, ) + assert.throws( + () => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: [] } }) }), + /mcp_servers\.time\.sampling/, + ) + assert.throws( + () => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { enabled: 'yes' } } }) }), + /mcp_servers\.time\.sampling\.enabled/, + ) + assert.throws( + () => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { allowed_models: 'gpt-5' } } }) }), + /mcp_servers\.time\.sampling\.allowed_models/, + ) + assert.throws( + () => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { max_tool_rounds: -1 } } }) }), + /mcp_servers\.time\.sampling\.max_tool_rounds/, + ) + assert.throws( + () => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { log_level: 'trace' } } }) }), + /mcp_servers\.time\.sampling\.log_level/, + ) })