feat: add complete MFA support

Add complete MFA support with TOTP, recovery codes, WebAuthn, trusted-device cookie flow, and email/SMS OTP delivery via notification channels. Security follow-up: trusted device tokens are stored in HttpOnly cookies, and SMS OTP reuses the existing Webhook notifier to avoid introducing a new dynamic URL sink.
This commit is contained in:
Wu Qing
2026-04-25 22:14:50 +08:00
committed by GitHub
parent 1715abfcfb
commit 5af5f97efb
47 changed files with 5718 additions and 378 deletions

View File

@@ -60,9 +60,9 @@ func New(ctx context.Context, cfg config.Config, version string) (*Application,
jwtManager := security.NewJWTManager(resolvedSecurity.JWTSecret, config.MustJWTDuration(cfg.Security))
rateLimiter := security.NewLoginRateLimiter(5, time.Minute)
authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, rateLimiter)
systemService := service.NewSystemService(cfg, version, time.Now().UTC())
configCipher := codec.NewConfigCipher(resolvedSecurity.EncryptionKey)
authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, rateLimiter, configCipher)
systemService := service.NewSystemService(cfg, version, time.Now().UTC())
storageRegistry := storage.NewRegistry(
storageRclone.NewLocalDiskFactory(),
storageRclone.NewS3Factory(),
@@ -87,6 +87,7 @@ func New(ctx context.Context, cfg config.Config, version string) (*Application,
retentionService := backupretention.NewService(backupRecordRepo)
notifyRegistry := notify.NewRegistry(notify.NewEmailNotifier(), notify.NewWebhookNotifier(), notify.NewTelegramNotifier())
notificationService := service.NewNotificationService(notificationRepo, notifyRegistry, configCipher)
authService.SetNotificationService(notificationService)
// 初始化 rclone 传输配置(重试 + 带宽限制)
rcloneCtx := storageRclone.ConfiguredContext(ctx, storageRclone.TransferConfig{
LowLevelRetries: cfg.Backup.Retries,
@@ -245,32 +246,32 @@ func New(ctx context.Context, cfg config.Config, version string) (*Application,
metricsCollector.Start(ctx)
router := aphttp.NewRouter(aphttp.RouterDependencies{
Context: ctx,
Config: cfg,
Version: version,
Logger: appLogger,
AuthService: authService,
SystemService: systemService,
StorageTargetService: storageTargetService,
BackupTaskService: backupTaskService,
BackupExecutionService: backupExecutionService,
BackupRecordService: backupRecordService,
RestoreService: restoreService,
VerificationService: verificationService,
ReplicationService: replicationService,
TaskTemplateService: taskTemplateService,
TaskExportService: taskExportService,
SearchService: searchService,
EventBroadcaster: eventBroadcaster,
UserService: userService,
ApiKeyService: apiKeyService,
NotificationService: notificationService,
DashboardService: dashboardService,
SettingsService: settingsService,
Context: ctx,
Config: cfg,
Version: version,
Logger: appLogger,
AuthService: authService,
SystemService: systemService,
StorageTargetService: storageTargetService,
BackupTaskService: backupTaskService,
BackupExecutionService: backupExecutionService,
BackupRecordService: backupRecordService,
RestoreService: restoreService,
VerificationService: verificationService,
ReplicationService: replicationService,
TaskTemplateService: taskTemplateService,
TaskExportService: taskExportService,
SearchService: searchService,
EventBroadcaster: eventBroadcaster,
UserService: userService,
ApiKeyService: apiKeyService,
NotificationService: notificationService,
DashboardService: dashboardService,
SettingsService: settingsService,
NodeService: nodeService,
AgentService: agentService,
DatabaseDiscoveryService: databaseDiscoveryService,
AuditService: auditService,
AuditService: auditService,
JWTManager: jwtManager,
UserRepository: userRepo,
SystemConfigRepo: systemConfigRepo,

View File

@@ -1,12 +1,23 @@
package http
import (
"net"
stdhttp "net/http"
"strings"
"time"
"backupx/server/internal/apperror"
"backupx/server/internal/service"
"backupx/server/pkg/response"
"github.com/gin-gonic/gin"
)
const (
trustedDeviceCookieName = "backupx_trusted_device"
trustedDeviceCookiePath = "/api/auth"
trustedDeviceCookieMaxAge = int((30 * 24 * time.Hour) / time.Second)
)
type AuthHandler struct {
authService *service.AuthService
}
@@ -44,11 +55,18 @@ func (h *AuthHandler) Login(c *gin.Context) {
response.Error(c, apperror.BadRequest("AUTH_LOGIN_INVALID", "登录参数不合法", err))
return
}
if strings.TrimSpace(input.TrustedDeviceToken) == "" {
input.TrustedDeviceToken = trustedDeviceCookieValue(c)
}
payload, err := h.authService.Login(c.Request.Context(), input, ClientKey(c))
if err != nil {
response.Error(c, err)
return
}
if payload.TrustedDeviceToken != "" {
setTrustedDeviceCookie(c, payload.TrustedDeviceToken)
payload.TrustedDeviceToken = ""
}
response.Success(c, payload)
}
@@ -83,9 +101,315 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
response.Error(c, err)
return
}
clearTrustedDeviceCookie(c)
response.Success(c, gin.H{"changed": true})
}
func (h *AuthHandler) PrepareTwoFactor(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.TwoFactorSetupInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err))
return
}
payload, err := h.authService.PrepareTwoFactor(c.Request.Context(), subject, input)
if err != nil {
response.Error(c, err)
return
}
response.Success(c, payload)
}
func (h *AuthHandler) EnableTwoFactor(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.EnableTwoFactorInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err))
return
}
user, err := h.authService.EnableTwoFactor(c.Request.Context(), subject, input)
if err != nil {
response.Error(c, err)
return
}
response.Success(c, user)
}
func (h *AuthHandler) DisableTwoFactor(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.DisableTwoFactorInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err))
return
}
user, err := h.authService.DisableTwoFactor(c.Request.Context(), subject, input)
if err != nil {
response.Error(c, err)
return
}
if !user.MFAEnabled {
clearTrustedDeviceCookie(c)
}
response.Success(c, user)
}
func (h *AuthHandler) RegenerateRecoveryCodes(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.RegenerateRecoveryCodesInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err))
return
}
payload, err := h.authService.RegenerateRecoveryCodes(c.Request.Context(), subject, input)
if err != nil {
response.Error(c, err)
return
}
response.Success(c, payload)
}
func (h *AuthHandler) ConfigureOTP(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.OTPConfigInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_OTP_INVALID", "参数不合法", err))
return
}
user, err := h.authService.ConfigureOutOfBandOTP(c.Request.Context(), subject, input)
if err != nil {
response.Error(c, err)
return
}
if !user.MFAEnabled {
clearTrustedDeviceCookie(c)
}
response.Success(c, user)
}
func (h *AuthHandler) SendLoginOTP(c *gin.Context) {
var input service.LoginOTPInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_OTP_INVALID", "参数不合法", err))
return
}
if err := h.authService.SendLoginOTP(c.Request.Context(), input, ClientKey(c)); err != nil {
response.Error(c, err)
return
}
response.Success(c, gin.H{"sent": true})
}
func (h *AuthHandler) BeginWebAuthnRegistration(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.WebAuthnRegistrationOptionsInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err))
return
}
options, err := h.authService.BeginWebAuthnRegistration(c.Request.Context(), subject, input, webAuthnRequestContext(c))
if err != nil {
response.Error(c, err)
return
}
response.Success(c, options)
}
func (h *AuthHandler) FinishWebAuthnRegistration(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.WebAuthnRegistrationFinishInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err))
return
}
user, err := h.authService.FinishWebAuthnRegistration(c.Request.Context(), subject, input)
if err != nil {
response.Error(c, err)
return
}
response.Success(c, user)
}
func (h *AuthHandler) BeginWebAuthnLogin(c *gin.Context) {
var input service.WebAuthnLoginOptionsInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err))
return
}
options, err := h.authService.BeginWebAuthnLogin(c.Request.Context(), input, webAuthnRequestContext(c), ClientKey(c))
if err != nil {
response.Error(c, err)
return
}
response.Success(c, options)
}
func (h *AuthHandler) ListWebAuthnCredentials(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
items, err := h.authService.ListWebAuthnCredentials(c.Request.Context(), subject)
if err != nil {
response.Error(c, err)
return
}
response.Success(c, items)
}
func (h *AuthHandler) DeleteWebAuthnCredential(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.WebAuthnCredentialDeleteInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err))
return
}
user, err := h.authService.DeleteWebAuthnCredential(c.Request.Context(), subject, c.Param("id"), input)
if err != nil {
response.Error(c, err)
return
}
if !user.MFAEnabled {
clearTrustedDeviceCookie(c)
}
response.Success(c, user)
}
func (h *AuthHandler) ListTrustedDevices(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
items, err := h.authService.ListTrustedDevices(c.Request.Context(), subject)
if err != nil {
response.Error(c, err)
return
}
response.Success(c, items)
}
func (h *AuthHandler) RevokeTrustedDevice(c *gin.Context) {
subjectValue, _ := c.Get(contextUserSubjectKey)
subject, err := service.SubjectFromContextValue(subjectValue)
if err != nil {
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
return
}
var input service.TrustedDeviceRevokeInput
if err := c.ShouldBindJSON(&input); err != nil {
response.Error(c, apperror.BadRequest("AUTH_TRUSTED_DEVICE_INVALID", "参数不合法", err))
return
}
if err := h.authService.RevokeTrustedDevice(c.Request.Context(), subject, c.Param("id"), input); err != nil {
response.Error(c, err)
return
}
clearTrustedDeviceCookie(c)
response.Success(c, gin.H{"deleted": true})
}
func (h *AuthHandler) Logout(c *gin.Context) {
response.Success(c, gin.H{"loggedOut": true})
}
func webAuthnRequestContext(c *gin.Context) service.WebAuthnRequestContext {
host := firstForwardedValue(c.Request.Host)
if forwardedHost := firstForwardedValue(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
host = forwardedHost
}
rpID := host
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
rpID = parsedHost
}
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
if forwardedProto := firstForwardedValue(c.GetHeader("X-Forwarded-Proto")); forwardedProto != "" {
scheme = forwardedProto
}
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin == "" {
origin = scheme + "://" + host
}
return service.WebAuthnRequestContext{RPID: rpID, Origin: origin}
}
func firstForwardedValue(value string) string {
parts := strings.Split(value, ",")
if len(parts) == 0 {
return ""
}
return strings.TrimSpace(parts[0])
}
func trustedDeviceCookieValue(c *gin.Context) string {
token, err := c.Cookie(trustedDeviceCookieName)
if err != nil {
return ""
}
return strings.TrimSpace(token)
}
func setTrustedDeviceCookie(c *gin.Context, token string) {
writeTrustedDeviceCookie(c, strings.TrimSpace(token), trustedDeviceCookieMaxAge)
}
func clearTrustedDeviceCookie(c *gin.Context) {
writeTrustedDeviceCookie(c, "", -1)
}
func writeTrustedDeviceCookie(c *gin.Context, value string, maxAge int) {
c.SetSameSite(stdhttp.SameSiteLaxMode)
c.SetCookie(trustedDeviceCookieName, value, maxAge, trustedDeviceCookiePath, "", requestIsSecure(c), true)
}
func requestIsSecure(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
return strings.EqualFold(firstForwardedValue(c.GetHeader("X-Forwarded-Proto")), "https")
}

View File

