From 7dfd12254b0accc6be345a8b8ccff74053a145d2 Mon Sep 17 00:00:00 2001 From: Awuqing <3184394176@qq.com> Date: Sat, 25 Apr 2026 21:14:39 +0800 Subject: [PATCH] feat: add complete MFA support --- server/go.mod | 2 +- server/internal/app/app.go | 53 +- server/internal/http/auth_handler.go | 270 ++++++++++ server/internal/http/install_flow_test.go | 10 +- server/internal/http/router.go | 22 +- server/internal/http/router_test.go | 18 +- server/internal/http/user_handler.go | 15 + server/internal/model/user.go | 25 +- server/internal/notify/sms.go | 64 +++ server/internal/security/otp_code.go | 23 + server/internal/security/recovery_code.go | 49 ++ server/internal/security/totp.go | 68 +++ server/internal/security/webauthn.go | 447 ++++++++++++++++ server/internal/service/auth_methods.go | 179 +++++++ server/internal/service/auth_otp.go | 252 +++++++++ server/internal/service/auth_service.go | 505 ++++++++++++++++-- server/internal/service/auth_service_test.go | 427 +++++++++++++++ .../internal/service/auth_trusted_device.go | 221 ++++++++ server/internal/service/auth_webauthn.go | 366 +++++++++++++ .../internal/service/notification_service.go | 57 +- server/internal/service/user_service.go | 99 +++- server/internal/service/user_service_test.go | 124 +++++ .../notifications/field-config.test.ts | 3 + .../components/notifications/field-config.ts | 7 + web/src/layouts/AppLayout.tsx | 385 ++++++++++++- web/src/pages/admin/UsersPage.tsx | 52 +- web/src/pages/audit/AuditLogsPage.tsx | 17 + web/src/pages/login/LoginPage.tsx | 115 +++- web/src/services/auth.ts | 229 +++++++- web/src/services/users.ts | 15 + web/src/stores/auth.ts | 4 + web/src/types/auth.ts | 15 + web/src/types/notifications.ts | 2 +- web/src/utils/webauthn.ts | 88 +++ 34 files changed, 4114 insertions(+), 114 deletions(-) create mode 100644 server/internal/notify/sms.go create mode 100644 server/internal/security/otp_code.go create mode 100644 server/internal/security/recovery_code.go create mode 100644 server/internal/security/totp.go create mode 100644 server/internal/security/webauthn.go create mode 100644 server/internal/service/auth_methods.go create mode 100644 server/internal/service/auth_otp.go create mode 100644 server/internal/service/auth_trusted_device.go create mode 100644 server/internal/service/auth_webauthn.go create mode 100644 server/internal/service/user_service_test.go create mode 100644 web/src/utils/webauthn.ts diff --git a/server/go.mod b/server/go.mod index 4f912e2..a3a289a 100644 --- a/server/go.mod +++ b/server/go.mod @@ -8,6 +8,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.0 github.com/natefinch/lumberjack v2.0.0+incompatible github.com/prometheus/client_golang v1.23.2 + github.com/pquerna/otp v1.5.0 github.com/rclone/rclone v1.73.3 github.com/robfig/cron/v3 v3.0.1 github.com/spf13/viper v1.20.0 @@ -181,7 +182,6 @@ require ( github.com/pkg/xattr v0.4.12 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect - github.com/pquerna/otp v1.5.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.2 // indirect github.com/prometheus/procfs v0.19.2 // indirect diff --git a/server/internal/app/app.go b/server/internal/app/app.go index 762afb6..e76c4b8 100644 --- a/server/internal/app/app.go +++ b/server/internal/app/app.go @@ -60,9 +60,9 @@ func New(ctx context.Context, cfg config.Config, version string) (*Application, jwtManager := security.NewJWTManager(resolvedSecurity.JWTSecret, config.MustJWTDuration(cfg.Security)) rateLimiter := security.NewLoginRateLimiter(5, time.Minute) - authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, rateLimiter) - systemService := service.NewSystemService(cfg, version, time.Now().UTC()) configCipher := codec.NewConfigCipher(resolvedSecurity.EncryptionKey) + authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, rateLimiter, configCipher) + systemService := service.NewSystemService(cfg, version, time.Now().UTC()) storageRegistry := storage.NewRegistry( storageRclone.NewLocalDiskFactory(), storageRclone.NewS3Factory(), @@ -85,8 +85,9 @@ func New(ctx context.Context, cfg config.Config, version string) (*Application, backupRunnerRegistry := backup.NewRegistry(backup.NewFileRunner(), backup.NewSQLiteRunner(), backup.NewMySQLRunner(nil), backup.NewPostgreSQLRunner(nil), backup.NewSAPHANARunner(nil)) logHub := backup.NewLogHub() retentionService := backupretention.NewService(backupRecordRepo) - notifyRegistry := notify.NewRegistry(notify.NewEmailNotifier(), notify.NewWebhookNotifier(), notify.NewTelegramNotifier()) + notifyRegistry := notify.NewRegistry(notify.NewEmailNotifier(), notify.NewWebhookNotifier(), notify.NewTelegramNotifier(), notify.NewSMSWebhookNotifier()) notificationService := service.NewNotificationService(notificationRepo, notifyRegistry, configCipher) + authService.SetNotificationService(notificationService) // 初始化 rclone 传输配置(重试 + 带宽限制) rcloneCtx := storageRclone.ConfiguredContext(ctx, storageRclone.TransferConfig{ LowLevelRetries: cfg.Backup.Retries, @@ -245,32 +246,32 @@ func New(ctx context.Context, cfg config.Config, version string) (*Application, metricsCollector.Start(ctx) router := aphttp.NewRouter(aphttp.RouterDependencies{ - Context: ctx, - Config: cfg, - Version: version, - Logger: appLogger, - AuthService: authService, - SystemService: systemService, - StorageTargetService: storageTargetService, - BackupTaskService: backupTaskService, - BackupExecutionService: backupExecutionService, - BackupRecordService: backupRecordService, - RestoreService: restoreService, - VerificationService: verificationService, - ReplicationService: replicationService, - TaskTemplateService: taskTemplateService, - TaskExportService: taskExportService, - SearchService: searchService, - EventBroadcaster: eventBroadcaster, - UserService: userService, - ApiKeyService: apiKeyService, - NotificationService: notificationService, - DashboardService: dashboardService, - SettingsService: settingsService, + Context: ctx, + Config: cfg, + Version: version, + Logger: appLogger, + AuthService: authService, + SystemService: systemService, + StorageTargetService: storageTargetService, + BackupTaskService: backupTaskService, + BackupExecutionService: backupExecutionService, + BackupRecordService: backupRecordService, + RestoreService: restoreService, + VerificationService: verificationService, + ReplicationService: replicationService, + TaskTemplateService: taskTemplateService, + TaskExportService: taskExportService, + SearchService: searchService, + EventBroadcaster: eventBroadcaster, + UserService: userService, + ApiKeyService: apiKeyService, + NotificationService: notificationService, + DashboardService: dashboardService, + SettingsService: settingsService, NodeService: nodeService, AgentService: agentService, DatabaseDiscoveryService: databaseDiscoveryService, - AuditService: auditService, + AuditService: auditService, JWTManager: jwtManager, UserRepository: userRepo, SystemConfigRepo: systemConfigRepo, diff --git a/server/internal/http/auth_handler.go b/server/internal/http/auth_handler.go index 6b25b73..3c49e86 100644 --- a/server/internal/http/auth_handler.go +++ b/server/internal/http/auth_handler.go @@ -1,6 +1,9 @@ package http import ( + "net" + "strings" + "backupx/server/internal/apperror" "backupx/server/internal/service" "backupx/server/pkg/response" @@ -86,6 +89,273 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) { response.Success(c, gin.H{"changed": true}) } +func (h *AuthHandler) PrepareTwoFactor(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.TwoFactorSetupInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err)) + return + } + payload, err := h.authService.PrepareTwoFactor(c.Request.Context(), subject, input) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, payload) +} + +func (h *AuthHandler) EnableTwoFactor(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.EnableTwoFactorInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err)) + return + } + user, err := h.authService.EnableTwoFactor(c.Request.Context(), subject, input) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, user) +} + +func (h *AuthHandler) DisableTwoFactor(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.DisableTwoFactorInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err)) + return + } + user, err := h.authService.DisableTwoFactor(c.Request.Context(), subject, input) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, user) +} + +func (h *AuthHandler) RegenerateRecoveryCodes(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.RegenerateRecoveryCodesInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_2FA_INVALID", "参数不合法", err)) + return + } + payload, err := h.authService.RegenerateRecoveryCodes(c.Request.Context(), subject, input) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, payload) +} + +func (h *AuthHandler) ConfigureOTP(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.OTPConfigInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_OTP_INVALID", "参数不合法", err)) + return + } + user, err := h.authService.ConfigureOutOfBandOTP(c.Request.Context(), subject, input) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, user) +} + +func (h *AuthHandler) SendLoginOTP(c *gin.Context) { + var input service.LoginOTPInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_OTP_INVALID", "参数不合法", err)) + return + } + if err := h.authService.SendLoginOTP(c.Request.Context(), input, ClientKey(c)); err != nil { + response.Error(c, err) + return + } + response.Success(c, gin.H{"sent": true}) +} + +func (h *AuthHandler) BeginWebAuthnRegistration(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.WebAuthnRegistrationOptionsInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err)) + return + } + options, err := h.authService.BeginWebAuthnRegistration(c.Request.Context(), subject, input, webAuthnRequestContext(c)) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, options) +} + +func (h *AuthHandler) FinishWebAuthnRegistration(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.WebAuthnRegistrationFinishInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err)) + return + } + user, err := h.authService.FinishWebAuthnRegistration(c.Request.Context(), subject, input) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, user) +} + +func (h *AuthHandler) BeginWebAuthnLogin(c *gin.Context) { + var input service.WebAuthnLoginOptionsInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err)) + return + } + options, err := h.authService.BeginWebAuthnLogin(c.Request.Context(), input, webAuthnRequestContext(c), ClientKey(c)) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, options) +} + +func (h *AuthHandler) ListWebAuthnCredentials(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + items, err := h.authService.ListWebAuthnCredentials(c.Request.Context(), subject) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, items) +} + +func (h *AuthHandler) DeleteWebAuthnCredential(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.WebAuthnCredentialDeleteInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_WEBAUTHN_INVALID", "参数不合法", err)) + return + } + user, err := h.authService.DeleteWebAuthnCredential(c.Request.Context(), subject, c.Param("id"), input) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, user) +} + +func (h *AuthHandler) ListTrustedDevices(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + items, err := h.authService.ListTrustedDevices(c.Request.Context(), subject) + if err != nil { + response.Error(c, err) + return + } + response.Success(c, items) +} + +func (h *AuthHandler) RevokeTrustedDevice(c *gin.Context) { + subjectValue, _ := c.Get(contextUserSubjectKey) + subject, err := service.SubjectFromContextValue(subjectValue) + if err != nil { + response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err)) + return + } + var input service.TrustedDeviceRevokeInput + if err := c.ShouldBindJSON(&input); err != nil { + response.Error(c, apperror.BadRequest("AUTH_TRUSTED_DEVICE_INVALID", "参数不合法", err)) + return + } + if err := h.authService.RevokeTrustedDevice(c.Request.Context(), subject, c.Param("id"), input); err != nil { + response.Error(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + func (h *AuthHandler) Logout(c *gin.Context) { response.Success(c, gin.H{"loggedOut": true}) } + +func webAuthnRequestContext(c *gin.Context) service.WebAuthnRequestContext { + host := firstForwardedValue(c.Request.Host) + if forwardedHost := firstForwardedValue(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" { + host = forwardedHost + } + rpID := host + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + rpID = parsedHost + } + scheme := "http" + if c.Request.TLS != nil { + scheme = "https" + } + if forwardedProto := firstForwardedValue(c.GetHeader("X-Forwarded-Proto")); forwardedProto != "" { + scheme = forwardedProto + } + origin := strings.TrimSpace(c.GetHeader("Origin")) + if origin == "" { + origin = scheme + "://" + host + } + return service.WebAuthnRequestContext{RPID: rpID, Origin: origin} +} + +func firstForwardedValue(value string) string { + parts := strings.Split(value, ",") + if len(parts) == 0 { + return "" + } + return strings.TrimSpace(parts[0]) +} diff --git a/server/internal/http/install_flow_test.go b/server/internal/http/install_flow_test.go index c14786a..ecf39c3 100644 --- a/server/internal/http/install_flow_test.go +++ b/server/internal/http/install_flow_test.go @@ -19,6 +19,7 @@ import ( "backupx/server/internal/repository" "backupx/server/internal/security" "backupx/server/internal/service" + "backupx/server/internal/storage/codec" ) // setupInstallFlowRouter 构造一个 Node + Agent + InstallToken 全量依赖的 router, @@ -40,6 +41,13 @@ func setupInstallFlowRouter(t *testing.T) (http.Handler, string) { if err != nil { t.Fatalf("db: %v", err) } + sqlDB, err := db.DB() + if err != nil { + t.Fatalf("sql db: %v", err) + } + t.Cleanup(func() { + _ = sqlDB.Close() + }) userRepo := repository.NewUserRepository(db) systemConfigRepo := repository.NewSystemConfigRepository(db) @@ -48,7 +56,7 @@ func setupInstallFlowRouter(t *testing.T) (http.Handler, string) { t.Fatalf("security: %v", err) } jwtMgr := security.NewJWTManager(resolved.JWTSecret, time.Hour) - authSvc := service.NewAuthService(userRepo, systemConfigRepo, jwtMgr, security.NewLoginRateLimiter(5, time.Minute)) + authSvc := service.NewAuthService(userRepo, systemConfigRepo, jwtMgr, security.NewLoginRateLimiter(5, time.Minute), codec.NewConfigCipher(resolved.EncryptionKey)) systemSvc := service.NewSystemService(cfg, "test", time.Now().UTC()) nodeRepo := repository.NewNodeRepository(db) diff --git a/server/internal/http/router.go b/server/internal/http/router.go index a911b03..f1ca3a2 100644 --- a/server/internal/http/router.go +++ b/server/internal/http/router.go @@ -94,9 +94,22 @@ func NewRouter(deps RouterDependencies) *gin.Engine { auth.GET("/setup/status", authHandler.SetupStatus) auth.POST("/setup", authHandler.Setup) auth.POST("/login", authHandler.Login) + auth.POST("/otp/send", authHandler.SendLoginOTP) + auth.POST("/webauthn/login/options", authHandler.BeginWebAuthnLogin) auth.POST("/logout", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.Logout) auth.GET("/profile", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.Profile) auth.PUT("/password", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ChangePassword) + auth.POST("/2fa/setup", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.PrepareTwoFactor) + auth.POST("/2fa/enable", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.EnableTwoFactor) + auth.POST("/2fa/recovery-codes", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.RegenerateRecoveryCodes) + auth.DELETE("/2fa", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.DisableTwoFactor) + auth.PUT("/otp/config", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ConfigureOTP) + auth.POST("/webauthn/register/options", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.BeginWebAuthnRegistration) + auth.POST("/webauthn/register/finish", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.FinishWebAuthnRegistration) + auth.GET("/webauthn/credentials", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ListWebAuthnCredentials) + auth.DELETE("/webauthn/credentials/:id", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.DeleteWebAuthnCredential) + auth.GET("/trusted-devices", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.ListTrustedDevices) + auth.DELETE("/trusted-devices/:id", AuthMiddleware(deps.JWTManager, apiKeyAuth), authHandler.RevokeTrustedDevice) } system := api.Group("/system") @@ -229,6 +242,7 @@ func NewRouter(deps RouterDependencies) *gin.Engine { users.GET("", userHandler.List) users.POST("", userHandler.Create) users.PUT("/:id", userHandler.Update) + users.POST("/:id/2fa/reset", userHandler.ResetTwoFactor) users.DELETE("/:id", userHandler.Delete) } @@ -279,10 +293,10 @@ func NewRouter(deps RouterDependencies) *gin.Engine { nodes.PUT("/:id", RequireRole("admin"), nodeHandler.Update) nodes.DELETE("/:id", RequireRole("admin"), nodeHandler.Delete) nodes.GET("/:id/fs/list", nodeHandler.ListDirectory) - nodes.POST("/batch", RequireRole("admin"), nodeHandler.BatchCreate) - nodes.POST("/:id/install-tokens", RequireRole("admin"), nodeHandler.CreateInstallToken) - nodes.POST("/:id/rotate-token", RequireRole("admin"), nodeHandler.RotateToken) - nodes.GET("/:id/install-script-preview", RequireRole("admin"), nodeHandler.PreviewScript) + nodes.POST("/batch", RequireRole("admin"), nodeHandler.BatchCreate) + nodes.POST("/:id/install-tokens", RequireRole("admin"), nodeHandler.CreateInstallToken) + nodes.POST("/:id/rotate-token", RequireRole("admin"), nodeHandler.RotateToken) + nodes.GET("/:id/install-script-preview", RequireRole("admin"), nodeHandler.PreviewScript) // Agent API(token 认证,无需 JWT) if deps.AgentService != nil { diff --git a/server/internal/http/router_test.go b/server/internal/http/router_test.go index 3520817..20aece6 100644 --- a/server/internal/http/router_test.go +++ b/server/internal/http/router_test.go @@ -16,15 +16,16 @@ import ( "backupx/server/internal/repository" "backupx/server/internal/security" "backupx/server/internal/service" + "backupx/server/internal/storage/codec" ) func TestSetupLoginAndProfileFlow(t *testing.T) { tempDir := t.TempDir() cfg := config.Config{ - Server: config.ServerConfig{Host: "127.0.0.1", Port: 8340, Mode: "test"}, + Server: config.ServerConfig{Host: "127.0.0.1", Port: 8340, Mode: "test"}, Database: config.DatabaseConfig{Path: filepath.Join(tempDir, "backupx.db")}, Security: config.SecurityConfig{JWTExpire: "24h"}, - Log: config.LogConfig{Level: "error"}, + Log: config.LogConfig{Level: "error"}, } log, err := logger.New(cfg.Log) @@ -35,6 +36,13 @@ func TestSetupLoginAndProfileFlow(t *testing.T) { if err != nil { t.Fatalf("database.Open error: %v", err) } + sqlDB, err := db.DB() + if err != nil { + t.Fatalf("db.DB error: %v", err) + } + t.Cleanup(func() { + _ = sqlDB.Close() + }) userRepo := repository.NewUserRepository(db) systemConfigRepo := repository.NewSystemConfigRepository(db) @@ -43,7 +51,7 @@ func TestSetupLoginAndProfileFlow(t *testing.T) { t.Fatalf("ResolveSecurity error: %v", err) } jwtManager := security.NewJWTManager(resolved.JWTSecret, time.Hour) - authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, security.NewLoginRateLimiter(5, time.Minute)) + authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, security.NewLoginRateLimiter(5, time.Minute), codec.NewConfigCipher(resolved.EncryptionKey)) systemService := service.NewSystemService(cfg, "test", time.Now().UTC()) router := NewRouter(RouterDependencies{ @@ -58,8 +66,8 @@ func TestSetupLoginAndProfileFlow(t *testing.T) { }) setupBody, _ := json.Marshal(map[string]string{ - "username": "admin", - "password": "password-123", + "username": "admin", + "password": "password-123", "displayName": "Admin", }) setupRequest := httptest.NewRequest(http.MethodPost, "/api/auth/setup", bytes.NewBuffer(setupBody)) diff --git a/server/internal/http/user_handler.go b/server/internal/http/user_handler.go index 0dfc441..6963120 100644 --- a/server/internal/http/user_handler.go +++ b/server/internal/http/user_handler.go @@ -78,3 +78,18 @@ func (h *UserHandler) Delete(c *gin.Context) { fmt.Sprintf("删除用户 (ID: %d)", id)) response.Success(c, gin.H{"deleted": true}) } + +func (h *UserHandler) ResetTwoFactor(c *gin.Context) { + id, ok := parseUintParam(c, "id") + if !ok { + return + } + item, err := h.service.ResetTwoFactor(c.Request.Context(), id) + if err != nil { + response.Error(c, err) + return + } + recordAudit(c, h.auditService, "user", "reset_two_factor", "user", fmt.Sprintf("%d", id), item.Username, + fmt.Sprintf("重置用户 %s 的 MFA", item.Username)) + response.Success(c, item) +} diff --git a/server/internal/model/user.go b/server/internal/model/user.go index 69785a3..0d75096 100644 --- a/server/internal/model/user.go +++ b/server/internal/model/user.go @@ -22,12 +22,25 @@ func IsValidRole(role string) bool { } type User struct { - ID uint `gorm:"primaryKey" json:"id"` - Username string `gorm:"size:64;uniqueIndex;not null" json:"username"` - PasswordHash string `gorm:"column:password_hash;not null" json:"-"` - DisplayName string `gorm:"size:128;not null" json:"displayName"` - Email string `gorm:"size:255" json:"email"` - Role string `gorm:"size:32;not null;default:admin" json:"role"` + ID uint `gorm:"primaryKey" json:"id"` + Username string `gorm:"size:64;uniqueIndex;not null" json:"username"` + PasswordHash string `gorm:"column:password_hash;not null" json:"-"` + DisplayName string `gorm:"size:128;not null" json:"displayName"` + Email string `gorm:"size:255" json:"email"` + Phone string `gorm:"size:64" json:"phone"` + Role string `gorm:"size:32;not null;default:admin" json:"role"` + // TwoFactorSecretCiphertext 保存 TOTP 密钥密文;未启用时可作为待确认密钥。 + TwoFactorEnabled bool `gorm:"column:two_factor_enabled;not null;default:false" json:"twoFactorEnabled"` + TwoFactorSecretCiphertext string `gorm:"column:two_factor_secret_ciphertext;type:text" json:"-"` + // TwoFactorRecoveryCodeHashes 保存一次性恢复码哈希的 JSON 数组。 + TwoFactorRecoveryCodeHashes string `gorm:"column:two_factor_recovery_code_hashes;type:text" json:"-"` + // WebAuthnCredentials 保存通行密钥公钥元数据 JSON,不包含私钥或明文密钥。 + WebAuthnCredentials string `gorm:"column:webauthn_credentials;type:text" json:"-"` + WebAuthnChallengeCiphertext string `gorm:"column:webauthn_challenge_ciphertext;type:text" json:"-"` + TrustedDevices string `gorm:"column:trusted_devices;type:text" json:"-"` + EmailOTPEnabled bool `gorm:"column:email_otp_enabled;not null;default:false" json:"emailOtpEnabled"` + SMSOTPEnabled bool `gorm:"column:sms_otp_enabled;not null;default:false" json:"smsOtpEnabled"` + OutOfBandOTPCiphertext string `gorm:"column:out_of_band_otp_ciphertext;type:text" json:"-"` // Disabled 禁用账号(不删除保留审计)。禁用后无法登录。 Disabled bool `gorm:"not null;default:false" json:"disabled"` CreatedAt time.Time `json:"createdAt"` diff --git a/server/internal/notify/sms.go b/server/internal/notify/sms.go new file mode 100644 index 0000000..34c3700 --- /dev/null +++ b/server/internal/notify/sms.go @@ -0,0 +1,64 @@ +package notify + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +type SMSWebhookNotifier struct { + client *http.Client +} + +func NewSMSWebhookNotifier() *SMSWebhookNotifier { + return &SMSWebhookNotifier{client: &http.Client{Timeout: 10 * time.Second}} +} + +func (n *SMSWebhookNotifier) Type() string { return "sms" } +func (n *SMSWebhookNotifier) SensitiveFields() []string { return []string{"secret"} } + +func (n *SMSWebhookNotifier) Validate(config map[string]any) error { + if strings.TrimSpace(asString(config["url"])) == "" { + return fmt.Errorf("sms webhook url is required") + } + return nil +} + +func (n *SMSWebhookNotifier) Send(ctx context.Context, config map[string]any, message Message) error { + if err := n.Validate(config); err != nil { + return err + } + payload := map[string]any{ + "title": message.Title, + "body": message.Body, + "fields": message.Fields, + "phone": message.Fields["phone"], + "code": message.Fields["code"], + "purpose": message.Fields["purpose"], + } + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshal sms webhook payload: %w", err) + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(asString(config["url"])), bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("create sms webhook request: %w", err) + } + request.Header.Set("Content-Type", "application/json") + if secret := strings.TrimSpace(asString(config["secret"])); secret != "" { + request.Header.Set("X-BackupX-Secret", secret) + } + response, err := n.client.Do(request) + if err != nil { + return fmt.Errorf("send sms webhook request: %w", err) + } + defer response.Body.Close() + if response.StatusCode >= http.StatusBadRequest { + return fmt.Errorf("sms webhook response status: %s", response.Status) + } + return nil +} diff --git a/server/internal/security/otp_code.go b/server/internal/security/otp_code.go new file mode 100644 index 0000000..d26669c --- /dev/null +++ b/server/internal/security/otp_code.go @@ -0,0 +1,23 @@ +package security + +import ( + "crypto/rand" + "fmt" + "math/big" + "strings" +) + +const LoginOTPDigits = 6 + +func GenerateNumericOTP() (string, error) { + limit := big.NewInt(1_000_000) + value, err := rand.Int(rand.Reader, limit) + if err != nil { + return "", err + } + return fmt.Sprintf("%0*d", LoginOTPDigits, value.Int64()), nil +} + +func NormalizeNumericOTP(code string) string { + return strings.TrimSpace(code) +} diff --git a/server/internal/security/recovery_code.go b/server/internal/security/recovery_code.go new file mode 100644 index 0000000..895e4af --- /dev/null +++ b/server/internal/security/recovery_code.go @@ -0,0 +1,49 @@ +package security + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + "unicode" +) + +const RecoveryCodeCount = 10 + +func GenerateRecoveryCodes(count int) ([]string, error) { + if count <= 0 { + count = RecoveryCodeCount + } + codes := make([]string, 0, count) + for i := 0; i < count; i++ { + raw := make([]byte, 8) + if _, err := rand.Read(raw); err != nil { + return nil, fmt.Errorf("generate recovery code: %w", err) + } + encoded := strings.ToUpper(hex.EncodeToString(raw)) + codes = append(codes, encoded[0:4]+"-"+encoded[4:8]+"-"+encoded[8:12]+"-"+encoded[12:16]) + } + return codes, nil +} + +func NormalizeRecoveryCode(code string) string { + return strings.Map(func(r rune) rune { + if unicode.IsSpace(r) || r == '-' { + return -1 + } + return unicode.ToUpper(r) + }, strings.TrimSpace(code)) +} + +func IsRecoveryCodeCandidate(code string) bool { + normalized := NormalizeRecoveryCode(code) + if len(normalized) != 16 { + return false + } + for _, r := range normalized { + if !('0' <= r && r <= '9') && !('A' <= r && r <= 'F') { + return false + } + } + return true +} diff --git a/server/internal/security/totp.go b/server/internal/security/totp.go new file mode 100644 index 0000000..c21f53e --- /dev/null +++ b/server/internal/security/totp.go @@ -0,0 +1,68 @@ +package security + +import ( + "bytes" + "encoding/base64" + "image/png" + "strings" + "time" + "unicode" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" +) + +const TOTPIssuer = "BackupX" + +type TOTPEnrollment struct { + Secret string + OTPAuthURL string + QRCodeDataURL string +} + +func GenerateTOTPEnrollment(accountName string) (*TOTPEnrollment, error) { + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: TOTPIssuer, + AccountName: accountName, + Period: 30, + SecretSize: 20, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + if err != nil { + return nil, err + } + + image, err := key.Image(220, 220) + if err != nil { + return nil, err + } + var buf bytes.Buffer + if err := png.Encode(&buf, image); err != nil { + return nil, err + } + + return &TOTPEnrollment{ + Secret: key.Secret(), + OTPAuthURL: key.URL(), + QRCodeDataURL: "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()), + }, nil +} + +func ValidateTOTPCode(secret string, code string) (bool, error) { + return totp.ValidateCustom(NormalizeTOTPCode(code), secret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) +} + +func NormalizeTOTPCode(code string) string { + return strings.Map(func(r rune) rune { + if unicode.IsSpace(r) { + return -1 + } + return r + }, strings.TrimSpace(code)) +} diff --git a/server/internal/security/webauthn.go b/server/internal/security/webauthn.go new file mode 100644 index 0000000..d52eb46 --- /dev/null +++ b/server/internal/security/webauthn.go @@ -0,0 +1,447 @@ +package security + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "math/big" + "strings" +) + +const ( + WebAuthnChallengeBytes = 32 +) + +type WebAuthnCredentialMaterial struct { + CredentialID string + PublicKeyX string + PublicKeyY string + SignCount uint32 +} + +type WebAuthnParsedCredential struct { + CredentialID string + PublicKeyX string + PublicKeyY string + SignCount uint32 +} + +type WebAuthnClientData struct { + Type string `json:"type"` + Challenge string `json:"challenge"` + Origin string `json:"origin"` +} + +type WebAuthnAttestationResponse struct { + ClientDataJSON string `json:"clientDataJSON"` + AttestationObject string `json:"attestationObject"` +} + +type WebAuthnRegistrationResponse struct { + ID string `json:"id"` + RawID string `json:"rawId"` + Type string `json:"type"` + Response WebAuthnAttestationResponse `json:"response"` +} + +type WebAuthnAssertionResponse struct { + ClientDataJSON string `json:"clientDataJSON"` + AuthenticatorData string `json:"authenticatorData"` + Signature string `json:"signature"` + UserHandle string `json:"userHandle,omitempty"` +} + +type WebAuthnLoginAssertion struct { + ID string `json:"id"` + RawID string `json:"rawId"` + Type string `json:"type"` + Response WebAuthnAssertionResponse `json:"response"` +} + +func GenerateWebAuthnChallenge() (string, error) { + buf := make([]byte, WebAuthnChallengeBytes) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return EncodeBase64URL(buf), nil +} + +func EncodeBase64URL(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} + +func DecodeBase64URL(value string) ([]byte, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return nil, errors.New("empty base64url value") + } + if decoded, err := base64.RawURLEncoding.DecodeString(trimmed); err == nil { + return decoded, nil + } + return base64.URLEncoding.DecodeString(trimmed) +} + +func VerifyWebAuthnRegistration(input WebAuthnRegistrationResponse, challenge string, rpID string, expectedOrigin string) (*WebAuthnParsedCredential, error) { + if input.Type != "public-key" { + return nil, fmt.Errorf("unexpected credential type: %s", input.Type) + } + clientDataRaw, err := DecodeBase64URL(input.Response.ClientDataJSON) + if err != nil { + return nil, fmt.Errorf("decode client data: %w", err) + } + if err := validateWebAuthnClientData(clientDataRaw, "webauthn.create", challenge, expectedOrigin); err != nil { + return nil, err + } + attestationObject, err := DecodeBase64URL(input.Response.AttestationObject) + if err != nil { + return nil, fmt.Errorf("decode attestation object: %w", err) + } + parsed, err := parseCBORExact(attestationObject) + if err != nil { + return nil, fmt.Errorf("parse attestation object: %w", err) + } + attestationMap, ok := parsed.(map[any]any) + if !ok { + return nil, errors.New("attestation object is not a map") + } + authData, ok := attestationMap["authData"].([]byte) + if !ok { + return nil, errors.New("attestation authData is missing") + } + credential, err := parseAttestedCredentialData(authData, rpID) + if err != nil { + return nil, err + } + rawID := strings.TrimSpace(input.RawID) + if rawID == "" { + rawID = strings.TrimSpace(input.ID) + } + if rawID != "" && rawID != credential.CredentialID { + return nil, errors.New("credential raw id does not match attested credential id") + } + return credential, nil +} + +func VerifyWebAuthnAssertion(input WebAuthnLoginAssertion, challenge string, rpID string, expectedOrigin string, credential WebAuthnCredentialMaterial) (uint32, error) { + if input.Type != "public-key" { + return 0, fmt.Errorf("unexpected credential type: %s", input.Type) + } + rawID := strings.TrimSpace(input.RawID) + if rawID == "" { + rawID = strings.TrimSpace(input.ID) + } + if rawID != credential.CredentialID { + return 0, errors.New("credential id does not match") + } + clientDataRaw, err := DecodeBase64URL(input.Response.ClientDataJSON) + if err != nil { + return 0, fmt.Errorf("decode client data: %w", err) + } + if err := validateWebAuthnClientData(clientDataRaw, "webauthn.get", challenge, expectedOrigin); err != nil { + return 0, err + } + authData, err := DecodeBase64URL(input.Response.AuthenticatorData) + if err != nil { + return 0, fmt.Errorf("decode authenticator data: %w", err) + } + signature, err := DecodeBase64URL(input.Response.Signature) + if err != nil { + return 0, fmt.Errorf("decode signature: %w", err) + } + signCount, err := parseAssertionAuthenticatorData(authData, rpID, credential.SignCount) + if err != nil { + return 0, err + } + xBytes, err := DecodeBase64URL(credential.PublicKeyX) + if err != nil { + return 0, fmt.Errorf("decode public key x: %w", err) + } + yBytes, err := DecodeBase64URL(credential.PublicKeyY) + if err != nil { + return 0, fmt.Errorf("decode public key y: %w", err) + } + publicKey := ecdsa.PublicKey{Curve: elliptic.P256(), X: new(big.Int).SetBytes(xBytes), Y: new(big.Int).SetBytes(yBytes)} + if !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) { + return 0, errors.New("webauthn public key is not on P-256 curve") + } + clientDataHash := sha256.Sum256(clientDataRaw) + verifyData := append(append([]byte{}, authData...), clientDataHash[:]...) + digest := sha256.Sum256(verifyData) + if !ecdsa.VerifyASN1(&publicKey, digest[:], signature) { + return 0, errors.New("invalid webauthn signature") + } + return signCount, nil +} + +func validateWebAuthnClientData(raw []byte, expectedType string, challenge string, expectedOrigin string) error { + var clientData WebAuthnClientData + if err := json.Unmarshal(raw, &clientData); err != nil { + return fmt.Errorf("parse client data: %w", err) + } + if clientData.Type != expectedType { + return fmt.Errorf("unexpected webauthn client data type: %s", clientData.Type) + } + if clientData.Challenge != challenge { + return errors.New("webauthn challenge mismatch") + } + if expectedOrigin != "" && clientData.Origin != expectedOrigin { + return fmt.Errorf("webauthn origin mismatch: %s", clientData.Origin) + } + return nil +} + +func parseAttestedCredentialData(authData []byte, rpID string) (*WebAuthnParsedCredential, error) { + signCount, credentialData, err := parseAuthenticatorDataHeader(authData, rpID, true, 0) + if err != nil { + return nil, err + } + if len(credentialData) < 18 { + return nil, errors.New("attested credential data is too short") + } + offset := 16 + credentialIDLength := int(binary.BigEndian.Uint16(credentialData[offset : offset+2])) + offset += 2 + if credentialIDLength <= 0 || len(credentialData) < offset+credentialIDLength { + return nil, errors.New("invalid credential id length") + } + credentialID := credentialData[offset : offset+credentialIDLength] + offset += credentialIDLength + publicKeyRaw := credentialData[offset:] + publicKey, err := parseCBOR(publicKeyRaw) + if err != nil { + return nil, fmt.Errorf("parse credential public key: %w", err) + } + publicKeyMap, ok := publicKey.(map[any]any) + if !ok { + return nil, errors.New("credential public key is not a map") + } + kty, err := coseInt(publicKeyMap, 1) + if err != nil { + return nil, err + } + alg, err := coseInt(publicKeyMap, 3) + if err != nil { + return nil, err + } + crv, err := coseInt(publicKeyMap, -1) + if err != nil { + return nil, err + } + if kty != 2 || alg != -7 || crv != 1 { + return nil, fmt.Errorf("unsupported COSE key: kty=%d alg=%d crv=%d", kty, alg, crv) + } + x, err := coseBytes(publicKeyMap, -2) + if err != nil { + return nil, err + } + y, err := coseBytes(publicKeyMap, -3) + if err != nil { + return nil, err + } + if !elliptic.P256().IsOnCurve(new(big.Int).SetBytes(x), new(big.Int).SetBytes(y)) { + return nil, errors.New("credential public key is not on P-256 curve") + } + return &WebAuthnParsedCredential{ + CredentialID: EncodeBase64URL(credentialID), + PublicKeyX: EncodeBase64URL(x), + PublicKeyY: EncodeBase64URL(y), + SignCount: signCount, + }, nil +} + +func parseAssertionAuthenticatorData(authData []byte, rpID string, previousSignCount uint32) (uint32, error) { + signCount, _, err := parseAuthenticatorDataHeader(authData, rpID, false, previousSignCount) + if err != nil { + return 0, err + } + return signCount, nil +} + +func parseAuthenticatorDataHeader(authData []byte, rpID string, requireAttestedData bool, previousSignCount uint32) (uint32, []byte, error) { + if len(authData) < 37 { + return 0, nil, errors.New("authenticator data is too short") + } + expectedRPIDHash := sha256.Sum256([]byte(rpID)) + if string(authData[:32]) != string(expectedRPIDHash[:]) { + return 0, nil, errors.New("rp id hash mismatch") + } + flags := authData[32] + if flags&0x01 == 0 { + return 0, nil, errors.New("user presence flag is missing") + } + signCount := binary.BigEndian.Uint32(authData[33:37]) + if previousSignCount > 0 && signCount > 0 && signCount <= previousSignCount { + return 0, nil, errors.New("authenticator sign count did not increase") + } + if requireAttestedData && flags&0x40 == 0 { + return 0, nil, errors.New("attested credential data flag is missing") + } + return signCount, authData[37:], nil +} + +func coseInt(m map[any]any, key int64) (int64, error) { + value, ok := m[key] + if !ok { + return 0, fmt.Errorf("missing COSE key %d", key) + } + intValue, ok := value.(int64) + if !ok { + return 0, fmt.Errorf("invalid COSE key %d", key) + } + return intValue, nil +} + +func coseBytes(m map[any]any, key int64) ([]byte, error) { + value, ok := m[key] + if !ok { + return nil, fmt.Errorf("missing COSE key %d", key) + } + bytesValue, ok := value.([]byte) + if !ok || len(bytesValue) == 0 { + return nil, fmt.Errorf("invalid COSE key %d", key) + } + return bytesValue, nil +} + +func parseCBOR(data []byte) (any, error) { + reader := cborReader{data: data} + value, err := reader.read() + if err != nil { + return nil, err + } + return value, nil +} + +func parseCBORExact(data []byte) (any, error) { + reader := cborReader{data: data} + value, err := reader.read() + if err != nil { + return nil, err + } + if reader.pos != len(data) { + return nil, errors.New("trailing cbor data") + } + return value, nil +} + +type cborReader struct { + data []byte + pos int +} + +func (r *cborReader) read() (any, error) { + if r.pos >= len(r.data) { + return nil, errors.New("unexpected cbor eof") + } + initial := r.data[r.pos] + r.pos++ + major := initial >> 5 + additional := initial & 0x1f + length, err := r.readLength(additional) + if err != nil { + return nil, err + } + switch major { + case 0: + return int64(length), nil + case 1: + return -1 - int64(length), nil + case 2: + return r.readBytes(length) + case 3: + raw, err := r.readBytes(length) + if err != nil { + return nil, err + } + return string(raw), nil + case 4: + out := make([]any, 0, length) + for i := uint64(0); i < length; i++ { + item, err := r.read() + if err != nil { + return nil, err + } + out = append(out, item) + } + return out, nil + case 5: + out := make(map[any]any, length) + for i := uint64(0); i < length; i++ { + key, err := r.read() + if err != nil { + return nil, err + } + value, err := r.read() + if err != nil { + return nil, err + } + out[key] = value + } + return out, nil + case 7: + switch additional { + case 20: + return false, nil + case 21: + return true, nil + case 22, 23: + return nil, nil + default: + return nil, fmt.Errorf("unsupported cbor simple value: %d", additional) + } + default: + return nil, fmt.Errorf("unsupported cbor major type: %d", major) + } +} + +func (r *cborReader) readLength(additional byte) (uint64, error) { + switch { + case additional < 24: + return uint64(additional), nil + case additional == 24: + if r.pos+1 > len(r.data) { + return 0, errors.New("unexpected cbor eof") + } + value := r.data[r.pos] + r.pos++ + return uint64(value), nil + case additional == 25: + if r.pos+2 > len(r.data) { + return 0, errors.New("unexpected cbor eof") + } + value := binary.BigEndian.Uint16(r.data[r.pos : r.pos+2]) + r.pos += 2 + return uint64(value), nil + case additional == 26: + if r.pos+4 > len(r.data) { + return 0, errors.New("unexpected cbor eof") + } + value := binary.BigEndian.Uint32(r.data[r.pos : r.pos+4]) + r.pos += 4 + return uint64(value), nil + case additional == 27: + if r.pos+8 > len(r.data) { + return 0, errors.New("unexpected cbor eof") + } + value := binary.BigEndian.Uint64(r.data[r.pos : r.pos+8]) + r.pos += 8 + return value, nil + default: + return 0, fmt.Errorf("unsupported cbor additional info: %d", additional) + } +} + +func (r *cborReader) readBytes(length uint64) ([]byte, error) { + if length > uint64(len(r.data)-r.pos) { + return nil, errors.New("unexpected cbor eof") + } + out := r.data[r.pos : r.pos+int(length)] + r.pos += int(length) + return out, nil +} diff --git a/server/internal/service/auth_methods.go b/server/internal/service/auth_methods.go new file mode 100644 index 0000000..ed2c7fe --- /dev/null +++ b/server/internal/service/auth_methods.go @@ -0,0 +1,179 @@ +package service + +import ( + "encoding/json" + "strings" + "time" + + "backupx/server/internal/model" +) + +const ( + mfaChallengeTTL = 5 * time.Minute + trustedDeviceTTL = 30 * 24 * time.Hour + maxTrustedDeviceName = 128 + maxTrustedDevices = 10 +) + +type WebAuthnCredentialRecord struct { + ID string `json:"id"` + Name string `json:"name"` + CredentialID string `json:"credentialId"` + PublicKeyX string `json:"publicKeyX"` + PublicKeyY string `json:"publicKeyY"` + SignCount uint32 `json:"signCount"` + CreatedAt string `json:"createdAt"` + LastUsedAt string `json:"lastUsedAt,omitempty"` +} + +type WebAuthnCredentialOutput struct { + ID string `json:"id"` + Name string `json:"name"` + CreatedAt string `json:"createdAt"` + LastUsedAt string `json:"lastUsedAt,omitempty"` +} + +type webAuthnChallengeState struct { + Type string `json:"type"` + Challenge string `json:"challenge"` + RPID string `json:"rpId"` + Origin string `json:"origin"` + ExpiresAt time.Time `json:"expiresAt"` +} + +type TrustedDeviceRecord struct { + ID string `json:"id"` + Name string `json:"name"` + TokenHash string `json:"tokenHash"` + CreatedAt time.Time `json:"createdAt"` + LastUsedAt time.Time `json:"lastUsedAt"` + ExpiresAt time.Time `json:"expiresAt"` + LastIP string `json:"lastIp"` +} + +type TrustedDeviceOutput struct { + ID string `json:"id"` + Name string `json:"name"` + CreatedAt string `json:"createdAt"` + LastUsedAt string `json:"lastUsedAt"` + ExpiresAt string `json:"expiresAt"` + LastIP string `json:"lastIp"` +} + +type pendingOutOfBandOTP struct { + Channel string `json:"channel"` + CodeHash string `json:"codeHash"` + ExpiresAt time.Time `json:"expiresAt"` +} + +func userMFAEnabled(user *model.User) bool { + if user == nil { + return false + } + return user.TwoFactorEnabled || + strings.TrimSpace(user.WebAuthnCredentials) != "" || + user.EmailOTPEnabled || + user.SMSOTPEnabled +} + +func clearTrustedDevicesIfMFAOff(user *model.User) { + if user == nil || userMFAEnabled(user) { + return + } + user.TrustedDevices = "" + user.OutOfBandOTPCiphertext = "" + user.WebAuthnChallengeCiphertext = "" +} + +func parseWebAuthnCredentials(value string) ([]WebAuthnCredentialRecord, error) { + if strings.TrimSpace(value) == "" { + return nil, nil + } + var credentials []WebAuthnCredentialRecord + if err := json.Unmarshal([]byte(value), &credentials); err != nil { + return nil, err + } + return credentials, nil +} + +func encodeWebAuthnCredentials(credentials []WebAuthnCredentialRecord) (string, error) { + if len(credentials) == 0 { + return "", nil + } + encoded, err := json.Marshal(credentials) + if err != nil { + return "", err + } + return string(encoded), nil +} + +func webAuthnCredentialCount(user *model.User) int { + if user == nil { + return 0 + } + credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials) + if err != nil { + return 0 + } + return len(credentials) +} + +func parseTrustedDevices(value string) ([]TrustedDeviceRecord, error) { + if strings.TrimSpace(value) == "" { + return nil, nil + } + var devices []TrustedDeviceRecord + if err := json.Unmarshal([]byte(value), &devices); err != nil { + return nil, err + } + return devices, nil +} + +func encodeTrustedDevices(devices []TrustedDeviceRecord) (string, error) { + if len(devices) == 0 { + return "", nil + } + encoded, err := json.Marshal(devices) + if err != nil { + return "", err + } + return string(encoded), nil +} + +func trustedDeviceCount(user *model.User) int { + if user == nil { + return 0 + } + devices, err := parseTrustedDevices(user.TrustedDevices) + if err != nil { + return 0 + } + now := time.Now().UTC() + count := 0 + for _, device := range devices { + if device.ExpiresAt.After(now) { + count++ + } + } + return count +} + +func toWebAuthnCredentialOutput(record WebAuthnCredentialRecord) WebAuthnCredentialOutput { + return WebAuthnCredentialOutput{ + ID: record.ID, + Name: record.Name, + CreatedAt: record.CreatedAt, + LastUsedAt: record.LastUsedAt, + } +} + +func toTrustedDeviceOutput(record TrustedDeviceRecord) TrustedDeviceOutput { + return TrustedDeviceOutput{ + ID: record.ID, + Name: record.Name, + CreatedAt: record.CreatedAt.Format(time.RFC3339), + LastUsedAt: record.LastUsedAt.Format(time.RFC3339), + ExpiresAt: record.ExpiresAt.Format(time.RFC3339), + LastIP: record.LastIP, + } +} diff --git a/server/internal/service/auth_otp.go b/server/internal/service/auth_otp.go new file mode 100644 index 0000000..2c4e6cc --- /dev/null +++ b/server/internal/service/auth_otp.go @@ -0,0 +1,252 @@ +package service + +import ( + "context" + "fmt" + "strings" + "time" + + "backupx/server/internal/apperror" + "backupx/server/internal/model" + "backupx/server/internal/security" +) + +type OTPConfigInput struct { + CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"` + Channel string `json:"channel" binding:"required,oneof=email sms"` + Enabled bool `json:"enabled"` + Email string `json:"email" binding:"omitempty,max=255"` + Phone string `json:"phone" binding:"omitempty,max=64"` +} + +type LoginOTPInput struct { + Username string `json:"username" binding:"required,min=3,max=64"` + Password string `json:"password" binding:"required,min=8,max=128"` + Channel string `json:"channel" binding:"required,oneof=email sms"` +} + +func (s *AuthService) ConfigureOutOfBandOTP(ctx context.Context, subject string, input OTPConfigInput) (*UserOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil { + return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err) + } + channel := strings.TrimSpace(input.Channel) + previousEmail := strings.TrimSpace(user.Email) + previousPhone := strings.TrimSpace(user.Phone) + contactChanged := false + switch channel { + case "email": + email := strings.TrimSpace(input.Email) + if email != "" { + user.Email = email + } + contactChanged = previousEmail != strings.TrimSpace(user.Email) + if input.Enabled && strings.TrimSpace(user.Email) == "" { + return nil, apperror.BadRequest("AUTH_EMAIL_REQUIRED", "请先在用户资料中设置邮箱", nil) + } + user.EmailOTPEnabled = input.Enabled + case "sms": + phone := strings.TrimSpace(input.Phone) + if phone != "" { + user.Phone = phone + } + contactChanged = previousPhone != strings.TrimSpace(user.Phone) + if input.Enabled && strings.TrimSpace(user.Phone) == "" { + return nil, apperror.BadRequest("AUTH_PHONE_REQUIRED", "请先设置手机号", nil) + } + user.SMSOTPEnabled = input.Enabled + default: + return nil, apperror.BadRequest("AUTH_OTP_CHANNEL_INVALID", "验证码渠道不支持", nil) + } + if s.shouldClearPendingOTP(user, channel, contactChanged) { + user.OutOfBandOTPCiphertext = "" + } + clearTrustedDevicesIfMFAOff(user) + if err := s.users.Update(ctx, user); err != nil { + return nil, apperror.Internal("AUTH_OTP_CONFIG_FAILED", "无法更新 OTP 配置", err) + } + if s.auditService != nil { + action := "otp_disable" + if input.Enabled { + action = "otp_enable" + } + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: action, + TargetType: "otp", TargetID: channel, + Detail: fmt.Sprintf("%s %s OTP", map[bool]string{true: "启用", false: "关闭"}[input.Enabled], channel), + }) + } + return ToUserOutput(user), nil +} + +func (s *AuthService) SendLoginOTP(ctx context.Context, input LoginOTPInput, clientKey string) error { + user, err := s.verifyPasswordForMFAStart(ctx, input.Username, input.Password, clientKey) + if err != nil { + return err + } + channel := strings.TrimSpace(input.Channel) + if channel == "email" && !user.EmailOTPEnabled { + return apperror.BadRequest("AUTH_EMAIL_OTP_DISABLED", "当前账号未启用邮件验证码", nil) + } + if channel == "sms" && !user.SMSOTPEnabled { + return apperror.BadRequest("AUTH_SMS_OTP_DISABLED", "当前账号未启用短信验证码", nil) + } + code, err := security.GenerateNumericOTP() + if err != nil { + return apperror.Internal("AUTH_OTP_GENERATE_FAILED", "无法生成登录验证码", err) + } + hash, err := security.HashPassword(code) + if err != nil { + return apperror.Internal("AUTH_OTP_GENERATE_FAILED", "无法处理登录验证码", err) + } + pending := pendingOutOfBandOTP{ + Channel: channel, + CodeHash: hash, + ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL), + } + ciphertext, err := s.twoFactorCipher.EncryptJSON(pending) + if err != nil { + return apperror.Internal("AUTH_OTP_SAVE_FAILED", "无法保存登录验证码状态", err) + } + user.OutOfBandOTPCiphertext = ciphertext + if err := s.users.Update(ctx, user); err != nil { + return apperror.Internal("AUTH_OTP_SAVE_FAILED", "无法保存登录验证码状态", err) + } + if err := s.deliverLoginOTP(ctx, user, channel, code); err != nil { + user.OutOfBandOTPCiphertext = "" + if updateErr := s.users.Update(ctx, user); updateErr != nil { + return apperror.Internal("AUTH_OTP_SAVE_FAILED", "登录验证码发送失败,且无法回滚验证码状态", updateErr) + } + return apperror.BadRequest("AUTH_OTP_DELIVERY_FAILED", "登录验证码发送失败", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "otp_send", + TargetType: "otp", TargetID: channel, + Detail: "发送登录 OTP", ClientIP: clientKey, + }) + } + return nil +} + +func (s *AuthService) consumeOutOfBandOTP(ctx context.Context, user *model.User, code string, clientKey string) (bool, error) { + if strings.TrimSpace(user.OutOfBandOTPCiphertext) == "" { + return false, nil + } + var pending pendingOutOfBandOTP + if err := s.twoFactorCipher.DecryptJSON(user.OutOfBandOTPCiphertext, &pending); err != nil { + return false, apperror.Internal("AUTH_OTP_INVALID", "登录验证码状态异常", err) + } + if pending.ExpiresAt.Before(time.Now().UTC()) { + user.OutOfBandOTPCiphertext = "" + if err := s.users.Update(ctx, user); err != nil { + return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法更新登录验证码状态", err) + } + return false, nil + } + if !outOfBandOTPChannelEnabled(user, pending.Channel) { + user.OutOfBandOTPCiphertext = "" + if err := s.users.Update(ctx, user); err != nil { + return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法更新登录验证码状态", err) + } + return false, nil + } + if security.ComparePassword(pending.CodeHash, security.NormalizeNumericOTP(code)) != nil { + return false, nil + } + user.OutOfBandOTPCiphertext = "" + if err := s.users.Update(ctx, user); err != nil { + return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法使用登录验证码", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "otp_used", + TargetType: "otp", TargetID: pending.Channel, + Detail: "使用登录 OTP 完成登录", ClientIP: clientKey, + }) + } + return true, nil +} + +func (s *AuthService) deliverLoginOTP(ctx context.Context, user *model.User, channel string, code string) error { + if s.notificationService == nil { + return fmt.Errorf("notification service is not configured") + } + switch channel { + case "email": + email := strings.TrimSpace(user.Email) + if email == "" { + return fmt.Errorf("user email is empty") + } + return s.notificationService.SendAuthEmailOTP(ctx, email, code) + case "sms": + phone := strings.TrimSpace(user.Phone) + if phone == "" { + return fmt.Errorf("user phone is empty") + } + return s.notificationService.SendAuthSMSOTP(ctx, phone, code) + default: + return fmt.Errorf("unsupported otp channel: %s", channel) + } +} + +func (s *AuthService) verifyPasswordForMFAStart(ctx context.Context, username string, password string, clientKey string) (*model.User, error) { + if clientKey == "" { + clientKey = "unknown" + } + if !s.rateLimiter.Allow(clientKey) { + return nil, apperror.TooManyRequests("AUTH_RATE_LIMITED", "登录尝试过于频繁,请稍后再试", nil) + } + user, err := s.users.FindByUsername(ctx, strings.TrimSpace(username)) + if err != nil { + return nil, apperror.Internal("AUTH_LOOKUP_FAILED", "无法执行登录校验", err) + } + if user == nil || user.Disabled { + return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", nil) + } + if err := security.ComparePassword(user.PasswordHash, password); err != nil { + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "login_failed", + Detail: "密码错误", ClientIP: clientKey, + }) + } + return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", err) + } + if !userMFAEnabled(user) { + return nil, apperror.BadRequest("AUTH_MFA_NOT_ENABLED", "当前账号未启用多因素验证", nil) + } + return user, nil +} + +func outOfBandOTPChannelEnabled(user *model.User, channel string) bool { + switch channel { + case "email": + return user.EmailOTPEnabled + case "sms": + return user.SMSOTPEnabled + default: + return false + } +} + +func (s *AuthService) shouldClearPendingOTP(user *model.User, changedChannel string, contactChanged bool) bool { + if !user.EmailOTPEnabled && !user.SMSOTPEnabled { + return true + } + if strings.TrimSpace(user.OutOfBandOTPCiphertext) == "" { + return false + } + var pending pendingOutOfBandOTP + if err := s.twoFactorCipher.DecryptJSON(user.OutOfBandOTPCiphertext, &pending); err != nil { + return true + } + return pending.Channel == changedChannel && (contactChanged || !outOfBandOTPChannelEnabled(user, changedChannel)) +} diff --git a/server/internal/service/auth_service.go b/server/internal/service/auth_service.go index fd2633a..0b3ce63 100644 --- a/server/internal/service/auth_service.go +++ b/server/internal/service/auth_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "encoding/json" "errors" "fmt" "strconv" @@ -11,6 +12,7 @@ import ( "backupx/server/internal/model" "backupx/server/internal/repository" "backupx/server/internal/security" + "backupx/server/internal/storage/codec" ) type SetupInput struct { @@ -20,28 +22,47 @@ type SetupInput struct { } type LoginInput struct { - Username string `json:"username" binding:"required,min=3,max=64"` - Password string `json:"password" binding:"required,min=8,max=128"` + Username string `json:"username" binding:"required,min=3,max=64"` + Password string `json:"password" binding:"required,min=8,max=128"` + TwoFactorCode string `json:"twoFactorCode" binding:"omitempty,min=6,max=32"` + WebAuthnAssertion *security.WebAuthnLoginAssertion `json:"webAuthnAssertion"` + TrustedDeviceToken string `json:"trustedDeviceToken"` + RememberDevice bool `json:"rememberDevice"` + TrustedDeviceName string `json:"trustedDeviceName" binding:"omitempty,max=128"` } type AuthPayload struct { - Token string `json:"token"` - User *UserOutput `json:"user"` + Token string `json:"token"` + User *UserOutput `json:"user"` + TrustedDeviceToken string `json:"trustedDeviceToken,omitempty"` + TrustedDevice *TrustedDeviceOutput `json:"trustedDevice,omitempty"` } type UserOutput struct { - ID uint `json:"id"` - Username string `json:"username"` - DisplayName string `json:"displayName"` - Role string `json:"role"` + ID uint `json:"id"` + Username string `json:"username"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + Phone string `json:"phone"` + Role string `json:"role"` + MFAEnabled bool `json:"mfaEnabled"` + TwoFactorEnabled bool `json:"twoFactorEnabled"` + TwoFactorRecoveryCodesRemaining int `json:"twoFactorRecoveryCodesRemaining"` + WebAuthnEnabled bool `json:"webAuthnEnabled"` + WebAuthnCredentialCount int `json:"webAuthnCredentialCount"` + TrustedDeviceCount int `json:"trustedDeviceCount"` + EmailOTPEnabled bool `json:"emailOtpEnabled"` + SMSOTPEnabled bool `json:"smsOtpEnabled"` } type AuthService struct { - users repository.UserRepository - configs repository.SystemConfigRepository - jwtManager *security.JWTManager - rateLimiter *security.LoginRateLimiter - auditService *AuditService + users repository.UserRepository + configs repository.SystemConfigRepository + jwtManager *security.JWTManager + rateLimiter *security.LoginRateLimiter + twoFactorCipher *codec.ConfigCipher + auditService *AuditService + notificationService *NotificationService } func NewAuthService( @@ -49,14 +70,25 @@ func NewAuthService( configs repository.SystemConfigRepository, jwtManager *security.JWTManager, rateLimiter *security.LoginRateLimiter, + twoFactorCipher *codec.ConfigCipher, ) *AuthService { - return &AuthService{users: users, configs: configs, jwtManager: jwtManager, rateLimiter: rateLimiter} + return &AuthService{ + users: users, + configs: configs, + jwtManager: jwtManager, + rateLimiter: rateLimiter, + twoFactorCipher: twoFactorCipher, + } } func (s *AuthService) SetAuditService(auditService *AuditService) { s.auditService = auditService } +func (s *AuthService) SetNotificationService(notificationService *NotificationService) { + s.notificationService = notificationService +} + func (s *AuthService) SetupStatus(ctx context.Context) (bool, error) { count, err := s.users.Count(ctx) if err != nil { @@ -130,7 +162,7 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str if s.auditService != nil { s.auditService.Record(AuditEntry{ Category: "auth", Action: "login_failed", - Detail: fmt.Sprintf("用户名不存在: %s", strings.TrimSpace(input.Username)), + Detail: fmt.Sprintf("用户名不存在: %s", strings.TrimSpace(input.Username)), ClientIP: clientKey, }) } @@ -156,6 +188,20 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str } return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", err) } + mfaRequired := userMFAEnabled(user) + trustedDeviceUsed := false + if mfaRequired { + trusted, err := s.verifyTrustedDevice(ctx, user, input.TrustedDeviceToken, clientKey) + if err != nil { + return nil, err + } + trustedDeviceUsed = trusted + if !trusted { + if err := s.verifyLoginMFA(ctx, user, input, clientKey); err != nil { + return nil, err + } + } + } s.rateLimiter.Reset(clientKey) token, err := s.jwtManager.Generate(user) @@ -163,6 +209,16 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str return nil, apperror.Internal("AUTH_TOKEN_FAILED", "无法生成访问令牌", err) } + payload := &AuthPayload{Token: token, User: ToUserOutput(user)} + if mfaRequired && !trustedDeviceUsed && input.RememberDevice { + deviceToken, device, err := s.issueTrustedDevice(ctx, user, input.TrustedDeviceName, clientKey) + if err != nil { + return nil, err + } + payload.TrustedDeviceToken = deviceToken + payload.TrustedDevice = device + } + if s.auditService != nil { s.auditService.Record(AuditEntry{ UserID: user.ID, Username: user.Username, @@ -171,10 +227,72 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str }) } - return &AuthPayload{Token: token, User: ToUserOutput(user)}, nil + return payload, nil } -func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*UserOutput, error) { +func (s *AuthService) verifyLoginMFA(ctx context.Context, user *model.User, input LoginInput, clientKey string) error { + if input.WebAuthnAssertion != nil { + if err := s.VerifyWebAuthnLogin(ctx, user, *input.WebAuthnAssertion, clientKey); err != nil { + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "login_failed", + Detail: "通行密钥校验失败", ClientIP: clientKey, + }) + } + return err + } + return nil + } + code := strings.TrimSpace(input.TwoFactorCode) + if code == "" { + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "two_factor_required", + Detail: "登录需要多因素验证", ClientIP: clientKey, + }) + } + return apperror.Unauthorized("AUTH_2FA_REQUIRED", "请输入验证码、恢复码或使用通行密钥", nil) + } + if user.TwoFactorEnabled { + secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext) + if err != nil { + return apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err) + } + ok, err := security.ValidateTOTPCode(secret, code) + if err == nil && ok { + return nil + } + if consumed, err := s.consumeRecoveryCode(ctx, user, code); err != nil { + return err + } else if consumed { + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "two_factor_recovery_code_used", + Detail: "使用恢复码完成登录", ClientIP: clientKey, + }) + } + return nil + } + } + if consumed, err := s.consumeOutOfBandOTP(ctx, user, code, clientKey); err != nil { + return err + } else if consumed { + return nil + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "login_failed", + Detail: "多因素验证码错误", ClientIP: clientKey, + }) + } + return apperror.Unauthorized("AUTH_2FA_INVALID", "验证码、恢复码或通行密钥错误", nil) +} + +func (s *AuthService) userBySubject(ctx context.Context, subject string) (*model.User, error) { userID, err := strconv.ParseUint(subject, 10, 64) if err != nil { return nil, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err) @@ -186,6 +304,14 @@ func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*User if user == nil { return nil, apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found")) } + return user, nil +} + +func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*UserOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } return ToUserOutput(user), nil } @@ -195,16 +321,9 @@ type ChangePasswordInput struct { } func (s *AuthService) ChangePassword(ctx context.Context, subject string, input ChangePasswordInput) error { - userID, err := strconv.ParseUint(subject, 10, 64) + user, err := s.userBySubject(ctx, subject) if err != nil { - return apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err) - } - user, err := s.users.FindByID(ctx, uint(userID)) - if err != nil { - return apperror.Internal("AUTH_LOOKUP_FAILED", "无法获取当前用户", err) - } - if user == nil { - return apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found")) + return err } if err := security.ComparePassword(user.PasswordHash, input.OldPassword); err != nil { return apperror.BadRequest("AUTH_WRONG_PASSWORD", "旧密码不正确", err) @@ -214,6 +333,9 @@ func (s *AuthService) ChangePassword(ctx context.Context, subject string, input return apperror.Internal("AUTH_HASH_FAILED", "无法处理密码", err) } user.PasswordHash = hash + user.TrustedDevices = "" + user.OutOfBandOTPCiphertext = "" + user.WebAuthnChallengeCiphertext = "" if err := s.users.Update(ctx, user); err != nil { return apperror.Internal("AUTH_UPDATE_FAILED", "密码修改失败", err) } @@ -229,15 +351,338 @@ func (s *AuthService) ChangePassword(ctx context.Context, subject string, input return nil } +type TwoFactorSetupInput struct { + CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"` +} + +type TwoFactorSetupOutput struct { + Secret string `json:"secret"` + OTPAuthURL string `json:"otpAuthUrl"` + QRCodeDataURL string `json:"qrCodeDataUrl"` + TwoFactorEnabled bool `json:"twoFactorEnabled"` + TwoFactorConfirmed bool `json:"twoFactorConfirmed"` +} + +type EnableTwoFactorInput struct { + Code string `json:"code" binding:"required,min=6,max=10"` +} + +type EnableTwoFactorOutput struct { + User *UserOutput `json:"user"` + RecoveryCodes []string `json:"recoveryCodes"` +} + +type DisableTwoFactorInput struct { + CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"` + Code string `json:"code" binding:"required,min=6,max=32"` +} + +type RegenerateRecoveryCodesInput struct { + CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"` + Code string `json:"code" binding:"required,min=6,max=10"` +} + +type RecoveryCodesOutput struct { + User *UserOutput `json:"user"` + RecoveryCodes []string `json:"recoveryCodes"` +} + +func (s *AuthService) PrepareTwoFactor(ctx context.Context, subject string, input TwoFactorSetupInput) (*TwoFactorSetupOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + if user.TwoFactorEnabled { + return nil, apperror.Conflict("AUTH_2FA_ALREADY_ENABLED", "TOTP 已启用", nil) + } + if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil { + return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err) + } + + enrollment, err := security.GenerateTOTPEnrollment(user.Username) + if err != nil { + return nil, apperror.Internal("AUTH_2FA_SETUP_FAILED", "无法生成 TOTP 密钥", err) + } + ciphertext, err := s.encryptTwoFactorSecret(enrollment.Secret) + if err != nil { + return nil, apperror.Internal("AUTH_2FA_SAVE_FAILED", "无法保存 TOTP 密钥", err) + } + user.TwoFactorSecretCiphertext = ciphertext + user.TwoFactorEnabled = false + if err := s.users.Update(ctx, user); err != nil { + return nil, apperror.Internal("AUTH_2FA_SAVE_FAILED", "无法保存 TOTP 密钥", err) + } + + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "two_factor_setup", + TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username, + Detail: "生成 TOTP 密钥", + }) + } + + return &TwoFactorSetupOutput{ + Secret: enrollment.Secret, + OTPAuthURL: enrollment.OTPAuthURL, + QRCodeDataURL: enrollment.QRCodeDataURL, + TwoFactorEnabled: false, + TwoFactorConfirmed: false, + }, nil +} + +func (s *AuthService) EnableTwoFactor(ctx context.Context, subject string, input EnableTwoFactorInput) (*EnableTwoFactorOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + if user.TwoFactorEnabled { + return nil, apperror.Conflict("AUTH_2FA_ALREADY_ENABLED", "TOTP 已启用", nil) + } + if strings.TrimSpace(user.TwoFactorSecretCiphertext) == "" { + return nil, apperror.BadRequest("AUTH_2FA_NOT_PREPARED", "请先生成 TOTP 密钥", nil) + } + secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext) + if err != nil { + return nil, apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err) + } + ok, err := security.ValidateTOTPCode(secret, input.Code) + if err != nil { + return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err) + } + if !ok { + return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil) + } + recoveryCodes, recoveryHashes, err := s.generateRecoveryCodeHashes() + if err != nil { + return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法生成恢复码", err) + } + + user.TwoFactorEnabled = true + user.TwoFactorRecoveryCodeHashes = recoveryHashes + if err := s.users.Update(ctx, user); err != nil { + return nil, apperror.Internal("AUTH_2FA_ENABLE_FAILED", "无法启用 TOTP", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "two_factor_enable", + TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username, + Detail: "启用 TOTP", + }) + } + return &EnableTwoFactorOutput{User: ToUserOutput(user), RecoveryCodes: recoveryCodes}, nil +} + +func (s *AuthService) DisableTwoFactor(ctx context.Context, subject string, input DisableTwoFactorInput) (*UserOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + if !user.TwoFactorEnabled { + return nil, apperror.BadRequest("AUTH_2FA_NOT_ENABLED", "TOTP 未启用", nil) + } + if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil { + return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err) + } + secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext) + if err != nil { + return nil, apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err) + } + ok, err := security.ValidateTOTPCode(secret, input.Code) + if err != nil { + return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err) + } + if !ok { + return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil) + } + + user.TwoFactorEnabled = false + user.TwoFactorSecretCiphertext = "" + user.TwoFactorRecoveryCodeHashes = "" + clearTrustedDevicesIfMFAOff(user) + if err := s.users.Update(ctx, user); err != nil { + return nil, apperror.Internal("AUTH_2FA_DISABLE_FAILED", "无法关闭 TOTP", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "two_factor_disable", + TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username, + Detail: "关闭 TOTP", + }) + } + return ToUserOutput(user), nil +} + +func (s *AuthService) verifyCurrentTOTP(user *model.User, code string) error { + secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext) + if err != nil { + return apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err) + } + ok, err := security.ValidateTOTPCode(secret, code) + if err != nil { + return apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err) + } + if !ok { + return apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil) + } + return nil +} + +func (s *AuthService) RegenerateRecoveryCodes(ctx context.Context, subject string, input RegenerateRecoveryCodesInput) (*RecoveryCodesOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + if !user.TwoFactorEnabled { + return nil, apperror.BadRequest("AUTH_2FA_NOT_ENABLED", "TOTP 未启用", nil) + } + if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil { + return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err) + } + if err := s.verifyCurrentTOTP(user, input.Code); err != nil { + return nil, err + } + recoveryCodes, recoveryHashes, err := s.generateRecoveryCodeHashes() + if err != nil { + return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法生成恢复码", err) + } + user.TwoFactorRecoveryCodeHashes = recoveryHashes + if err := s.users.Update(ctx, user); err != nil { + return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法更新恢复码", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "two_factor_recovery_codes_regenerate", + TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username, + Detail: "重新生成 TOTP 恢复码", + }) + } + return &RecoveryCodesOutput{User: ToUserOutput(user), RecoveryCodes: recoveryCodes}, nil +} + +func (s *AuthService) generateRecoveryCodeHashes() ([]string, string, error) { + codes, err := security.GenerateRecoveryCodes(security.RecoveryCodeCount) + if err != nil { + return nil, "", err + } + hashes := make([]string, 0, len(codes)) + for _, code := range codes { + hash, err := security.HashPassword(security.NormalizeRecoveryCode(code)) + if err != nil { + return nil, "", err + } + hashes = append(hashes, hash) + } + encoded, err := encodeRecoveryCodeHashes(hashes) + if err != nil { + return nil, "", err + } + return codes, encoded, nil +} + +func (s *AuthService) consumeRecoveryCode(ctx context.Context, user *model.User, code string) (bool, error) { + if !security.IsRecoveryCodeCandidate(code) { + return false, nil + } + hashes, err := parseRecoveryCodeHashes(user.TwoFactorRecoveryCodeHashes) + if err != nil { + return false, apperror.Internal("AUTH_2FA_RECOVERY_INVALID", "恢复码配置异常", err) + } + if len(hashes) == 0 { + return false, nil + } + normalized := security.NormalizeRecoveryCode(code) + for i, hash := range hashes { + if security.ComparePassword(hash, normalized) != nil { + continue + } + hashes = append(hashes[:i], hashes[i+1:]...) + encoded, err := encodeRecoveryCodeHashes(hashes) + if err != nil { + return false, apperror.Internal("AUTH_2FA_RECOVERY_INVALID", "恢复码配置异常", err) + } + user.TwoFactorRecoveryCodeHashes = encoded + if err := s.users.Update(ctx, user); err != nil { + return false, apperror.Internal("AUTH_2FA_RECOVERY_CONSUME_FAILED", "无法使用恢复码", err) + } + return true, nil + } + return false, nil +} + +func (s *AuthService) encryptTwoFactorSecret(secret string) (string, error) { + if s.twoFactorCipher == nil { + return "", errors.New("two-factor cipher is not configured") + } + return s.twoFactorCipher.Encrypt([]byte(strings.TrimSpace(secret))) +} + +func (s *AuthService) decryptTwoFactorSecret(ciphertext string) (string, error) { + if s.twoFactorCipher == nil { + return "", errors.New("two-factor cipher is not configured") + } + raw, err := s.twoFactorCipher.Decrypt(strings.TrimSpace(ciphertext)) + if err != nil { + return "", err + } + return strings.TrimSpace(string(raw)), nil +} + +func parseRecoveryCodeHashes(encoded string) ([]string, error) { + if strings.TrimSpace(encoded) == "" { + return nil, nil + } + var hashes []string + if err := json.Unmarshal([]byte(encoded), &hashes); err != nil { + return nil, err + } + return hashes, nil +} + +func encodeRecoveryCodeHashes(hashes []string) (string, error) { + if len(hashes) == 0 { + return "", nil + } + encoded, err := json.Marshal(hashes) + if err != nil { + return "", err + } + return string(encoded), nil +} + +func recoveryCodeRemainingCount(user *model.User) int { + if user == nil { + return 0 + } + hashes, err := parseRecoveryCodeHashes(user.TwoFactorRecoveryCodeHashes) + if err != nil { + return 0 + } + return len(hashes) +} + func ToUserOutput(user *model.User) *UserOutput { if user == nil { return nil } return &UserOutput{ - ID: user.ID, - Username: user.Username, - DisplayName: user.DisplayName, - Role: user.Role, + ID: user.ID, + Username: user.Username, + DisplayName: user.DisplayName, + Email: user.Email, + Phone: user.Phone, + Role: user.Role, + MFAEnabled: userMFAEnabled(user), + TwoFactorEnabled: user.TwoFactorEnabled, + TwoFactorRecoveryCodesRemaining: recoveryCodeRemainingCount(user), + WebAuthnEnabled: webAuthnCredentialCount(user) > 0, + WebAuthnCredentialCount: webAuthnCredentialCount(user), + TrustedDeviceCount: trustedDeviceCount(user), + EmailOTPEnabled: user.EmailOTPEnabled, + SMSOTPEnabled: user.SMSOTPEnabled, } } diff --git a/server/internal/service/auth_service_test.go b/server/internal/service/auth_service_test.go index 940072b..33048b9 100644 --- a/server/internal/service/auth_service_test.go +++ b/server/internal/service/auth_service_test.go @@ -5,8 +5,11 @@ import ( "testing" "time" + "backupx/server/internal/apperror" "backupx/server/internal/model" "backupx/server/internal/security" + "backupx/server/internal/storage/codec" + "github.com/pquerna/otp/totp" ) type fakeUserRepository struct { @@ -100,6 +103,7 @@ func TestAuthServiceSetupAndLogin(t *testing.T) { &fakeSystemConfigRepository{}, security.NewJWTManager("test-secret", time.Hour), security.NewLoginRateLimiter(5, time.Minute), + codec.NewConfigCipher("test-encryption-secret"), ) setupResult, err := service.Setup(context.Background(), SetupInput{ @@ -133,6 +137,7 @@ func newTestAuthService() (*AuthService, *fakeUserRepository) { &fakeSystemConfigRepository{}, security.NewJWTManager("test-secret", time.Hour), security.NewLoginRateLimiter(5, time.Minute), + codec.NewConfigCipher("test-encryption-secret"), ) return svc, users } @@ -188,3 +193,425 @@ func TestChangePasswordWrongOld(t *testing.T) { t.Fatalf("expected ChangePassword with wrong old password to fail") } } + +func TestAuthServiceLoginRequiresTwoFactorWhenEnabled(t *testing.T) { + svc, _ := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + + setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{ + CurrentPassword: "password-123", + }) + if err != nil { + t.Fatalf("PrepareTwoFactor: %v", err) + } + if setup.Secret == "" || setup.QRCodeDataURL == "" || setup.OTPAuthURL == "" { + t.Fatalf("expected populated 2FA enrollment, got %#v", setup) + } + + code, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode: %v", err) + } + enabledUser, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}) + if err != nil { + t.Fatalf("EnableTwoFactor: %v", err) + } + if !enabledUser.User.TwoFactorEnabled { + t.Fatalf("expected 2FA enabled") + } + if len(enabledUser.RecoveryCodes) != security.RecoveryCodeCount { + t.Fatalf("expected %d recovery codes, got %d", security.RecoveryCodeCount, len(enabledUser.RecoveryCodes)) + } + + _, err = svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", + }, "127.0.0.1") + if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_REQUIRED" { + t.Fatalf("expected AUTH_2FA_REQUIRED, got %v", err) + } + + loginCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode login: %v", err) + } + loginResult, err := svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: loginCode, + }, "127.0.0.1") + if err != nil { + t.Fatalf("Login with 2FA: %v", err) + } + if loginResult.Token == "" { + t.Fatalf("expected non-empty token") + } +} + +func TestAuthServiceDisableTwoFactor(t *testing.T) { + svc, _ := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{ + CurrentPassword: "password-123", + }) + if err != nil { + t.Fatalf("PrepareTwoFactor: %v", err) + } + code, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode: %v", err) + } + if _, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}); err != nil { + t.Fatalf("EnableTwoFactor: %v", err) + } + + disableCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode disable: %v", err) + } + user, err := svc.DisableTwoFactor(context.Background(), "1", DisableTwoFactorInput{ + CurrentPassword: "password-123", + Code: disableCode, + }) + if err != nil { + t.Fatalf("DisableTwoFactor: %v", err) + } + if user.TwoFactorEnabled { + t.Fatalf("expected 2FA disabled") + } + + loginResult, err := svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", + }, "127.0.0.1") + if err != nil { + t.Fatalf("Login after disable: %v", err) + } + if loginResult.Token == "" { + t.Fatalf("expected non-empty token") + } +} + +func TestAuthServiceRecoveryCodeLoginConsumesCode(t *testing.T) { + svc, _ := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{ + CurrentPassword: "password-123", + }) + if err != nil { + t.Fatalf("PrepareTwoFactor: %v", err) + } + code, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode: %v", err) + } + enabled, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}) + if err != nil { + t.Fatalf("EnableTwoFactor: %v", err) + } + recoveryCode := enabled.RecoveryCodes[0] + + loginResult, err := svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: recoveryCode, + }, "127.0.0.1") + if err != nil { + t.Fatalf("Login with recovery code: %v", err) + } + if loginResult.User.TwoFactorRecoveryCodesRemaining != security.RecoveryCodeCount-1 { + t.Fatalf("expected one recovery code consumed, got remaining=%d", loginResult.User.TwoFactorRecoveryCodesRemaining) + } + + _, err = svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: recoveryCode, + }, "127.0.0.1") + if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" { + t.Fatalf("expected consumed recovery code to fail, got %v", err) + } +} + +func TestAuthServiceRegenerateRecoveryCodesInvalidatesOldCodes(t *testing.T) { + svc, _ := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{ + CurrentPassword: "password-123", + }) + if err != nil { + t.Fatalf("PrepareTwoFactor: %v", err) + } + code, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode: %v", err) + } + enabled, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}) + if err != nil { + t.Fatalf("EnableTwoFactor: %v", err) + } + oldRecoveryCode := enabled.RecoveryCodes[0] + + regenerateCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode regenerate: %v", err) + } + regenerated, err := svc.RegenerateRecoveryCodes(context.Background(), "1", RegenerateRecoveryCodesInput{ + CurrentPassword: "password-123", + Code: regenerateCode, + }) + if err != nil { + t.Fatalf("RegenerateRecoveryCodes: %v", err) + } + if len(regenerated.RecoveryCodes) != security.RecoveryCodeCount { + t.Fatalf("expected %d recovery codes, got %d", security.RecoveryCodeCount, len(regenerated.RecoveryCodes)) + } + + _, err = svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: oldRecoveryCode, + }, "127.0.0.1") + if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" { + t.Fatalf("expected old recovery code to fail, got %v", err) + } +} + +func TestAuthServiceTrustedDeviceSkipsMFA(t *testing.T) { + svc, repo := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{ + CurrentPassword: "password-123", + }) + if err != nil { + t.Fatalf("PrepareTwoFactor: %v", err) + } + code, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode: %v", err) + } + if _, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}); err != nil { + t.Fatalf("EnableTwoFactor: %v", err) + } + loginCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode login: %v", err) + } + firstLogin, err := svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: loginCode, + RememberDevice: true, TrustedDeviceName: "test browser", + }, "127.0.0.1") + if err != nil { + t.Fatalf("Login with 2FA: %v", err) + } + if firstLogin.TrustedDeviceToken == "" || firstLogin.TrustedDevice == nil { + t.Fatalf("expected trusted device token") + } + secondLogin, err := svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TrustedDeviceToken: firstLogin.TrustedDeviceToken, + }, "127.0.0.1") + if err != nil { + t.Fatalf("Login with trusted device: %v", err) + } + if secondLogin.Token == "" { + t.Fatalf("expected token") + } + disableCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC()) + if err != nil { + t.Fatalf("GenerateCode disable: %v", err) + } + if _, err := svc.DisableTwoFactor(context.Background(), "1", DisableTwoFactorInput{ + CurrentPassword: "password-123", + Code: disableCode, + }); err != nil { + t.Fatalf("DisableTwoFactor: %v", err) + } + if repo.users[0].TrustedDevices != "" { + t.Fatalf("expected trusted devices cleared after disabling last MFA method") + } +} + +func TestAuthServiceOutOfBandOTPLoginConsumesCode(t *testing.T) { + svc, repo := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + user := repo.users[0] + user.Email = "admin@example.com" + user.EmailOTPEnabled = true + hash, err := security.HashPassword("123456") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{ + Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute), + }) + if err != nil { + t.Fatalf("EncryptJSON: %v", err) + } + user.OutOfBandOTPCiphertext = ciphertext + if err := repo.Update(context.Background(), user); err != nil { + t.Fatalf("Update: %v", err) + } + + loginResult, err := svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: "123456", + }, "127.0.0.1") + if err != nil { + t.Fatalf("Login with email OTP: %v", err) + } + if loginResult.Token == "" { + t.Fatalf("expected token") + } + if repo.users[0].OutOfBandOTPCiphertext != "" { + t.Fatalf("expected OTP to be consumed") + } + + _, err = svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: "123456", + }, "127.0.0.1") + if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" { + t.Fatalf("expected consumed OTP to fail, got %v", err) + } +} + +func TestAuthServiceMFAStartIsRateLimited(t *testing.T) { + svc, repo := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + repo.users[0].Email = "admin@example.com" + repo.users[0].EmailOTPEnabled = true + + for i := 0; i < 5; i++ { + _ = svc.SendLoginOTP(context.Background(), LoginOTPInput{ + Username: "admin", Password: "wrong-password", Channel: "email", + }, "127.0.0.1") + } + err = svc.SendLoginOTP(context.Background(), LoginOTPInput{ + Username: "admin", Password: "wrong-password", Channel: "email", + }, "127.0.0.1") + if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_RATE_LIMITED" { + t.Fatalf("expected AUTH_RATE_LIMITED, got %v", err) + } +} + +func TestAuthServiceDisabledOTPChannelCannotConsumePendingCode(t *testing.T) { + svc, repo := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + user := repo.users[0] + user.Email = "admin@example.com" + user.EmailOTPEnabled = false + user.SMSOTPEnabled = true + hash, err := security.HashPassword("123456") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{ + Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute), + }) + if err != nil { + t.Fatalf("EncryptJSON: %v", err) + } + user.OutOfBandOTPCiphertext = ciphertext + if err := repo.Update(context.Background(), user); err != nil { + t.Fatalf("Update: %v", err) + } + + _, err = svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", TwoFactorCode: "123456", + }, "127.0.0.1") + if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" { + t.Fatalf("expected disabled OTP channel to fail, got %v", err) + } + if repo.users[0].OutOfBandOTPCiphertext != "" { + t.Fatalf("expected disabled channel OTP to be cleared") + } +} + +func TestAuthServiceChangingOTPRecipientClearsPendingCode(t *testing.T) { + svc, repo := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + user := repo.users[0] + user.Email = "old@example.com" + user.EmailOTPEnabled = true + hash, err := security.HashPassword("123456") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{ + Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute), + }) + if err != nil { + t.Fatalf("EncryptJSON: %v", err) + } + user.OutOfBandOTPCiphertext = ciphertext + if err := repo.Update(context.Background(), user); err != nil { + t.Fatalf("Update: %v", err) + } + + updated, err := svc.ConfigureOutOfBandOTP(context.Background(), "1", OTPConfigInput{ + CurrentPassword: "password-123", + Channel: "email", + Enabled: true, + Email: "new@example.com", + }) + if err != nil { + t.Fatalf("ConfigureOutOfBandOTP: %v", err) + } + if updated.Email != "new@example.com" { + t.Fatalf("expected email updated, got %q", updated.Email) + } + if repo.users[0].OutOfBandOTPCiphertext != "" { + t.Fatalf("expected pending email OTP to be cleared after recipient change") + } +} + +func TestAuthServiceCorruptWebAuthnCredentialsStillRequireMFA(t *testing.T) { + svc, repo := newTestAuthService() + _, err := svc.Setup(context.Background(), SetupInput{ + Username: "admin", Password: "password-123", DisplayName: "Admin", + }) + if err != nil { + t.Fatalf("Setup: %v", err) + } + repo.users[0].WebAuthnCredentials = "{invalid-json" + + _, err = svc.Login(context.Background(), LoginInput{ + Username: "admin", Password: "password-123", + }, "127.0.0.1") + if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_REQUIRED" { + t.Fatalf("expected corrupt WebAuthn credentials to require MFA, got %v", err) + } +} diff --git a/server/internal/service/auth_trusted_device.go b/server/internal/service/auth_trusted_device.go new file mode 100644 index 0000000..310adc2 --- /dev/null +++ b/server/internal/service/auth_trusted_device.go @@ -0,0 +1,221 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + "time" + + "backupx/server/internal/apperror" + "backupx/server/internal/model" + "backupx/server/internal/security" +) + +func (s *AuthService) ListTrustedDevices(ctx context.Context, subject string) ([]TrustedDeviceOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + devices, err := parseTrustedDevices(user.TrustedDevices) + if err != nil { + return nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + now := time.Now().UTC() + output := make([]TrustedDeviceOutput, 0, len(devices)) + for _, device := range devices { + if device.ExpiresAt.Before(now) { + continue + } + output = append(output, toTrustedDeviceOutput(device)) + } + return output, nil +} + +type TrustedDeviceRevokeInput struct { + CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"` +} + +func (s *AuthService) RevokeTrustedDevice(ctx context.Context, subject string, id string, input TrustedDeviceRevokeInput) error { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return err + } + if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil { + return apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err) + } + devices, err := parseTrustedDevices(user.TrustedDevices) + if err != nil { + return apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + found := false + filtered := make([]TrustedDeviceRecord, 0, len(devices)) + for _, device := range devices { + if device.ID == strings.TrimSpace(id) { + found = true + } else { + filtered = append(filtered, device) + } + } + if !found { + return apperror.New(404, "AUTH_TRUSTED_DEVICE_NOT_FOUND", "可信设备不存在", nil) + } + encoded, err := encodeTrustedDevices(filtered) + if err != nil { + return apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + user.TrustedDevices = encoded + if err := s.users.Update(ctx, user); err != nil { + return apperror.Internal("AUTH_TRUSTED_DEVICE_REVOKE_FAILED", "无法移除可信设备", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "trusted_device_revoke", + TargetType: "trusted_device", TargetID: strings.TrimSpace(id), + Detail: "移除可信设备", + }) + } + return nil +} + +func (s *AuthService) verifyTrustedDevice(ctx context.Context, user *model.User, token string, clientKey string) (bool, error) { + token = strings.TrimSpace(token) + if token == "" { + return false, nil + } + devices, err := parseTrustedDevices(user.TrustedDevices) + if err != nil { + return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + now := time.Now().UTC() + hash := trustedDeviceTokenHash(token) + changed := false + for i := range devices { + device := &devices[i] + if device.ExpiresAt.Before(now) { + changed = true + continue + } + if subtle.ConstantTimeCompare([]byte(device.TokenHash), []byte(hash)) != 1 { + continue + } + device.LastUsedAt = now + device.LastIP = clientKey + changed = true + encoded, err := encodeTrustedDevices(filterActiveTrustedDevices(devices, now)) + if err != nil { + return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + user.TrustedDevices = encoded + if err := s.users.Update(ctx, user); err != nil { + return false, apperror.Internal("AUTH_TRUSTED_DEVICE_UPDATE_FAILED", "无法更新可信设备", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "trusted_device_used", + TargetType: "trusted_device", TargetID: device.ID, TargetName: device.Name, + Detail: "使用可信设备跳过多因素验证", ClientIP: clientKey, + }) + } + return true, nil + } + if changed { + encoded, err := encodeTrustedDevices(filterActiveTrustedDevices(devices, now)) + if err != nil { + return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + user.TrustedDevices = encoded + if err := s.users.Update(ctx, user); err != nil { + return false, apperror.Internal("AUTH_TRUSTED_DEVICE_UPDATE_FAILED", "无法更新可信设备", err) + } + } + return false, nil +} + +func (s *AuthService) issueTrustedDevice(ctx context.Context, user *model.User, name string, clientKey string) (string, *TrustedDeviceOutput, error) { + token, err := randomURLToken(32) + if err != nil { + return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法生成可信设备令牌", err) + } + id, err := randomURLToken(16) + if err != nil { + return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法生成可信设备编号", err) + } + now := time.Now().UTC() + deviceName := normalizeTrustedDeviceName(name) + device := TrustedDeviceRecord{ + ID: id, + Name: deviceName, + TokenHash: trustedDeviceTokenHash(token), + CreatedAt: now, + LastUsedAt: now, + ExpiresAt: now.Add(trustedDeviceTTL), + LastIP: clientKey, + } + devices, err := parseTrustedDevices(user.TrustedDevices) + if err != nil { + return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + devices = append(filterActiveTrustedDevices(devices, now), device) + if len(devices) > maxTrustedDevices { + devices = devices[len(devices)-maxTrustedDevices:] + } + encoded, err := encodeTrustedDevices(devices) + if err != nil { + return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err) + } + user.TrustedDevices = encoded + if err := s.users.Update(ctx, user); err != nil { + return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法保存可信设备", err) + } + output := toTrustedDeviceOutput(device) + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "trusted_device_create", + TargetType: "trusted_device", TargetID: device.ID, TargetName: device.Name, + Detail: fmt.Sprintf("添加可信设备,有效期至 %s", device.ExpiresAt.Format(time.RFC3339)), ClientIP: clientKey, + }) + } + return token, &output, nil +} + +func filterActiveTrustedDevices(devices []TrustedDeviceRecord, now time.Time) []TrustedDeviceRecord { + active := make([]TrustedDeviceRecord, 0, len(devices)) + for _, device := range devices { + if device.ExpiresAt.After(now) { + active = append(active, device) + } + } + return active +} + +func trustedDeviceTokenHash(token string) string { + sum := sha256.Sum256([]byte(strings.TrimSpace(token))) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} + +func randomURLToken(size int) (string, error) { + buf := make([]byte, size) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func normalizeTrustedDeviceName(name string) string { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return "当前设备" + } + if len([]rune(trimmed)) <= maxTrustedDeviceName { + return trimmed + } + runes := []rune(trimmed) + return string(runes[:maxTrustedDeviceName]) +} diff --git a/server/internal/service/auth_webauthn.go b/server/internal/service/auth_webauthn.go new file mode 100644 index 0000000..e2d827e --- /dev/null +++ b/server/internal/service/auth_webauthn.go @@ -0,0 +1,366 @@ +package service + +import ( + "context" + "fmt" + "strings" + "time" + + "backupx/server/internal/apperror" + "backupx/server/internal/model" + "backupx/server/internal/security" +) + +type WebAuthnRequestContext struct { + RPID string + Origin string +} + +type WebAuthnRegistrationOptionsInput struct { + CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"` +} + +type WebAuthnRegistrationFinishInput struct { + Name string `json:"name" binding:"omitempty,max=128"` + Credential security.WebAuthnRegistrationResponse `json:"credential" binding:"required"` +} + +type WebAuthnCredentialDeleteInput struct { + CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"` +} + +type WebAuthnLoginOptionsInput struct { + Username string `json:"username" binding:"required,min=3,max=64"` + Password string `json:"password" binding:"required,min=8,max=128"` +} + +type webAuthnPublicKeyCredentialParam struct { + Type string `json:"type"` + Alg int `json:"alg"` +} + +type webAuthnRelyingParty struct { + Name string `json:"name"` + ID string `json:"id"` +} + +type webAuthnUserEntity struct { + ID string `json:"id"` + Name string `json:"name"` + DisplayName string `json:"displayName"` +} + +type webAuthnCredentialDescriptor struct { + Type string `json:"type"` + ID string `json:"id"` +} + +type webAuthnAuthenticatorSelection struct { + UserVerification string `json:"userVerification"` +} + +type WebAuthnRegistrationOptions struct { + Challenge string `json:"challenge"` + RP webAuthnRelyingParty `json:"rp"` + User webAuthnUserEntity `json:"user"` + PubKeyCredParams []webAuthnPublicKeyCredentialParam `json:"pubKeyCredParams"` + Timeout int `json:"timeout"` + Attestation string `json:"attestation"` + AuthenticatorSelection webAuthnAuthenticatorSelection `json:"authenticatorSelection"` + ExcludeCredentials []webAuthnCredentialDescriptor `json:"excludeCredentials"` +} + +type WebAuthnLoginOptions struct { + Challenge string `json:"challenge"` + RPID string `json:"rpId"` + Timeout int `json:"timeout"` + UserVerification string `json:"userVerification"` + AllowCredentials []webAuthnCredentialDescriptor `json:"allowCredentials"` +} + +func (s *AuthService) BeginWebAuthnRegistration(ctx context.Context, subject string, input WebAuthnRegistrationOptionsInput, request WebAuthnRequestContext) (*WebAuthnRegistrationOptions, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil { + return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err) + } + credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err) + } + challenge, err := security.GenerateWebAuthnChallenge() + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法生成通行密钥挑战", err) + } + state := webAuthnChallengeState{ + Type: "register", + Challenge: challenge, + RPID: request.RPID, + Origin: request.Origin, + ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL), + } + if err := s.saveWebAuthnChallenge(ctx, user, state); err != nil { + return nil, err + } + exclude := make([]webAuthnCredentialDescriptor, 0, len(credentials)) + for _, credential := range credentials { + exclude = append(exclude, webAuthnCredentialDescriptor{Type: "public-key", ID: credential.CredentialID}) + } + return &WebAuthnRegistrationOptions{ + Challenge: challenge, + RP: webAuthnRelyingParty{Name: "BackupX", ID: request.RPID}, + User: webAuthnUserEntity{ + ID: security.EncodeBase64URL([]byte(fmt.Sprintf("%d", user.ID))), + Name: user.Username, + DisplayName: user.DisplayName, + }, + PubKeyCredParams: []webAuthnPublicKeyCredentialParam{ + {Type: "public-key", Alg: -7}, + }, + Timeout: int(mfaChallengeTTL / time.Millisecond), + Attestation: "none", + AuthenticatorSelection: webAuthnAuthenticatorSelection{UserVerification: "preferred"}, + ExcludeCredentials: exclude, + }, nil +} + +func (s *AuthService) FinishWebAuthnRegistration(ctx context.Context, subject string, input WebAuthnRegistrationFinishInput) (*UserOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + state, err := s.loadWebAuthnChallenge(user, "register") + if err != nil { + return nil, err + } + parsed, err := security.VerifyWebAuthnRegistration(input.Credential, state.Challenge, state.RPID, state.Origin) + if err != nil { + return nil, apperror.BadRequest("AUTH_WEBAUTHN_VERIFY_FAILED", "通行密钥注册校验失败", err) + } + credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err) + } + for _, credential := range credentials { + if credential.CredentialID == parsed.CredentialID { + return nil, apperror.Conflict("AUTH_WEBAUTHN_EXISTS", "该通行密钥已注册", nil) + } + } + id, err := randomURLToken(16) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法生成通行密钥编号", err) + } + now := time.Now().UTC().Format(time.RFC3339) + name := strings.TrimSpace(input.Name) + if name == "" { + name = "通行密钥" + } + credentials = append(credentials, WebAuthnCredentialRecord{ + ID: id, + Name: normalizeTrustedDeviceName(name), + CredentialID: parsed.CredentialID, + PublicKeyX: parsed.PublicKeyX, + PublicKeyY: parsed.PublicKeyY, + SignCount: parsed.SignCount, + CreatedAt: now, + }) + encoded, err := encodeWebAuthnCredentials(credentials) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法保存通行密钥", err) + } + user.WebAuthnCredentials = encoded + user.WebAuthnChallengeCiphertext = "" + if err := s.users.Update(ctx, user); err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法保存通行密钥", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "webauthn_register", + TargetType: "webauthn_credential", TargetID: id, TargetName: name, + Detail: "注册通行密钥", + }) + } + return ToUserOutput(user), nil +} + +func (s *AuthService) BeginWebAuthnLogin(ctx context.Context, input WebAuthnLoginOptionsInput, request WebAuthnRequestContext, clientKey string) (*WebAuthnLoginOptions, error) { + user, err := s.verifyPasswordForMFAStart(ctx, input.Username, input.Password, clientKey) + if err != nil { + return nil, err + } + credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err) + } + if len(credentials) == 0 { + return nil, apperror.BadRequest("AUTH_WEBAUTHN_NOT_ENABLED", "当前账号未注册通行密钥", nil) + } + challenge, err := security.GenerateWebAuthnChallenge() + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法生成通行密钥挑战", err) + } + state := webAuthnChallengeState{ + Type: "login", + Challenge: challenge, + RPID: request.RPID, + Origin: request.Origin, + ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL), + } + if err := s.saveWebAuthnChallenge(ctx, user, state); err != nil { + return nil, err + } + allowed := make([]webAuthnCredentialDescriptor, 0, len(credentials)) + for _, credential := range credentials { + allowed = append(allowed, webAuthnCredentialDescriptor{Type: "public-key", ID: credential.CredentialID}) + } + return &WebAuthnLoginOptions{ + Challenge: challenge, + RPID: request.RPID, + Timeout: int(mfaChallengeTTL / time.Millisecond), + UserVerification: "preferred", + AllowCredentials: allowed, + }, nil +} + +func (s *AuthService) VerifyWebAuthnLogin(ctx context.Context, user *model.User, assertion security.WebAuthnLoginAssertion, clientKey string) error { + state, err := s.loadWebAuthnChallenge(user, "login") + if err != nil { + return err + } + credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials) + if err != nil { + return apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err) + } + rawID := strings.TrimSpace(assertion.RawID) + if rawID == "" { + rawID = strings.TrimSpace(assertion.ID) + } + for i := range credentials { + credential := &credentials[i] + if credential.CredentialID != rawID { + continue + } + nextSignCount, err := security.VerifyWebAuthnAssertion(assertion, state.Challenge, state.RPID, state.Origin, security.WebAuthnCredentialMaterial{ + CredentialID: credential.CredentialID, + PublicKeyX: credential.PublicKeyX, + PublicKeyY: credential.PublicKeyY, + SignCount: credential.SignCount, + }) + if err != nil { + return apperror.Unauthorized("AUTH_WEBAUTHN_INVALID", "通行密钥校验失败", err) + } + credential.SignCount = nextSignCount + credential.LastUsedAt = time.Now().UTC().Format(time.RFC3339) + encoded, err := encodeWebAuthnCredentials(credentials) + if err != nil { + return apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err) + } + user.WebAuthnCredentials = encoded + user.WebAuthnChallengeCiphertext = "" + if err := s.users.Update(ctx, user); err != nil { + return apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "webauthn_used", + TargetType: "webauthn_credential", TargetID: credential.ID, TargetName: credential.Name, + Detail: "使用通行密钥完成多因素验证", ClientIP: clientKey, + }) + } + return nil + } + return apperror.Unauthorized("AUTH_WEBAUTHN_INVALID", "通行密钥不存在", nil) +} + +func (s *AuthService) ListWebAuthnCredentials(ctx context.Context, subject string) ([]WebAuthnCredentialOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err) + } + output := make([]WebAuthnCredentialOutput, 0, len(credentials)) + for _, credential := range credentials { + output = append(output, toWebAuthnCredentialOutput(credential)) + } + return output, nil +} + +func (s *AuthService) DeleteWebAuthnCredential(ctx context.Context, subject string, id string, input WebAuthnCredentialDeleteInput) (*UserOutput, error) { + user, err := s.userBySubject(ctx, subject) + if err != nil { + return nil, err + } + if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil { + return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err) + } + credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err) + } + found := false + filtered := make([]WebAuthnCredentialRecord, 0, len(credentials)) + for _, credential := range credentials { + if credential.ID == strings.TrimSpace(id) { + found = true + } else { + filtered = append(filtered, credential) + } + } + if !found { + return nil, apperror.New(404, "AUTH_WEBAUTHN_NOT_FOUND", "通行密钥不存在", nil) + } + encoded, err := encodeWebAuthnCredentials(filtered) + if err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err) + } + user.WebAuthnCredentials = encoded + clearTrustedDevicesIfMFAOff(user) + if err := s.users.Update(ctx, user); err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_DELETE_FAILED", "无法删除通行密钥", err) + } + if s.auditService != nil { + s.auditService.Record(AuditEntry{ + UserID: user.ID, Username: user.Username, + Category: "auth", Action: "webauthn_delete", + TargetType: "webauthn_credential", TargetID: strings.TrimSpace(id), + Detail: "删除通行密钥", + }) + } + return ToUserOutput(user), nil +} + +func (s *AuthService) saveWebAuthnChallenge(ctx context.Context, user *model.User, state webAuthnChallengeState) error { + ciphertext, err := s.twoFactorCipher.EncryptJSON(state) + if err != nil { + return apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法保存通行密钥挑战", err) + } + user.WebAuthnChallengeCiphertext = ciphertext + if err := s.users.Update(ctx, user); err != nil { + return apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法保存通行密钥挑战", err) + } + return nil +} + +func (s *AuthService) loadWebAuthnChallenge(user *model.User, challengeType string) (*webAuthnChallengeState, error) { + if strings.TrimSpace(user.WebAuthnChallengeCiphertext) == "" { + return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_MISSING", "请先发起通行密钥验证", nil) + } + var state webAuthnChallengeState + if err := s.twoFactorCipher.DecryptJSON(user.WebAuthnChallengeCiphertext, &state); err != nil { + return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_INVALID", "通行密钥挑战状态异常", err) + } + if state.Type != challengeType { + return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_INVALID", "通行密钥挑战类型不匹配", nil) + } + if state.ExpiresAt.Before(time.Now().UTC()) { + return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_EXPIRED", "通行密钥挑战已过期", nil) + } + return &state, nil +} diff --git a/server/internal/service/notification_service.go b/server/internal/service/notification_service.go index be241bc..dbfa283 100644 --- a/server/internal/service/notification_service.go +++ b/server/internal/service/notification_service.go @@ -16,11 +16,11 @@ import ( ) type NotificationUpsertInput struct { - Name string `json:"name" binding:"required,min=1,max=100"` - Type string `json:"type" binding:"required,oneof=email webhook telegram"` - Enabled bool `json:"enabled"` - OnSuccess bool `json:"onSuccess"` - OnFailure bool `json:"onFailure"` + Name string `json:"name" binding:"required,min=1,max=100"` + Type string `json:"type" binding:"required,oneof=email webhook telegram sms"` + Enabled bool `json:"enabled"` + OnSuccess bool `json:"onSuccess"` + OnFailure bool `json:"onFailure"` // EventTypes 订阅的扩展事件列表。与 OnSuccess/OnFailure 并存: // - 两者均空时,订阅"备份成功/失败"对应原有语义(兼容)。 // - EventTypes 显式指定时优先按清单匹配。 @@ -186,8 +186,8 @@ func (s *NotificationService) NotifyBackupResult(ctx context.Context, event Back // - eventType 对应 model.NotificationEvent* 常量,用于订阅匹配 // // 订阅匹配规则: -// 1) notification.EventTypes 非空:必须包含 eventType -// 2) notification.EventTypes 为空:沿用 OnSuccess/OnFailure 开关(仅 backup_* 事件) +// 1. notification.EventTypes 非空:必须包含 eventType +// 2. notification.EventTypes 为空:沿用 OnSuccess/OnFailure 开关(仅 backup_* 事件) func (s *NotificationService) DispatchEvent(ctx context.Context, eventType string, title string, body string, fields map[string]any) error { // 同步广播到 SSE 订阅者(前端 Dashboard 实时推送)。 // 非阻塞:即便广播器未注入或订阅者已满也不影响 Notification 持久渠道。 @@ -209,6 +209,49 @@ func (s *NotificationService) DispatchEvent(ctx context.Context, eventType strin return s.deliver(ctx, items, message) } +func (s *NotificationService) SendAuthEmailOTP(ctx context.Context, to string, code string) error { + return s.sendFirstByType(ctx, "email", map[string]any{"to": strings.TrimSpace(to)}, notify.Message{ + Title: "BackupX 登录验证码", + Body: fmt.Sprintf("您的 BackupX 登录验证码为:%s\n验证码 5 分钟内有效。若非本人操作,请立即检查账号安全。", code), + Fields: map[string]any{ + "purpose": "login_otp", + }, + }) +} + +func (s *NotificationService) SendAuthSMSOTP(ctx context.Context, phone string, code string) error { + return s.sendFirstByType(ctx, "sms", nil, notify.Message{ + Title: "BackupX 登录验证码", + Body: fmt.Sprintf("BackupX 登录验证码:%s,5 分钟内有效。", code), + Fields: map[string]any{ + "phone": strings.TrimSpace(phone), + "code": code, + "purpose": "login_otp", + }, + }) +} + +func (s *NotificationService) sendFirstByType(ctx context.Context, notificationType string, override map[string]any, message notify.Message) error { + items, err := s.notifications.List(ctx) + if err != nil { + return err + } + for _, item := range items { + if !item.Enabled || item.Type != notificationType { + continue + } + configMap := map[string]any{} + if err := s.cipher.DecryptJSON(item.ConfigCiphertext, &configMap); err != nil { + return fmt.Errorf("decrypt notification %d config: %w", item.ID, err) + } + for key, value := range override { + configMap[key] = value + } + return s.registry.Send(ctx, item.Type, configMap, message) + } + return fmt.Errorf("no enabled %s notification configured", notificationType) +} + // collectSubscribers 按事件类型收集启用的订阅者。 // 列出启用通知后按事件类型再过滤(避免引入新 repository 方法)。 func (s *NotificationService) collectSubscribers(ctx context.Context, eventType string, fallbackSuccess bool) ([]model.Notification, error) { diff --git a/server/internal/service/user_service.go b/server/internal/service/user_service.go index 6107950..1df1f3e 100644 --- a/server/internal/service/user_service.go +++ b/server/internal/service/user_service.go @@ -22,13 +22,22 @@ func NewUserService(users repository.UserRepository) *UserService { // UserSummary 用户列表项(不含密码哈希)。 type UserSummary struct { - ID uint `json:"id"` - Username string `json:"username"` - DisplayName string `json:"displayName"` - Email string `json:"email"` - Role string `json:"role"` - Disabled bool `json:"disabled"` - CreatedAt string `json:"createdAt"` + ID uint `json:"id"` + Username string `json:"username"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + Phone string `json:"phone"` + Role string `json:"role"` + Disabled bool `json:"disabled"` + MFAEnabled bool `json:"mfaEnabled"` + TwoFactorEnabled bool `json:"twoFactorEnabled"` + TwoFactorRecoveryCodesRemaining int `json:"twoFactorRecoveryCodesRemaining"` + WebAuthnEnabled bool `json:"webAuthnEnabled"` + WebAuthnCredentialCount int `json:"webAuthnCredentialCount"` + TrustedDeviceCount int `json:"trustedDeviceCount"` + EmailOTPEnabled bool `json:"emailOtpEnabled"` + SMSOTPEnabled bool `json:"smsOtpEnabled"` + CreatedAt string `json:"createdAt"` } // UserUpsertInput 创建/更新用户的输入。 @@ -37,6 +46,7 @@ type UserUpsertInput struct { Password string `json:"password" binding:"omitempty,min=8,max=128"` DisplayName string `json:"displayName" binding:"required,min=1,max=128"` Email string `json:"email" binding:"omitempty,max=255"` + Phone string `json:"phone" binding:"omitempty,max=64"` Role string `json:"role" binding:"required,oneof=admin operator viewer"` Disabled bool `json:"disabled"` } @@ -76,6 +86,7 @@ func (s *UserService) Create(ctx context.Context, input UserUpsertInput) (*UserS PasswordHash: hash, DisplayName: strings.TrimSpace(input.DisplayName), Email: strings.TrimSpace(input.Email), + Phone: strings.TrimSpace(input.Phone), Role: input.Role, Disabled: input.Disabled, } @@ -107,18 +118,43 @@ func (s *UserService) Update(ctx context.Context, id uint, input UserUpsertInput return nil, apperror.Conflict("USER_USERNAME_EXISTS", "用户名已存在", nil) } } + passwordChanged := strings.TrimSpace(input.Password) != "" + disabledChanged := input.Disabled && !existing.Disabled + emailChanged := strings.TrimSpace(input.Email) != strings.TrimSpace(existing.Email) + phoneChanged := strings.TrimSpace(input.Phone) != strings.TrimSpace(existing.Phone) existing.Username = strings.TrimSpace(input.Username) existing.DisplayName = strings.TrimSpace(input.DisplayName) existing.Email = strings.TrimSpace(input.Email) + existing.Phone = strings.TrimSpace(input.Phone) existing.Role = input.Role existing.Disabled = input.Disabled - if strings.TrimSpace(input.Password) != "" { + if passwordChanged { hash, err := security.HashPassword(input.Password) if err != nil { return nil, apperror.Internal("USER_HASH_FAILED", "无法处理密码", err) } existing.PasswordHash = hash + existing.TrustedDevices = "" + existing.OutOfBandOTPCiphertext = "" + existing.WebAuthnChallengeCiphertext = "" } + if strings.TrimSpace(existing.Email) == "" && existing.EmailOTPEnabled { + existing.EmailOTPEnabled = false + existing.OutOfBandOTPCiphertext = "" + } + if strings.TrimSpace(existing.Phone) == "" && existing.SMSOTPEnabled { + existing.SMSOTPEnabled = false + existing.OutOfBandOTPCiphertext = "" + } + if emailChanged || phoneChanged { + existing.OutOfBandOTPCiphertext = "" + } + if disabledChanged { + existing.TrustedDevices = "" + existing.OutOfBandOTPCiphertext = "" + existing.WebAuthnChallengeCiphertext = "" + } + clearTrustedDevicesIfMFAOff(existing) if err := s.users.Update(ctx, existing); err != nil { return nil, apperror.Internal("USER_UPDATE_FAILED", "无法更新用户", err) } @@ -147,14 +183,47 @@ func (s *UserService) Delete(ctx context.Context, id uint) error { return s.users.Delete(ctx, id) } +func (s *UserService) ResetTwoFactor(ctx context.Context, id uint) (*UserSummary, error) { + existing, err := s.users.FindByID(ctx, id) + if err != nil { + return nil, apperror.Internal("USER_GET_FAILED", "无法获取用户", err) + } + if existing == nil { + return nil, apperror.New(404, "USER_NOT_FOUND", "用户不存在", nil) + } + existing.TwoFactorEnabled = false + existing.TwoFactorSecretCiphertext = "" + existing.TwoFactorRecoveryCodeHashes = "" + existing.WebAuthnCredentials = "" + existing.WebAuthnChallengeCiphertext = "" + existing.TrustedDevices = "" + existing.EmailOTPEnabled = false + existing.SMSOTPEnabled = false + existing.OutOfBandOTPCiphertext = "" + if err := s.users.Update(ctx, existing); err != nil { + return nil, apperror.Internal("USER_2FA_RESET_FAILED", "无法重置 MFA", err) + } + summary := toUserSummary(existing) + return &summary, nil +} + func toUserSummary(u *model.User) UserSummary { return UserSummary{ - ID: u.ID, - Username: u.Username, - DisplayName: u.DisplayName, - Email: u.Email, - Role: u.Role, - Disabled: u.Disabled, - CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + ID: u.ID, + Username: u.Username, + DisplayName: u.DisplayName, + Email: u.Email, + Phone: u.Phone, + Role: u.Role, + Disabled: u.Disabled, + MFAEnabled: userMFAEnabled(u), + TwoFactorEnabled: u.TwoFactorEnabled, + TwoFactorRecoveryCodesRemaining: recoveryCodeRemainingCount(u), + WebAuthnEnabled: webAuthnCredentialCount(u) > 0, + WebAuthnCredentialCount: webAuthnCredentialCount(u), + TrustedDeviceCount: trustedDeviceCount(u), + EmailOTPEnabled: u.EmailOTPEnabled, + SMSOTPEnabled: u.SMSOTPEnabled, + CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), } } diff --git a/server/internal/service/user_service_test.go b/server/internal/service/user_service_test.go new file mode 100644 index 0000000..9d9b28a --- /dev/null +++ b/server/internal/service/user_service_test.go @@ -0,0 +1,124 @@ +package service + +import ( + "context" + "testing" + + "backupx/server/internal/model" + "backupx/server/internal/security" +) + +func TestUserServiceUpdatePasswordClearsTrustedDeviceState(t *testing.T) { + hash, err := security.HashPassword("old-password") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + repo := &fakeUserRepository{users: []*model.User{{ + ID: 1, + Username: "admin", + PasswordHash: hash, + DisplayName: "Admin", + Email: "admin@example.com", + Role: model.UserRoleAdmin, + TwoFactorEnabled: true, + TrustedDevices: `[{"id":"device"}]`, + OutOfBandOTPCiphertext: "pending", + WebAuthnChallengeCiphertext: "challenge", + }}} + svc := NewUserService(repo) + + if _, err := svc.Update(context.Background(), 1, UserUpsertInput{ + Username: "admin", + Password: "new-password", + DisplayName: "Admin", + Email: "admin@example.com", + Role: model.UserRoleAdmin, + }); err != nil { + t.Fatalf("Update: %v", err) + } + + updated := repo.users[0] + if security.ComparePassword(updated.PasswordHash, "new-password") != nil { + t.Fatalf("expected password hash to be updated") + } + if updated.TrustedDevices != "" || updated.OutOfBandOTPCiphertext != "" || updated.WebAuthnChallengeCiphertext != "" { + t.Fatalf("expected password update to clear trusted device state, got trusted=%q otp=%q challenge=%q", updated.TrustedDevices, updated.OutOfBandOTPCiphertext, updated.WebAuthnChallengeCiphertext) + } +} + +func TestUserServiceUpdateContactClearsUnavailableOTP(t *testing.T) { + hash, err := security.HashPassword("password-123") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + repo := &fakeUserRepository{users: []*model.User{{ + ID: 1, + Username: "admin", + PasswordHash: hash, + DisplayName: "Admin", + Email: "admin@example.com", + Phone: "+15550000000", + Role: model.UserRoleAdmin, + EmailOTPEnabled: true, + SMSOTPEnabled: true, + TrustedDevices: `[{"id":"device"}]`, + OutOfBandOTPCiphertext: "pending", + }}} + svc := NewUserService(repo) + + summary, err := svc.Update(context.Background(), 1, UserUpsertInput{ + Username: "admin", + DisplayName: "Admin", + Role: model.UserRoleAdmin, + }) + if err != nil { + t.Fatalf("Update: %v", err) + } + + updated := repo.users[0] + if updated.EmailOTPEnabled || updated.SMSOTPEnabled || summary.MFAEnabled { + t.Fatalf("expected unavailable OTP channels to be disabled") + } + if updated.TrustedDevices != "" || updated.OutOfBandOTPCiphertext != "" || updated.WebAuthnChallengeCiphertext != "" { + t.Fatalf("expected last MFA removal to clear temporary state") + } +} + +func TestUserServiceUpdateContactChangeClearsPendingOTP(t *testing.T) { + hash, err := security.HashPassword("password-123") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + repo := &fakeUserRepository{users: []*model.User{{ + ID: 1, + Username: "admin", + PasswordHash: hash, + DisplayName: "Admin", + Email: "old@example.com", + Role: model.UserRoleAdmin, + EmailOTPEnabled: true, + OutOfBandOTPCiphertext: "pending", + }}} + svc := NewUserService(repo) + + summary, err := svc.Update(context.Background(), 1, UserUpsertInput{ + Username: "admin", + DisplayName: "Admin", + Email: "new@example.com", + Role: model.UserRoleAdmin, + }) + if err != nil { + t.Fatalf("Update: %v", err) + } + + updated := repo.users[0] + if updated.Email != "new@example.com" || summary.Email != "new@example.com" { + t.Fatalf("expected email to be updated") + } + if !updated.EmailOTPEnabled { + t.Fatalf("expected email OTP to remain enabled") + } + if updated.OutOfBandOTPCiphertext != "" { + t.Fatalf("expected contact change to clear pending OTP") + } +} diff --git a/web/src/components/notifications/field-config.test.ts b/web/src/components/notifications/field-config.test.ts index 78782de..13cebfb 100644 --- a/web/src/components/notifications/field-config.test.ts +++ b/web/src/components/notifications/field-config.test.ts @@ -5,15 +5,18 @@ describe('notification field config', () => { it('returns readable type labels', () => { expect(getNotificationTypeLabel('email')).toBe('Email') expect(getNotificationTypeLabel('telegram')).toBe('Telegram') + expect(getNotificationTypeLabel('sms')).toBe('SMS Webhook') }) it('returns required fields for each notification type', () => { const emailFields = getNotificationFieldConfigs('email') const webhookFields = getNotificationFieldConfigs('webhook') const telegramFields = getNotificationFieldConfigs('telegram') + const smsFields = getNotificationFieldConfigs('sms') expect(emailFields.some((field) => field.key === 'host' && field.required)).toBe(true) expect(webhookFields.some((field) => field.key === 'url' && field.required)).toBe(true) expect(telegramFields.some((field) => field.key === 'botToken' && field.required)).toBe(true) + expect(smsFields.some((field) => field.key === 'url' && field.required)).toBe(true) }) }) diff --git a/web/src/components/notifications/field-config.ts b/web/src/components/notifications/field-config.ts index 0b59f12..995587c 100644 --- a/web/src/components/notifications/field-config.ts +++ b/web/src/components/notifications/field-config.ts @@ -13,6 +13,10 @@ const FIELD_CONFIG_MAP: Record = { { key: 'url', label: 'Webhook URL', type: 'input', required: true, placeholder: 'https://hooks.example.com/backupx' }, { key: 'secret', label: '共享密钥', type: 'password', placeholder: '可选', sensitive: true }, ], + sms: [ + { key: 'url', label: 'SMS Webhook URL', type: 'input', required: true, placeholder: 'https://sms-gateway.example.com/send' }, + { key: 'secret', label: '共享密钥', type: 'password', placeholder: '可选', sensitive: true }, + ], telegram: [ { key: 'botToken', label: 'Bot Token', type: 'password', required: true, placeholder: '123456:ABC', sensitive: true }, { key: 'chatId', label: 'Chat ID', type: 'input', required: true, placeholder: '-100xxxxxxxxxx' }, @@ -23,6 +27,7 @@ export const notificationTypeOptions = [ { label: 'Email', value: 'email' }, { label: 'Webhook', value: 'webhook' }, { label: 'Telegram', value: 'telegram' }, + { label: 'SMS Webhook', value: 'sms' }, ] as const export function getNotificationTypeLabel(type: NotificationType) { @@ -33,6 +38,8 @@ export function getNotificationTypeLabel(type: NotificationType) { return 'Webhook' case 'telegram': return 'Telegram' + case 'sms': + return 'SMS Webhook' default: return type } diff --git a/web/src/layouts/AppLayout.tsx b/web/src/layouts/AppLayout.tsx index cc6b045..d19c5c7 100644 --- a/web/src/layouts/AppLayout.tsx +++ b/web/src/layouts/AppLayout.tsx @@ -1,4 +1,4 @@ -import { Avatar, Button, Dropdown, Layout, Menu, Message, Modal, Form, Input, Space, Typography } from '@arco-design/web-react' +import { Alert, Avatar, Button, Divider, Dropdown, Layout, Menu, Message, Modal, Form, Input, Space, Tag, Typography } from '@arco-design/web-react' import { IconDashboard, IconStorage, @@ -23,7 +23,27 @@ import { } from '@arco-design/web-react/icon' import { useState } from 'react' import { Outlet, useLocation, useNavigate } from 'react-router-dom' -import { changePassword, type ChangePasswordPayload } from '../services/auth' +import { + changePassword, + beginWebAuthnRegistration, + clearTrustedDeviceToken, + configureOtp, + deleteWebAuthnCredential, + disableTwoFactor, + enableTwoFactor, + finishWebAuthnRegistration, + listTrustedDevices, + listWebAuthnCredentials, + prepareTwoFactor, + regenerateRecoveryCodes, + revokeTrustedDevice, + type ChangePasswordPayload, + type TrustedDevice, + type UserInfo, + type WebAuthnCredential, + type TwoFactorSetupResult, +} from '../services/auth' +import { createWebAuthnCredential } from '../utils/webauthn' import { useAuthStore } from '../stores/auth' import { resolveErrorMessage } from '../utils/error' import { isAdmin, roleLabel } from '../utils/permissions' @@ -105,11 +125,27 @@ export function AppLayout() { const [collapsed, setCollapsed] = useState(false) const [pwdVisible, setPwdVisible] = useState(false) const [pwdLoading, setPwdLoading] = useState(false) + const [twoFactorVisible, setTwoFactorVisible] = useState(false) + const [twoFactorLoading, setTwoFactorLoading] = useState(false) + const [twoFactorSetup, setTwoFactorSetup] = useState(null) + const [recoveryCodes, setRecoveryCodes] = useState([]) + const [webAuthnCredentials, setWebAuthnCredentials] = useState([]) + const [trustedDevices, setTrustedDevices] = useState([]) + const [securityDetailsLoading, setSecurityDetailsLoading] = useState(false) const [pwdForm] = Form.useForm() + const [twoFactorForm] = Form.useForm<{ currentPassword: string; code: string; email: string; phone: string }>() const location = useLocation() const navigate = useNavigate() const user = useAuthStore((state) => state.user) const logout = useAuthStore((state) => state.logout) + const setUser = useAuthStore((state) => state.setUser) + + function applySecurityUserUpdate(updated: UserInfo) { + setUser(updated) + if (!updated.mfaEnabled) { + clearTrustedDeviceToken(updated.username) + } + } async function handleChangePassword() { try { @@ -120,6 +156,7 @@ export function AppLayout() { } setPwdLoading(true) await changePassword({ oldPassword: values.oldPassword, newPassword: values.newPassword }) + clearTrustedDeviceToken(user?.username) Message.success('密码修改成功') setPwdVisible(false) pwdForm.resetFields() @@ -132,15 +169,227 @@ export function AppLayout() { } } + function closeTwoFactorModal() { + setTwoFactorVisible(false) + setTwoFactorSetup(null) + setRecoveryCodes([]) + setWebAuthnCredentials([]) + setTrustedDevices([]) + twoFactorForm.resetFields() + } + + async function openSecurityModal() { + setTwoFactorVisible(true) + twoFactorForm.setFieldValue('email', user?.email ?? '') + twoFactorForm.setFieldValue('phone', user?.phone ?? '') + await loadSecurityDetails() + } + + async function loadSecurityDetails() { + setSecurityDetailsLoading(true) + try { + const [credentials, devices] = await Promise.all([listWebAuthnCredentials(), listTrustedDevices()]) + setWebAuthnCredentials(credentials) + setTrustedDevices(devices) + } catch (err) { + Message.error(resolveErrorMessage(err, '加载安全配置失败')) + } finally { + setSecurityDetailsLoading(false) + } + } + + async function copyRecoveryCodes() { + if (recoveryCodes.length === 0) return + try { + await navigator.clipboard.writeText(recoveryCodes.join('\n')) + Message.success('已复制到剪贴板') + } catch { + Message.info('请手动选择文本复制') + } + } + + async function handleTwoFactorSetupAction() { + try { + const values = await twoFactorForm.validate() + setTwoFactorLoading(true) + if (!twoFactorSetup) { + const setup = await prepareTwoFactor({ currentPassword: values.currentPassword }) + setTwoFactorSetup(setup) + Message.success('TOTP 密钥已生成') + return + } + const result = await enableTwoFactor({ code: values.code }) + setUser(result.user) + setRecoveryCodes(result.recoveryCodes) + Message.success('TOTP 已启用') + } catch (err) { + if (err) { + Message.error(resolveErrorMessage(err, 'TOTP 操作失败')) + } + } finally { + setTwoFactorLoading(false) + } + } + + async function handleRegenerateRecoveryCodes() { + try { + const values = await twoFactorForm.validate() + setTwoFactorLoading(true) + const result = await regenerateRecoveryCodes({ + currentPassword: values.currentPassword, + code: values.code, + }) + setUser(result.user) + setRecoveryCodes(result.recoveryCodes) + twoFactorForm.resetFields() + Message.success('恢复码已重新生成') + } catch (err) { + if (err) { + Message.error(resolveErrorMessage(err, '恢复码生成失败')) + } + } finally { + setTwoFactorLoading(false) + } + } + + async function handleDisableTwoFactor() { + try { + const values = await twoFactorForm.validate() + setTwoFactorLoading(true) + const updated = await disableTwoFactor({ + currentPassword: values.currentPassword, + code: values.code, + }) + applySecurityUserUpdate(updated) + Message.success('TOTP 已关闭') + closeTwoFactorModal() + } catch (err) { + if (err) { + Message.error(resolveErrorMessage(err, '关闭 TOTP 失败')) + } + } finally { + setTwoFactorLoading(false) + } + } + + function readCurrentPassword() { + const currentPassword = String(twoFactorForm.getFieldValue('currentPassword') ?? '') + if (currentPassword.trim().length < 8) { + Message.error('请输入当前密码') + return '' + } + return currentPassword + } + + async function handleRegisterWebAuthn() { + const currentPassword = readCurrentPassword() + if (!currentPassword) return + try { + setTwoFactorLoading(true) + const options = await beginWebAuthnRegistration({ currentPassword }) + const credential = await createWebAuthnCredential(options) + const updated = await finishWebAuthnRegistration({ name: navigator.userAgent.slice(0, 120), credential }) + applySecurityUserUpdate(updated) + await loadSecurityDetails() + Message.success('通行密钥已注册') + } catch (err) { + Message.error(resolveErrorMessage(err, '通行密钥注册失败')) + } finally { + setTwoFactorLoading(false) + } + } + + async function handleDeleteWebAuthnCredential(id: string) { + const currentPassword = readCurrentPassword() + if (!currentPassword) return + try { + setTwoFactorLoading(true) + const updated = await deleteWebAuthnCredential(id, { currentPassword }) + applySecurityUserUpdate(updated) + await loadSecurityDetails() + Message.success('通行密钥已删除') + } catch (err) { + Message.error(resolveErrorMessage(err, '删除通行密钥失败')) + } finally { + setTwoFactorLoading(false) + } + } + + async function handleConfigureOtp(channel: 'email' | 'sms', enabled: boolean) { + const currentPassword = readCurrentPassword() + if (!currentPassword) return + const email = String(twoFactorForm.getFieldValue('email') ?? '') + const phone = String(twoFactorForm.getFieldValue('phone') ?? '') + try { + setTwoFactorLoading(true) + const updated = await configureOtp({ currentPassword, channel, enabled, email, phone }) + applySecurityUserUpdate(updated) + twoFactorForm.setFieldValue('email', updated.email ?? '') + twoFactorForm.setFieldValue('phone', updated.phone ?? '') + Message.success(enabled ? 'OTP 已启用' : 'OTP 已关闭') + } catch (err) { + Message.error(resolveErrorMessage(err, 'OTP 配置失败')) + } finally { + setTwoFactorLoading(false) + } + } + + async function handleRevokeTrustedDevice(id: string) { + const currentPassword = readCurrentPassword() + if (!currentPassword) return + try { + setTwoFactorLoading(true) + await revokeTrustedDevice(id, { currentPassword }) + clearTrustedDeviceToken(user?.username) + await loadSecurityDetails() + Message.success('可信设备已移除') + } catch (err) { + Message.error(resolveErrorMessage(err, '移除可信设备失败')) + } finally { + setTwoFactorLoading(false) + } + } + + function renderTwoFactorFooter() { + if (recoveryCodes.length > 0) { + return ( + + + + + ) + } + if (user?.twoFactorEnabled) { + return ( + + + + + + ) + } + return ( + + + + + ) + } + const userDroplist = ( { if (key === 'password') { setPwdVisible(true) + } else if (key === 'two-factor') { + void openSecurityModal() } else if (key === 'logout') { logout() } }}> 修改密码 + 多因素认证 退出登录 ) @@ -217,6 +466,138 @@ export function AppLayout() { + + + {recoveryCodes.length > 0 ? ( + + + + + ) : ( +
+ {user?.twoFactorEnabled ? ( + <> + + + + + + + + + ) : ( + <> + {!twoFactorSetup ? ( + <> + + + + + + ) : ( + <> + +
+ TOTP 二维码 + + 手动密钥 + + +
+ + + + + )} + + )} + + + + 通行密钥 + 0 ? 'green' : 'gray'} bordered> + {webAuthnCredentials.length > 0 ? `${webAuthnCredentials.length} 个` : '未注册'} + + + + 支持浏览器 Passkey、平台验证器或安全密钥,用于登录时替代验证码。 + + + + {securityDetailsLoading ? 正在加载通行密钥... : null} + {webAuthnCredentials.map((item) => ( +
+ + {item.name} + {item.lastUsedAt ? `最近使用 ${item.lastUsedAt}` : `创建于 ${item.createdAt}`} + + +
+ ))} +
+
+ + + 邮件 / 短信 OTP + + + 邮件 OTP {user?.emailOtpEnabled ? '已启用' : '未启用'} + 短信 OTP {user?.smsOtpEnabled ? '已启用' : '未启用'} + + + + + + + + + + + + + + + + 可信设备 + 0 ? 'green' : 'gray'} bordered>{trustedDevices.length} 个 + + + 登录时勾选“信任此设备”后,30 天内该设备可在密码校验通过后跳过多因素验证。 + + + {trustedDevices.map((item) => ( +
+ + {item.name} + 最近使用 {item.lastUsedAt || '-'},到期 {item.expiresAt} + + +
+ ))} + {!securityDetailsLoading && trustedDevices.length === 0 ? 暂无可信设备 : null} +
+
+ + )} +
) } diff --git a/web/src/pages/admin/UsersPage.tsx b/web/src/pages/admin/UsersPage.tsx index 62bafc6..61b18c8 100644 --- a/web/src/pages/admin/UsersPage.tsx +++ b/web/src/pages/admin/UsersPage.tsx @@ -1,6 +1,7 @@ import { Alert, Button, Card, Empty, Form, Input, Message, Modal, Select, Space, Switch, Table, Tag, Typography } from '@arco-design/web-react' import { useCallback, useEffect, useState } from 'react' -import { createUser, deleteUser, listUsers, updateUser, type UserRole, type UserSummary, type UserUpsertPayload } from '../../services/users' +import { createUser, deleteUser, listUsers, resetUserTwoFactor, updateUser, type UserRole, type UserSummary, type UserUpsertPayload } from '../../services/users' +import { clearTrustedDeviceToken } from '../../services/auth' import { useAuthStore } from '../../stores/auth' import { resolveErrorMessage } from '../../utils/error' import { isAdmin, roleLabel } from '../../utils/permissions' @@ -12,12 +13,13 @@ const roleOptions = [ ] function createEmpty(): UserUpsertPayload { - return { username: '', password: '', displayName: '', email: '', role: 'operator', disabled: false } + return { username: '', password: '', displayName: '', email: '', phone: '', role: 'operator', disabled: false } } // UsersPage admin 用户管理。非 admin 角色进入路由会被路由守卫拦截。 export function UsersPage() { const user = useAuthStore((s) => s.user) + const setUser = useAuthStore((s) => s.setUser) const [items, setItems] = useState([]) const [loading, setLoading] = useState(true) const [error, setError] = useState('') @@ -55,6 +57,7 @@ export function UsersPage() { password: '', displayName: item.displayName, email: item.email, + phone: item.phone, role: item.role, disabled: item.disabled, }) @@ -73,7 +76,13 @@ export function UsersPage() { setSubmitting(true) try { if (editing) { - await updateUser(editing.id, draft) + const updated = await updateUser(editing.id, draft) + if (updated.id === user?.id) { + if (draft.password?.trim()) { + clearTrustedDeviceToken(updated.username) + } + setUser(updated) + } Message.success('用户已更新') } else { await createUser(draft) @@ -99,6 +108,21 @@ export function UsersPage() { } } + async function handleResetTwoFactor(item: UserSummary) { + if (!window.confirm(`确定重置用户「${item.username}」的全部 MFA 配置吗?该用户之后可仅凭密码登录。`)) return + try { + const updated = await resetUserTwoFactor(item.id) + if (updated.id === user?.id) { + clearTrustedDeviceToken(updated.username) + setUser(updated) + } + Message.success('MFA 已重置') + await load() + } catch (e) { + Message.error(resolveErrorMessage(e, '重置 MFA 失败')) + } + } + if (!isAdmin(user)) { return } @@ -132,12 +156,27 @@ export function UsersPage() { ) }, { title: '角色', dataIndex: 'role', render: (value: string) => {roleLabel(value)} }, - { title: '邮箱', dataIndex: 'email', render: (v: string) => v || '-' }, + { title: '邮箱 / 手机', dataIndex: 'email', render: (_: string, row: UserSummary) => ( + + {row.email || '-'} + {row.phone || '-'} + + ) }, { title: '状态', dataIndex: 'disabled', render: (disabled: boolean) => disabled ? 已停用 : 启用 }, + { title: 'MFA', dataIndex: 'mfaEnabled', render: (_: boolean, row: UserSummary) => row.mfaEnabled ? ( + + {row.twoFactorEnabled ? TOTP : null} + {row.webAuthnEnabled ? Passkey {row.webAuthnCredentialCount} : null} + {row.emailOtpEnabled ? 邮件 : null} + {row.smsOtpEnabled ? 短信 : null} + {row.twoFactorEnabled ? 恢复码 {row.twoFactorRecoveryCodesRemaining} : null} + + ) : 未启用 }, { title: '创建时间', dataIndex: 'createdAt' }, - { title: '操作', width: 180, render: (_: unknown, row: UserSummary) => ( + { title: '操作', width: 260, render: (_: unknown, row: UserSummary) => ( + {row.mfaEnabled && } ) }, @@ -163,6 +202,9 @@ export function UsersPage() { setDraft({ ...draft, email: v })} /> + + setDraft({ ...draft, phone: v })} /> + setDraft({ ...draft, password: v })} /> diff --git a/web/src/pages/audit/AuditLogsPage.tsx b/web/src/pages/audit/AuditLogsPage.tsx index 182c648..b82c8b4 100644 --- a/web/src/pages/audit/AuditLogsPage.tsx +++ b/web/src/pages/audit/AuditLogsPage.tsx @@ -26,6 +26,23 @@ const categoryLabels: Record = { const actionLabels: Record = { login_success: '登录成功', login_failed: '登录失败', + two_factor_required: '需要 MFA', + two_factor_setup: '生成 TOTP', + two_factor_enable: '启用 TOTP', + two_factor_disable: '关闭 TOTP', + two_factor_recovery_code_used: '使用恢复码', + two_factor_recovery_codes_regenerate: '重建恢复码', + webauthn_register: '注册通行密钥', + webauthn_used: '使用通行密钥', + webauthn_delete: '删除通行密钥', + trusted_device_create: '信任设备', + trusted_device_used: '使用可信设备', + trusted_device_revoke: '移除可信设备', + otp_enable: '启用 OTP', + otp_disable: '关闭 OTP', + otp_send: '发送 OTP', + otp_used: '使用 OTP', + reset_two_factor: '重置 MFA', setup: '系统初始化', change_password: '修改密码', create: '创建', diff --git a/web/src/pages/login/LoginPage.tsx b/web/src/pages/login/LoginPage.tsx index 22ad35b..038a8dd 100644 --- a/web/src/pages/login/LoginPage.tsx +++ b/web/src/pages/login/LoginPage.tsx @@ -1,10 +1,11 @@ -import { Alert, Button, Card, Form, Input, Space, Typography, Message } from '@arco-design/web-react' -import { IconCloud, IconLock, IconUser } from '@arco-design/web-react/icon' +import { Button, Checkbox, Form, Input, Space, Typography, Message } from '@arco-design/web-react' +import { IconCloud, IconLock, IconSafe, IconUser } from '@arco-design/web-react/icon' import { useEffect, useState } from 'react' import { useNavigate } from 'react-router-dom' import axios from 'axios' -import { fetchSetupStatus } from '../../services/auth' +import { beginWebAuthnLogin, fetchSetupStatus, sendLoginOtp } from '../../services/auth' import { useAuthStore } from '../../stores/auth' +import { getWebAuthnAssertion } from '../../utils/webauthn' interface SetupFormValues { username: string @@ -15,12 +16,17 @@ interface SetupFormValues { interface LoginFormValues { username: string password: string + twoFactorCode?: string + rememberDevice?: boolean } function resolveErrorMessage(error: unknown) { if (axios.isAxiosError(error)) { return error.response?.data?.message ?? '请求失败,请稍后重试' } + if (error instanceof Error) { + return error.message + } return '请求失败,请稍后重试' } @@ -29,8 +35,20 @@ export function LoginPage() { const authStatus = useAuthStore((state) => state.status) const doLogin = useAuthStore((state) => state.login) const doSetup = useAuthStore((state) => state.setup) + const [loginForm] = Form.useForm() const [initialized, setInitialized] = useState(null) const [loading, setLoading] = useState(false) + const [mfaActionLoading, setMfaActionLoading] = useState('') + const [twoFactorRequired, setTwoFactorRequired] = useState(false) + + function resetTwoFactorPrompt() { + if (!twoFactorRequired) { + return + } + setTwoFactorRequired(false) + loginForm.setFieldValue('twoFactorCode', undefined) + loginForm.setFieldValue('rememberDevice', false) + } useEffect(() => { if (authStatus === 'authenticated') { @@ -73,13 +91,77 @@ export function LoginPage() { const handleLogin = async (values: LoginFormValues) => { setLoading(true) try { - await doLogin(values) + await doLogin({ + ...values, + trustedDeviceName: values.rememberDevice ? navigator.userAgent.slice(0, 120) : undefined, + }) + setTwoFactorRequired(false) + Message.success('登录成功') + navigate('/dashboard', { replace: true }) + } catch (error) { + if (axios.isAxiosError(error)) { + const code = error.response?.data?.code + if (code === 'AUTH_2FA_REQUIRED' || code === 'AUTH_2FA_INVALID') { + setTwoFactorRequired(true) + Message.error(resolveErrorMessage(error)) + return + } + } + Message.error(resolveErrorMessage(error)) + } finally { + setLoading(false) + } + } + + function readLoginCredentials(): (LoginFormValues & { username: string; password: string }) | null { + const values = loginForm.getFieldsValue() + if (!values.username?.trim() || !values.password?.trim()) { + Message.error('请先输入用户名和密码') + return null + } + return { + ...values, + username: values.username, + password: values.password, + } + } + + async function handleSendOTP(channel: 'email' | 'sms') { + const values = readLoginCredentials() + if (!values) return + setMfaActionLoading(channel) + try { + await sendLoginOtp({ username: values.username, password: values.password, channel }) + Message.success(channel === 'email' ? '邮件验证码已发送' : '短信验证码已发送') + } catch (error) { + Message.error(resolveErrorMessage(error)) + } finally { + setMfaActionLoading('') + } + } + + async function handleWebAuthnLogin() { + const values = readLoginCredentials() + if (!values) return + setMfaActionLoading('webauthn') + try { + const options = await beginWebAuthnLogin({ username: values.username, password: values.password }) + const assertion = await getWebAuthnAssertion(options) + await doLogin({ + username: values.username, + password: values.password, + webAuthnAssertion: assertion, + trustedDeviceToken: '', + rememberDevice: values.rememberDevice, + trustedDeviceName: navigator.userAgent.slice(0, 120), + }) + setTwoFactorRequired(false) Message.success('登录成功') navigate('/dashboard', { replace: true }) } catch (error) { Message.error(resolveErrorMessage(error)) } finally { - setLoading(false) + setMfaActionLoading('') } } @@ -181,15 +263,30 @@ export function LoginPage() { ) : ( - layout="vertical" onSubmit={handleLogin}> + form={loginForm} layout="vertical" onSubmit={handleLogin}> - } size="large" /> + } size="large" onChange={resetTwoFactorPrompt} /> - } size="large" /> + } size="large" onChange={resetTwoFactorPrompt} /> + {twoFactorRequired && ( + <> + + } size="large" maxLength={32} /> + + + + + + + + 信任此设备 30 天 + + + )} )} diff --git a/web/src/services/auth.ts b/web/src/services/auth.ts index dcfbd61..bdf0b7c 100644 --- a/web/src/services/auth.ts +++ b/web/src/services/auth.ts @@ -9,18 +9,64 @@ export interface SetupPayload { export interface LoginPayload { username: string password: string + twoFactorCode?: string + webAuthnAssertion?: WebAuthnAssertion + trustedDeviceToken?: string + rememberDevice?: boolean + trustedDeviceName?: string } export interface UserInfo { id: number username: string displayName: string + email?: string + phone?: string role: string + mfaEnabled?: boolean + twoFactorEnabled?: boolean + twoFactorRecoveryCodesRemaining?: number + webAuthnEnabled?: boolean + webAuthnCredentialCount?: number + trustedDeviceCount?: number + emailOtpEnabled?: boolean + smsOtpEnabled?: boolean } export interface AuthResult { token: string user: UserInfo + trustedDeviceToken?: string + trustedDevice?: TrustedDevice +} + +const TRUSTED_DEVICE_TOKEN_KEY = 'backupx-trusted-device-token' +const TRUSTED_DEVICE_TOKEN_PREFIX = 'backupx-trusted-device-token:' + +function trustedDeviceTokenKey(username: string) { + return `${TRUSTED_DEVICE_TOKEN_PREFIX}${username.trim().toLowerCase()}` +} + +export function getTrustedDeviceToken(username?: string) { + if (username?.trim()) { + return localStorage.getItem(trustedDeviceTokenKey(username)) ?? localStorage.getItem(TRUSTED_DEVICE_TOKEN_KEY) ?? '' + } + return localStorage.getItem(TRUSTED_DEVICE_TOKEN_KEY) ?? '' +} + +export function clearTrustedDeviceToken(username?: string) { + if (username?.trim()) { + localStorage.removeItem(trustedDeviceTokenKey(username)) + localStorage.removeItem(TRUSTED_DEVICE_TOKEN_KEY) + return + } + localStorage.removeItem(TRUSTED_DEVICE_TOKEN_KEY) + for (let index = localStorage.length - 1; index >= 0; index -= 1) { + const key = localStorage.key(index) + if (key?.startsWith(TRUSTED_DEVICE_TOKEN_PREFIX)) { + localStorage.removeItem(key) + } + } } export async function fetchSetupStatus() { @@ -34,8 +80,16 @@ export async function setup(payload: SetupPayload) { } export async function login(payload: LoginPayload) { - const response = await http.post<{ code: string; message: string; data: AuthResult }>('/auth/login', payload) - return response.data.data + const response = await http.post<{ code: string; message: string; data: AuthResult }>('/auth/login', { + ...payload, + trustedDeviceToken: payload.trustedDeviceToken ?? getTrustedDeviceToken(payload.username), + }) + const result = response.data.data + if (result.trustedDeviceToken) { + localStorage.setItem(trustedDeviceTokenKey(payload.username), result.trustedDeviceToken) + localStorage.removeItem(TRUSTED_DEVICE_TOKEN_KEY) + } + return result } export async function fetchProfile() { @@ -53,6 +107,177 @@ export async function changePassword(payload: ChangePasswordPayload) { return response.data.data } +export interface TwoFactorSetupPayload { + currentPassword: string +} + +export interface TwoFactorSetupResult { + secret: string + otpAuthUrl: string + qrCodeDataUrl: string + twoFactorEnabled: boolean + twoFactorConfirmed: boolean +} + +export interface TwoFactorCodesResult { + user: UserInfo + recoveryCodes: string[] +} + +export interface EnableTwoFactorPayload { + code: string +} + +export interface DisableTwoFactorPayload { + currentPassword: string + code: string +} + +export type RegenerateRecoveryCodesPayload = DisableTwoFactorPayload + +export type OTPChannel = 'email' | 'sms' + +export interface OTPConfigPayload { + currentPassword: string + channel: OTPChannel + enabled: boolean + email?: string + phone?: string +} + +export interface SendLoginOTPPayload { + username: string + password: string + channel: OTPChannel +} + +export interface WebAuthnCredentialDescriptor { + type: 'public-key' + id: string +} + +export interface WebAuthnRegistrationOptions { + challenge: string + rp: { name: string; id: string } + user: { id: string; name: string; displayName: string } + pubKeyCredParams: Array<{ type: 'public-key'; alg: number }> + timeout: number + attestation: 'none' + authenticatorSelection: { userVerification: UserVerificationRequirement } + excludeCredentials: WebAuthnCredentialDescriptor[] +} + +export interface WebAuthnLoginOptions { + challenge: string + rpId: string + timeout: number + userVerification: UserVerificationRequirement + allowCredentials: WebAuthnCredentialDescriptor[] +} + +export interface WebAuthnAttestation { + id: string + rawId: string + type: 'public-key' + response: { + clientDataJSON: string + attestationObject: string + } +} + +export interface WebAuthnAssertion { + id: string + rawId: string + type: 'public-key' + response: { + clientDataJSON: string + authenticatorData: string + signature: string + userHandle?: string + } +} + +export interface WebAuthnCredential { + id: string + name: string + createdAt: string + lastUsedAt?: string +} + +export interface TrustedDevice { + id: string + name: string + createdAt: string + lastUsedAt: string + expiresAt: string + lastIp: string +} + +export async function prepareTwoFactor(payload: TwoFactorSetupPayload) { + const response = await http.post<{ code: string; message: string; data: TwoFactorSetupResult }>('/auth/2fa/setup', payload) + return response.data.data +} + +export async function enableTwoFactor(payload: EnableTwoFactorPayload) { + const response = await http.post<{ code: string; message: string; data: TwoFactorCodesResult }>('/auth/2fa/enable', payload) + return response.data.data +} + +export async function regenerateRecoveryCodes(payload: RegenerateRecoveryCodesPayload) { + const response = await http.post<{ code: string; message: string; data: TwoFactorCodesResult }>('/auth/2fa/recovery-codes', payload) + return response.data.data +} + +export async function disableTwoFactor(payload: DisableTwoFactorPayload) { + const response = await http.delete<{ code: string; message: string; data: UserInfo }>('/auth/2fa', { data: payload }) + return response.data.data +} + +export async function configureOtp(payload: OTPConfigPayload) { + const response = await http.put<{ code: string; message: string; data: UserInfo }>('/auth/otp/config', payload) + return response.data.data +} + +export async function sendLoginOtp(payload: SendLoginOTPPayload) { + const response = await http.post<{ code: string; message: string; data: { sent: boolean } }>('/auth/otp/send', payload) + return response.data.data +} + +export async function beginWebAuthnRegistration(payload: { currentPassword: string }) { + const response = await http.post<{ code: string; message: string; data: WebAuthnRegistrationOptions }>('/auth/webauthn/register/options', payload) + return response.data.data +} + +export async function finishWebAuthnRegistration(payload: { name?: string; credential: WebAuthnAttestation }) { + const response = await http.post<{ code: string; message: string; data: UserInfo }>('/auth/webauthn/register/finish', payload) + return response.data.data +} + +export async function beginWebAuthnLogin(payload: { username: string; password: string }) { + const response = await http.post<{ code: string; message: string; data: WebAuthnLoginOptions }>('/auth/webauthn/login/options', payload) + return response.data.data +} + +export async function listWebAuthnCredentials() { + const response = await http.get<{ code: string; message: string; data: WebAuthnCredential[] }>('/auth/webauthn/credentials') + return response.data.data +} + +export async function deleteWebAuthnCredential(id: string, payload: { currentPassword: string }) { + const response = await http.delete<{ code: string; message: string; data: UserInfo }>(`/auth/webauthn/credentials/${id}`, { data: payload }) + return response.data.data +} + +export async function listTrustedDevices() { + const response = await http.get<{ code: string; message: string; data: TrustedDevice[] }>('/auth/trusted-devices') + return response.data.data +} + +export async function revokeTrustedDevice(id: string, payload: { currentPassword: string }) { + const response = await http.delete<{ code: string; message: string; data: { deleted: boolean } }>(`/auth/trusted-devices/${id}`, { data: payload }) + return response.data.data +} + export async function logout() { const response = await http.post<{ code: string; message: string; data: { loggedOut: boolean } }>('/auth/logout') return response.data.data diff --git a/web/src/services/users.ts b/web/src/services/users.ts index 5519c86..7dab525 100644 --- a/web/src/services/users.ts +++ b/web/src/services/users.ts @@ -7,8 +7,17 @@ export interface UserSummary { username: string displayName: string email: string + phone: string role: UserRole disabled: boolean + mfaEnabled: boolean + twoFactorEnabled: boolean + twoFactorRecoveryCodesRemaining: number + webAuthnEnabled: boolean + webAuthnCredentialCount: number + trustedDeviceCount: number + emailOtpEnabled: boolean + smsOtpEnabled: boolean createdAt: string } @@ -17,6 +26,7 @@ export interface UserUpsertPayload { password?: string displayName: string email?: string + phone?: string role: UserRole disabled: boolean } @@ -40,3 +50,8 @@ export async function deleteUser(id: number) { const response = await http.delete>(`/users/${id}`) return unwrapApiEnvelope(response.data) } + +export async function resetUserTwoFactor(id: number) { + const response = await http.post>(`/users/${id}/2fa/reset`) + return unwrapApiEnvelope(response.data) +} diff --git a/web/src/stores/auth.ts b/web/src/stores/auth.ts index 8290eab..910176c 100644 --- a/web/src/stores/auth.ts +++ b/web/src/stores/auth.ts @@ -15,6 +15,7 @@ interface AuthState { setup: (payload: SetupPayload) => Promise logout: () => void applyAuth: (token: string, user: UserInfo) => void + setUser: (user: UserInfo) => void } function clearAuthState(set: (partial: Partial) => void) { @@ -65,6 +66,9 @@ export const useAuthStore = create()( setAccessToken(token) set({ token, user, status: 'authenticated', bootstrapped: true }) }, + setUser: (user) => { + set({ user }) + }, }), { name: 'backupx-auth', diff --git a/web/src/types/auth.ts b/web/src/types/auth.ts index 8850e72..c945f1e 100644 --- a/web/src/types/auth.ts +++ b/web/src/types/auth.ts @@ -2,12 +2,27 @@ export interface AuthUser { id: number; username: string; displayName: string; + email?: string; + phone?: string; role: string; + mfaEnabled?: boolean; + twoFactorEnabled?: boolean; + twoFactorRecoveryCodesRemaining?: number; + webAuthnEnabled?: boolean; + webAuthnCredentialCount?: number; + trustedDeviceCount?: number; + emailOtpEnabled?: boolean; + smsOtpEnabled?: boolean; } export interface LoginPayload { username: string; password: string; + twoFactorCode?: string; + webAuthnAssertion?: unknown; + trustedDeviceToken?: string; + rememberDevice?: boolean; + trustedDeviceName?: string; } export interface LoginResult { diff --git a/web/src/types/notifications.ts b/web/src/types/notifications.ts index 287578a..aa7bb13 100644 --- a/web/src/types/notifications.ts +++ b/web/src/types/notifications.ts @@ -1,4 +1,4 @@ -export type NotificationType = 'email' | 'webhook' | 'telegram' +export type NotificationType = 'email' | 'webhook' | 'telegram' | 'sms' export type NotificationFieldType = 'input' | 'password' | 'number' | 'textarea' export interface NotificationSummary { diff --git a/web/src/utils/webauthn.ts b/web/src/utils/webauthn.ts new file mode 100644 index 0000000..03bd670 --- /dev/null +++ b/web/src/utils/webauthn.ts @@ -0,0 +1,88 @@ +import type { WebAuthnAssertion, WebAuthnAttestation, WebAuthnLoginOptions, WebAuthnRegistrationOptions } from '../services/auth' + +function base64UrlToBuffer(value: string) { + const padded = value.replace(/-/g, '+').replace(/_/g, '/').padEnd(Math.ceil(value.length / 4) * 4, '=') + const binary = atob(padded) + const bytes = new Uint8Array(binary.length) + for (let index = 0; index < binary.length; index += 1) { + bytes[index] = binary.charCodeAt(index) + } + return bytes.buffer +} + +function bufferToBase64Url(buffer: ArrayBuffer) { + const bytes = new Uint8Array(buffer) + let binary = '' + for (let index = 0; index < bytes.byteLength; index += 1) { + binary += String.fromCharCode(bytes[index]) + } + return btoa(binary).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/g, '') +} + +function assertWebAuthnAvailable() { + if (!window.PublicKeyCredential || !navigator.credentials) { + throw new Error('当前浏览器不支持通行密钥') + } +} + +export async function createWebAuthnCredential(options: WebAuthnRegistrationOptions): Promise { + assertWebAuthnAvailable() + const credential = await navigator.credentials.create({ + publicKey: { + ...options, + challenge: base64UrlToBuffer(options.challenge), + user: { + ...options.user, + id: base64UrlToBuffer(options.user.id), + }, + excludeCredentials: options.excludeCredentials.map((item) => ({ + ...item, + id: base64UrlToBuffer(item.id), + })), + }, + }) as PublicKeyCredential | null + if (!credential) { + throw new Error('通行密钥创建已取消') + } + const response = credential.response as AuthenticatorAttestationResponse + return { + id: credential.id, + rawId: bufferToBase64Url(credential.rawId), + type: 'public-key', + response: { + clientDataJSON: bufferToBase64Url(response.clientDataJSON), + attestationObject: bufferToBase64Url(response.attestationObject), + }, + } +} + +export async function getWebAuthnAssertion(options: WebAuthnLoginOptions): Promise { + assertWebAuthnAvailable() + const credential = await navigator.credentials.get({ + publicKey: { + challenge: base64UrlToBuffer(options.challenge), + rpId: options.rpId, + timeout: options.timeout, + userVerification: options.userVerification, + allowCredentials: options.allowCredentials.map((item) => ({ + ...item, + id: base64UrlToBuffer(item.id), + })), + }, + }) as PublicKeyCredential | null + if (!credential) { + throw new Error('通行密钥验证已取消') + } + const response = credential.response as AuthenticatorAssertionResponse + return { + id: credential.id, + rawId: bufferToBase64Url(credential.rawId), + type: 'public-key', + response: { + clientDataJSON: bufferToBase64Url(response.clientDataJSON), + authenticatorData: bufferToBase64Url(response.authenticatorData), + signature: bufferToBase64Url(response.signature), + userHandle: response.userHandle ? bufferToBase64Url(response.userHandle) : undefined, + }, + } +}