mirror of
https://github.com/Awuqing/BackupX.git
synced 2026-05-31 07:59:34 +08:00
根据 Spec + Code Quality 双审查修复: 1. BatchCreate 事务保护(node_service.go/node_repository.go) 原循环 Create 在 DB 约束失败时会残留半截数据。改为预先构造所有 Node 再走 repo.BatchCreate 单一事务,任一失败整体回滚。 2. Peek 语义与 Consume 对齐(agent_install_token_repository.go) FindByToken 无条件返回任意记录,导致已消费/已过期的僵尸 token 可通过 compose 端点的 mode 检查但必然 Consume 失败,出现 410 假错。 新增 FindValidByToken,Peek 改用之。 3. MasterURL / AgentToken / AgentVersion 渲染前校验(installscript/renderer.go) 防止 YAML 注入(换行/引号逃逸 compose 配置)、shell 注入($(...))、 非法字符。加 TestRenderScriptRejects* 系列测试覆盖。 4. ipLimiter 无界增长修复(install_handler.go) 新增 gc 方法 + startGC 后台协程,每 window 周期清理过期 IP 条目。 RouterDependencies.Context 控制生命周期;app 传入 ctx,测试 t.Cleanup 取消。 5. CreateInstallToken 的 CreatedByID 从 JWT subject 解析(node_handler.go) 原硬编码 0 导致审计不可追溯。新增 resolveCurrentUserID helper, 借助 UserRepository 把 JWT subject(用户名)→ user.ID;失败退回 0。
127 lines
3.8 KiB
Go
127 lines
3.8 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"time"
|
||
|
||
"backupx/server/internal/model"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type NodeRepository interface {
|
||
List(context.Context) ([]model.Node, error)
|
||
FindByID(context.Context, uint) (*model.Node, error)
|
||
FindByToken(context.Context, string) (*model.Node, error)
|
||
FindLocal(context.Context) (*model.Node, error)
|
||
Create(context.Context, *model.Node) error
|
||
// BatchCreate 在单一事务内批量创建节点,任一失败即全部回滚。
|
||
BatchCreate(ctx context.Context, nodes []*model.Node) error
|
||
Update(context.Context, *model.Node) error
|
||
Delete(context.Context, uint) error
|
||
MarkStaleOffline(ctx context.Context, threshold time.Time) (int64, error)
|
||
}
|
||
|
||
type GormNodeRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
func NewNodeRepository(db *gorm.DB) *GormNodeRepository {
|
||
return &GormNodeRepository{db: db}
|
||
}
|
||
|
||
func (r *GormNodeRepository) List(ctx context.Context) ([]model.Node, error) {
|
||
var items []model.Node
|
||
if err := r.db.WithContext(ctx).Order("is_local desc, updated_at desc").Find(&items).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return items, nil
|
||
}
|
||
|
||
func (r *GormNodeRepository) FindByID(ctx context.Context, id uint) (*model.Node, error) {
|
||
var item model.Node
|
||
if err := r.db.WithContext(ctx).First(&item, id).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
return &item, nil
|
||
}
|
||
|
||
func (r *GormNodeRepository) FindByToken(ctx context.Context, token string) (*model.Node, error) {
|
||
var item model.Node
|
||
// 主 token 查询
|
||
err := r.db.WithContext(ctx).Where("token = ?", token).First(&item).Error
|
||
if err == nil {
|
||
return &item, nil
|
||
}
|
||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, err
|
||
}
|
||
// 回退:prev_token 且未过期
|
||
now := time.Now().UTC()
|
||
err = r.db.WithContext(ctx).
|
||
Where("prev_token = ? AND prev_token_expires IS NOT NULL AND prev_token_expires > ?", token, now).
|
||
First(&item).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
return &item, nil
|
||
}
|
||
|
||
func (r *GormNodeRepository) FindLocal(ctx context.Context) (*model.Node, error) {
|
||
var item model.Node
|
||
if err := r.db.WithContext(ctx).Where("is_local = ?", true).First(&item).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
return &item, nil
|
||
}
|
||
|
||
func (r *GormNodeRepository) Create(ctx context.Context, item *model.Node) error {
|
||
return r.db.WithContext(ctx).Create(item).Error
|
||
}
|
||
|
||
// BatchCreate 在单一事务中批量创建节点。任一记录失败即事务回滚。
|
||
// 节点 ID 在事务提交后回填到入参切片元素上。
|
||
func (r *GormNodeRepository) BatchCreate(ctx context.Context, nodes []*model.Node) error {
|
||
if len(nodes) == 0 {
|
||
return nil
|
||
}
|
||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
for _, n := range nodes {
|
||
if err := tx.Create(n).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (r *GormNodeRepository) Update(ctx context.Context, item *model.Node) error {
|
||
return r.db.WithContext(ctx).Save(item).Error
|
||
}
|
||
|
||
func (r *GormNodeRepository) Delete(ctx context.Context, id uint) error {
|
||
return r.db.WithContext(ctx).Delete(&model.Node{}, id).Error
|
||
}
|
||
|
||
// MarkStaleOffline 把最近心跳早于 threshold 的在线远程节点标记为离线。
|
||
// 本机节点 (is_local=true) 不受影响,由主程序自己维护 online 状态。
|
||
// 返回受影响行数。
|
||
func (r *GormNodeRepository) MarkStaleOffline(ctx context.Context, threshold time.Time) (int64, error) {
|
||
result := r.db.WithContext(ctx).Model(&model.Node{}).
|
||
Where("is_local = ? AND status = ? AND last_seen < ?", false, model.NodeStatusOnline, threshold).
|
||
Update("status", model.NodeStatusOffline)
|
||
if result.Error != nil {
|
||
return 0, result.Error
|
||
}
|
||
return result.RowsAffected, nil
|
||
}
|