@@ -19,6 +19,7 @@ import (
"backupx/server/internal/repository"
"backupx/server/internal/security"
"backupx/server/internal/service"
"backupx/server/internal/storage/codec"
)
// setupInstallFlowRouter 构造一个 Node + Agent + InstallToken 全量依赖的 router
@@ -40,6 +41,13 @@ func setupInstallFlowRouter(t *testing.T) (http.Handler, string) {
if err != nil {
t.Fatalf("db: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("sql db: %v", err)
}
t.Cleanup(func() {
_ = sqlDB.Close()
})
userRepo := repository.NewUserRepository(db)
systemConfigRepo := repository.NewSystemConfigRepository(db)
@@ -48,7 +56,7 @@ func setupInstallFlowRouter(t *testing.T) (http.Handler, string) {
t.Fatalf("security: %v", err)
}
jwtMgr := security.NewJWTManager(resolved.JWTSecret, time.Hour)
authSvc := service.NewAuthService(userRepo, systemConfigRepo, jwtMgr, security.NewLoginRateLimiter(5, time.Minute))
authSvc := service.NewAuthService(userRepo, systemConfigRepo, jwtMgr, security.NewLoginRateLimiter(5, time.Minute), codec.NewConfigCipher(resolved.EncryptionKey))
systemSvc := service.NewSystemService(cfg, "test", time.Now().UTC())
nodeRepo := repository.NewNodeRepository(db)

View File

@@ -94,9 +94,22 @@ func NewRouter(deps RouterDependencies) *gin.Engine {
auth.GET("/setup/status", authHandler.SetupStatus)
auth.POST("/setup", authHandler.Setup)
auth.POST("/login", authHandler.Login)
auth.POST("/otp/send", authHandler.SendLoginOTP)
auth.POST("/webauthn/login/options", authHandler.BeginWebAuthnLogin)
auth.POST("/logout", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.Logout)
auth.GET("/profile", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.Profile)
auth.PUT("/password", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ChangePassword)
auth.POST("/2fa/setup", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.PrepareTwoFactor)
auth.POST("/2fa/enable", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.EnableTwoFactor)
auth.POST("/2fa/recovery-codes", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.RegenerateRecoveryCodes)
auth.DELETE("/2fa", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.DisableTwoFactor)
auth.PUT("/otp/config", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ConfigureOTP)
auth.POST("/webauthn/register/options", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.BeginWebAuthnRegistration)
auth.POST("/webauthn/register/finish", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.FinishWebAuthnRegistration)
auth.GET("/webauthn/credentials", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ListWebAuthnCredentials)
auth.DELETE("/webauthn/credentials/:id", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.DeleteWebAuthnCredential)
auth.GET("/trusted-devices", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ListTrustedDevices)
auth.DELETE("/trusted-devices/:id", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.RevokeTrustedDevice)
}
system := api.Group("/system")
@@ -229,6 +242,7 @@ func NewRouter(deps RouterDependencies) *gin.Engine {
users.GET("", userHandler.List)
users.POST("", userHandler.Create)
users.PUT("/:id", userHandler.Update)
users.POST("/:id/2fa/reset", userHandler.ResetTwoFactor)
users.DELETE("/:id", userHandler.Delete)
}
@@ -279,10 +293,10 @@ func NewRouter(deps RouterDependencies) *gin.Engine {
nodes.PUT("/:id", RequireRole("admin"), nodeHandler.Update)
nodes.DELETE("/:id", RequireRole("admin"), nodeHandler.Delete)
nodes.GET("/:id/fs/list", nodeHandler.ListDirectory)
nodes.POST("/batch", RequireRole("admin"), nodeHandler.BatchCreate)
nodes.POST("/:id/install-tokens", RequireRole("admin"), nodeHandler.CreateInstallToken)
nodes.POST("/:id/rotate-token", RequireRole("admin"), nodeHandler.RotateToken)
nodes.GET("/:id/install-script-preview", RequireRole("admin"), nodeHandler.PreviewScript)
nodes.POST("/batch", RequireRole("admin"), nodeHandler.BatchCreate)
nodes.POST("/:id/install-tokens", RequireRole("admin"), nodeHandler.CreateInstallToken)
nodes.POST("/:id/rotate-token", RequireRole("admin"), nodeHandler.RotateToken)
nodes.GET("/:id/install-script-preview", RequireRole("admin"), nodeHandler.PreviewScript)
// Agent APItoken 认证,无需 JWT
if deps.AgentService != nil {

View File

@@ -16,50 +16,17 @@ import (
"backupx/server/internal/repository"
"backupx/server/internal/security"
"backupx/server/internal/service"
"backupx/server/internal/storage/codec"
"github.com/pquerna/otp/totp"
)
func TestSetupLoginAndProfileFlow(t *testing.T) {
tempDir := t.TempDir()
cfg := config.Config{
Server: config.ServerConfig{Host: "127.0.0.1", Port: 8340, Mode: "test"},
Database: config.DatabaseConfig{Path: filepath.Join(tempDir, "backupx.db")},
Security: config.SecurityConfig{JWTExpire: "24h"},
Log: config.LogConfig{Level: "error"},
}
log, err := logger.New(cfg.Log)
if err != nil {
t.Fatalf("logger.New error: %v", err)
}
db, err := database.Open(cfg.Database, log)
if err != nil {
t.Fatalf("database.Open error: %v", err)
}
userRepo := repository.NewUserRepository(db)
systemConfigRepo := repository.NewSystemConfigRepository(db)
resolved, err := service.ResolveSecurity(context.Background(), cfg.Security, systemConfigRepo)
if err != nil {
t.Fatalf("ResolveSecurity error: %v", err)
}
jwtManager := security.NewJWTManager(resolved.JWTSecret, time.Hour)
authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, security.NewLoginRateLimiter(5, time.Minute))
systemService := service.NewSystemService(cfg, "test", time.Now().UTC())
router := NewRouter(RouterDependencies{
Config: cfg,
Version: "test",
Logger: log,
AuthService: authService,
SystemService: systemService,
JWTManager: jwtManager,
UserRepository: userRepo,
SystemConfigRepo: systemConfigRepo,
})
router, _ := newTestHTTPRouter(t)
setupBody, _ := json.Marshal(map[string]string{
"username": "admin",
"password": "password-123",
"username": "admin",
"password": "password-123",
"displayName": "Admin",
})
setupRequest := httptest.NewRequest(http.MethodPost, "/api/auth/setup", bytes.NewBuffer(setupBody))
@@ -92,3 +59,143 @@ func TestSetupLoginAndProfileFlow(t *testing.T) {
t.Fatalf("expected profile 200, got %d", profileRecorder.Code)
}
}
func TestTrustedDeviceCookieSkipsMFA(t *testing.T) {
router, authService := newTestHTTPRouter(t)
if _, err := authService.Setup(context.Background(), service.SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
}); err != nil {
t.Fatalf("Setup error: %v", err)
}
totpSetup, err := authService.PrepareTwoFactor(context.Background(), "1", service.TwoFactorSetupInput{
CurrentPassword: "password-123",
})
if err != nil {
t.Fatalf("PrepareTwoFactor error: %v", err)
}
enableCode, err := totp.GenerateCode(totpSetup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode error: %v", err)
}
if _, err := authService.EnableTwoFactor(context.Background(), "1", service.EnableTwoFactorInput{Code: enableCode}); err != nil {
t.Fatalf("EnableTwoFactor error: %v", err)
}
loginCode, err := totp.GenerateCode(totpSetup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode login error: %v", err)
}
loginBody, _ := json.Marshal(map[string]any{
"username": "admin",
"password": "password-123",
"twoFactorCode": loginCode,
"rememberDevice": true,
"trustedDeviceName": "test browser",
})
loginRequest := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(loginBody))
loginRequest.Header.Set("Content-Type", "application/json")
loginRecorder := httptest.NewRecorder()
router.ServeHTTP(loginRecorder, loginRequest)
if loginRecorder.Code != http.StatusOK {
t.Fatalf("expected login 200, got %d: %s", loginRecorder.Code, loginRecorder.Body.String())
}
trustedCookie := findCookie(loginRecorder.Result().Cookies(), trustedDeviceCookieName)
if trustedCookie == nil {
t.Fatalf("expected trusted device cookie")
}
if !trustedCookie.HttpOnly {
t.Fatalf("expected trusted device cookie to be HttpOnly")
}
if trustedCookie.Path != trustedDeviceCookiePath {
t.Fatalf("expected trusted device cookie path %q, got %q", trustedDeviceCookiePath, trustedCookie.Path)
}
var loginResponse struct {
Data struct {
Token string `json:"token"`
TrustedDeviceToken string `json:"trustedDeviceToken"`
TrustedDevice *service.TrustedDeviceOutput `json:"trustedDevice"`
} `json:"data"`
}
if err := json.Unmarshal(loginRecorder.Body.Bytes(), &loginResponse); err != nil {
t.Fatalf("unmarshal login response: %v", err)
}
if loginResponse.Data.Token == "" || loginResponse.Data.TrustedDevice == nil {
t.Fatalf("expected login token and trusted device metadata")
}
if loginResponse.Data.TrustedDeviceToken != "" {
t.Fatalf("trusted device token should not be exposed in response body")
}
secondBody, _ := json.Marshal(map[string]string{
"username": "admin",
"password": "password-123",
})
secondRequest := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(secondBody))
secondRequest.Header.Set("Content-Type", "application/json")
secondRequest.AddCookie(trustedCookie)
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondRequest)
if secondRecorder.Code != http.StatusOK {
t.Fatalf("expected trusted device login 200, got %d: %s", secondRecorder.Code, secondRecorder.Body.String())
}
}
func newTestHTTPRouter(t *testing.T) (http.Handler, *service.AuthService) {
t.Helper()
tempDir := t.TempDir()
cfg := config.Config{
Server: config.ServerConfig{Host: "127.0.0.1", Port: 8340, Mode: "test"},
Database: config.DatabaseConfig{Path: filepath.Join(tempDir, "backupx.db")},
Security: config.SecurityConfig{JWTExpire: "24h"},
Log: config.LogConfig{Level: "error"},
}
log, err := logger.New(cfg.Log)
if err != nil {
t.Fatalf("logger.New error: %v", err)
}
db, err := database.Open(cfg.Database, log)
if err != nil {
t.Fatalf("database.Open error: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("db.DB error: %v", err)
}
t.Cleanup(func() {
_ = sqlDB.Close()
})
userRepo := repository.NewUserRepository(db)
systemConfigRepo := repository.NewSystemConfigRepository(db)
resolved, err := service.ResolveSecurity(context.Background(), cfg.Security, systemConfigRepo)
if err != nil {
t.Fatalf("ResolveSecurity error: %v", err)
}
jwtManager := security.NewJWTManager(resolved.JWTSecret, time.Hour)
authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, security.NewLoginRateLimiter(5, time.Minute), codec.NewConfigCipher(resolved.EncryptionKey))
systemService := service.NewSystemService(cfg, "test", time.Now().UTC())
router := NewRouter(RouterDependencies{
Config: cfg,
Version: "test",
Logger: log,
AuthService: authService,
SystemService: systemService,
JWTManager: jwtManager,
UserRepository: userRepo,
SystemConfigRepo: systemConfigRepo,
})
return router, authService
}
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
for _, cookie := range cookies {
if cookie.Name == name {
return cookie
}
}
return nil
}

View File

@@ -78,3 +78,18 @@ func (h *UserHandler) Delete(c *gin.Context) {
fmt.Sprintf("删除用户 (ID: %d)", id))
response.Success(c, gin.H{"deleted": true})
}
func (h *UserHandler) ResetTwoFactor(c *gin.Context) {
id, ok := parseUintParam(c, "id")
if !ok {
return
}
item, err := h.service.ResetTwoFactor(c.Request.Context(), id)
if err != nil {
response.Error(c, err)
return
}
recordAudit(c, h.auditService, "user", "reset_two_factor", "user", fmt.Sprintf("%d", id), item.Username,
fmt.Sprintf("重置用户 %s 的 MFA", item.Username))
response.Success(c, item)
}

View File

@@ -22,12 +22,25 @@ func IsValidRole(role string) bool {
}
type User struct {
ID uint `gorm:"primaryKey" json:"id"`
Username string `gorm:"size:64;uniqueIndex;not null" json:"username"`
PasswordHash string `gorm:"column:password_hash;not null" json:"-"`
DisplayName string `gorm:"size:128;not null" json:"displayName"`
Email string `gorm:"size:255" json:"email"`
Role string `gorm:"size:32;not null;default:admin" json:"role"`
ID uint `gorm:"primaryKey" json:"id"`
Username string `gorm:"size:64;uniqueIndex;not null" json:"username"`
PasswordHash string `gorm:"column:password_hash;not null" json:"-"`
DisplayName string `gorm:"size:128;not null" json:"displayName"`
Email string `gorm:"size:255" json:"email"`
Phone string `gorm:"size:64" json:"phone"`
Role string `gorm:"size:32;not null;default:admin" json:"role"`
// TwoFactorSecretCiphertext 保存 TOTP 密钥密文;未启用时可作为待确认密钥。
TwoFactorEnabled bool `gorm:"column:two_factor_enabled;not null;default:false" json:"twoFactorEnabled"`
TwoFactorSecretCiphertext string `gorm:"column:two_factor_secret_ciphertext;type:text" json:"-"`
// TwoFactorRecoveryCodeHashes 保存一次性恢复码哈希的 JSON 数组。
TwoFactorRecoveryCodeHashes string `gorm:"column:two_factor_recovery_code_hashes;type:text" json:"-"`
// WebAuthnCredentials 保存通行密钥公钥元数据 JSON不包含私钥或明文密钥。
WebAuthnCredentials string `gorm:"column:webauthn_credentials;type:text" json:"-"`
WebAuthnChallengeCiphertext string `gorm:"column:webauthn_challenge_ciphertext;type:text" json:"-"`
TrustedDevices string `gorm:"column:trusted_devices;type:text" json:"-"`
EmailOTPEnabled bool `gorm:"column:email_otp_enabled;not null;default:false" json:"emailOtpEnabled"`
SMSOTPEnabled bool `gorm:"column:sms_otp_enabled;not null;default:false" json:"smsOtpEnabled"`
OutOfBandOTPCiphertext string `gorm:"column:out_of_band_otp_ciphertext;type:text" json:"-"`
// Disabled 禁用账号(不删除保留审计)。禁用后无法登录。
Disabled bool `gorm:"not null;default:false" json:"disabled"`
CreatedAt time.Time `json:"createdAt"`

View File

@@ -0,0 +1,23 @@
package security
import (
"crypto/rand"
"fmt"
"math/big"
"strings"
)
const LoginOTPDigits = 6
func GenerateNumericOTP() (string, error) {
limit := big.NewInt(1_000_000)
value, err := rand.Int(rand.Reader, limit)
if err != nil {
return "", err
}
return fmt.Sprintf("%0*d", LoginOTPDigits, value.Int64()), nil
}
func NormalizeNumericOTP(code string) string {
return strings.TrimSpace(code)
}

View File

@@ -0,0 +1,49 @@
package security
import (
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"unicode"
)
const RecoveryCodeCount = 10
func GenerateRecoveryCodes(count int) ([]string, error) {
if count <= 0 {
count = RecoveryCodeCount
}
codes := make([]string, 0, count)
for i := 0; i < count; i++ {
raw := make([]byte, 8)
if _, err := rand.Read(raw); err != nil {
return nil, fmt.Errorf("generate recovery code: %w", err)
}
encoded := strings.ToUpper(hex.EncodeToString(raw))
codes = append(codes, encoded[0:4]+"-"+encoded[4:8]+"-"+encoded[8:12]+"-"+encoded[12:16])
}
return codes, nil
}
func NormalizeRecoveryCode(code string) string {
return strings.Map(func(r rune) rune {
if unicode.IsSpace(r) || r == '-' {
return -1
}
return unicode.ToUpper(r)
}, strings.TrimSpace(code))
}
func IsRecoveryCodeCandidate(code string) bool {
normalized := NormalizeRecoveryCode(code)
if len(normalized) != 16 {
return false
}
for _, r := range normalized {
if !('0' <= r && r <= '9') && !('A' <= r && r <= 'F') {
return false
}
}
return true
}

View File

@@ -0,0 +1,68 @@
package security
import (
"bytes"
"encoding/base64"
"image/png"
"strings"
"time"
"unicode"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
)
const TOTPIssuer = "BackupX"
type TOTPEnrollment struct {
Secret string
OTPAuthURL string
QRCodeDataURL string
}
func GenerateTOTPEnrollment(accountName string) (*TOTPEnrollment, error) {
key, err := totp.Generate(totp.GenerateOpts{
Issuer: TOTPIssuer,
AccountName: accountName,
Period: 30,
SecretSize: 20,
Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1,
})
if err != nil {
return nil, err
}
image, err := key.Image(220, 220)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := png.Encode(&buf, image); err != nil {
return nil, err
}
return &TOTPEnrollment{
Secret: key.Secret(),
OTPAuthURL: key.URL(),
QRCodeDataURL: "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()),
}, nil
}
func ValidateTOTPCode(secret string, code string) (bool, error) {
return totp.ValidateCustom(NormalizeTOTPCode(code), secret, time.Now().UTC(), totp.ValidateOpts{
Period: 30,
Skew: 1,
Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1,
})
}
func NormalizeTOTPCode(code string) string {
return strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return -1
}
return r
}, strings.TrimSpace(code))
}

