Files
MyGoNavi/internal/mcpserver/service.go
Syngnat 6c53fb14a6 feat(ai-mcp): 补全跨表字段探针并拆分 MCP 设置区块
- 新增 get_all_columns 内置工具和 MCP server 只读探针
- MCP 服务表单增加实际启动命令预览并补强参数提示
- 抽离 AISettingsMCPSection 并补齐前后端测试
2026-06-08 07:29:52 +08:00

928 lines
29 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mcpserver
import (
"context"
"encoding/json"
"fmt"
"strings"
"GoNavi-Wails/internal/ai"
appcore "GoNavi-Wails/internal/app"
"GoNavi-Wails/internal/connection"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
const (
defaultMaxRowsPerResult = 200
maxRowsPerResultLimit = 1000
)
type Service struct {
backend Backend
}
func NewService(backend Backend) *Service {
return &Service{backend: backend}
}
type emptyArgs struct{}
type connectionIDArgs struct {
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
}
type databaseArgs struct {
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"`
}
type tableArgs struct {
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"`
TableName string `json:"tableName" jsonschema:"目标表或视图名称"`
}
type executeSQLArgs struct {
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"`
SQL string `json:"sql" jsonschema:"待执行的 SQL 文本,可以包含多条语句"`
AllowMutating bool `json:"allowMutating,omitempty" jsonschema:"当 SQL 包含当前 AI 安全控制允许范围内的 DDL/DML 等非只读语句时,必须显式设为 true"`
MaxRowsPerResult int `json:"maxRowsPerResult,omitempty" jsonschema:"每个结果集最多返回多少行。默认 200最大 1000"`
}
type connectionDescriptor struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
Database string `json:"database,omitempty"`
Driver string `json:"driver,omitempty"`
Topology string `json:"topology,omitempty"`
Target string `json:"target,omitempty"`
UseSSH bool `json:"useSSH,omitempty"`
UseProxy bool `json:"useProxy,omitempty"`
UseHTTPTunnel bool `json:"useHttpTunnel,omitempty"`
DefaultDatabase string `json:"defaultDatabase,omitempty"`
}
type getConnectionsResult struct {
Connections []connectionDescriptor `json:"connections"`
}
type getDatabasesResult struct {
ConnectionID string `json:"connectionId"`
Databases []string `json:"databases"`
}
type getTablesResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
Tables []string `json:"tables"`
}
type getAllColumnsResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
Columns []connection.ColumnDefinitionWithTable `json:"columns"`
}
type getColumnsResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName"`
Columns []connection.ColumnDefinition `json:"columns"`
}
type getIndexesResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName"`
Indexes []connection.IndexDefinition `json:"indexes"`
}
type getForeignKeysResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName"`
ForeignKeys []connection.ForeignKeyDefinition `json:"foreignKeys"`
}
type getTriggersResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName"`
Triggers []connection.TriggerDefinition `json:"triggers"`
}
type getTableDDLResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName"`
DDL string `json:"ddl"`
}
type sqlStatementSummary struct {
Index int `json:"index"`
Keyword string `json:"keyword,omitempty"`
ReadOnly bool `json:"readOnly"`
}
type sqlResultSet struct {
StatementIndex int `json:"statementIndex,omitempty"`
Columns []string `json:"columns"`
Rows []map[string]interface{} `json:"rows"`
Messages []string `json:"messages,omitempty"`
RowCount int `json:"rowCount"`
Truncated bool `json:"truncated,omitempty"`
}
type executeSQLResult struct {
ConnectionID string `json:"connectionId"`
DBName string `json:"dbName,omitempty"`
StatementCount int `json:"statementCount"`
ReadOnly bool `json:"readOnly"`
QueryID string `json:"queryId,omitempty"`
Message string `json:"message,omitempty"`
Truncated bool `json:"truncated,omitempty"`
Statements []sqlStatementSummary `json:"statements"`
Results []sqlResultSet `json:"results"`
}
func (s *Service) GetConnections(ctx context.Context, req *mcp.CallToolRequest, args emptyArgs) (*mcp.CallToolResult, getConnectionsResult, error) {
_ = ctx
_ = req
_ = args
items, err := s.backend.GetSavedConnections()
if err != nil {
return toolError("获取已保存连接失败: %v", err), getConnectionsResult{}, nil
}
result := getConnectionsResult{
Connections: make([]connectionDescriptor, 0, len(items)),
}
for _, item := range items {
cfg := item.Config
result.Connections = append(result.Connections, connectionDescriptor{
ID: item.ID,
Name: item.Name,
Type: strings.TrimSpace(cfg.Type),
Host: strings.TrimSpace(cfg.Host),
Port: cfg.Port,
Database: strings.TrimSpace(cfg.Database),
Driver: strings.TrimSpace(cfg.Driver),
Topology: strings.TrimSpace(cfg.Topology),
Target: describeConnectionTarget(cfg),
UseSSH: cfg.UseSSH,
UseProxy: cfg.UseProxy,
UseHTTPTunnel: cfg.UseHTTPTunnel,
DefaultDatabase: strings.TrimSpace(cfg.Database),
})
}
return successResult(), result, nil
}
func (s *Service) GetDatabases(ctx context.Context, req *mcp.CallToolRequest, args connectionIDArgs) (*mcp.CallToolResult, getDatabasesResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getDatabasesResult{}, nil
}
queryResult := s.backend.DBGetDatabases(view.Config)
if !queryResult.Success {
return toolError("获取数据库列表失败: %s", strings.TrimSpace(queryResult.Message)), getDatabasesResult{}, nil
}
databases, err := decodeNamedStringSlice(queryResult.Data, "Database", "database", "name")
if err != nil {
return toolError("解析数据库列表失败: %v", err), getDatabasesResult{}, nil
}
return successResult(), getDatabasesResult{
ConnectionID: view.ID,
Databases: ensureNonNilStrings(databases),
}, nil
}
func (s *Service) GetTables(ctx context.Context, req *mcp.CallToolRequest, args databaseArgs) (*mcp.CallToolResult, getTablesResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getTablesResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
queryResult := s.backend.DBGetTables(view.Config, dbName)
if !queryResult.Success {
return toolError("获取表列表失败: %s", strings.TrimSpace(queryResult.Message)), getTablesResult{}, nil
}
tables, err := decodeNamedStringSlice(queryResult.Data, "Table", "table", "name")
if err != nil {
return toolError("解析表列表失败: %v", err), getTablesResult{}, nil
}
return successResult(), getTablesResult{
ConnectionID: view.ID,
DBName: dbName,
Tables: ensureNonNilStrings(tables),
}, nil
}
func (s *Service) GetAllColumns(ctx context.Context, req *mcp.CallToolRequest, args databaseArgs) (*mcp.CallToolResult, getAllColumnsResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getAllColumnsResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
if strings.TrimSpace(dbName) == "" {
return toolError("dbName 不能为空"), getAllColumnsResult{}, nil
}
queryResult := s.backend.DBGetAllColumns(view.Config, dbName)
if !queryResult.Success {
return toolError("获取全库字段摘要失败: %s", strings.TrimSpace(queryResult.Message)), getAllColumnsResult{}, nil
}
columns, err := decodeColumnsWithTable(queryResult.Data)
if err != nil {
return toolError("解析全库字段摘要失败: %v", err), getAllColumnsResult{}, nil
}
return successResult(), getAllColumnsResult{
ConnectionID: view.ID,
DBName: dbName,
Columns: ensureNonNilColumnsWithTable(columns),
}, nil
}
func (s *Service) GetColumns(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getColumnsResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getColumnsResult{}, nil
}
tableName := strings.TrimSpace(args.TableName)
if tableName == "" {
return toolError("tableName 不能为空"), getColumnsResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
queryResult := s.backend.DBGetColumns(view.Config, dbName, tableName)
if !queryResult.Success {
return toolError("获取字段列表失败: %s", strings.TrimSpace(queryResult.Message)), getColumnsResult{}, nil
}
columns, err := decodeColumns(queryResult.Data)
if err != nil {
return toolError("解析字段列表失败: %v", err), getColumnsResult{}, nil
}
return successResult(), getColumnsResult{
ConnectionID: view.ID,
DBName: dbName,
TableName: tableName,
Columns: ensureNonNilColumns(columns),
}, nil
}
func (s *Service) GetIndexes(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getIndexesResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getIndexesResult{}, nil
}
tableName := strings.TrimSpace(args.TableName)
if tableName == "" {
return toolError("tableName 不能为空"), getIndexesResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
queryResult := s.backend.DBGetIndexes(view.Config, dbName, tableName)
if !queryResult.Success {
return toolError("获取索引定义失败: %s", strings.TrimSpace(queryResult.Message)), getIndexesResult{}, nil
}
indexes, err := decodeIndexes(queryResult.Data)
if err != nil {
return toolError("解析索引定义失败: %v", err), getIndexesResult{}, nil
}
return successResult(), getIndexesResult{
ConnectionID: view.ID,
DBName: dbName,
TableName: tableName,
Indexes: ensureNonNilIndexes(indexes),
}, nil
}
func (s *Service) GetForeignKeys(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getForeignKeysResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getForeignKeysResult{}, nil
}
tableName := strings.TrimSpace(args.TableName)
if tableName == "" {
return toolError("tableName 不能为空"), getForeignKeysResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
queryResult := s.backend.DBGetForeignKeys(view.Config, dbName, tableName)
if !queryResult.Success {
return toolError("获取外键关系失败: %s", strings.TrimSpace(queryResult.Message)), getForeignKeysResult{}, nil
}
foreignKeys, err := decodeForeignKeys(queryResult.Data)
if err != nil {
return toolError("解析外键关系失败: %v", err), getForeignKeysResult{}, nil
}
return successResult(), getForeignKeysResult{
ConnectionID: view.ID,
DBName: dbName,
TableName: tableName,
ForeignKeys: ensureNonNilForeignKeys(foreignKeys),
}, nil
}
func (s *Service) GetTriggers(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getTriggersResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getTriggersResult{}, nil
}
tableName := strings.TrimSpace(args.TableName)
if tableName == "" {
return toolError("tableName 不能为空"), getTriggersResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
queryResult := s.backend.DBGetTriggers(view.Config, dbName, tableName)
if !queryResult.Success {
return toolError("获取触发器定义失败: %s", strings.TrimSpace(queryResult.Message)), getTriggersResult{}, nil
}
triggers, err := decodeTriggers(queryResult.Data)
if err != nil {
return toolError("解析触发器定义失败: %v", err), getTriggersResult{}, nil
}
return successResult(), getTriggersResult{
ConnectionID: view.ID,
DBName: dbName,
TableName: tableName,
Triggers: ensureNonNilTriggers(triggers),
}, nil
}
func (s *Service) GetTableDDL(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getTableDDLResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, getTableDDLResult{}, nil
}
tableName := strings.TrimSpace(args.TableName)
if tableName == "" {
return toolError("tableName 不能为空"), getTableDDLResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
queryResult := s.backend.DBShowCreateTable(view.Config, dbName, tableName)
if !queryResult.Success {
return toolError("获取建表语句失败: %s", strings.TrimSpace(queryResult.Message)), getTableDDLResult{}, nil
}
ddl, err := decodeString(queryResult.Data)
if err != nil {
return toolError("解析建表语句失败: %v", err), getTableDDLResult{}, nil
}
return successResult(), getTableDDLResult{
ConnectionID: view.ID,
DBName: dbName,
TableName: tableName,
DDL: ddl,
}, nil
}
func (s *Service) ExecuteSQL(ctx context.Context, req *mcp.CallToolRequest, args executeSQLArgs) (*mcp.CallToolResult, executeSQLResult, error) {
_ = ctx
_ = req
view, errResult := s.resolveConnection(args.ConnectionID)
if errResult != nil {
return errResult, executeSQLResult{}, nil
}
sqlText := strings.TrimSpace(args.SQL)
if sqlText == "" {
return toolError("sql 不能为空"), executeSQLResult{}, nil
}
inspection := s.backend.InspectSQL(view.Config.Type, sqlText)
if inspection.StatementCount == 0 {
return toolError("未识别到可执行的 SQL 语句"), executeSQLResult{}, nil
}
safetyLevel := normalizeSQLSafetyLevel(s.backend.GetSQLSafetyLevel())
safetyDecision := evaluateSQLSafety(safetyLevel, inspection)
if len(safetyDecision.disallowed) > 0 {
return toolError("%s", buildSafetyDeniedMessage(safetyLevel, safetyDecision.disallowed)), executeSQLResult{}, nil
}
if safetyDecision.requiresConfirm && !args.AllowMutating {
return toolError("当前 SQL 已通过 GoNavi AI 安全控制(%s但包含非只读语句 %s请显式传入 allowMutating=true 后重试", safetyLevelDisplayName(safetyLevel), formatSafetyStatements(safetyDecision.confirmRequired)), executeSQLResult{}, nil
}
dbName := effectiveDBName(args.DBName, view.Config)
queryResult := s.backend.DBQueryMulti(view.Config, dbName, sqlText, "")
if !queryResult.Success {
return toolError("SQL 执行失败: %s", strings.TrimSpace(queryResult.Message)), executeSQLResult{}, nil
}
resultSets, err := decodeResultSets(queryResult.Data)
if err != nil {
return toolError("解析 SQL 执行结果失败: %v", err), executeSQLResult{}, nil
}
normalizedResults, truncated := normalizeResultSets(resultSets, normalizeMaxRowsPerResult(args.MaxRowsPerResult))
return successResult(), executeSQLResult{
ConnectionID: view.ID,
DBName: dbName,
StatementCount: inspection.StatementCount,
ReadOnly: inspection.ReadOnly,
QueryID: strings.TrimSpace(queryResult.QueryID),
Message: strings.TrimSpace(queryResult.Message),
Truncated: truncated,
Statements: toStatementSummaries(inspection.Statements),
Results: normalizedResults,
}, nil
}
func successResult() *mcp.CallToolResult {
return &mcp.CallToolResult{}
}
func toolError(format string, args ...interface{}) *mcp.CallToolResult {
return &mcp.CallToolResult{
IsError: true,
Content: []mcp.Content{
&mcp.TextContent{Text: fmt.Sprintf(format, args...)},
},
}
}
func (s *Service) resolveConnection(connectionID string) (connection.SavedConnectionView, *mcp.CallToolResult) {
id := strings.TrimSpace(connectionID)
if id == "" {
return connection.SavedConnectionView{}, toolError("connectionId 不能为空")
}
view, err := s.backend.GetEditableSavedConnection(id)
if err != nil {
return connection.SavedConnectionView{}, toolError("加载连接 %s 失败: %v", id, err)
}
return view, nil
}
func effectiveDBName(input string, config connection.ConnectionConfig) string {
if trimmed := strings.TrimSpace(input); trimmed != "" {
return trimmed
}
return strings.TrimSpace(config.Database)
}
func describeConnectionTarget(config connection.ConnectionConfig) string {
dbType := strings.ToLower(strings.TrimSpace(config.Type))
switch dbType {
case "sqlite", "duckdb":
if path := strings.TrimSpace(config.Database); path != "" {
return path
}
}
if len(config.Hosts) > 0 {
return strings.Join(config.Hosts, ",")
}
if host := strings.TrimSpace(config.Host); host != "" {
if config.Port > 0 {
return fmt.Sprintf("%s:%d", host, config.Port)
}
return host
}
if uri := strings.TrimSpace(config.URI); uri != "" {
return uri
}
if dsn := strings.TrimSpace(config.DSN); dsn != "" {
return dsn
}
return strings.TrimSpace(config.Database)
}
func decodeNamedStringSlice(data interface{}, keys ...string) ([]string, error) {
switch items := data.(type) {
case nil:
return []string{}, nil
case []string:
return ensureNonNilStrings(append([]string(nil), items...)), nil
case []map[string]string:
result := make([]string, 0, len(items))
for _, item := range items {
result = append(result, pickNamedStringFromStringMap(item, keys...))
}
return result, nil
case []map[string]interface{}:
result := make([]string, 0, len(items))
for _, item := range items {
result = append(result, pickNamedStringFromAnyMap(item, keys...))
}
return result, nil
default:
var decoded []map[string]interface{}
if err := remarshal(data, &decoded); err != nil {
return nil, err
}
return decodeNamedStringSlice(decoded, keys...)
}
}
func pickNamedStringFromStringMap(item map[string]string, keys ...string) string {
for _, key := range keys {
if value := strings.TrimSpace(item[key]); value != "" {
return value
}
}
for _, value := range item {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func pickNamedStringFromAnyMap(item map[string]interface{}, keys ...string) string {
for _, key := range keys {
if value, ok := item[key]; ok {
if text := strings.TrimSpace(fmt.Sprint(value)); text != "" {
return text
}
}
}
for _, value := range item {
if text := strings.TrimSpace(fmt.Sprint(value)); text != "" {
return text
}
}
return ""
}
func decodeColumns(data interface{}) ([]connection.ColumnDefinition, error) {
switch cols := data.(type) {
case nil:
return []connection.ColumnDefinition{}, nil
case []connection.ColumnDefinition:
return ensureNonNilColumns(append([]connection.ColumnDefinition(nil), cols...)), nil
default:
var decoded []connection.ColumnDefinition
if err := remarshal(data, &decoded); err != nil {
return nil, err
}
return ensureNonNilColumns(decoded), nil
}
}
func decodeColumnsWithTable(data interface{}) ([]connection.ColumnDefinitionWithTable, error) {
switch cols := data.(type) {
case nil:
return []connection.ColumnDefinitionWithTable{}, nil
case []connection.ColumnDefinitionWithTable:
return ensureNonNilColumnsWithTable(append([]connection.ColumnDefinitionWithTable(nil), cols...)), nil
default:
var decoded []connection.ColumnDefinitionWithTable
if err := remarshal(data, &decoded); err != nil {
return nil, err
}
return ensureNonNilColumnsWithTable(decoded), nil
}
}
func decodeIndexes(data interface{}) ([]connection.IndexDefinition, error) {
switch indexes := data.(type) {
case nil:
return []connection.IndexDefinition{}, nil
case []connection.IndexDefinition:
return ensureNonNilIndexes(append([]connection.IndexDefinition(nil), indexes...)), nil
default:
var decoded []connection.IndexDefinition
if err := remarshal(data, &decoded); err != nil {
return nil, err
}
return ensureNonNilIndexes(decoded), nil
}
}
func decodeForeignKeys(data interface{}) ([]connection.ForeignKeyDefinition, error) {
switch foreignKeys := data.(type) {
case nil:
return []connection.ForeignKeyDefinition{}, nil
case []connection.ForeignKeyDefinition:
return ensureNonNilForeignKeys(append([]connection.ForeignKeyDefinition(nil), foreignKeys...)), nil
default:
var decoded []connection.ForeignKeyDefinition
if err := remarshal(data, &decoded); err != nil {
return nil, err
}
return ensureNonNilForeignKeys(decoded), nil
}
}
func decodeTriggers(data interface{}) ([]connection.TriggerDefinition, error) {
switch triggers := data.(type) {
case nil:
return []connection.TriggerDefinition{}, nil
case []connection.TriggerDefinition:
return ensureNonNilTriggers(append([]connection.TriggerDefinition(nil), triggers...)), nil
default:
var decoded []connection.TriggerDefinition
if err := remarshal(data, &decoded); err != nil {
return nil, err
}
return ensureNonNilTriggers(decoded), nil
}
}
func decodeString(data interface{}) (string, error) {
switch value := data.(type) {
case nil:
return "", nil
case string:
return value, nil
default:
return fmt.Sprint(value), nil
}
}
func decodeResultSets(data interface{}) ([]connection.ResultSetData, error) {
switch items := data.(type) {
case nil:
return []connection.ResultSetData{}, nil
case []connection.ResultSetData:
return ensureNonNilResultSets(append([]connection.ResultSetData(nil), items...)), nil
default:
var decoded []connection.ResultSetData
if err := remarshal(data, &decoded); err != nil {
return nil, err
}
return ensureNonNilResultSets(decoded), nil
}
}
func remarshal(from interface{}, to interface{}) error {
payload, err := json.Marshal(from)
if err != nil {
return err
}
return json.Unmarshal(payload, to)
}
func normalizeMaxRowsPerResult(input int) int {
if input <= 0 {
return defaultMaxRowsPerResult
}
if input > maxRowsPerResultLimit {
return maxRowsPerResultLimit
}
return input
}
func normalizeResultSets(resultSets []connection.ResultSetData, maxRows int) ([]sqlResultSet, bool) {
normalized := make([]sqlResultSet, 0, len(resultSets))
truncatedAny := false
for _, resultSet := range resultSets {
rows := ensureNonNilRows(resultSet.Rows)
rowCount := len(rows)
truncated := false
if maxRows > 0 && len(rows) > maxRows {
rows = append([]map[string]interface{}(nil), rows[:maxRows]...)
truncated = true
truncatedAny = true
}
normalized = append(normalized, sqlResultSet{
StatementIndex: resultSet.StatementIndex,
Columns: ensureNonNilStrings(append([]string(nil), resultSet.Columns...)),
Rows: rows,
Messages: ensureNonNilStrings(append([]string(nil), resultSet.Messages...)),
RowCount: rowCount,
Truncated: truncated,
})
}
return normalized, truncatedAny
}
func toStatementSummaries(items []appcore.SQLStatementInspection) []sqlStatementSummary {
result := make([]sqlStatementSummary, 0, len(items))
for _, item := range items {
result = append(result, sqlStatementSummary{
Index: item.Index,
Keyword: item.Keyword,
ReadOnly: item.ReadOnly,
})
}
return result
}
func ensureNonNilStrings(items []string) []string {
if items == nil {
return []string{}
}
return items
}
func ensureNonNilColumns(items []connection.ColumnDefinition) []connection.ColumnDefinition {
if items == nil {
return []connection.ColumnDefinition{}
}
return items
}
func ensureNonNilColumnsWithTable(items []connection.ColumnDefinitionWithTable) []connection.ColumnDefinitionWithTable {
if items == nil {
return []connection.ColumnDefinitionWithTable{}
}
return items
}
func ensureNonNilIndexes(items []connection.IndexDefinition) []connection.IndexDefinition {
if items == nil {
return []connection.IndexDefinition{}
}
return items
}
func ensureNonNilForeignKeys(items []connection.ForeignKeyDefinition) []connection.ForeignKeyDefinition {
if items == nil {
return []connection.ForeignKeyDefinition{}
}
return items
}
func ensureNonNilTriggers(items []connection.TriggerDefinition) []connection.TriggerDefinition {
if items == nil {
return []connection.TriggerDefinition{}
}
return items
}
func ensureNonNilRows(items []map[string]interface{}) []map[string]interface{} {
if items == nil {
return []map[string]interface{}{}
}
return items
}
func ensureNonNilResultSets(items []connection.ResultSetData) []connection.ResultSetData {
if items == nil {
return []connection.ResultSetData{}
}
return items
}
type sqlSafetyStatement struct {
Index int
Keyword string
OperationType ai.SQLOperationType
}
type sqlSafetyDecision struct {
requiresConfirm bool
disallowed []sqlSafetyStatement
confirmRequired []sqlSafetyStatement
}
func evaluateSQLSafety(level ai.SQLPermissionLevel, inspection appcore.SQLInspection) sqlSafetyDecision {
decision := sqlSafetyDecision{
disallowed: []sqlSafetyStatement{},
confirmRequired: []sqlSafetyStatement{},
}
for _, stmt := range inspection.Statements {
statement := sqlSafetyStatement{
Index: stmt.Index,
Keyword: strings.TrimSpace(stmt.Keyword),
OperationType: classifyStatementOperation(stmt),
}
if !isOperationAllowed(level, statement.OperationType) {
decision.disallowed = append(decision.disallowed, statement)
continue
}
if statement.OperationType != ai.SQLOpQuery {
decision.requiresConfirm = true
decision.confirmRequired = append(decision.confirmRequired, statement)
}
}
return decision
}
func classifyStatementOperation(stmt appcore.SQLStatementInspection) ai.SQLOperationType {
if stmt.ReadOnly {
return ai.SQLOpQuery
}
switch strings.ToLower(strings.TrimSpace(stmt.Keyword)) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return ai.SQLOpDML
case "create", "alter", "drop", "truncate", "rename":
return ai.SQLOpDDL
default:
return ai.SQLOpOther
}
}
func isOperationAllowed(level ai.SQLPermissionLevel, opType ai.SQLOperationType) bool {
switch normalizeSQLSafetyLevel(level) {
case ai.PermissionReadOnly:
return opType == ai.SQLOpQuery
case ai.PermissionReadWrite:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
case ai.PermissionFull:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
default:
return opType == ai.SQLOpQuery
}
}
func normalizeSQLSafetyLevel(level ai.SQLPermissionLevel) ai.SQLPermissionLevel {
switch level {
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
return level
default:
return ai.PermissionReadOnly
}
}
func buildSafetyDeniedMessage(level ai.SQLPermissionLevel, statements []sqlSafetyStatement) string {
return fmt.Sprintf("当前 GoNavi AI 安全控制为%s已阻止以下语句%s。%s", safetyLevelDisplayName(level), formatSafetyStatements(statements), safetyLevelRuleText(level))
}
func safetyLevelDisplayName(level ai.SQLPermissionLevel) string {
switch normalizeSQLSafetyLevel(level) {
case ai.PermissionReadOnly:
return "只读模式"
case ai.PermissionReadWrite:
return "读写模式"
case ai.PermissionFull:
return "完全模式"
default:
return "只读模式"
}
}
func safetyLevelRuleText(level ai.SQLPermissionLevel) string {
switch normalizeSQLSafetyLevel(level) {
case ai.PermissionReadOnly:
return "只读模式仅允许查询语句。"
case ai.PermissionReadWrite:
return "读写模式仅允许查询和 DML 语句。"
case ai.PermissionFull:
return "完全模式仅允许查询、DML 和 DDL未识别操作仍会被阻止。"
default:
return "只读模式仅允许查询语句。"
}
}
func formatSafetyStatements(statements []sqlSafetyStatement) string {
parts := make([]string, 0, len(statements))
for _, stmt := range statements {
keyword := strings.TrimSpace(stmt.Keyword)
if keyword == "" {
keyword = "unknown"
}
parts = append(parts, fmt.Sprintf("#%d %s(%s)", stmt.Index, strings.ToLower(keyword), strings.ToUpper(string(stmt.OperationType))))
}
return strings.Join(parts, "")
}