Files
BackupX/server/internal/repository/node_repository.go
Awuqing 34106556be 修复: 后端审查发现的 5 项问题
根据 Spec + Code Quality 双审查修复:

1. BatchCreate 事务保护(node_service.go/node_repository.go)
   原循环 Create 在 DB 约束失败时会残留半截数据。改为预先构造所有 Node
   再走 repo.BatchCreate 单一事务,任一失败整体回滚。

2. Peek 语义与 Consume 对齐(agent_install_token_repository.go)
   FindByToken 无条件返回任意记录,导致已消费/已过期的僵尸 token
   可通过 compose 端点的 mode 检查但必然 Consume 失败,出现 410 假错。
   新增 FindValidByToken,Peek 改用之。

3. MasterURL / AgentToken / AgentVersion 渲染前校验(installscript/renderer.go)
   防止 YAML 注入(换行/引号逃逸 compose 配置)、shell 注入($(...))、
   非法字符。加 TestRenderScriptRejects* 系列测试覆盖。

4. ipLimiter 无界增长修复(install_handler.go)
   新增 gc 方法 + startGC 后台协程,每 window 周期清理过期 IP 条目。
   RouterDependencies.Context 控制生命周期;app 传入 ctx,测试 t.Cleanup 取消。

5. CreateInstallToken 的 CreatedByID 从 JWT subject 解析(node_handler.go)
   原硬编码 0 导致审计不可追溯。新增 resolveCurrentUserID helper,
   借助 UserRepository 把 JWT subject(用户名)→ user.ID;失败退回 0。
2026-04-19 17:14:17 +08:00

127 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"errors"
"time"
"backupx/server/internal/model"
"gorm.io/gorm"
)
type NodeRepository interface {
List(context.Context) ([]model.Node, error)
FindByID(context.Context, uint) (*model.Node, error)
FindByToken(context.Context, string) (*model.Node, error)
FindLocal(context.Context) (*model.Node, error)
Create(context.Context, *model.Node) error
// BatchCreate 在单一事务内批量创建节点,任一失败即全部回滚。
BatchCreate(ctx context.Context, nodes []*model.Node) error
Update(context.Context, *model.Node) error
Delete(context.Context, uint) error
MarkStaleOffline(ctx context.Context, threshold time.Time) (int64, error)
}
type GormNodeRepository struct {
db *gorm.DB
}
func NewNodeRepository(db *gorm.DB) *GormNodeRepository {
return &GormNodeRepository{db: db}
}
func (r *GormNodeRepository) List(ctx context.Context) ([]model.Node, error) {
var items []model.Node
if err := r.db.WithContext(ctx).Order("is_local desc, updated_at desc").Find(&items).Error; err != nil {
return nil, err
}
return items, nil
}
func (r *GormNodeRepository) FindByID(ctx context.Context, id uint) (*model.Node, error) {
var item model.Node
if err := r.db.WithContext(ctx).First(&item, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &item, nil
}
func (r *GormNodeRepository) FindByToken(ctx context.Context, token string) (*model.Node, error) {
var item model.Node
// 主 token 查询
err := r.db.WithContext(ctx).Where("token = ?", token).First(&item).Error
if err == nil {
return &item, nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
// 回退prev_token 且未过期
now := time.Now().UTC()
err = r.db.WithContext(ctx).
Where("prev_token = ? AND prev_token_expires IS NOT NULL AND prev_token_expires > ?", token, now).
First(&item).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &item, nil
}
func (r *GormNodeRepository) FindLocal(ctx context.Context) (*model.Node, error) {
var item model.Node
if err := r.db.WithContext(ctx).Where("is_local = ?", true).First(&item).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &item, nil
}
func (r *GormNodeRepository) Create(ctx context.Context, item *model.Node) error {
return r.db.WithContext(ctx).Create(item).Error
}
// BatchCreate 在单一事务中批量创建节点。任一记录失败即事务回滚。
// 节点 ID 在事务提交后回填到入参切片元素上。
func (r *GormNodeRepository) BatchCreate(ctx context.Context, nodes []*model.Node) error {
if len(nodes) == 0 {
return nil
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, n := range nodes {
if err := tx.Create(n).Error; err != nil {
return err
}
}
return nil
})
}
func (r *GormNodeRepository) Update(ctx context.Context, item *model.Node) error {
return r.db.WithContext(ctx).Save(item).Error
}
func (r *GormNodeRepository) Delete(ctx context.Context, id uint) error {
return r.db.WithContext(ctx).Delete(&model.Node{}, id).Error
}
// MarkStaleOffline 把最近心跳早于 threshold 的在线远程节点标记为离线。
// 本机节点 (is_local=true) 不受影响,由主程序自己维护 online 状态。
// 返回受影响行数。
func (r *GormNodeRepository) MarkStaleOffline(ctx context.Context, threshold time.Time) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.Node{}).
Where("is_local = ? AND status = ? AND last_seen < ?", false, model.NodeStatusOnline, threshold).
Update("status", model.NodeStatusOffline)
if result.Error != nil {
return 0, result.Error
}
return result.RowsAffected, nil
}