View File

@@ -0,0 +1,447 @@
package security
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"math/big"
"strings"
)
const (
WebAuthnChallengeBytes = 32
)
type WebAuthnCredentialMaterial struct {
CredentialID string
PublicKeyX string
PublicKeyY string
SignCount uint32
}
type WebAuthnParsedCredential struct {
CredentialID string
PublicKeyX string
PublicKeyY string
SignCount uint32
}
type WebAuthnClientData struct {
Type string `json:"type"`
Challenge string `json:"challenge"`
Origin string `json:"origin"`
}
type WebAuthnAttestationResponse struct {
ClientDataJSON string `json:"clientDataJSON"`
AttestationObject string `json:"attestationObject"`
}
type WebAuthnRegistrationResponse struct {
ID string `json:"id"`
RawID string `json:"rawId"`
Type string `json:"type"`
Response WebAuthnAttestationResponse `json:"response"`
}
type WebAuthnAssertionResponse struct {
ClientDataJSON string `json:"clientDataJSON"`
AuthenticatorData string `json:"authenticatorData"`
Signature string `json:"signature"`
UserHandle string `json:"userHandle,omitempty"`
}
type WebAuthnLoginAssertion struct {
ID string `json:"id"`
RawID string `json:"rawId"`
Type string `json:"type"`
Response WebAuthnAssertionResponse `json:"response"`
}
func GenerateWebAuthnChallenge() (string, error) {
buf := make([]byte, WebAuthnChallengeBytes)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return EncodeBase64URL(buf), nil
}
func EncodeBase64URL(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
func DecodeBase64URL(value string) ([]byte, error) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return nil, errors.New("empty base64url value")
}
if decoded, err := base64.RawURLEncoding.DecodeString(trimmed); err == nil {
return decoded, nil
}
return base64.URLEncoding.DecodeString(trimmed)
}
func VerifyWebAuthnRegistration(input WebAuthnRegistrationResponse, challenge string, rpID string, expectedOrigin string) (*WebAuthnParsedCredential, error) {
if input.Type != "public-key" {
return nil, fmt.Errorf("unexpected credential type: %s", input.Type)
}
clientDataRaw, err := DecodeBase64URL(input.Response.ClientDataJSON)
if err != nil {
return nil, fmt.Errorf("decode client data: %w", err)
}
if err := validateWebAuthnClientData(clientDataRaw, "webauthn.create", challenge, expectedOrigin); err != nil {
return nil, err
}
attestationObject, err := DecodeBase64URL(input.Response.AttestationObject)
if err != nil {
return nil, fmt.Errorf("decode attestation object: %w", err)
}
parsed, err := parseCBORExact(attestationObject)
if err != nil {
return nil, fmt.Errorf("parse attestation object: %w", err)
}
attestationMap, ok := parsed.(map[any]any)
if !ok {
return nil, errors.New("attestation object is not a map")
}
authData, ok := attestationMap["authData"].([]byte)
if !ok {
return nil, errors.New("attestation authData is missing")
}
credential, err := parseAttestedCredentialData(authData, rpID)
if err != nil {
return nil, err
}
rawID := strings.TrimSpace(input.RawID)
if rawID == "" {
rawID = strings.TrimSpace(input.ID)
}
if rawID != "" && rawID != credential.CredentialID {
return nil, errors.New("credential raw id does not match attested credential id")
}
return credential, nil
}
func VerifyWebAuthnAssertion(input WebAuthnLoginAssertion, challenge string, rpID string, expectedOrigin string, credential WebAuthnCredentialMaterial) (uint32, error) {
if input.Type != "public-key" {
return 0, fmt.Errorf("unexpected credential type: %s", input.Type)
}
rawID := strings.TrimSpace(input.RawID)
if rawID == "" {
rawID = strings.TrimSpace(input.ID)
}
if rawID != credential.CredentialID {
return 0, errors.New("credential id does not match")
}
clientDataRaw, err := DecodeBase64URL(input.Response.ClientDataJSON)
if err != nil {
return 0, fmt.Errorf("decode client data: %w", err)
}
if err := validateWebAuthnClientData(clientDataRaw, "webauthn.get", challenge, expectedOrigin); err != nil {
return 0, err
}
authData, err := DecodeBase64URL(input.Response.AuthenticatorData)
if err != nil {
return 0, fmt.Errorf("decode authenticator data: %w", err)
}
signature, err := DecodeBase64URL(input.Response.Signature)
if err != nil {
return 0, fmt.Errorf("decode signature: %w", err)
}
signCount, err := parseAssertionAuthenticatorData(authData, rpID, credential.SignCount)
if err != nil {
return 0, err
}
xBytes, err := DecodeBase64URL(credential.PublicKeyX)
if err != nil {
return 0, fmt.Errorf("decode public key x: %w", err)
}
yBytes, err := DecodeBase64URL(credential.PublicKeyY)
if err != nil {
return 0, fmt.Errorf("decode public key y: %w", err)
}
publicKey := ecdsa.PublicKey{Curve: elliptic.P256(), X: new(big.Int).SetBytes(xBytes), Y: new(big.Int).SetBytes(yBytes)}
if !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
return 0, errors.New("webauthn public key is not on P-256 curve")
}
clientDataHash := sha256.Sum256(clientDataRaw)
verifyData := append(append([]byte{}, authData...), clientDataHash[:]...)
digest := sha256.Sum256(verifyData)
if !ecdsa.VerifyASN1(&publicKey, digest[:], signature) {
return 0, errors.New("invalid webauthn signature")
}
return signCount, nil
}
func validateWebAuthnClientData(raw []byte, expectedType string, challenge string, expectedOrigin string) error {
var clientData WebAuthnClientData
if err := json.Unmarshal(raw, &clientData); err != nil {
return fmt.Errorf("parse client data: %w", err)
}
if clientData.Type != expectedType {
return fmt.Errorf("unexpected webauthn client data type: %s", clientData.Type)
}
if clientData.Challenge != challenge {
return errors.New("webauthn challenge mismatch")
}
if expectedOrigin != "" && clientData.Origin != expectedOrigin {
return fmt.Errorf("webauthn origin mismatch: %s", clientData.Origin)
}
return nil
}
func parseAttestedCredentialData(authData []byte, rpID string) (*WebAuthnParsedCredential, error) {
signCount, credentialData, err := parseAuthenticatorDataHeader(authData, rpID, true, 0)
if err != nil {
return nil, err
}
if len(credentialData) < 18 {
return nil, errors.New("attested credential data is too short")
}
offset := 16
credentialIDLength := int(binary.BigEndian.Uint16(credentialData[offset : offset+2]))
offset += 2
if credentialIDLength <= 0 || len(credentialData) < offset+credentialIDLength {
return nil, errors.New("invalid credential id length")
}
credentialID := credentialData[offset : offset+credentialIDLength]
offset += credentialIDLength
publicKeyRaw := credentialData[offset:]
publicKey, err := parseCBOR(publicKeyRaw)
if err != nil {
return nil, fmt.Errorf("parse credential public key: %w", err)
}
publicKeyMap, ok := publicKey.(map[any]any)
if !ok {
return nil, errors.New("credential public key is not a map")
}
kty, err := coseInt(publicKeyMap, 1)
if err != nil {
return nil, err
}
alg, err := coseInt(publicKeyMap, 3)
if err != nil {
return nil, err
}
crv, err := coseInt(publicKeyMap, -1)
if err != nil {
return nil, err
}
if kty != 2 || alg != -7 || crv != 1 {
return nil, fmt.Errorf("unsupported COSE key: kty=%d alg=%d crv=%d", kty, alg, crv)
}
x, err := coseBytes(publicKeyMap, -2)
if err != nil {
return nil, err
}
y, err := coseBytes(publicKeyMap, -3)
if err != nil {
return nil, err
}
if !elliptic.P256().IsOnCurve(new(big.Int).SetBytes(x), new(big.Int).SetBytes(y)) {
return nil, errors.New("credential public key is not on P-256 curve")
}
return &WebAuthnParsedCredential{
CredentialID: EncodeBase64URL(credentialID),
PublicKeyX: EncodeBase64URL(x),
PublicKeyY: EncodeBase64URL(y),
SignCount: signCount,
}, nil
}
func parseAssertionAuthenticatorData(authData []byte, rpID string, previousSignCount uint32) (uint32, error) {
signCount, _, err := parseAuthenticatorDataHeader(authData, rpID, false, previousSignCount)
if err != nil {
return 0, err
}
return signCount, nil
}
func parseAuthenticatorDataHeader(authData []byte, rpID string, requireAttestedData bool, previousSignCount uint32) (uint32, []byte, error) {
if len(authData) < 37 {
return 0, nil, errors.New("authenticator data is too short")
}
expectedRPIDHash := sha256.Sum256([]byte(rpID))
if string(authData[:32]) != string(expectedRPIDHash[:]) {
return 0, nil, errors.New("rp id hash mismatch")
}
flags := authData[32]
if flags&0x01 == 0 {
return 0, nil, errors.New("user presence flag is missing")
}
signCount := binary.BigEndian.Uint32(authData[33:37])
if previousSignCount > 0 && signCount > 0 && signCount <= previousSignCount {
return 0, nil, errors.New("authenticator sign count did not increase")
}
if requireAttestedData && flags&0x40 == 0 {
return 0, nil, errors.New("attested credential data flag is missing")
}
return signCount, authData[37:], nil
}
func coseInt(m map[any]any, key int64) (int64, error) {
value, ok := m[key]
if !ok {
return 0, fmt.Errorf("missing COSE key %d", key)
}
intValue, ok := value.(int64)
if !ok {
return 0, fmt.Errorf("invalid COSE key %d", key)
}
return intValue, nil
}
func coseBytes(m map[any]any, key int64) ([]byte, error) {
value, ok := m[key]
if !ok {
return nil, fmt.Errorf("missing COSE key %d", key)
}
bytesValue, ok := value.([]byte)
if !ok || len(bytesValue) == 0 {
return nil, fmt.Errorf("invalid COSE key %d", key)
}
return bytesValue, nil
}
func parseCBOR(data []byte) (any, error) {
reader := cborReader{data: data}
value, err := reader.read()
if err != nil {
return nil, err
}
return value, nil
}
func parseCBORExact(data []byte) (any, error) {
reader := cborReader{data: data}
value, err := reader.read()
if err != nil {
return nil, err
}
if reader.pos != len(data) {
return nil, errors.New("trailing cbor data")
}
return value, nil
}
type cborReader struct {
data []byte
pos int
}
func (r *cborReader) read() (any, error) {
if r.pos >= len(r.data) {
return nil, errors.New("unexpected cbor eof")
}
initial := r.data[r.pos]
r.pos++
major := initial >> 5
additional := initial & 0x1f
length, err := r.readLength(additional)
if err != nil {
return nil, err
}
switch major {
case 0:
return int64(length), nil
case 1:
return -1 - int64(length), nil
case 2:
return r.readBytes(length)
case 3:
raw, err := r.readBytes(length)
if err != nil {
return nil, err
}
return string(raw), nil
case 4:
out := make([]any, 0, length)
for i := uint64(0); i < length; i++ {
item, err := r.read()
if err != nil {
return nil, err
}
out = append(out, item)
}
return out, nil
case 5:
out := make(map[any]any, length)
for i := uint64(0); i < length; i++ {
key, err := r.read()
if err != nil {
return nil, err
}
value, err := r.read()
if err != nil {
return nil, err
}
out[key] = value
}
return out, nil
case 7:
switch additional {
case 20:
return false, nil
case 21:
return true, nil
case 22, 23:
return nil, nil
default:
return nil, fmt.Errorf("unsupported cbor simple value: %d", additional)
}
default:
return nil, fmt.Errorf("unsupported cbor major type: %d", major)
}
}
func (r *cborReader) readLength(additional byte) (uint64, error) {
switch {
case additional < 24:
return uint64(additional), nil
case additional == 24:
if r.pos+1 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := r.data[r.pos]
r.pos++
return uint64(value), nil
case additional == 25:
if r.pos+2 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := binary.BigEndian.Uint16(r.data[r.pos : r.pos+2])
r.pos += 2
return uint64(value), nil
case additional == 26:
if r.pos+4 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := binary.BigEndian.Uint32(r.data[r.pos : r.pos+4])
r.pos += 4
return uint64(value), nil
case additional == 27:
if r.pos+8 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := binary.BigEndian.Uint64(r.data[r.pos : r.pos+8])
r.pos += 8
return value, nil
default:
return 0, fmt.Errorf("unsupported cbor additional info: %d", additional)
}
}
func (r *cborReader) readBytes(length uint64) ([]byte, error) {
if length > uint64(len(r.data)-r.pos) {
return nil, errors.New("unexpected cbor eof")
}
out := r.data[r.pos : r.pos+int(length)]
r.pos += int(length)
return out, nil
}

View File

