fix(hermes): validate mcp sampling config

This commit is contained in:
晴天
2026-05-26 05:39:19 +08:00
parent 66375d2807
commit 51be3ab4ca
3 changed files with 246 additions and 1 deletions

View File

@@ -4513,6 +4513,56 @@ function normalizeHermesMcpTimeout(entry, field, key) {
entry[field] = parseHermesInteger(entry[field], key, 120, 1, 86400, true)
}
function normalizeHermesMcpSampling(value, key) {
if (!value || typeof value !== 'object' || Array.isArray(value)) {
throw new Error(`${key} 必须是 JSON 对象`)
}
const sampling = mergeConfigsPreservingFields(value, {})
if (Object.hasOwn(sampling, 'enabled')) {
if (typeof sampling.enabled !== 'boolean') throw new Error(`${key}.enabled 必须是布尔值`)
}
if (Object.hasOwn(sampling, 'model')) {
if (sampling.model == null || sampling.model === '') {
delete sampling.model
} else if (typeof sampling.model !== 'string') {
throw new Error(`${key}.model 必须是字符串`)
} else {
const model = sampling.model.trim()
if (model) sampling.model = model
else delete sampling.model
}
}
if (Object.hasOwn(sampling, 'max_tokens_cap')) {
sampling.max_tokens_cap = parseHermesInteger(sampling.max_tokens_cap, `${key}.max_tokens_cap`, 4096, 1, 1000000, true)
}
if (Object.hasOwn(sampling, 'timeout')) {
sampling.timeout = parseHermesInteger(sampling.timeout, `${key}.timeout`, 30, 1, 86400, true)
}
if (Object.hasOwn(sampling, 'max_rpm')) {
sampling.max_rpm = parseHermesInteger(sampling.max_rpm, `${key}.max_rpm`, 10, 1, 100000, true)
}
if (Object.hasOwn(sampling, 'allowed_models')) {
sampling.allowed_models = normalizeHermesStringArray(sampling.allowed_models, `${key}.allowed_models`)
}
if (Object.hasOwn(sampling, 'max_tool_rounds')) {
sampling.max_tool_rounds = parseHermesInteger(sampling.max_tool_rounds, `${key}.max_tool_rounds`, 5, 0, 1000, true)
}
if (Object.hasOwn(sampling, 'log_level')) {
if (sampling.log_level == null || sampling.log_level === '') {
delete sampling.log_level
} else if (typeof sampling.log_level !== 'string') {
throw new Error(`${key}.log_level 必须是字符串`)
} else {
const level = sampling.log_level.trim().toLowerCase()
if (!['debug', 'info', 'warning', 'error'].includes(level)) {
throw new Error(`${key}.log_level 必须是 debug、info、warning 或 error`)
}
sampling.log_level = level
}
}
return sampling
}
function validateHermesMcpServers(value) {
if (!value || typeof value !== 'object' || Array.isArray(value)) {
throw new Error('mcp_servers 必须是 JSON 对象')
@@ -4540,6 +4590,7 @@ function validateHermesMcpServers(value) {
if (Object.hasOwn(entry, 'headers')) entry.headers = normalizeHermesStringMap(entry.headers, `mcp_servers.${name}.headers`)
normalizeHermesMcpTimeout(entry, 'timeout', `mcp_servers.${name}.timeout`)
normalizeHermesMcpTimeout(entry, 'connect_timeout', `mcp_servers.${name}.connect_timeout`)
if (Object.hasOwn(entry, 'sampling')) entry.sampling = normalizeHermesMcpSampling(entry.sampling, `mcp_servers.${name}.sampling`)
normalized[name] = entry
}
return normalized

View File

@@ -4860,6 +4860,84 @@ fn normalize_hermes_mcp_timeout(
Ok(())
}
fn normalize_hermes_mcp_sampling(value: &Value, key: &str) -> Result<Value, String> {
let Some(config) = value.as_object() else {
return Err(format!("{key} 必须是 JSON 对象"));
};
let mut sampling = config.clone();
if let Some(enabled) = sampling.get("enabled") {
if !enabled.is_boolean() {
return Err(format!("{key}.enabled 必须是布尔值"));
}
}
if sampling.contains_key("model") {
let empty = sampling.get("model").is_some_and(|value| {
value.is_null() || value.as_str().is_some_and(|text| text.trim().is_empty())
});
if empty {
sampling.remove("model");
} else {
let Some(model) = sampling.get("model").and_then(|value| value.as_str()) else {
return Err(format!("{key}.model 必须是字符串"));
};
sampling.insert("model".to_string(), Value::String(model.trim().to_string()));
}
}
for (field, fallback, min, max) in [
("max_tokens_cap", 4096, 1, 1_000_000),
("timeout", 30, 1, 86400),
("max_rpm", 10, 1, 100000),
("max_tool_rounds", 5, 0, 1000),
] {
if let Some(raw) = sampling.get(field).cloned() {
let parsed = if let Some(value) = raw.as_i64() {
Some(value)
} else if let Some(value) = raw.as_u64() {
i64::try_from(value).ok()
} else if let Some(value) = raw.as_str() {
value.trim().parse::<i64>().ok()
} else {
None
};
let parsed = parsed.ok_or_else(|| format!("{key}.{field} 必须是整数"))?;
let parsed =
validate_hermes_i64(Some(parsed), &format!("{key}.{field}"), fallback, min, max)?;
sampling.insert(field.to_string(), Value::Number(parsed.into()));
}
}
if let Some(allowed_models) = sampling.get("allowed_models") {
let allowed_models =
normalize_hermes_json_string_array(allowed_models, &format!("{key}.allowed_models"))?;
sampling.insert("allowed_models".to_string(), Value::Array(allowed_models));
}
if sampling.contains_key("log_level") {
let empty = sampling.get("log_level").is_some_and(|value| {
value.is_null() || value.as_str().is_some_and(|text| text.trim().is_empty())
});
if empty {
sampling.remove("log_level");
} else {
let Some(level) = sampling.get("log_level").and_then(|value| value.as_str()) else {
return Err(format!("{key}.log_level 必须是字符串"));
};
let level = level.trim().to_ascii_lowercase();
if !matches!(level.as_str(), "debug" | "info" | "warning" | "error") {
return Err(format!(
"{key}.log_level 必须是 debug、info、warning 或 error"
));
}
sampling.insert("log_level".to_string(), Value::String(level));
}
}
Ok(Value::Object(sampling))
}
fn validate_hermes_mcp_servers(value: &Value) -> Result<serde_json::Map<String, Value>, String> {
let Some(map) = value.as_object() else {
return Err("mcp_servers 必须是 JSON 对象".to_string());
@@ -4932,6 +5010,11 @@ fn validate_hermes_mcp_servers(value: &Value) -> Result<serde_json::Map<String,
"connect_timeout",
&format!("mcp_servers.{name}.connect_timeout"),
)?;
if let Some(sampling) = entry.get("sampling").cloned() {
let sampling =
normalize_hermes_mcp_sampling(&sampling, &format!("mcp_servers.{name}.sampling"))?;
entry.insert("sampling".to_string(), sampling);
}
normalized.insert(name.to_string(), Value::Object(entry));
}
Ok(normalized)
@@ -17220,7 +17303,14 @@ memory:
"timeout": 120,
"sampling": {
"enabled": true,
"model": "gemini-3-flash"
"model": "gemini-3-flash",
"max_tokens_cap": 4096,
"timeout": 30,
"max_rpm": 10,
"allowed_models": ["gemini-3-flash", "gpt-5-mini"],
"max_tool_rounds": 5,
"log_level": "info",
"custom_flag": "keep-sampling"
}
},
"notion": {
@@ -17250,6 +17340,34 @@ memory:
config["mcp_servers"]["time"]["sampling"]["model"].as_str(),
Some("gemini-3-flash")
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["max_tokens_cap"].as_i64(),
Some(4096)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["timeout"].as_i64(),
Some(30)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["max_rpm"].as_i64(),
Some(10)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["allowed_models"][1].as_str(),
Some("gpt-5-mini")
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["max_tool_rounds"].as_i64(),
Some(5)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["log_level"].as_str(),
Some("info")
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["custom_flag"].as_str(),
Some("keep-sampling")
);
assert_eq!(
config["mcp_servers"]["notion"]["headers"]["Authorization"].as_str(),
Some("Bearer token")
@@ -17327,6 +17445,41 @@ streaming:
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.timeout"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": [] } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "enabled": "yes" } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.enabled"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "allowed_models": "gpt-5" } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.allowed_models"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "max_tool_rounds": -1 } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.max_tool_rounds"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "log_level": "trace" } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.log_level"));
}
}

