fix(hermes): validate mcp sampling config

This commit is contained in:
晴天
2026-05-26 05:39:19 +08:00
parent 66375d2807
commit 51be3ab4ca
3 changed files with 246 additions and 1 deletions

View File

@@ -4860,6 +4860,84 @@ fn normalize_hermes_mcp_timeout(
Ok(())
}
fn normalize_hermes_mcp_sampling(value: &Value, key: &str) -> Result<Value, String> {
let Some(config) = value.as_object() else {
return Err(format!("{key} 必须是 JSON 对象"));
};
let mut sampling = config.clone();
if let Some(enabled) = sampling.get("enabled") {
if !enabled.is_boolean() {
return Err(format!("{key}.enabled 必须是布尔值"));
}
}
if sampling.contains_key("model") {
let empty = sampling.get("model").is_some_and(|value| {
value.is_null() || value.as_str().is_some_and(|text| text.trim().is_empty())
});
if empty {
sampling.remove("model");
} else {
let Some(model) = sampling.get("model").and_then(|value| value.as_str()) else {
return Err(format!("{key}.model 必须是字符串"));
};
sampling.insert("model".to_string(), Value::String(model.trim().to_string()));
}
}
for (field, fallback, min, max) in [
("max_tokens_cap", 4096, 1, 1_000_000),
("timeout", 30, 1, 86400),
("max_rpm", 10, 1, 100000),
("max_tool_rounds", 5, 0, 1000),
] {
if let Some(raw) = sampling.get(field).cloned() {
let parsed = if let Some(value) = raw.as_i64() {
Some(value)
} else if let Some(value) = raw.as_u64() {
i64::try_from(value).ok()
} else if let Some(value) = raw.as_str() {
value.trim().parse::<i64>().ok()
} else {
None
};
let parsed = parsed.ok_or_else(|| format!("{key}.{field} 必须是整数"))?;
let parsed =
validate_hermes_i64(Some(parsed), &format!("{key}.{field}"), fallback, min, max)?;
sampling.insert(field.to_string(), Value::Number(parsed.into()));
}
}
if let Some(allowed_models) = sampling.get("allowed_models") {
let allowed_models =
normalize_hermes_json_string_array(allowed_models, &format!("{key}.allowed_models"))?;
sampling.insert("allowed_models".to_string(), Value::Array(allowed_models));
}
if sampling.contains_key("log_level") {
let empty = sampling.get("log_level").is_some_and(|value| {
value.is_null() || value.as_str().is_some_and(|text| text.trim().is_empty())
});
if empty {
sampling.remove("log_level");
} else {
let Some(level) = sampling.get("log_level").and_then(|value| value.as_str()) else {
return Err(format!("{key}.log_level 必须是字符串"));
};
let level = level.trim().to_ascii_lowercase();
if !matches!(level.as_str(), "debug" | "info" | "warning" | "error") {
return Err(format!(
"{key}.log_level 必须是 debug、info、warning 或 error"
));
}
sampling.insert("log_level".to_string(), Value::String(level));
}
}
Ok(Value::Object(sampling))
}
fn validate_hermes_mcp_servers(value: &Value) -> Result<serde_json::Map<String, Value>, String> {
let Some(map) = value.as_object() else {
return Err("mcp_servers 必须是 JSON 对象".to_string());
@@ -4932,6 +5010,11 @@ fn validate_hermes_mcp_servers(value: &Value) -> Result<serde_json::Map<String,
"connect_timeout",
&format!("mcp_servers.{name}.connect_timeout"),
)?;
if let Some(sampling) = entry.get("sampling").cloned() {
let sampling =
normalize_hermes_mcp_sampling(&sampling, &format!("mcp_servers.{name}.sampling"))?;
entry.insert("sampling".to_string(), sampling);
}
normalized.insert(name.to_string(), Value::Object(entry));
}
Ok(normalized)
@@ -17220,7 +17303,14 @@ memory:
"timeout": 120,
"sampling": {
"enabled": true,
"model": "gemini-3-flash"
"model": "gemini-3-flash",
"max_tokens_cap": 4096,
"timeout": 30,
"max_rpm": 10,
"allowed_models": ["gemini-3-flash", "gpt-5-mini"],
"max_tool_rounds": 5,
"log_level": "info",
"custom_flag": "keep-sampling"
}
},
"notion": {
@@ -17250,6 +17340,34 @@ memory:
config["mcp_servers"]["time"]["sampling"]["model"].as_str(),
Some("gemini-3-flash")
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["max_tokens_cap"].as_i64(),
Some(4096)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["timeout"].as_i64(),
Some(30)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["max_rpm"].as_i64(),
Some(10)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["allowed_models"][1].as_str(),
Some("gpt-5-mini")
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["max_tool_rounds"].as_i64(),
Some(5)
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["log_level"].as_str(),
Some("info")
);
assert_eq!(
config["mcp_servers"]["time"]["sampling"]["custom_flag"].as_str(),
Some("keep-sampling")
);
assert_eq!(
config["mcp_servers"]["notion"]["headers"]["Authorization"].as_str(),
Some("Bearer token")
@@ -17327,6 +17445,41 @@ streaming:
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.timeout"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": [] } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "enabled": "yes" } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.enabled"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "allowed_models": "gpt-5" } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.allowed_models"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "max_tool_rounds": -1 } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.max_tool_rounds"));
let err = merge_hermes_mcp_servers_config(
&mut config,
&json!({ "mcpServersJson": serde_json::to_string(&json!({ "time": { "command": "uvx", "sampling": { "log_level": "trace" } } })).unwrap() }),
)
.unwrap_err();
assert!(err.contains("mcp_servers.time.sampling.log_level"));
}
}