@@ -0,0 +1,179 @@
package service
import (
"encoding/json"
"strings"
"time"
"backupx/server/internal/model"
)
const (
mfaChallengeTTL = 5 * time.Minute
trustedDeviceTTL = 30 * 24 * time.Hour
maxTrustedDeviceName = 128
maxTrustedDevices = 10
)
type WebAuthnCredentialRecord struct {
ID string `json:"id"`
Name string `json:"name"`
CredentialID string `json:"credentialId"`
PublicKeyX string `json:"publicKeyX"`
PublicKeyY string `json:"publicKeyY"`
SignCount uint32 `json:"signCount"`
CreatedAt string `json:"createdAt"`
LastUsedAt string `json:"lastUsedAt,omitempty"`
}
type WebAuthnCredentialOutput struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt string `json:"createdAt"`
LastUsedAt string `json:"lastUsedAt,omitempty"`
}
type webAuthnChallengeState struct {
Type string `json:"type"`
Challenge string `json:"challenge"`
RPID string `json:"rpId"`
Origin string `json:"origin"`
ExpiresAt time.Time `json:"expiresAt"`
}
type TrustedDeviceRecord struct {
ID string `json:"id"`
Name string `json:"name"`
TokenHash string `json:"tokenHash"`
CreatedAt time.Time `json:"createdAt"`
LastUsedAt time.Time `json:"lastUsedAt"`
ExpiresAt time.Time `json:"expiresAt"`
LastIP string `json:"lastIp"`
}
type TrustedDeviceOutput struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt string `json:"createdAt"`
LastUsedAt string `json:"lastUsedAt"`
ExpiresAt string `json:"expiresAt"`
LastIP string `json:"lastIp"`
}
type pendingOutOfBandOTP struct {
Channel string `json:"channel"`
CodeHash string `json:"codeHash"`
ExpiresAt time.Time `json:"expiresAt"`
}
func userMFAEnabled(user *model.User) bool {
if user == nil {
return false
}
return user.TwoFactorEnabled ||
strings.TrimSpace(user.WebAuthnCredentials) != "" ||
user.EmailOTPEnabled ||
user.SMSOTPEnabled
}
func clearTrustedDevicesIfMFAOff(user *model.User) {
if user == nil || userMFAEnabled(user) {
return
}
user.TrustedDevices = ""
user.OutOfBandOTPCiphertext = ""
user.WebAuthnChallengeCiphertext = ""
}
func parseWebAuthnCredentials(value string) ([]WebAuthnCredentialRecord, error) {
if strings.TrimSpace(value) == "" {
return nil, nil
}
var credentials []WebAuthnCredentialRecord
if err := json.Unmarshal([]byte(value), &credentials); err != nil {
return nil, err
}
return credentials, nil
}
func encodeWebAuthnCredentials(credentials []WebAuthnCredentialRecord) (string, error) {
if len(credentials) == 0 {
return "", nil
}
encoded, err := json.Marshal(credentials)
if err != nil {
return "", err
}
return string(encoded), nil
}
func webAuthnCredentialCount(user *model.User) int {
if user == nil {
return 0
}
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
if err != nil {
return 0
}
return len(credentials)
}
func parseTrustedDevices(value string) ([]TrustedDeviceRecord, error) {
if strings.TrimSpace(value) == "" {
return nil, nil
}
var devices []TrustedDeviceRecord
if err := json.Unmarshal([]byte(value), &devices); err != nil {
return nil, err
}
return devices, nil
}
func encodeTrustedDevices(devices []TrustedDeviceRecord) (string, error) {
if len(devices) == 0 {
return "", nil
}
encoded, err := json.Marshal(devices)
if err != nil {
return "", err
}
return string(encoded), nil
}
func trustedDeviceCount(user *model.User) int {
if user == nil {
return 0
}
devices, err := parseTrustedDevices(user.TrustedDevices)
if err != nil {
return 0
}
now := time.Now().UTC()
count := 0
for _, device := range devices {
if device.ExpiresAt.After(now) {
count++
}
}
return count
}
func toWebAuthnCredentialOutput(record WebAuthnCredentialRecord) WebAuthnCredentialOutput {
return WebAuthnCredentialOutput{
ID: record.ID,
Name: record.Name,
CreatedAt: record.CreatedAt,
LastUsedAt: record.LastUsedAt,
}
}
func toTrustedDeviceOutput(record TrustedDeviceRecord) TrustedDeviceOutput {
return TrustedDeviceOutput{
ID: record.ID,
Name: record.Name,
CreatedAt: record.CreatedAt.Format(time.RFC3339),
LastUsedAt: record.LastUsedAt.Format(time.RFC3339),
ExpiresAt: record.ExpiresAt.Format(time.RFC3339),
LastIP: record.LastIP,
}
}

View File