View File

@@ -49,6 +49,13 @@ test('Hermes MCP 服务配置保存会保留未知字段并写入 mcp_servers',
sampling: {
enabled: true,
model: 'gemini-3-flash',
max_tokens_cap: 4096,
timeout: 30,
max_rpm: 10,
allowed_models: ['gemini-3-flash', 'gpt-5-mini'],
max_tool_rounds: 5,
log_level: 'info',
custom_flag: 'keep-sampling',
},
},
},
@@ -62,6 +69,13 @@ test('Hermes MCP 服务配置保存会保留未知字段并写入 mcp_servers',
sampling: {
enabled: true,
model: 'gemini-3-flash',
max_tokens_cap: 4096,
timeout: 30,
max_rpm: 10,
allowed_models: ['gemini-3-flash', 'gpt-5-mini'],
max_tool_rounds: 5,
log_level: 'info',
custom_flag: 'keep-sampling',
},
},
notion: {
@@ -81,6 +95,13 @@ test('Hermes MCP 服务配置保存会保留未知字段并写入 mcp_servers',
assert.equal(next.mcp_servers.time.timeout, 120)
assert.equal(next.mcp_servers.time.sampling.enabled, true)
assert.equal(next.mcp_servers.time.sampling.model, 'gemini-3-flash')
assert.equal(next.mcp_servers.time.sampling.max_tokens_cap, 4096)
assert.equal(next.mcp_servers.time.sampling.timeout, 30)
assert.equal(next.mcp_servers.time.sampling.max_rpm, 10)
assert.deepEqual(next.mcp_servers.time.sampling.allowed_models, ['gemini-3-flash', 'gpt-5-mini'])
assert.equal(next.mcp_servers.time.sampling.max_tool_rounds, 5)
assert.equal(next.mcp_servers.time.sampling.log_level, 'info')
assert.equal(next.mcp_servers.time.sampling.custom_flag, 'keep-sampling')
assert.equal(next.mcp_servers.notion.url, 'https://mcp.notion.com/mcp')
assert.equal(next.mcp_servers.notion.headers.Authorization, 'Bearer token')
assert.equal(next.mcp_servers.notion.connect_timeout, 30)
@@ -129,4 +150,24 @@ test('Hermes MCP 服务配置保存会拒绝非法 JSON、名称、结构和超
() => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', timeout: 0 } }) }),
/mcp_servers\.time\.timeout/,
)
assert.throws(
() => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: [] } }) }),
/mcp_servers\.time\.sampling/,
)
assert.throws(
() => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { enabled: 'yes' } } }) }),
/mcp_servers\.time\.sampling\.enabled/,
)
assert.throws(
() => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { allowed_models: 'gpt-5' } } }) }),
/mcp_servers\.time\.sampling\.allowed_models/,
)
assert.throws(
() => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { max_tool_rounds: -1 } } }) }),
/mcp_servers\.time\.sampling\.max_tool_rounds/,
)
assert.throws(
() => mergeHermesMcpServersConfig({}, { mcpServersJson: JSON.stringify({ time: { command: 'uvx', sampling: { log_level: 'trace' } } }) }),
/mcp_servers\.time\.sampling\.log_level/,
)
})