@@ -0,0 +1,252 @@
package service
import (
"context"
"fmt"
"strings"
"time"
"backupx/server/internal/apperror"
"backupx/server/internal/model"
"backupx/server/internal/security"
)
type OTPConfigInput struct {
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
Channel string `json:"channel" binding:"required,oneof=email sms"`
Enabled bool `json:"enabled"`
Email string `json:"email" binding:"omitempty,max=255"`
Phone string `json:"phone" binding:"omitempty,max=64"`
}
type LoginOTPInput struct {
Username string `json:"username" binding:"required,min=3,max=64"`
Password string `json:"password" binding:"required,min=8,max=128"`
Channel string `json:"channel" binding:"required,oneof=email sms"`
}
func (s *AuthService) ConfigureOutOfBandOTP(ctx context.Context, subject string, input OTPConfigInput) (*UserOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
}
channel := strings.TrimSpace(input.Channel)
previousEmail := strings.TrimSpace(user.Email)
previousPhone := strings.TrimSpace(user.Phone)
contactChanged := false
switch channel {
case "email":
email := strings.TrimSpace(input.Email)
if email != "" {
user.Email = email
}
contactChanged = previousEmail != strings.TrimSpace(user.Email)
if input.Enabled && strings.TrimSpace(user.Email) == "" {
return nil, apperror.BadRequest("AUTH_EMAIL_REQUIRED", "请先在用户资料中设置邮箱", nil)
}
user.EmailOTPEnabled = input.Enabled
case "sms":
phone := strings.TrimSpace(input.Phone)
if phone != "" {
user.Phone = phone
}
contactChanged = previousPhone != strings.TrimSpace(user.Phone)
if input.Enabled && strings.TrimSpace(user.Phone) == "" {
return nil, apperror.BadRequest("AUTH_PHONE_REQUIRED", "请先设置手机号", nil)
}
user.SMSOTPEnabled = input.Enabled
default:
return nil, apperror.BadRequest("AUTH_OTP_CHANNEL_INVALID", "验证码渠道不支持", nil)
}
if s.shouldClearPendingOTP(user, channel, contactChanged) {
user.OutOfBandOTPCiphertext = ""
}
clearTrustedDevicesIfMFAOff(user)
if err := s.users.Update(ctx, user); err != nil {
return nil, apperror.Internal("AUTH_OTP_CONFIG_FAILED", "无法更新 OTP 配置", err)
}
if s.auditService != nil {
action := "otp_disable"
if input.Enabled {
action = "otp_enable"
}
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: action,
TargetType: "otp", TargetID: channel,
Detail: fmt.Sprintf("%s %s OTP", map[bool]string{true: "启用", false: "关闭"}[input.Enabled], channel),
})
}
return ToUserOutput(user), nil
}
func (s *AuthService) SendLoginOTP(ctx context.Context, input LoginOTPInput, clientKey string) error {
user, err := s.verifyPasswordForMFAStart(ctx, input.Username, input.Password, clientKey)
if err != nil {
return err
}
channel := strings.TrimSpace(input.Channel)
if channel == "email" && !user.EmailOTPEnabled {
return apperror.BadRequest("AUTH_EMAIL_OTP_DISABLED", "当前账号未启用邮件验证码", nil)
}
if channel == "sms" && !user.SMSOTPEnabled {
return apperror.BadRequest("AUTH_SMS_OTP_DISABLED", "当前账号未启用短信验证码", nil)
}
code, err := security.GenerateNumericOTP()
if err != nil {
return apperror.Internal("AUTH_OTP_GENERATE_FAILED", "无法生成登录验证码", err)
}
hash, err := security.HashPassword(code)
if err != nil {
return apperror.Internal("AUTH_OTP_GENERATE_FAILED", "无法处理登录验证码", err)
}
pending := pendingOutOfBandOTP{
Channel: channel,
CodeHash: hash,
ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL),
}
ciphertext, err := s.twoFactorCipher.EncryptJSON(pending)
if err != nil {
return apperror.Internal("AUTH_OTP_SAVE_FAILED", "无法保存登录验证码状态", err)
}
user.OutOfBandOTPCiphertext = ciphertext
if err := s.users.Update(ctx, user); err != nil {
return apperror.Internal("AUTH_OTP_SAVE_FAILED", "无法保存登录验证码状态", err)
}
if err := s.deliverLoginOTP(ctx, user, channel, code); err != nil {
user.OutOfBandOTPCiphertext = ""
if updateErr := s.users.Update(ctx, user); updateErr != nil {
return apperror.Internal("AUTH_OTP_SAVE_FAILED", "登录验证码发送失败,且无法回滚验证码状态", updateErr)
}
return apperror.BadRequest("AUTH_OTP_DELIVERY_FAILED", "登录验证码发送失败", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "otp_send",
TargetType: "otp", TargetID: channel,
Detail: "发送登录 OTP", ClientIP: clientKey,
})
}
return nil
}
func (s *AuthService) consumeOutOfBandOTP(ctx context.Context, user *model.User, code string, clientKey string) (bool, error) {
if strings.TrimSpace(user.OutOfBandOTPCiphertext) == "" {
return false, nil
}
var pending pendingOutOfBandOTP
if err := s.twoFactorCipher.DecryptJSON(user.OutOfBandOTPCiphertext, &pending); err != nil {
return false, apperror.Internal("AUTH_OTP_INVALID", "登录验证码状态异常", err)
}
if pending.ExpiresAt.Before(time.Now().UTC()) {
user.OutOfBandOTPCiphertext = ""
if err := s.users.Update(ctx, user); err != nil {
return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法更新登录验证码状态", err)
}
return false, nil
}
if !outOfBandOTPChannelEnabled(user, pending.Channel) {
user.OutOfBandOTPCiphertext = ""
if err := s.users.Update(ctx, user); err != nil {
return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法更新登录验证码状态", err)
}
return false, nil
}
if security.ComparePassword(pending.CodeHash, security.NormalizeNumericOTP(code)) != nil {
return false, nil
}
user.OutOfBandOTPCiphertext = ""
if err := s.users.Update(ctx, user); err != nil {
return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法使用登录验证码", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "otp_used",
TargetType: "otp", TargetID: pending.Channel,
Detail: "使用登录 OTP 完成登录", ClientIP: clientKey,
})
}
return true, nil
}
func (s *AuthService) deliverLoginOTP(ctx context.Context, user *model.User, channel string, code string) error {
if s.notificationService == nil {
return fmt.Errorf("notification service is not configured")
}
switch channel {
case "email":
email := strings.TrimSpace(user.Email)
if email == "" {
return fmt.Errorf("user email is empty")
}
return s.notificationService.SendAuthEmailOTP(ctx, email, code)
case "sms":
phone := strings.TrimSpace(user.Phone)
if phone == "" {
return fmt.Errorf("user phone is empty")
}
return s.notificationService.SendAuthSMSOTP(ctx, phone, code)
default:
return fmt.Errorf("unsupported otp channel: %s", channel)
}
}
func (s *AuthService) verifyPasswordForMFAStart(ctx context.Context, username string, password string, clientKey string) (*model.User, error) {
if clientKey == "" {
clientKey = "unknown"
}
if !s.rateLimiter.Allow(clientKey) {
return nil, apperror.TooManyRequests("AUTH_RATE_LIMITED", "登录尝试过于频繁,请稍后再试", nil)
}
user, err := s.users.FindByUsername(ctx, strings.TrimSpace(username))
if err != nil {
return nil, apperror.Internal("AUTH_LOOKUP_FAILED", "无法执行登录校验", err)
}
if user == nil || user.Disabled {
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", nil)
}
if err := security.ComparePassword(user.PasswordHash, password); err != nil {
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "login_failed",
Detail: "密码错误", ClientIP: clientKey,
})
}
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", err)
}
if !userMFAEnabled(user) {
return nil, apperror.BadRequest("AUTH_MFA_NOT_ENABLED", "当前账号未启用多因素验证", nil)
}
return user, nil
}
func outOfBandOTPChannelEnabled(user *model.User, channel string) bool {
switch channel {
case "email":
return user.EmailOTPEnabled
case "sms":
return user.SMSOTPEnabled
default:
return false
}
}
func (s *AuthService) shouldClearPendingOTP(user *model.User, changedChannel string, contactChanged bool) bool {
if !user.EmailOTPEnabled && !user.SMSOTPEnabled {
return true
}
if strings.TrimSpace(user.OutOfBandOTPCiphertext) == "" {
return false
}
var pending pendingOutOfBandOTP
if err := s.twoFactorCipher.DecryptJSON(user.OutOfBandOTPCiphertext, &pending); err != nil {
return true
}
return pending.Channel == changedChannel && (contactChanged || !outOfBandOTPChannelEnabled(user, changedChannel))
}

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
@@ -11,6 +12,7 @@ import (
"backupx/server/internal/model"
"backupx/server/internal/repository"
"backupx/server/internal/security"
"backupx/server/internal/storage/codec"
)
type SetupInput struct {
@@ -20,28 +22,47 @@ type SetupInput struct {
}
type LoginInput struct {
Username string `json:"username" binding:"required,min=3,max=64"`
Password string `json:"password" binding:"required,min=8,max=128"`
Username string `json:"username" binding:"required,min=3,max=64"`
Password string `json:"password" binding:"required,min=8,max=128"`
TwoFactorCode string `json:"twoFactorCode" binding:"omitempty,min=6,max=32"`
WebAuthnAssertion *security.WebAuthnLoginAssertion `json:"webAuthnAssertion"`
TrustedDeviceToken string `json:"trustedDeviceToken"`
RememberDevice bool `json:"rememberDevice"`
TrustedDeviceName string `json:"trustedDeviceName" binding:"omitempty,max=128"`
}
type AuthPayload struct {
Token string `json:"token"`
User *UserOutput `json:"user"`
Token string `json:"token"`
User *UserOutput `json:"user"`
TrustedDeviceToken string `json:"trustedDeviceToken,omitempty"`
TrustedDevice *TrustedDeviceOutput `json:"trustedDevice,omitempty"`
}
type UserOutput struct {
ID uint `json:"id"`
Username string `json:"username"`
DisplayName string `json:"displayName"`
Role string `json:"role"`
ID uint `json:"id"`
Username string `json:"username"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
Phone string `json:"phone"`
Role string `json:"role"`
MFAEnabled bool `json:"mfaEnabled"`
TwoFactorEnabled bool `json:"twoFactorEnabled"`
TwoFactorRecoveryCodesRemaining int `json:"twoFactorRecoveryCodesRemaining"`
WebAuthnEnabled bool `json:"webAuthnEnabled"`
WebAuthnCredentialCount int `json:"webAuthnCredentialCount"`
TrustedDeviceCount int `json:"trustedDeviceCount"`
EmailOTPEnabled bool `json:"emailOtpEnabled"`
SMSOTPEnabled bool `json:"smsOtpEnabled"`
}
type AuthService struct {
users repository.UserRepository
configs repository.SystemConfigRepository
jwtManager *security.JWTManager
rateLimiter *security.LoginRateLimiter
auditService *AuditService
users repository.UserRepository
configs repository.SystemConfigRepository
jwtManager *security.JWTManager
rateLimiter *security.LoginRateLimiter
twoFactorCipher *codec.ConfigCipher
auditService *AuditService
notificationService *NotificationService
}
func NewAuthService(
@@ -49,14 +70,25 @@ func NewAuthService(
configs repository.SystemConfigRepository,
jwtManager *security.JWTManager,
rateLimiter *security.LoginRateLimiter,
twoFactorCipher *codec.ConfigCipher,
) *AuthService {
return &AuthService{users: users, configs: configs, jwtManager: jwtManager, rateLimiter: rateLimiter}
return &AuthService{
users: users,
configs: configs,
jwtManager: jwtManager,
rateLimiter: rateLimiter,
twoFactorCipher: twoFactorCipher,
}
}
func (s *AuthService) SetAuditService(auditService *AuditService) {
s.auditService = auditService
}
func (s *AuthService) SetNotificationService(notificationService *NotificationService) {
s.notificationService = notificationService
}
func (s *AuthService) SetupStatus(ctx context.Context) (bool, error) {
count, err := s.users.Count(ctx)
if err != nil {
@@ -130,7 +162,7 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
if s.auditService != nil {
s.auditService.Record(AuditEntry{
Category: "auth", Action: "login_failed",
Detail: fmt.Sprintf("用户名不存在: %s", strings.TrimSpace(input.Username)),
Detail: fmt.Sprintf("用户名不存在: %s", strings.TrimSpace(input.Username)),
ClientIP: clientKey,
})
}
@@ -156,6 +188,20 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
}
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", err)
}
mfaRequired := userMFAEnabled(user)
trustedDeviceUsed := false
if mfaRequired {
trusted, err := s.verifyTrustedDevice(ctx, user, input.TrustedDeviceToken, clientKey)
if err != nil {
return nil, err
}
trustedDeviceUsed = trusted
if !trusted {
if err := s.verifyLoginMFA(ctx, user, input, clientKey); err != nil {
return nil, err
}
}
}
s.rateLimiter.Reset(clientKey)
token, err := s.jwtManager.Generate(user)
@@ -163,6 +209,16 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
return nil, apperror.Internal("AUTH_TOKEN_FAILED", "无法生成访问令牌", err)
}
payload := &AuthPayload{Token: token, User: ToUserOutput(user)}
if mfaRequired && !trustedDeviceUsed && input.RememberDevice {
deviceToken, device, err := s.issueTrustedDevice(ctx, user, input.TrustedDeviceName, clientKey)
if err != nil {
return nil, err
}
payload.TrustedDeviceToken = deviceToken
payload.TrustedDevice = device
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
@@ -171,10 +227,72 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
})
}
return &AuthPayload{Token: token, User: ToUserOutput(user)}, nil
return payload, nil
}
func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*UserOutput, error) {
func (s *AuthService) verifyLoginMFA(ctx context.Context, user *model.User, input LoginInput, clientKey string) error {
if input.WebAuthnAssertion != nil {
if err := s.VerifyWebAuthnLogin(ctx, user, *input.WebAuthnAssertion, clientKey); err != nil {
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "login_failed",
Detail: "通行密钥校验失败", ClientIP: clientKey,
})
}
return err
}
return nil
}
code := strings.TrimSpace(input.TwoFactorCode)
if code == "" {
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "two_factor_required",
Detail: "登录需要多因素验证", ClientIP: clientKey,
})
}
return apperror.Unauthorized("AUTH_2FA_REQUIRED", "请输入验证码、恢复码或使用通行密钥", nil)
}
if user.TwoFactorEnabled {
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
if err != nil {
return apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
}
ok, err := security.ValidateTOTPCode(secret, code)
if err == nil && ok {
return nil
}
if consumed, err := s.consumeRecoveryCode(ctx, user, code); err != nil {
return err
} else if consumed {
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "two_factor_recovery_code_used",
Detail: "使用恢复码完成登录", ClientIP: clientKey,
})
}
return nil
}
}
if consumed, err := s.consumeOutOfBandOTP(ctx, user, code, clientKey); err != nil {
return err
} else if consumed {
return nil
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "login_failed",
Detail: "多因素验证码错误", ClientIP: clientKey,
})
}
return apperror.Unauthorized("AUTH_2FA_INVALID", "验证码、恢复码或通行密钥错误", nil)
}
func (s *AuthService) userBySubject(ctx context.Context, subject string) (*model.User, error) {
userID, err := strconv.ParseUint(subject, 10, 64)
if err != nil {
return nil, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err)
@@ -186,6 +304,14 @@ func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*User
if user == nil {
return nil, apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found"))
}
return user, nil
}
func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*UserOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
return ToUserOutput(user), nil
}
@@ -195,16 +321,9 @@ type ChangePasswordInput struct {
}
func (s *AuthService) ChangePassword(ctx context.Context, subject string, input ChangePasswordInput) error {
userID, err := strconv.ParseUint(subject, 10, 64)
user, err := s.userBySubject(ctx, subject)
if err != nil {
return apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err)
}
user, err := s.users.FindByID(ctx, uint(userID))
if err != nil {
return apperror.Internal("AUTH_LOOKUP_FAILED", "无法获取当前用户", err)
}
if user == nil {
return apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found"))
return err
}
if err := security.ComparePassword(user.PasswordHash, input.OldPassword); err != nil {
return apperror.BadRequest("AUTH_WRONG_PASSWORD", "旧密码不正确", err)
@@ -214,6 +333,9 @@ func (s *AuthService) ChangePassword(ctx context.Context, subject string, input
return apperror.Internal("AUTH_HASH_FAILED", "无法处理密码", err)
}
user.PasswordHash = hash
user.TrustedDevices = ""
user.OutOfBandOTPCiphertext = ""
user.WebAuthnChallengeCiphertext = ""
if err := s.users.Update(ctx, user); err != nil {
return apperror.Internal("AUTH_UPDATE_FAILED", "密码修改失败", err)
}
@@ -229,15 +351,338 @@ func (s *AuthService) ChangePassword(ctx context.Context, subject string, input
return nil
}
type TwoFactorSetupInput struct {
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
}
type TwoFactorSetupOutput struct {
Secret string `json:"secret"`
OTPAuthURL string `json:"otpAuthUrl"`
QRCodeDataURL string `json:"qrCodeDataUrl"`
TwoFactorEnabled bool `json:"twoFactorEnabled"`
TwoFactorConfirmed bool `json:"twoFactorConfirmed"`
}
type EnableTwoFactorInput struct {
Code string `json:"code" binding:"required,min=6,max=10"`
}
type EnableTwoFactorOutput struct {
User *UserOutput `json:"user"`
RecoveryCodes []string `json:"recoveryCodes"`
}
type DisableTwoFactorInput struct {
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
Code string `json:"code" binding:"required,min=6,max=32"`
}
type RegenerateRecoveryCodesInput struct {
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
Code string `json:"code" binding:"required,min=6,max=10"`
}
type RecoveryCodesOutput struct {
User *UserOutput `json:"user"`
RecoveryCodes []string `json:"recoveryCodes"`
}
func (s *AuthService) PrepareTwoFactor(ctx context.Context, subject string, input TwoFactorSetupInput) (*TwoFactorSetupOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
if user.TwoFactorEnabled {
return nil, apperror.Conflict("AUTH_2FA_ALREADY_ENABLED", "TOTP 已启用", nil)
}
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
}
enrollment, err := security.GenerateTOTPEnrollment(user.Username)
if err != nil {
return nil, apperror.Internal("AUTH_2FA_SETUP_FAILED", "无法生成 TOTP 密钥", err)
}
ciphertext, err := s.encryptTwoFactorSecret(enrollment.Secret)
if err != nil {
return nil, apperror.Internal("AUTH_2FA_SAVE_FAILED", "无法保存 TOTP 密钥", err)
}
user.TwoFactorSecretCiphertext = ciphertext
user.TwoFactorEnabled = false
if err := s.users.Update(ctx, user); err != nil {
return nil, apperror.Internal("AUTH_2FA_SAVE_FAILED", "无法保存 TOTP 密钥", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "two_factor_setup",
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
Detail: "生成 TOTP 密钥",
})
}
return &TwoFactorSetupOutput{
Secret: enrollment.Secret,
OTPAuthURL: enrollment.OTPAuthURL,
QRCodeDataURL: enrollment.QRCodeDataURL,
TwoFactorEnabled: false,
TwoFactorConfirmed: false,
}, nil
}
func (s *AuthService) EnableTwoFactor(ctx context.Context, subject string, input EnableTwoFactorInput) (*EnableTwoFactorOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
if user.TwoFactorEnabled {
return nil, apperror.Conflict("AUTH_2FA_ALREADY_ENABLED", "TOTP 已启用", nil)
}
if strings.TrimSpace(user.TwoFactorSecretCiphertext) == "" {
return nil, apperror.BadRequest("AUTH_2FA_NOT_PREPARED", "请先生成 TOTP 密钥", nil)
}
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
if err != nil {
return nil, apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
}
ok, err := security.ValidateTOTPCode(secret, input.Code)
if err != nil {
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err)
}
if !ok {
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil)
}
recoveryCodes, recoveryHashes, err := s.generateRecoveryCodeHashes()
if err != nil {
return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法生成恢复码", err)
}
user.TwoFactorEnabled = true
user.TwoFactorRecoveryCodeHashes = recoveryHashes
if err := s.users.Update(ctx, user); err != nil {
return nil, apperror.Internal("AUTH_2FA_ENABLE_FAILED", "无法启用 TOTP", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "two_factor_enable",
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
Detail: "启用 TOTP",
})
}
return &EnableTwoFactorOutput{User: ToUserOutput(user), RecoveryCodes: recoveryCodes}, nil
}
func (s *AuthService) DisableTwoFactor(ctx context.Context, subject string, input DisableTwoFactorInput) (*UserOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
if !user.TwoFactorEnabled {
return nil, apperror.BadRequest("AUTH_2FA_NOT_ENABLED", "TOTP 未启用", nil)
}
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
}
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
if err != nil {
return nil, apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
}
ok, err := security.ValidateTOTPCode(secret, input.Code)
if err != nil {
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err)
}
if !ok {
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil)
}
user.TwoFactorEnabled = false
user.TwoFactorSecretCiphertext = ""
user.TwoFactorRecoveryCodeHashes = ""
clearTrustedDevicesIfMFAOff(user)
if err := s.users.Update(ctx, user); err != nil {
return nil, apperror.Internal("AUTH_2FA_DISABLE_FAILED", "无法关闭 TOTP", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "two_factor_disable",
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
Detail: "关闭 TOTP",
})
}
return ToUserOutput(user), nil
}
func (s *AuthService) verifyCurrentTOTP(user *model.User, code string) error {
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
if err != nil {
return apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
}
ok, err := security.ValidateTOTPCode(secret, code)
if err != nil {
return apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err)
}
if !ok {
return apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil)
}
return nil
}
func (s *AuthService) RegenerateRecoveryCodes(ctx context.Context, subject string, input RegenerateRecoveryCodesInput) (*RecoveryCodesOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
if !user.TwoFactorEnabled {
return nil, apperror.BadRequest("AUTH_2FA_NOT_ENABLED", "TOTP 未启用", nil)
}
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
}
if err := s.verifyCurrentTOTP(user, input.Code); err != nil {
return nil, err
}
recoveryCodes, recoveryHashes, err := s.generateRecoveryCodeHashes()
if err != nil {
return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法生成恢复码", err)
}
user.TwoFactorRecoveryCodeHashes = recoveryHashes
if err := s.users.Update(ctx, user); err != nil {
return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法更新恢复码", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "two_factor_recovery_codes_regenerate",
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
Detail: "重新生成 TOTP 恢复码",
})
}
return &RecoveryCodesOutput{User: ToUserOutput(user), RecoveryCodes: recoveryCodes}, nil
}
func (s *AuthService) generateRecoveryCodeHashes() ([]string, string, error) {
codes, err := security.GenerateRecoveryCodes(security.RecoveryCodeCount)
if err != nil {
return nil, "", err
}
hashes := make([]string, 0, len(codes))
for _, code := range codes {
hash, err := security.HashPassword(security.NormalizeRecoveryCode(code))
if err != nil {
return nil, "", err
}
hashes = append(hashes, hash)
}
encoded, err := encodeRecoveryCodeHashes(hashes)
if err != nil {
return nil, "", err
}
return codes, encoded, nil
}
func (s *AuthService) consumeRecoveryCode(ctx context.Context, user *model.User, code string) (bool, error) {
if !security.IsRecoveryCodeCandidate(code) {
return false, nil
}
hashes, err := parseRecoveryCodeHashes(user.TwoFactorRecoveryCodeHashes)
if err != nil {
return false, apperror.Internal("AUTH_2FA_RECOVERY_INVALID", "恢复码配置异常", err)
}
if len(hashes) == 0 {
return false, nil
}
normalized := security.NormalizeRecoveryCode(code)
for i, hash := range hashes {
if security.ComparePassword(hash, normalized) != nil {
continue
}
hashes = append(hashes[:i], hashes[i+1:]...)
encoded, err := encodeRecoveryCodeHashes(hashes)
if err != nil {
return false, apperror.Internal("AUTH_2FA_RECOVERY_INVALID", "恢复码配置异常", err)
}
user.TwoFactorRecoveryCodeHashes = encoded
if err := s.users.Update(ctx, user); err != nil {
return false, apperror.Internal("AUTH_2FA_RECOVERY_CONSUME_FAILED", "无法使用恢复码", err)
}
return true, nil
}
return false, nil
}
func (s *AuthService) encryptTwoFactorSecret(secret string) (string, error) {
if s.twoFactorCipher == nil {
return "", errors.New("two-factor cipher is not configured")
}
return s.twoFactorCipher.Encrypt([]byte(strings.TrimSpace(secret)))
}
func (s *AuthService) decryptTwoFactorSecret(ciphertext string) (string, error) {
if s.twoFactorCipher == nil {
return "", errors.New("two-factor cipher is not configured")
}
raw, err := s.twoFactorCipher.Decrypt(strings.TrimSpace(ciphertext))
if err != nil {
return "", err
}
return strings.TrimSpace(string(raw)), nil
}
func parseRecoveryCodeHashes(encoded string) ([]string, error) {
if strings.TrimSpace(encoded) == "" {
return nil, nil
}
var hashes []string
if err := json.Unmarshal([]byte(encoded), &hashes); err != nil {
return nil, err
}
return hashes, nil
}
func encodeRecoveryCodeHashes(hashes []string) (string, error) {
if len(hashes) == 0 {
return "", nil
}
encoded, err := json.Marshal(hashes)
if err != nil {
return "", err
}
return string(encoded), nil
}
func recoveryCodeRemainingCount(user *model.User) int {
if user == nil {
return 0
}
hashes, err := parseRecoveryCodeHashes(user.TwoFactorRecoveryCodeHashes)
if err != nil {
return 0
}
return len(hashes)
}
func ToUserOutput(user *model.User) *UserOutput {
if user == nil {
return nil
}
return &UserOutput{
ID: user.ID,
Username: user.Username,
DisplayName: user.DisplayName,
Role: user.Role,
ID: user.ID,
Username: user.Username,
DisplayName: user.DisplayName,
Email: user.Email,
Phone: user.Phone,
Role: user.Role,
MFAEnabled: userMFAEnabled(user),
TwoFactorEnabled: user.TwoFactorEnabled,
TwoFactorRecoveryCodesRemaining: recoveryCodeRemainingCount(user),
WebAuthnEnabled: webAuthnCredentialCount(user) > 0,
WebAuthnCredentialCount: webAuthnCredentialCount(user),
TrustedDeviceCount: trustedDeviceCount(user),
EmailOTPEnabled: user.EmailOTPEnabled,
SMSOTPEnabled: user.SMSOTPEnabled,
}
}

View File

@@ -5,8 +5,11 @@ import (
"testing"
"time"
"backupx/server/internal/apperror"
"backupx/server/internal/model"
"backupx/server/internal/security"
"backupx/server/internal/storage/codec"
"github.com/pquerna/otp/totp"
)
type fakeUserRepository struct {
@@ -100,6 +103,7 @@ func TestAuthServiceSetupAndLogin(t *testing.T) {
&fakeSystemConfigRepository{},
security.NewJWTManager("test-secret", time.Hour),
security.NewLoginRateLimiter(5, time.Minute),
codec.NewConfigCipher("test-encryption-secret"),
)
setupResult, err := service.Setup(context.Background(), SetupInput{
@@ -133,6 +137,7 @@ func newTestAuthService() (*AuthService, *fakeUserRepository) {
&fakeSystemConfigRepository{},
security.NewJWTManager("test-secret", time.Hour),
security.NewLoginRateLimiter(5, time.Minute),
codec.NewConfigCipher("test-encryption-secret"),
)
return svc, users
}
@@ -188,3 +193,425 @@ func TestChangePasswordWrongOld(t *testing.T) {
t.Fatalf("expected ChangePassword with wrong old password to fail")
}
}
func TestAuthServiceLoginRequiresTwoFactorWhenEnabled(t *testing.T) {
svc, _ := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
CurrentPassword: "password-123",
})
if err != nil {
t.Fatalf("PrepareTwoFactor: %v", err)
}
if setup.Secret == "" || setup.QRCodeDataURL == "" || setup.OTPAuthURL == "" {
t.Fatalf("expected populated 2FA enrollment, got %#v", setup)
}
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode: %v", err)
}
enabledUser, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code})
if err != nil {
t.Fatalf("EnableTwoFactor: %v", err)
}
if !enabledUser.User.TwoFactorEnabled {
t.Fatalf("expected 2FA enabled")
}
if len(enabledUser.RecoveryCodes) != security.RecoveryCodeCount {
t.Fatalf("expected %d recovery codes, got %d", security.RecoveryCodeCount, len(enabledUser.RecoveryCodes))
}
_, err = svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123",
}, "127.0.0.1")
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_REQUIRED" {
t.Fatalf("expected AUTH_2FA_REQUIRED, got %v", err)
}
loginCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode login: %v", err)
}
loginResult, err := svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: loginCode,
}, "127.0.0.1")
if err != nil {
t.Fatalf("Login with 2FA: %v", err)
}
if loginResult.Token == "" {
t.Fatalf("expected non-empty token")
}
}
func TestAuthServiceDisableTwoFactor(t *testing.T) {
svc, _ := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
CurrentPassword: "password-123",
})
if err != nil {
t.Fatalf("PrepareTwoFactor: %v", err)
}
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode: %v", err)
}
if _, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}); err != nil {
t.Fatalf("EnableTwoFactor: %v", err)
}
disableCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode disable: %v", err)
}
user, err := svc.DisableTwoFactor(context.Background(), "1", DisableTwoFactorInput{
CurrentPassword: "password-123",
Code: disableCode,
})
if err != nil {
t.Fatalf("DisableTwoFactor: %v", err)
}
if user.TwoFactorEnabled {
t.Fatalf("expected 2FA disabled")
}
loginResult, err := svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123",
}, "127.0.0.1")
if err != nil {
t.Fatalf("Login after disable: %v", err)
}
if loginResult.Token == "" {
t.Fatalf("expected non-empty token")
}
}
func TestAuthServiceRecoveryCodeLoginConsumesCode(t *testing.T) {
svc, _ := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
CurrentPassword: "password-123",
})
if err != nil {
t.Fatalf("PrepareTwoFactor: %v", err)
}
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode: %v", err)
}
enabled, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code})
if err != nil {
t.Fatalf("EnableTwoFactor: %v", err)
}
recoveryCode := enabled.RecoveryCodes[0]
loginResult, err := svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: recoveryCode,
}, "127.0.0.1")
if err != nil {
t.Fatalf("Login with recovery code: %v", err)
}
if loginResult.User.TwoFactorRecoveryCodesRemaining != security.RecoveryCodeCount-1 {
t.Fatalf("expected one recovery code consumed, got remaining=%d", loginResult.User.TwoFactorRecoveryCodesRemaining)
}
_, err = svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: recoveryCode,
}, "127.0.0.1")
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
t.Fatalf("expected consumed recovery code to fail, got %v", err)
}
}
func TestAuthServiceRegenerateRecoveryCodesInvalidatesOldCodes(t *testing.T) {
svc, _ := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
CurrentPassword: "password-123",
})
if err != nil {
t.Fatalf("PrepareTwoFactor: %v", err)
}
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode: %v", err)
}
enabled, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code})
if err != nil {
t.Fatalf("EnableTwoFactor: %v", err)
}
oldRecoveryCode := enabled.RecoveryCodes[0]
regenerateCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode regenerate: %v", err)
}
regenerated, err := svc.RegenerateRecoveryCodes(context.Background(), "1", RegenerateRecoveryCodesInput{
CurrentPassword: "password-123",
Code: regenerateCode,
})
if err != nil {
t.Fatalf("RegenerateRecoveryCodes: %v", err)
}
if len(regenerated.RecoveryCodes) != security.RecoveryCodeCount {
t.Fatalf("expected %d recovery codes, got %d", security.RecoveryCodeCount, len(regenerated.RecoveryCodes))
}
_, err = svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: oldRecoveryCode,
}, "127.0.0.1")
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
t.Fatalf("expected old recovery code to fail, got %v", err)
}
}
func TestAuthServiceTrustedDeviceSkipsMFA(t *testing.T) {
svc, repo := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
CurrentPassword: "password-123",
})
if err != nil {
t.Fatalf("PrepareTwoFactor: %v", err)
}
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode: %v", err)
}
if _, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}); err != nil {
t.Fatalf("EnableTwoFactor: %v", err)
}
loginCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode login: %v", err)
}
firstLogin, err := svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: loginCode,
RememberDevice: true, TrustedDeviceName: "test browser",
}, "127.0.0.1")
if err != nil {
t.Fatalf("Login with 2FA: %v", err)
}
if firstLogin.TrustedDeviceToken == "" || firstLogin.TrustedDevice == nil {
t.Fatalf("expected trusted device token")
}
secondLogin, err := svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TrustedDeviceToken: firstLogin.TrustedDeviceToken,
}, "127.0.0.1")
if err != nil {
t.Fatalf("Login with trusted device: %v", err)
}
if secondLogin.Token == "" {
t.Fatalf("expected token")
}
disableCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
if err != nil {
t.Fatalf("GenerateCode disable: %v", err)
}
if _, err := svc.DisableTwoFactor(context.Background(), "1", DisableTwoFactorInput{
CurrentPassword: "password-123",
Code: disableCode,
}); err != nil {
t.Fatalf("DisableTwoFactor: %v", err)
}
if repo.users[0].TrustedDevices != "" {
t.Fatalf("expected trusted devices cleared after disabling last MFA method")
}
}
func TestAuthServiceOutOfBandOTPLoginConsumesCode(t *testing.T) {
svc, repo := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
user := repo.users[0]
user.Email = "admin@example.com"
user.EmailOTPEnabled = true
hash, err := security.HashPassword("123456")
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{
Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute),
})
if err != nil {
t.Fatalf("EncryptJSON: %v", err)
}
user.OutOfBandOTPCiphertext = ciphertext
if err := repo.Update(context.Background(), user); err != nil {
t.Fatalf("Update: %v", err)
}
loginResult, err := svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: "123456",
}, "127.0.0.1")
if err != nil {
t.Fatalf("Login with email OTP: %v", err)
}
if loginResult.Token == "" {
t.Fatalf("expected token")
}
if repo.users[0].OutOfBandOTPCiphertext != "" {
t.Fatalf("expected OTP to be consumed")
}
_, err = svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: "123456",
}, "127.0.0.1")
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
t.Fatalf("expected consumed OTP to fail, got %v", err)
}
}
func TestAuthServiceMFAStartIsRateLimited(t *testing.T) {
svc, repo := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
repo.users[0].Email = "admin@example.com"
repo.users[0].EmailOTPEnabled = true
for i := 0; i < 5; i++ {
_ = svc.SendLoginOTP(context.Background(), LoginOTPInput{
Username: "admin", Password: "wrong-password", Channel: "email",
}, "127.0.0.1")
}
err = svc.SendLoginOTP(context.Background(), LoginOTPInput{
Username: "admin", Password: "wrong-password", Channel: "email",
}, "127.0.0.1")
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_RATE_LIMITED" {
t.Fatalf("expected AUTH_RATE_LIMITED, got %v", err)
}
}
func TestAuthServiceDisabledOTPChannelCannotConsumePendingCode(t *testing.T) {
svc, repo := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
user := repo.users[0]
user.Email = "admin@example.com"
user.EmailOTPEnabled = false
user.SMSOTPEnabled = true
hash, err := security.HashPassword("123456")
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{
Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute),
})
if err != nil {
t.Fatalf("EncryptJSON: %v", err)
}
user.OutOfBandOTPCiphertext = ciphertext
if err := repo.Update(context.Background(), user); err != nil {
t.Fatalf("Update: %v", err)
}
_, err = svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123", TwoFactorCode: "123456",
}, "127.0.0.1")
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
t.Fatalf("expected disabled OTP channel to fail, got %v", err)
}
if repo.users[0].OutOfBandOTPCiphertext != "" {
t.Fatalf("expected disabled channel OTP to be cleared")
}
}
func TestAuthServiceChangingOTPRecipientClearsPendingCode(t *testing.T) {
svc, repo := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
user := repo.users[0]
user.Email = "old@example.com"
user.EmailOTPEnabled = true
hash, err := security.HashPassword("123456")
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{
Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute),
})
if err != nil {
t.Fatalf("EncryptJSON: %v", err)
}
user.OutOfBandOTPCiphertext = ciphertext
if err := repo.Update(context.Background(), user); err != nil {
t.Fatalf("Update: %v", err)
}
updated, err := svc.ConfigureOutOfBandOTP(context.Background(), "1", OTPConfigInput{
CurrentPassword: "password-123",
Channel: "email",
Enabled: true,
Email: "new@example.com",
})
if err != nil {
t.Fatalf("ConfigureOutOfBandOTP: %v", err)
}
if updated.Email != "new@example.com" {
t.Fatalf("expected email updated, got %q", updated.Email)
}
if repo.users[0].OutOfBandOTPCiphertext != "" {
t.Fatalf("expected pending email OTP to be cleared after recipient change")
}
}
func TestAuthServiceCorruptWebAuthnCredentialsStillRequireMFA(t *testing.T) {
svc, repo := newTestAuthService()
_, err := svc.Setup(context.Background(), SetupInput{
Username: "admin", Password: "password-123", DisplayName: "Admin",
})
if err != nil {
t.Fatalf("Setup: %v", err)
}
repo.users[0].WebAuthnCredentials = "{invalid-json"
_, err = svc.Login(context.Background(), LoginInput{
Username: "admin", Password: "password-123",
}, "127.0.0.1")
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_REQUIRED" {
t.Fatalf("expected corrupt WebAuthn credentials to require MFA, got %v", err)
}
}

View File

@@ -0,0 +1,221 @@
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"
"time"
"backupx/server/internal/apperror"
"backupx/server/internal/model"
"backupx/server/internal/security"
)
func (s *AuthService) ListTrustedDevices(ctx context.Context, subject string) ([]TrustedDeviceOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
devices, err := parseTrustedDevices(user.TrustedDevices)
if err != nil {
return nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
now := time.Now().UTC()
output := make([]TrustedDeviceOutput, 0, len(devices))
for _, device := range devices {
if device.ExpiresAt.Before(now) {
continue
}
output = append(output, toTrustedDeviceOutput(device))
}
return output, nil
}
type TrustedDeviceRevokeInput struct {
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
}
func (s *AuthService) RevokeTrustedDevice(ctx context.Context, subject string, id string, input TrustedDeviceRevokeInput) error {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return err
}
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
return apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
}
devices, err := parseTrustedDevices(user.TrustedDevices)
if err != nil {
return apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
found := false
filtered := make([]TrustedDeviceRecord, 0, len(devices))
for _, device := range devices {
if device.ID == strings.TrimSpace(id) {
found = true
} else {
filtered = append(filtered, device)
}
}
if !found {
return apperror.New(404, "AUTH_TRUSTED_DEVICE_NOT_FOUND", "可信设备不存在", nil)
}
encoded, err := encodeTrustedDevices(filtered)
if err != nil {
return apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
user.TrustedDevices = encoded
if err := s.users.Update(ctx, user); err != nil {
return apperror.Internal("AUTH_TRUSTED_DEVICE_REVOKE_FAILED", "无法移除可信设备", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "trusted_device_revoke",
TargetType: "trusted_device", TargetID: strings.TrimSpace(id),
Detail: "移除可信设备",
})
}
return nil
}
func (s *AuthService) verifyTrustedDevice(ctx context.Context, user *model.User, token string, clientKey string) (bool, error) {
token = strings.TrimSpace(token)
if token == "" {
return false, nil
}
devices, err := parseTrustedDevices(user.TrustedDevices)
if err != nil {
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
now := time.Now().UTC()
hash := trustedDeviceTokenHash(token)
changed := false
for i := range devices {
device := &devices[i]
if device.ExpiresAt.Before(now) {
changed = true
continue
}
if subtle.ConstantTimeCompare([]byte(device.TokenHash), []byte(hash)) != 1 {
continue
}
device.LastUsedAt = now
device.LastIP = clientKey
changed = true
encoded, err := encodeTrustedDevices(filterActiveTrustedDevices(devices, now))
if err != nil {
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
user.TrustedDevices = encoded
if err := s.users.Update(ctx, user); err != nil {
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_UPDATE_FAILED", "无法更新可信设备", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "trusted_device_used",
TargetType: "trusted_device", TargetID: device.ID, TargetName: device.Name,
Detail: "使用可信设备跳过多因素验证", ClientIP: clientKey,
})
}
return true, nil
}
if changed {
encoded, err := encodeTrustedDevices(filterActiveTrustedDevices(devices, now))
if err != nil {
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
user.TrustedDevices = encoded
if err := s.users.Update(ctx, user); err != nil {
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_UPDATE_FAILED", "无法更新可信设备", err)
}
}
return false, nil
}
func (s *AuthService) issueTrustedDevice(ctx context.Context, user *model.User, name string, clientKey string) (string, *TrustedDeviceOutput, error) {
token, err := randomURLToken(32)
if err != nil {
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法生成可信设备令牌", err)
}
id, err := randomURLToken(16)
if err != nil {
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法生成可信设备编号", err)
}
now := time.Now().UTC()
deviceName := normalizeTrustedDeviceName(name)
device := TrustedDeviceRecord{
ID: id,
Name: deviceName,
TokenHash: trustedDeviceTokenHash(token),
CreatedAt: now,
LastUsedAt: now,
ExpiresAt: now.Add(trustedDeviceTTL),
LastIP: clientKey,
}
devices, err := parseTrustedDevices(user.TrustedDevices)
if err != nil {
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
devices = append(filterActiveTrustedDevices(devices, now), device)
if len(devices) > maxTrustedDevices {
devices = devices[len(devices)-maxTrustedDevices:]
}
encoded, err := encodeTrustedDevices(devices)
if err != nil {
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
}
user.TrustedDevices = encoded
if err := s.users.Update(ctx, user); err != nil {
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法保存可信设备", err)
}
output := toTrustedDeviceOutput(device)
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "trusted_device_create",
TargetType: "trusted_device", TargetID: device.ID, TargetName: device.Name,
Detail: fmt.Sprintf("添加可信设备,有效期至 %s", device.ExpiresAt.Format(time.RFC3339)), ClientIP: clientKey,
})
}
return token, &output, nil
}
func filterActiveTrustedDevices(devices []TrustedDeviceRecord, now time.Time) []TrustedDeviceRecord {
active := make([]TrustedDeviceRecord, 0, len(devices))
for _, device := range devices {
if device.ExpiresAt.After(now) {
active = append(active, device)
}
}
return active
}
func trustedDeviceTokenHash(token string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(token)))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func randomURLToken(size int) (string, error) {
buf := make([]byte, size)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func normalizeTrustedDeviceName(name string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return "当前设备"
}
if len([]rune(trimmed)) <= maxTrustedDeviceName {
return trimmed
}
runes := []rune(trimmed)
return string(runes[:maxTrustedDeviceName])
}

View File

@@ -0,0 +1,366 @@
package service
import (
"context"
"fmt"
"strings"
"time"
"backupx/server/internal/apperror"
"backupx/server/internal/model"
"backupx/server/internal/security"
)
type WebAuthnRequestContext struct {
RPID string
Origin string
}
type WebAuthnRegistrationOptionsInput struct {
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
}
type WebAuthnRegistrationFinishInput struct {
Name string `json:"name" binding:"omitempty,max=128"`
Credential security.WebAuthnRegistrationResponse `json:"credential" binding:"required"`
}
type WebAuthnCredentialDeleteInput struct {
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
}
type WebAuthnLoginOptionsInput struct {
Username string `json:"username" binding:"required,min=3,max=64"`
Password string `json:"password" binding:"required,min=8,max=128"`
}
type webAuthnPublicKeyCredentialParam struct {
Type string `json:"type"`
Alg int `json:"alg"`
}
type webAuthnRelyingParty struct {
Name string `json:"name"`
ID string `json:"id"`
}
type webAuthnUserEntity struct {
ID string `json:"id"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
}
type webAuthnCredentialDescriptor struct {
Type string `json:"type"`
ID string `json:"id"`
}
type webAuthnAuthenticatorSelection struct {
UserVerification string `json:"userVerification"`
}
type WebAuthnRegistrationOptions struct {
Challenge string `json:"challenge"`
RP webAuthnRelyingParty `json:"rp"`
User webAuthnUserEntity `json:"user"`
PubKeyCredParams []webAuthnPublicKeyCredentialParam `json:"pubKeyCredParams"`
Timeout int `json:"timeout"`
Attestation string `json:"attestation"`
AuthenticatorSelection webAuthnAuthenticatorSelection `json:"authenticatorSelection"`
ExcludeCredentials []webAuthnCredentialDescriptor `json:"excludeCredentials"`
}
type WebAuthnLoginOptions struct {
Challenge string `json:"challenge"`
RPID string `json:"rpId"`
Timeout int `json:"timeout"`
UserVerification string `json:"userVerification"`
AllowCredentials []webAuthnCredentialDescriptor `json:"allowCredentials"`
}
func (s *AuthService) BeginWebAuthnRegistration(ctx context.Context, subject string, input WebAuthnRegistrationOptionsInput, request WebAuthnRequestContext) (*WebAuthnRegistrationOptions, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
}
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
}
challenge, err := security.GenerateWebAuthnChallenge()
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法生成通行密钥挑战", err)
}
state := webAuthnChallengeState{
Type: "register",
Challenge: challenge,
RPID: request.RPID,
Origin: request.Origin,
ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL),
}
if err := s.saveWebAuthnChallenge(ctx, user, state); err != nil {
return nil, err
}
exclude := make([]webAuthnCredentialDescriptor, 0, len(credentials))
for _, credential := range credentials {
exclude = append(exclude, webAuthnCredentialDescriptor{Type: "public-key", ID: credential.CredentialID})
}
return &WebAuthnRegistrationOptions{
Challenge: challenge,
RP: webAuthnRelyingParty{Name: "BackupX", ID: request.RPID},
User: webAuthnUserEntity{
ID: security.EncodeBase64URL([]byte(fmt.Sprintf("%d", user.ID))),
Name: user.Username,
DisplayName: user.DisplayName,
},
PubKeyCredParams: []webAuthnPublicKeyCredentialParam{
{Type: "public-key", Alg: -7},
},
Timeout: int(mfaChallengeTTL / time.Millisecond),
Attestation: "none",
AuthenticatorSelection: webAuthnAuthenticatorSelection{UserVerification: "preferred"},
ExcludeCredentials: exclude,
}, nil
}
func (s *AuthService) FinishWebAuthnRegistration(ctx context.Context, subject string, input WebAuthnRegistrationFinishInput) (*UserOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
state, err := s.loadWebAuthnChallenge(user, "register")
if err != nil {
return nil, err
}
parsed, err := security.VerifyWebAuthnRegistration(input.Credential, state.Challenge, state.RPID, state.Origin)
if err != nil {
return nil, apperror.BadRequest("AUTH_WEBAUTHN_VERIFY_FAILED", "通行密钥注册校验失败", err)
}
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
}
for _, credential := range credentials {
if credential.CredentialID == parsed.CredentialID {
return nil, apperror.Conflict("AUTH_WEBAUTHN_EXISTS", "该通行密钥已注册", nil)
}
}
id, err := randomURLToken(16)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法生成通行密钥编号", err)
}
now := time.Now().UTC().Format(time.RFC3339)
name := strings.TrimSpace(input.Name)
if name == "" {
name = "通行密钥"
}
credentials = append(credentials, WebAuthnCredentialRecord{
ID: id,
Name: normalizeTrustedDeviceName(name),
CredentialID: parsed.CredentialID,
PublicKeyX: parsed.PublicKeyX,
PublicKeyY: parsed.PublicKeyY,
SignCount: parsed.SignCount,
CreatedAt: now,
})
encoded, err := encodeWebAuthnCredentials(credentials)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法保存通行密钥", err)
}
user.WebAuthnCredentials = encoded
user.WebAuthnChallengeCiphertext = ""
if err := s.users.Update(ctx, user); err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法保存通行密钥", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "webauthn_register",
TargetType: "webauthn_credential", TargetID: id, TargetName: name,
Detail: "注册通行密钥",
})
}
return ToUserOutput(user), nil
}
func (s *AuthService) BeginWebAuthnLogin(ctx context.Context, input WebAuthnLoginOptionsInput, request WebAuthnRequestContext, clientKey string) (*WebAuthnLoginOptions, error) {
user, err := s.verifyPasswordForMFAStart(ctx, input.Username, input.Password, clientKey)
if err != nil {
return nil, err
}
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
}
if len(credentials) == 0 {
return nil, apperror.BadRequest("AUTH_WEBAUTHN_NOT_ENABLED", "当前账号未注册通行密钥", nil)
}
challenge, err := security.GenerateWebAuthnChallenge()
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法生成通行密钥挑战", err)
}
state := webAuthnChallengeState{
Type: "login",
Challenge: challenge,
RPID: request.RPID,
Origin: request.Origin,
ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL),
}
if err := s.saveWebAuthnChallenge(ctx, user, state); err != nil {
return nil, err
}
allowed := make([]webAuthnCredentialDescriptor, 0, len(credentials))
for _, credential := range credentials {
allowed = append(allowed, webAuthnCredentialDescriptor{Type: "public-key", ID: credential.CredentialID})
}
return &WebAuthnLoginOptions{
Challenge: challenge,
RPID: request.RPID,
Timeout: int(mfaChallengeTTL / time.Millisecond),
UserVerification: "preferred",
AllowCredentials: allowed,
}, nil
}
func (s *AuthService) VerifyWebAuthnLogin(ctx context.Context, user *model.User, assertion security.WebAuthnLoginAssertion, clientKey string) error {
state, err := s.loadWebAuthnChallenge(user, "login")
if err != nil {
return err
}
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
if err != nil {
return apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
}
rawID := strings.TrimSpace(assertion.RawID)
if rawID == "" {
rawID = strings.TrimSpace(assertion.ID)
}
for i := range credentials {
credential := &credentials[i]
if credential.CredentialID != rawID {
continue
}
nextSignCount, err := security.VerifyWebAuthnAssertion(assertion, state.Challenge, state.RPID, state.Origin, security.WebAuthnCredentialMaterial{
CredentialID: credential.CredentialID,
PublicKeyX: credential.PublicKeyX,
PublicKeyY: credential.PublicKeyY,
SignCount: credential.SignCount,
})
if err != nil {
return apperror.Unauthorized("AUTH_WEBAUTHN_INVALID", "通行密钥校验失败", err)
}
credential.SignCount = nextSignCount
credential.LastUsedAt = time.Now().UTC().Format(time.RFC3339)
encoded, err := encodeWebAuthnCredentials(credentials)
if err != nil {
return apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err)
}
user.WebAuthnCredentials = encoded
user.WebAuthnChallengeCiphertext = ""
if err := s.users.Update(ctx, user); err != nil {
return apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "webauthn_used",
TargetType: "webauthn_credential", TargetID: credential.ID, TargetName: credential.Name,
Detail: "使用通行密钥完成多因素验证", ClientIP: clientKey,
})
}
return nil
}
return apperror.Unauthorized("AUTH_WEBAUTHN_INVALID", "通行密钥不存在", nil)
}
func (s *AuthService) ListWebAuthnCredentials(ctx context.Context, subject string) ([]WebAuthnCredentialOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
}
output := make([]WebAuthnCredentialOutput, 0, len(credentials))
for _, credential := range credentials {
output = append(output, toWebAuthnCredentialOutput(credential))
}
return output, nil
}
func (s *AuthService) DeleteWebAuthnCredential(ctx context.Context, subject string, id string, input WebAuthnCredentialDeleteInput) (*UserOutput, error) {
user, err := s.userBySubject(ctx, subject)
if err != nil {
return nil, err
}
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
}
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
}
found := false
filtered := make([]WebAuthnCredentialRecord, 0, len(credentials))
for _, credential := range credentials {
if credential.ID == strings.TrimSpace(id) {
found = true
} else {
filtered = append(filtered, credential)
}
}
if !found {
return nil, apperror.New(404, "AUTH_WEBAUTHN_NOT_FOUND", "通行密钥不存在", nil)
}
encoded, err := encodeWebAuthnCredentials(filtered)
if err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err)
}
user.WebAuthnCredentials = encoded
clearTrustedDevicesIfMFAOff(user)
if err := s.users.Update(ctx, user); err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_DELETE_FAILED", "无法删除通行密钥", err)
}
if s.auditService != nil {
s.auditService.Record(AuditEntry{
UserID: user.ID, Username: user.Username,
Category: "auth", Action: "webauthn_delete",
TargetType: "webauthn_credential", TargetID: strings.TrimSpace(id),
Detail: "删除通行密钥",
})
}
return ToUserOutput(user), nil
}
func (s *AuthService) saveWebAuthnChallenge(ctx context.Context, user *model.User, state webAuthnChallengeState) error {
ciphertext, err := s.twoFactorCipher.EncryptJSON(state)
if err != nil {
return apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法保存通行密钥挑战", err)
}
user.WebAuthnChallengeCiphertext = ciphertext
if err := s.users.Update(ctx, user); err != nil {
return apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法保存通行密钥挑战", err)
}
return nil
}
func (s *AuthService) loadWebAuthnChallenge(user *model.User, challengeType string) (*webAuthnChallengeState, error) {
if strings.TrimSpace(user.WebAuthnChallengeCiphertext) == "" {
return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_MISSING", "请先发起通行密钥验证", nil)
}
var state webAuthnChallengeState
if err := s.twoFactorCipher.DecryptJSON(user.WebAuthnChallengeCiphertext, &state); err != nil {
return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_INVALID", "通行密钥挑战状态异常", err)
}
if state.Type != challengeType {
return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_INVALID", "通行密钥挑战类型不匹配", nil)
}
if state.ExpiresAt.Before(time.Now().UTC()) {
return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_EXPIRED", "通行密钥挑战已过期", nil)
}
return &state, nil
}

View File

@@ -16,11 +16,11 @@ import (
)
type NotificationUpsertInput struct {
Name string `json:"name" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=email webhook telegram"`
Enabled bool `json:"enabled"`
OnSuccess bool `json:"onSuccess"`
OnFailure bool `json:"onFailure"`
Name string `json:"name" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=email webhook telegram"`
Enabled bool `json:"enabled"`
OnSuccess bool `json:"onSuccess"`
OnFailure bool `json:"onFailure"`
// EventTypes 订阅的扩展事件列表。与 OnSuccess/OnFailure 并存:
// - 两者均空时,订阅"备份成功/失败"对应原有语义(兼容)。
// - EventTypes 显式指定时优先按清单匹配。
@@ -186,8 +186,8 @@ func (s *NotificationService) NotifyBackupResult(ctx context.Context, event Back
// - eventType 对应 model.NotificationEvent* 常量,用于订阅匹配
//
// 订阅匹配规则:
// 1) notification.EventTypes 非空:必须包含 eventType
// 2) notification.EventTypes 为空:沿用 OnSuccess/OnFailure 开关(仅 backup_* 事件)
// 1. notification.EventTypes 非空:必须包含 eventType
// 2. notification.EventTypes 为空:沿用 OnSuccess/OnFailure 开关(仅 backup_* 事件)
func (s *NotificationService) DispatchEvent(ctx context.Context, eventType string, title string, body string, fields map[string]any) error {
// 同步广播到 SSE 订阅者(前端 Dashboard 实时推送)。
// 非阻塞:即便广播器未注入或订阅者已满也不影响 Notification 持久渠道。
@@ -209,6 +209,49 @@ func (s *NotificationService) DispatchEvent(ctx context.Context, eventType strin
return s.deliver(ctx, items, message)
}
func (s *NotificationService) SendAuthEmailOTP(ctx context.Context, to string, code string) error {
return s.sendFirstByType(ctx, "email", map[string]any{"to": strings.TrimSpace(to)}, notify.Message{
Title: "BackupX 登录验证码",
Body: fmt.Sprintf("您的 BackupX 登录验证码为:%s\n验证码 5 分钟内有效。若非本人操作,请立即检查账号安全。", code),
Fields: map[string]any{
"purpose": "login_otp",
},
})
}
func (s *NotificationService) SendAuthSMSOTP(ctx context.Context, phone string, code string) error {
return s.sendFirstByType(ctx, "webhook", nil, notify.Message{
Title: "BackupX 登录验证码",
Body: fmt.Sprintf("BackupX 登录验证码:%s5 分钟内有效。", code),
Fields: map[string]any{
"phone": strings.TrimSpace(phone),
"code": code,
"purpose": "login_otp",
},
})
}
func (s *NotificationService) sendFirstByType(ctx context.Context, notificationType string, override map[string]any, message notify.Message) error {
items, err := s.notifications.List(ctx)
if err != nil {
return err
}
for _, item := range items {
if !item.Enabled || item.Type != notificationType {
continue
}
configMap := map[string]any{}
if err := s.cipher.DecryptJSON(item.ConfigCiphertext, &configMap); err != nil {
return fmt.Errorf("decrypt notification %d config: %w", item.ID, err)
}
for key, value := range override {
configMap[key] = value
}
return s.registry.Send(ctx, item.Type, configMap, message)
}
return fmt.Errorf("no enabled %s notification configured", notificationType)
}
// collectSubscribers 按事件类型收集启用的订阅者。
// 列出启用通知后按事件类型再过滤(避免引入新 repository 方法)。
func (s *NotificationService) collectSubscribers(ctx context.Context, eventType string, fallbackSuccess bool) ([]model.Notification, error) {

View File

@@ -22,13 +22,22 @@ func NewUserService(users repository.UserRepository) *UserService {
// UserSummary 用户列表项(不含密码哈希)。
type UserSummary struct {
ID uint `json:"id"`
Username string `json:"username"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
Role string `json:"role"`
Disabled bool `json:"disabled"`
CreatedAt string `json:"createdAt"`
ID uint `json:"id"`
Username string `json:"username"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
Phone string `json:"phone"`
Role string `json:"role"`
Disabled bool `json:"disabled"`
MFAEnabled bool `json:"mfaEnabled"`
TwoFactorEnabled bool `json:"twoFactorEnabled"`
TwoFactorRecoveryCodesRemaining int `json:"twoFactorRecoveryCodesRemaining"`
WebAuthnEnabled bool `json:"webAuthnEnabled"`
WebAuthnCredentialCount int `json:"webAuthnCredentialCount"`
TrustedDeviceCount int `json:"trustedDeviceCount"`
EmailOTPEnabled bool `json:"emailOtpEnabled"`
SMSOTPEnabled bool `json:"smsOtpEnabled"`
CreatedAt string `json:"createdAt"`
}
// UserUpsertInput 创建/更新用户的输入。
@@ -37,6 +46,7 @@ type UserUpsertInput struct {
Password string `json:"password" binding:"omitempty,min=8,max=128"`
DisplayName string `json:"displayName" binding:"required,min=1,max=128"`
Email string `json:"email" binding:"omitempty,max=255"`
Phone string `json:"phone" binding:"omitempty,max=64"`
Role string `json:"role" binding:"required,oneof=admin operator viewer"`
Disabled bool `json:"disabled"`
}
@@ -76,6 +86,7 @@ func (s *UserService) Create(ctx context.Context, input UserUpsertInput) (*UserS
PasswordHash: hash,
DisplayName: strings.TrimSpace(input.DisplayName),
Email: strings.TrimSpace(input.Email),
Phone: strings.TrimSpace(input.Phone),
Role: input.Role,
Disabled: input.Disabled,
}
@@ -107,18 +118,43 @@ func (s *UserService) Update(ctx context.Context, id uint, input UserUpsertInput
return nil, apperror.Conflict("USER_USERNAME_EXISTS", "用户名已存在", nil)
}
}
passwordChanged := strings.TrimSpace(input.Password) != ""
disabledChanged := input.Disabled && !existing.Disabled
emailChanged := strings.TrimSpace(input.Email) != strings.TrimSpace(existing.Email)
phoneChanged := strings.TrimSpace(input.Phone) != strings.TrimSpace(existing.Phone)
existing.Username = strings.TrimSpace(input.Username)
existing.DisplayName = strings.TrimSpace(input.DisplayName)
existing.Email = strings.TrimSpace(input.Email)
existing.Phone = strings.TrimSpace(input.Phone)
existing.Role = input.Role
existing.Disabled = input.Disabled
if strings.TrimSpace(input.Password) != "" {
if passwordChanged {
hash, err := security.HashPassword(input.Password)
if err != nil {
return nil, apperror.Internal("USER_HASH_FAILED", "无法处理密码", err)
}
existing.PasswordHash = hash
existing.TrustedDevices = ""
existing.OutOfBandOTPCiphertext = ""
existing.WebAuthnChallengeCiphertext = ""
}
if strings.TrimSpace(existing.Email) == "" && existing.EmailOTPEnabled {
existing.EmailOTPEnabled = false
existing.OutOfBandOTPCiphertext = ""
}
if strings.TrimSpace(existing.Phone) == "" && existing.SMSOTPEnabled {
existing.SMSOTPEnabled = false
existing.OutOfBandOTPCiphertext = ""
}
if emailChanged || phoneChanged {
existing.OutOfBandOTPCiphertext = ""
}
if disabledChanged {
existing.TrustedDevices = ""
existing.OutOfBandOTPCiphertext = ""
existing.WebAuthnChallengeCiphertext = ""
}
clearTrustedDevicesIfMFAOff(existing)
if err := s.users.Update(ctx, existing); err != nil {
return nil, apperror.Internal("USER_UPDATE_FAILED", "无法更新用户", err)
}
@@ -147,14 +183,47 @@ func (s *UserService) Delete(ctx context.Context, id uint) error {
return s.users.Delete(ctx, id)
}
func (s *UserService) ResetTwoFactor(ctx context.Context, id uint) (*UserSummary, error) {
existing, err := s.users.FindByID(ctx, id)
if err != nil {
return nil, apperror.Internal("USER_GET_FAILED", "无法获取用户", err)
}
if existing == nil {
return nil, apperror.New(404, "USER_NOT_FOUND", "用户不存在", nil)
}
existing.TwoFactorEnabled = false
existing.TwoFactorSecretCiphertext = ""
existing.TwoFactorRecoveryCodeHashes = ""
existing.WebAuthnCredentials = ""
existing.WebAuthnChallengeCiphertext = ""
existing.TrustedDevices = ""
existing.EmailOTPEnabled = false
existing.SMSOTPEnabled = false
existing.OutOfBandOTPCiphertext = ""
if err := s.users.Update(ctx, existing); err != nil {
return nil, apperror.Internal("USER_2FA_RESET_FAILED", "无法重置 MFA", err)
}
summary := toUserSummary(existing)
return &summary, nil
}
func toUserSummary(u *model.User) UserSummary {
return UserSummary{
ID: u.ID,
Username: u.Username,
DisplayName: u.DisplayName,
Email: u.Email,
Role: u.Role,
Disabled: u.Disabled,
CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
ID: u.ID,
Username: u.Username,
DisplayName: u.DisplayName,
Email: u.Email,
Phone: u.Phone,
Role: u.Role,
Disabled: u.Disabled,
MFAEnabled: userMFAEnabled(u),
TwoFactorEnabled: u.TwoFactorEnabled,
TwoFactorRecoveryCodesRemaining: recoveryCodeRemainingCount(u),
WebAuthnEnabled: webAuthnCredentialCount(u) > 0,
WebAuthnCredentialCount: webAuthnCredentialCount(u),
TrustedDeviceCount: trustedDeviceCount(u),
EmailOTPEnabled: u.EmailOTPEnabled,
SMSOTPEnabled: u.SMSOTPEnabled,
CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}

View File

@@ -0,0 +1,124 @@
package service
import (
"context"
"testing"
"backupx/server/internal/model"
"backupx/server/internal/security"
)
func TestUserServiceUpdatePasswordClearsTrustedDeviceState(t *testing.T) {
hash, err := security.HashPassword("old-password")
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
repo := &fakeUserRepository{users: []*model.User{{
ID: 1,
Username: "admin",
PasswordHash: hash,
DisplayName: "Admin",
Email: "admin@example.com",
Role: model.UserRoleAdmin,
TwoFactorEnabled: true,
TrustedDevices: `[{"id":"device"}]`,
OutOfBandOTPCiphertext: "pending",
WebAuthnChallengeCiphertext: "challenge",
}}}
svc := NewUserService(repo)
if _, err := svc.Update(context.Background(), 1, UserUpsertInput{
Username: "admin",
Password: "new-password",
DisplayName: "Admin",
Email: "admin@example.com",
Role: model.UserRoleAdmin,
}); err != nil {
t.Fatalf("Update: %v", err)
}
updated := repo.users[0]
if security.ComparePassword(updated.PasswordHash, "new-password") != nil {
t.Fatalf("expected password hash to be updated")
}
if updated.TrustedDevices != "" || updated.OutOfBandOTPCiphertext != "" || updated.WebAuthnChallengeCiphertext != "" {
t.Fatalf("expected password update to clear trusted device state, got trusted=%q otp=%q challenge=%q", updated.TrustedDevices, updated.OutOfBandOTPCiphertext, updated.WebAuthnChallengeCiphertext)
}
}
func TestUserServiceUpdateContactClearsUnavailableOTP(t *testing.T) {
hash, err := security.HashPassword("password-123")
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
repo := &fakeUserRepository{users: []*model.User{{
ID: 1,
Username: "admin",
PasswordHash: hash,
DisplayName: "Admin",
Email: "admin@example.com",
Phone: "+15550000000",
Role: model.UserRoleAdmin,
EmailOTPEnabled: true,
SMSOTPEnabled: true,
TrustedDevices: `[{"id":"device"}]`,
OutOfBandOTPCiphertext: "pending",
}}}
svc := NewUserService(repo)
summary, err := svc.Update(context.Background(), 1, UserUpsertInput{
Username: "admin",
DisplayName: "Admin",
Role: model.UserRoleAdmin,
})
if err != nil {
t.Fatalf("Update: %v", err)
}
updated := repo.users[0]
if updated.EmailOTPEnabled || updated.SMSOTPEnabled || summary.MFAEnabled {
t.Fatalf("expected unavailable OTP channels to be disabled")
}
if updated.TrustedDevices != "" || updated.OutOfBandOTPCiphertext != "" || updated.WebAuthnChallengeCiphertext != "" {
t.Fatalf("expected last MFA removal to clear temporary state")
}
}
func TestUserServiceUpdateContactChangeClearsPendingOTP(t *testing.T) {
hash, err := security.HashPassword("password-123")
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
repo := &fakeUserRepository{users: []*model.User{{
ID: 1,
Username: "admin",
PasswordHash: hash,
DisplayName: "Admin",
Email: "old@example.com",
Role: model.UserRoleAdmin,
EmailOTPEnabled: true,
OutOfBandOTPCiphertext: "pending",
}}}
svc := NewUserService(repo)
summary, err := svc.Update(context.Background(), 1, UserUpsertInput{
Username: "admin",
DisplayName: "Admin",
Email: "new@example.com",
Role: model.UserRoleAdmin,
})
if err != nil {
t.Fatalf("Update: %v", err)
}
updated := repo.users[0]
if updated.Email != "new@example.com" || summary.Email != "new@example.com" {
t.Fatalf("expected email to be updated")
}
if !updated.EmailOTPEnabled {
t.Fatalf("expected email OTP to remain enabled")
}
if updated.OutOfBandOTPCiphertext != "" {
t.Fatalf("expected contact change to clear pending OTP")
}
}