mirror of
https://github.com/Awuqing/BackupX.git
synced 2026-06-01 15:59:42 +08:00
feat: add complete MFA support
Add complete MFA support with TOTP, recovery codes, WebAuthn, trusted-device cookie flow, and email/SMS OTP delivery via notification channels. Security follow-up: trusted device tokens are stored in HttpOnly cookies, and SMS OTP reuses the existing Webhook notifier to avoid introducing a new dynamic URL sink.
This commit is contained in:
179
server/internal/service/auth_methods.go
Normal file
179
server/internal/service/auth_methods.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
mfaChallengeTTL = 5 * time.Minute
|
||||
trustedDeviceTTL = 30 * 24 * time.Hour
|
||||
maxTrustedDeviceName = 128
|
||||
maxTrustedDevices = 10
|
||||
)
|
||||
|
||||
type WebAuthnCredentialRecord struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CredentialID string `json:"credentialId"`
|
||||
PublicKeyX string `json:"publicKeyX"`
|
||||
PublicKeyY string `json:"publicKeyY"`
|
||||
SignCount uint32 `json:"signCount"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
LastUsedAt string `json:"lastUsedAt,omitempty"`
|
||||
}
|
||||
|
||||
type WebAuthnCredentialOutput struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
LastUsedAt string `json:"lastUsedAt,omitempty"`
|
||||
}
|
||||
|
||||
type webAuthnChallengeState struct {
|
||||
Type string `json:"type"`
|
||||
Challenge string `json:"challenge"`
|
||||
RPID string `json:"rpId"`
|
||||
Origin string `json:"origin"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
}
|
||||
|
||||
type TrustedDeviceRecord struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
TokenHash string `json:"tokenHash"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
LastUsedAt time.Time `json:"lastUsedAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
LastIP string `json:"lastIp"`
|
||||
}
|
||||
|
||||
type TrustedDeviceOutput struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
LastUsedAt string `json:"lastUsedAt"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
LastIP string `json:"lastIp"`
|
||||
}
|
||||
|
||||
type pendingOutOfBandOTP struct {
|
||||
Channel string `json:"channel"`
|
||||
CodeHash string `json:"codeHash"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
}
|
||||
|
||||
func userMFAEnabled(user *model.User) bool {
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
return user.TwoFactorEnabled ||
|
||||
strings.TrimSpace(user.WebAuthnCredentials) != "" ||
|
||||
user.EmailOTPEnabled ||
|
||||
user.SMSOTPEnabled
|
||||
}
|
||||
|
||||
func clearTrustedDevicesIfMFAOff(user *model.User) {
|
||||
if user == nil || userMFAEnabled(user) {
|
||||
return
|
||||
}
|
||||
user.TrustedDevices = ""
|
||||
user.OutOfBandOTPCiphertext = ""
|
||||
user.WebAuthnChallengeCiphertext = ""
|
||||
}
|
||||
|
||||
func parseWebAuthnCredentials(value string) ([]WebAuthnCredentialRecord, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var credentials []WebAuthnCredentialRecord
|
||||
if err := json.Unmarshal([]byte(value), &credentials); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func encodeWebAuthnCredentials(credentials []WebAuthnCredentialRecord) (string, error) {
|
||||
if len(credentials) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
encoded, err := json.Marshal(credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
func webAuthnCredentialCount(user *model.User) int {
|
||||
if user == nil {
|
||||
return 0
|
||||
}
|
||||
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return len(credentials)
|
||||
}
|
||||
|
||||
func parseTrustedDevices(value string) ([]TrustedDeviceRecord, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var devices []TrustedDeviceRecord
|
||||
if err := json.Unmarshal([]byte(value), &devices); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func encodeTrustedDevices(devices []TrustedDeviceRecord) (string, error) {
|
||||
if len(devices) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
encoded, err := json.Marshal(devices)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
func trustedDeviceCount(user *model.User) int {
|
||||
if user == nil {
|
||||
return 0
|
||||
}
|
||||
devices, err := parseTrustedDevices(user.TrustedDevices)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
count := 0
|
||||
for _, device := range devices {
|
||||
if device.ExpiresAt.After(now) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func toWebAuthnCredentialOutput(record WebAuthnCredentialRecord) WebAuthnCredentialOutput {
|
||||
return WebAuthnCredentialOutput{
|
||||
ID: record.ID,
|
||||
Name: record.Name,
|
||||
CreatedAt: record.CreatedAt,
|
||||
LastUsedAt: record.LastUsedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func toTrustedDeviceOutput(record TrustedDeviceRecord) TrustedDeviceOutput {
|
||||
return TrustedDeviceOutput{
|
||||
ID: record.ID,
|
||||
Name: record.Name,
|
||||
CreatedAt: record.CreatedAt.Format(time.RFC3339),
|
||||
LastUsedAt: record.LastUsedAt.Format(time.RFC3339),
|
||||
ExpiresAt: record.ExpiresAt.Format(time.RFC3339),
|
||||
LastIP: record.LastIP,
|
||||
}
|
||||
}
|
||||
252
server/internal/service/auth_otp.go
Normal file
252
server/internal/service/auth_otp.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/security"
|
||||
)
|
||||
|
||||
type OTPConfigInput struct {
|
||||
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
|
||||
Channel string `json:"channel" binding:"required,oneof=email sms"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Email string `json:"email" binding:"omitempty,max=255"`
|
||||
Phone string `json:"phone" binding:"omitempty,max=64"`
|
||||
}
|
||||
|
||||
type LoginOTPInput struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
Channel string `json:"channel" binding:"required,oneof=email sms"`
|
||||
}
|
||||
|
||||
func (s *AuthService) ConfigureOutOfBandOTP(ctx context.Context, subject string, input OTPConfigInput) (*UserOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
|
||||
}
|
||||
channel := strings.TrimSpace(input.Channel)
|
||||
previousEmail := strings.TrimSpace(user.Email)
|
||||
previousPhone := strings.TrimSpace(user.Phone)
|
||||
contactChanged := false
|
||||
switch channel {
|
||||
case "email":
|
||||
email := strings.TrimSpace(input.Email)
|
||||
if email != "" {
|
||||
user.Email = email
|
||||
}
|
||||
contactChanged = previousEmail != strings.TrimSpace(user.Email)
|
||||
if input.Enabled && strings.TrimSpace(user.Email) == "" {
|
||||
return nil, apperror.BadRequest("AUTH_EMAIL_REQUIRED", "请先在用户资料中设置邮箱", nil)
|
||||
}
|
||||
user.EmailOTPEnabled = input.Enabled
|
||||
case "sms":
|
||||
phone := strings.TrimSpace(input.Phone)
|
||||
if phone != "" {
|
||||
user.Phone = phone
|
||||
}
|
||||
contactChanged = previousPhone != strings.TrimSpace(user.Phone)
|
||||
if input.Enabled && strings.TrimSpace(user.Phone) == "" {
|
||||
return nil, apperror.BadRequest("AUTH_PHONE_REQUIRED", "请先设置手机号", nil)
|
||||
}
|
||||
user.SMSOTPEnabled = input.Enabled
|
||||
default:
|
||||
return nil, apperror.BadRequest("AUTH_OTP_CHANNEL_INVALID", "验证码渠道不支持", nil)
|
||||
}
|
||||
if s.shouldClearPendingOTP(user, channel, contactChanged) {
|
||||
user.OutOfBandOTPCiphertext = ""
|
||||
}
|
||||
clearTrustedDevicesIfMFAOff(user)
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_OTP_CONFIG_FAILED", "无法更新 OTP 配置", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
action := "otp_disable"
|
||||
if input.Enabled {
|
||||
action = "otp_enable"
|
||||
}
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: action,
|
||||
TargetType: "otp", TargetID: channel,
|
||||
Detail: fmt.Sprintf("%s %s OTP", map[bool]string{true: "启用", false: "关闭"}[input.Enabled], channel),
|
||||
})
|
||||
}
|
||||
return ToUserOutput(user), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) SendLoginOTP(ctx context.Context, input LoginOTPInput, clientKey string) error {
|
||||
user, err := s.verifyPasswordForMFAStart(ctx, input.Username, input.Password, clientKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channel := strings.TrimSpace(input.Channel)
|
||||
if channel == "email" && !user.EmailOTPEnabled {
|
||||
return apperror.BadRequest("AUTH_EMAIL_OTP_DISABLED", "当前账号未启用邮件验证码", nil)
|
||||
}
|
||||
if channel == "sms" && !user.SMSOTPEnabled {
|
||||
return apperror.BadRequest("AUTH_SMS_OTP_DISABLED", "当前账号未启用短信验证码", nil)
|
||||
}
|
||||
code, err := security.GenerateNumericOTP()
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_OTP_GENERATE_FAILED", "无法生成登录验证码", err)
|
||||
}
|
||||
hash, err := security.HashPassword(code)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_OTP_GENERATE_FAILED", "无法处理登录验证码", err)
|
||||
}
|
||||
pending := pendingOutOfBandOTP{
|
||||
Channel: channel,
|
||||
CodeHash: hash,
|
||||
ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL),
|
||||
}
|
||||
ciphertext, err := s.twoFactorCipher.EncryptJSON(pending)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_OTP_SAVE_FAILED", "无法保存登录验证码状态", err)
|
||||
}
|
||||
user.OutOfBandOTPCiphertext = ciphertext
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return apperror.Internal("AUTH_OTP_SAVE_FAILED", "无法保存登录验证码状态", err)
|
||||
}
|
||||
if err := s.deliverLoginOTP(ctx, user, channel, code); err != nil {
|
||||
user.OutOfBandOTPCiphertext = ""
|
||||
if updateErr := s.users.Update(ctx, user); updateErr != nil {
|
||||
return apperror.Internal("AUTH_OTP_SAVE_FAILED", "登录验证码发送失败,且无法回滚验证码状态", updateErr)
|
||||
}
|
||||
return apperror.BadRequest("AUTH_OTP_DELIVERY_FAILED", "登录验证码发送失败", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "otp_send",
|
||||
TargetType: "otp", TargetID: channel,
|
||||
Detail: "发送登录 OTP", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) consumeOutOfBandOTP(ctx context.Context, user *model.User, code string, clientKey string) (bool, error) {
|
||||
if strings.TrimSpace(user.OutOfBandOTPCiphertext) == "" {
|
||||
return false, nil
|
||||
}
|
||||
var pending pendingOutOfBandOTP
|
||||
if err := s.twoFactorCipher.DecryptJSON(user.OutOfBandOTPCiphertext, &pending); err != nil {
|
||||
return false, apperror.Internal("AUTH_OTP_INVALID", "登录验证码状态异常", err)
|
||||
}
|
||||
if pending.ExpiresAt.Before(time.Now().UTC()) {
|
||||
user.OutOfBandOTPCiphertext = ""
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法更新登录验证码状态", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if !outOfBandOTPChannelEnabled(user, pending.Channel) {
|
||||
user.OutOfBandOTPCiphertext = ""
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法更新登录验证码状态", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if security.ComparePassword(pending.CodeHash, security.NormalizeNumericOTP(code)) != nil {
|
||||
return false, nil
|
||||
}
|
||||
user.OutOfBandOTPCiphertext = ""
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return false, apperror.Internal("AUTH_OTP_CONSUME_FAILED", "无法使用登录验证码", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "otp_used",
|
||||
TargetType: "otp", TargetID: pending.Channel,
|
||||
Detail: "使用登录 OTP 完成登录", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) deliverLoginOTP(ctx context.Context, user *model.User, channel string, code string) error {
|
||||
if s.notificationService == nil {
|
||||
return fmt.Errorf("notification service is not configured")
|
||||
}
|
||||
switch channel {
|
||||
case "email":
|
||||
email := strings.TrimSpace(user.Email)
|
||||
if email == "" {
|
||||
return fmt.Errorf("user email is empty")
|
||||
}
|
||||
return s.notificationService.SendAuthEmailOTP(ctx, email, code)
|
||||
case "sms":
|
||||
phone := strings.TrimSpace(user.Phone)
|
||||
if phone == "" {
|
||||
return fmt.Errorf("user phone is empty")
|
||||
}
|
||||
return s.notificationService.SendAuthSMSOTP(ctx, phone, code)
|
||||
default:
|
||||
return fmt.Errorf("unsupported otp channel: %s", channel)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) verifyPasswordForMFAStart(ctx context.Context, username string, password string, clientKey string) (*model.User, error) {
|
||||
if clientKey == "" {
|
||||
clientKey = "unknown"
|
||||
}
|
||||
if !s.rateLimiter.Allow(clientKey) {
|
||||
return nil, apperror.TooManyRequests("AUTH_RATE_LIMITED", "登录尝试过于频繁,请稍后再试", nil)
|
||||
}
|
||||
user, err := s.users.FindByUsername(ctx, strings.TrimSpace(username))
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_LOOKUP_FAILED", "无法执行登录校验", err)
|
||||
}
|
||||
if user == nil || user.Disabled {
|
||||
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", nil)
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, password); err != nil {
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "login_failed",
|
||||
Detail: "密码错误", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", err)
|
||||
}
|
||||
if !userMFAEnabled(user) {
|
||||
return nil, apperror.BadRequest("AUTH_MFA_NOT_ENABLED", "当前账号未启用多因素验证", nil)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func outOfBandOTPChannelEnabled(user *model.User, channel string) bool {
|
||||
switch channel {
|
||||
case "email":
|
||||
return user.EmailOTPEnabled
|
||||
case "sms":
|
||||
return user.SMSOTPEnabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) shouldClearPendingOTP(user *model.User, changedChannel string, contactChanged bool) bool {
|
||||
if !user.EmailOTPEnabled && !user.SMSOTPEnabled {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(user.OutOfBandOTPCiphertext) == "" {
|
||||
return false
|
||||
}
|
||||
var pending pendingOutOfBandOTP
|
||||
if err := s.twoFactorCipher.DecryptJSON(user.OutOfBandOTPCiphertext, &pending); err != nil {
|
||||
return true
|
||||
}
|
||||
return pending.Channel == changedChannel && (contactChanged || !outOfBandOTPChannelEnabled(user, changedChannel))
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/storage/codec"
|
||||
)
|
||||
|
||||
type SetupInput struct {
|
||||
@@ -20,28 +22,47 @@ type SetupInput struct {
|
||||
}
|
||||
|
||||
type LoginInput struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
TwoFactorCode string `json:"twoFactorCode" binding:"omitempty,min=6,max=32"`
|
||||
WebAuthnAssertion *security.WebAuthnLoginAssertion `json:"webAuthnAssertion"`
|
||||
TrustedDeviceToken string `json:"trustedDeviceToken"`
|
||||
RememberDevice bool `json:"rememberDevice"`
|
||||
TrustedDeviceName string `json:"trustedDeviceName" binding:"omitempty,max=128"`
|
||||
}
|
||||
|
||||
type AuthPayload struct {
|
||||
Token string `json:"token"`
|
||||
User *UserOutput `json:"user"`
|
||||
Token string `json:"token"`
|
||||
User *UserOutput `json:"user"`
|
||||
TrustedDeviceToken string `json:"trustedDeviceToken,omitempty"`
|
||||
TrustedDevice *TrustedDeviceOutput `json:"trustedDevice,omitempty"`
|
||||
}
|
||||
|
||||
type UserOutput struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Role string `json:"role"`
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
Role string `json:"role"`
|
||||
MFAEnabled bool `json:"mfaEnabled"`
|
||||
TwoFactorEnabled bool `json:"twoFactorEnabled"`
|
||||
TwoFactorRecoveryCodesRemaining int `json:"twoFactorRecoveryCodesRemaining"`
|
||||
WebAuthnEnabled bool `json:"webAuthnEnabled"`
|
||||
WebAuthnCredentialCount int `json:"webAuthnCredentialCount"`
|
||||
TrustedDeviceCount int `json:"trustedDeviceCount"`
|
||||
EmailOTPEnabled bool `json:"emailOtpEnabled"`
|
||||
SMSOTPEnabled bool `json:"smsOtpEnabled"`
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
users repository.UserRepository
|
||||
configs repository.SystemConfigRepository
|
||||
jwtManager *security.JWTManager
|
||||
rateLimiter *security.LoginRateLimiter
|
||||
auditService *AuditService
|
||||
users repository.UserRepository
|
||||
configs repository.SystemConfigRepository
|
||||
jwtManager *security.JWTManager
|
||||
rateLimiter *security.LoginRateLimiter
|
||||
twoFactorCipher *codec.ConfigCipher
|
||||
auditService *AuditService
|
||||
notificationService *NotificationService
|
||||
}
|
||||
|
||||
func NewAuthService(
|
||||
@@ -49,14 +70,25 @@ func NewAuthService(
|
||||
configs repository.SystemConfigRepository,
|
||||
jwtManager *security.JWTManager,
|
||||
rateLimiter *security.LoginRateLimiter,
|
||||
twoFactorCipher *codec.ConfigCipher,
|
||||
) *AuthService {
|
||||
return &AuthService{users: users, configs: configs, jwtManager: jwtManager, rateLimiter: rateLimiter}
|
||||
return &AuthService{
|
||||
users: users,
|
||||
configs: configs,
|
||||
jwtManager: jwtManager,
|
||||
rateLimiter: rateLimiter,
|
||||
twoFactorCipher: twoFactorCipher,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) SetAuditService(auditService *AuditService) {
|
||||
s.auditService = auditService
|
||||
}
|
||||
|
||||
func (s *AuthService) SetNotificationService(notificationService *NotificationService) {
|
||||
s.notificationService = notificationService
|
||||
}
|
||||
|
||||
func (s *AuthService) SetupStatus(ctx context.Context) (bool, error) {
|
||||
count, err := s.users.Count(ctx)
|
||||
if err != nil {
|
||||
@@ -130,7 +162,7 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
Category: "auth", Action: "login_failed",
|
||||
Detail: fmt.Sprintf("用户名不存在: %s", strings.TrimSpace(input.Username)),
|
||||
Detail: fmt.Sprintf("用户名不存在: %s", strings.TrimSpace(input.Username)),
|
||||
ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
@@ -156,6 +188,20 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
|
||||
}
|
||||
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", err)
|
||||
}
|
||||
mfaRequired := userMFAEnabled(user)
|
||||
trustedDeviceUsed := false
|
||||
if mfaRequired {
|
||||
trusted, err := s.verifyTrustedDevice(ctx, user, input.TrustedDeviceToken, clientKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
trustedDeviceUsed = trusted
|
||||
if !trusted {
|
||||
if err := s.verifyLoginMFA(ctx, user, input, clientKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.rateLimiter.Reset(clientKey)
|
||||
token, err := s.jwtManager.Generate(user)
|
||||
@@ -163,6 +209,16 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
|
||||
return nil, apperror.Internal("AUTH_TOKEN_FAILED", "无法生成访问令牌", err)
|
||||
}
|
||||
|
||||
payload := &AuthPayload{Token: token, User: ToUserOutput(user)}
|
||||
if mfaRequired && !trustedDeviceUsed && input.RememberDevice {
|
||||
deviceToken, device, err := s.issueTrustedDevice(ctx, user, input.TrustedDeviceName, clientKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload.TrustedDeviceToken = deviceToken
|
||||
payload.TrustedDevice = device
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
@@ -171,10 +227,72 @@ func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey str
|
||||
})
|
||||
}
|
||||
|
||||
return &AuthPayload{Token: token, User: ToUserOutput(user)}, nil
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*UserOutput, error) {
|
||||
func (s *AuthService) verifyLoginMFA(ctx context.Context, user *model.User, input LoginInput, clientKey string) error {
|
||||
if input.WebAuthnAssertion != nil {
|
||||
if err := s.VerifyWebAuthnLogin(ctx, user, *input.WebAuthnAssertion, clientKey); err != nil {
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "login_failed",
|
||||
Detail: "通行密钥校验失败", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
code := strings.TrimSpace(input.TwoFactorCode)
|
||||
if code == "" {
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "two_factor_required",
|
||||
Detail: "登录需要多因素验证", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return apperror.Unauthorized("AUTH_2FA_REQUIRED", "请输入验证码、恢复码或使用通行密钥", nil)
|
||||
}
|
||||
if user.TwoFactorEnabled {
|
||||
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
|
||||
}
|
||||
ok, err := security.ValidateTOTPCode(secret, code)
|
||||
if err == nil && ok {
|
||||
return nil
|
||||
}
|
||||
if consumed, err := s.consumeRecoveryCode(ctx, user, code); err != nil {
|
||||
return err
|
||||
} else if consumed {
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "two_factor_recovery_code_used",
|
||||
Detail: "使用恢复码完成登录", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if consumed, err := s.consumeOutOfBandOTP(ctx, user, code, clientKey); err != nil {
|
||||
return err
|
||||
} else if consumed {
|
||||
return nil
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "login_failed",
|
||||
Detail: "多因素验证码错误", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return apperror.Unauthorized("AUTH_2FA_INVALID", "验证码、恢复码或通行密钥错误", nil)
|
||||
}
|
||||
|
||||
func (s *AuthService) userBySubject(ctx context.Context, subject string) (*model.User, error) {
|
||||
userID, err := strconv.ParseUint(subject, 10, 64)
|
||||
if err != nil {
|
||||
return nil, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err)
|
||||
@@ -186,6 +304,14 @@ func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*User
|
||||
if user == nil {
|
||||
return nil, apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found"))
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*UserOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ToUserOutput(user), nil
|
||||
}
|
||||
|
||||
@@ -195,16 +321,9 @@ type ChangePasswordInput struct {
|
||||
}
|
||||
|
||||
func (s *AuthService) ChangePassword(ctx context.Context, subject string, input ChangePasswordInput) error {
|
||||
userID, err := strconv.ParseUint(subject, 10, 64)
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err)
|
||||
}
|
||||
user, err := s.users.FindByID(ctx, uint(userID))
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_LOOKUP_FAILED", "无法获取当前用户", err)
|
||||
}
|
||||
if user == nil {
|
||||
return apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found"))
|
||||
return err
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.OldPassword); err != nil {
|
||||
return apperror.BadRequest("AUTH_WRONG_PASSWORD", "旧密码不正确", err)
|
||||
@@ -214,6 +333,9 @@ func (s *AuthService) ChangePassword(ctx context.Context, subject string, input
|
||||
return apperror.Internal("AUTH_HASH_FAILED", "无法处理密码", err)
|
||||
}
|
||||
user.PasswordHash = hash
|
||||
user.TrustedDevices = ""
|
||||
user.OutOfBandOTPCiphertext = ""
|
||||
user.WebAuthnChallengeCiphertext = ""
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return apperror.Internal("AUTH_UPDATE_FAILED", "密码修改失败", err)
|
||||
}
|
||||
@@ -229,15 +351,338 @@ func (s *AuthService) ChangePassword(ctx context.Context, subject string, input
|
||||
return nil
|
||||
}
|
||||
|
||||
type TwoFactorSetupInput struct {
|
||||
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
type TwoFactorSetupOutput struct {
|
||||
Secret string `json:"secret"`
|
||||
OTPAuthURL string `json:"otpAuthUrl"`
|
||||
QRCodeDataURL string `json:"qrCodeDataUrl"`
|
||||
TwoFactorEnabled bool `json:"twoFactorEnabled"`
|
||||
TwoFactorConfirmed bool `json:"twoFactorConfirmed"`
|
||||
}
|
||||
|
||||
type EnableTwoFactorInput struct {
|
||||
Code string `json:"code" binding:"required,min=6,max=10"`
|
||||
}
|
||||
|
||||
type EnableTwoFactorOutput struct {
|
||||
User *UserOutput `json:"user"`
|
||||
RecoveryCodes []string `json:"recoveryCodes"`
|
||||
}
|
||||
|
||||
type DisableTwoFactorInput struct {
|
||||
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
|
||||
Code string `json:"code" binding:"required,min=6,max=32"`
|
||||
}
|
||||
|
||||
type RegenerateRecoveryCodesInput struct {
|
||||
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
|
||||
Code string `json:"code" binding:"required,min=6,max=10"`
|
||||
}
|
||||
|
||||
type RecoveryCodesOutput struct {
|
||||
User *UserOutput `json:"user"`
|
||||
RecoveryCodes []string `json:"recoveryCodes"`
|
||||
}
|
||||
|
||||
func (s *AuthService) PrepareTwoFactor(ctx context.Context, subject string, input TwoFactorSetupInput) (*TwoFactorSetupOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user.TwoFactorEnabled {
|
||||
return nil, apperror.Conflict("AUTH_2FA_ALREADY_ENABLED", "TOTP 已启用", nil)
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
|
||||
}
|
||||
|
||||
enrollment, err := security.GenerateTOTPEnrollment(user.Username)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_SETUP_FAILED", "无法生成 TOTP 密钥", err)
|
||||
}
|
||||
ciphertext, err := s.encryptTwoFactorSecret(enrollment.Secret)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_SAVE_FAILED", "无法保存 TOTP 密钥", err)
|
||||
}
|
||||
user.TwoFactorSecretCiphertext = ciphertext
|
||||
user.TwoFactorEnabled = false
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_SAVE_FAILED", "无法保存 TOTP 密钥", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "two_factor_setup",
|
||||
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
|
||||
Detail: "生成 TOTP 密钥",
|
||||
})
|
||||
}
|
||||
|
||||
return &TwoFactorSetupOutput{
|
||||
Secret: enrollment.Secret,
|
||||
OTPAuthURL: enrollment.OTPAuthURL,
|
||||
QRCodeDataURL: enrollment.QRCodeDataURL,
|
||||
TwoFactorEnabled: false,
|
||||
TwoFactorConfirmed: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) EnableTwoFactor(ctx context.Context, subject string, input EnableTwoFactorInput) (*EnableTwoFactorOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user.TwoFactorEnabled {
|
||||
return nil, apperror.Conflict("AUTH_2FA_ALREADY_ENABLED", "TOTP 已启用", nil)
|
||||
}
|
||||
if strings.TrimSpace(user.TwoFactorSecretCiphertext) == "" {
|
||||
return nil, apperror.BadRequest("AUTH_2FA_NOT_PREPARED", "请先生成 TOTP 密钥", nil)
|
||||
}
|
||||
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
|
||||
}
|
||||
ok, err := security.ValidateTOTPCode(secret, input.Code)
|
||||
if err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil)
|
||||
}
|
||||
recoveryCodes, recoveryHashes, err := s.generateRecoveryCodeHashes()
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法生成恢复码", err)
|
||||
}
|
||||
|
||||
user.TwoFactorEnabled = true
|
||||
user.TwoFactorRecoveryCodeHashes = recoveryHashes
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_ENABLE_FAILED", "无法启用 TOTP", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "two_factor_enable",
|
||||
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
|
||||
Detail: "启用 TOTP",
|
||||
})
|
||||
}
|
||||
return &EnableTwoFactorOutput{User: ToUserOutput(user), RecoveryCodes: recoveryCodes}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) DisableTwoFactor(ctx context.Context, subject string, input DisableTwoFactorInput) (*UserOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !user.TwoFactorEnabled {
|
||||
return nil, apperror.BadRequest("AUTH_2FA_NOT_ENABLED", "TOTP 未启用", nil)
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
|
||||
}
|
||||
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
|
||||
}
|
||||
ok, err := security.ValidateTOTPCode(secret, input.Code)
|
||||
if err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil)
|
||||
}
|
||||
|
||||
user.TwoFactorEnabled = false
|
||||
user.TwoFactorSecretCiphertext = ""
|
||||
user.TwoFactorRecoveryCodeHashes = ""
|
||||
clearTrustedDevicesIfMFAOff(user)
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_DISABLE_FAILED", "无法关闭 TOTP", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "two_factor_disable",
|
||||
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
|
||||
Detail: "关闭 TOTP",
|
||||
})
|
||||
}
|
||||
return ToUserOutput(user), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) verifyCurrentTOTP(user *model.User, code string) error {
|
||||
secret, err := s.decryptTwoFactorSecret(user.TwoFactorSecretCiphertext)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_2FA_SECRET_INVALID", "TOTP 配置异常", err)
|
||||
}
|
||||
ok, err := security.ValidateTOTPCode(secret, code)
|
||||
if err != nil {
|
||||
return apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码格式不正确", err)
|
||||
}
|
||||
if !ok {
|
||||
return apperror.BadRequest("AUTH_2FA_INVALID", "TOTP 验证码错误", nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) RegenerateRecoveryCodes(ctx context.Context, subject string, input RegenerateRecoveryCodesInput) (*RecoveryCodesOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !user.TwoFactorEnabled {
|
||||
return nil, apperror.BadRequest("AUTH_2FA_NOT_ENABLED", "TOTP 未启用", nil)
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
|
||||
}
|
||||
if err := s.verifyCurrentTOTP(user, input.Code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
recoveryCodes, recoveryHashes, err := s.generateRecoveryCodeHashes()
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法生成恢复码", err)
|
||||
}
|
||||
user.TwoFactorRecoveryCodeHashes = recoveryHashes
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_2FA_RECOVERY_FAILED", "无法更新恢复码", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "two_factor_recovery_codes_regenerate",
|
||||
TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), TargetName: user.Username,
|
||||
Detail: "重新生成 TOTP 恢复码",
|
||||
})
|
||||
}
|
||||
return &RecoveryCodesOutput{User: ToUserOutput(user), RecoveryCodes: recoveryCodes}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) generateRecoveryCodeHashes() ([]string, string, error) {
|
||||
codes, err := security.GenerateRecoveryCodes(security.RecoveryCodeCount)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
hashes := make([]string, 0, len(codes))
|
||||
for _, code := range codes {
|
||||
hash, err := security.HashPassword(security.NormalizeRecoveryCode(code))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
hashes = append(hashes, hash)
|
||||
}
|
||||
encoded, err := encodeRecoveryCodeHashes(hashes)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return codes, encoded, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) consumeRecoveryCode(ctx context.Context, user *model.User, code string) (bool, error) {
|
||||
if !security.IsRecoveryCodeCandidate(code) {
|
||||
return false, nil
|
||||
}
|
||||
hashes, err := parseRecoveryCodeHashes(user.TwoFactorRecoveryCodeHashes)
|
||||
if err != nil {
|
||||
return false, apperror.Internal("AUTH_2FA_RECOVERY_INVALID", "恢复码配置异常", err)
|
||||
}
|
||||
if len(hashes) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
normalized := security.NormalizeRecoveryCode(code)
|
||||
for i, hash := range hashes {
|
||||
if security.ComparePassword(hash, normalized) != nil {
|
||||
continue
|
||||
}
|
||||
hashes = append(hashes[:i], hashes[i+1:]...)
|
||||
encoded, err := encodeRecoveryCodeHashes(hashes)
|
||||
if err != nil {
|
||||
return false, apperror.Internal("AUTH_2FA_RECOVERY_INVALID", "恢复码配置异常", err)
|
||||
}
|
||||
user.TwoFactorRecoveryCodeHashes = encoded
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return false, apperror.Internal("AUTH_2FA_RECOVERY_CONSUME_FAILED", "无法使用恢复码", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) encryptTwoFactorSecret(secret string) (string, error) {
|
||||
if s.twoFactorCipher == nil {
|
||||
return "", errors.New("two-factor cipher is not configured")
|
||||
}
|
||||
return s.twoFactorCipher.Encrypt([]byte(strings.TrimSpace(secret)))
|
||||
}
|
||||
|
||||
func (s *AuthService) decryptTwoFactorSecret(ciphertext string) (string, error) {
|
||||
if s.twoFactorCipher == nil {
|
||||
return "", errors.New("two-factor cipher is not configured")
|
||||
}
|
||||
raw, err := s.twoFactorCipher.Decrypt(strings.TrimSpace(ciphertext))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(raw)), nil
|
||||
}
|
||||
|
||||
func parseRecoveryCodeHashes(encoded string) ([]string, error) {
|
||||
if strings.TrimSpace(encoded) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var hashes []string
|
||||
if err := json.Unmarshal([]byte(encoded), &hashes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hashes, nil
|
||||
}
|
||||
|
||||
func encodeRecoveryCodeHashes(hashes []string) (string, error) {
|
||||
if len(hashes) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
encoded, err := json.Marshal(hashes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
func recoveryCodeRemainingCount(user *model.User) int {
|
||||
if user == nil {
|
||||
return 0
|
||||
}
|
||||
hashes, err := parseRecoveryCodeHashes(user.TwoFactorRecoveryCodeHashes)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return len(hashes)
|
||||
}
|
||||
|
||||
func ToUserOutput(user *model.User) *UserOutput {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserOutput{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role,
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
DisplayName: user.DisplayName,
|
||||
Email: user.Email,
|
||||
Phone: user.Phone,
|
||||
Role: user.Role,
|
||||
MFAEnabled: userMFAEnabled(user),
|
||||
TwoFactorEnabled: user.TwoFactorEnabled,
|
||||
TwoFactorRecoveryCodesRemaining: recoveryCodeRemainingCount(user),
|
||||
WebAuthnEnabled: webAuthnCredentialCount(user) > 0,
|
||||
WebAuthnCredentialCount: webAuthnCredentialCount(user),
|
||||
TrustedDeviceCount: trustedDeviceCount(user),
|
||||
EmailOTPEnabled: user.EmailOTPEnabled,
|
||||
SMSOTPEnabled: user.SMSOTPEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/storage/codec"
|
||||
"github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
type fakeUserRepository struct {
|
||||
@@ -100,6 +103,7 @@ func TestAuthServiceSetupAndLogin(t *testing.T) {
|
||||
&fakeSystemConfigRepository{},
|
||||
security.NewJWTManager("test-secret", time.Hour),
|
||||
security.NewLoginRateLimiter(5, time.Minute),
|
||||
codec.NewConfigCipher("test-encryption-secret"),
|
||||
)
|
||||
|
||||
setupResult, err := service.Setup(context.Background(), SetupInput{
|
||||
@@ -133,6 +137,7 @@ func newTestAuthService() (*AuthService, *fakeUserRepository) {
|
||||
&fakeSystemConfigRepository{},
|
||||
security.NewJWTManager("test-secret", time.Hour),
|
||||
security.NewLoginRateLimiter(5, time.Minute),
|
||||
codec.NewConfigCipher("test-encryption-secret"),
|
||||
)
|
||||
return svc, users
|
||||
}
|
||||
@@ -188,3 +193,425 @@ func TestChangePasswordWrongOld(t *testing.T) {
|
||||
t.Fatalf("expected ChangePassword with wrong old password to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceLoginRequiresTwoFactorWhenEnabled(t *testing.T) {
|
||||
svc, _ := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
|
||||
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
|
||||
CurrentPassword: "password-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareTwoFactor: %v", err)
|
||||
}
|
||||
if setup.Secret == "" || setup.QRCodeDataURL == "" || setup.OTPAuthURL == "" {
|
||||
t.Fatalf("expected populated 2FA enrollment, got %#v", setup)
|
||||
}
|
||||
|
||||
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode: %v", err)
|
||||
}
|
||||
enabledUser, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code})
|
||||
if err != nil {
|
||||
t.Fatalf("EnableTwoFactor: %v", err)
|
||||
}
|
||||
if !enabledUser.User.TwoFactorEnabled {
|
||||
t.Fatalf("expected 2FA enabled")
|
||||
}
|
||||
if len(enabledUser.RecoveryCodes) != security.RecoveryCodeCount {
|
||||
t.Fatalf("expected %d recovery codes, got %d", security.RecoveryCodeCount, len(enabledUser.RecoveryCodes))
|
||||
}
|
||||
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123",
|
||||
}, "127.0.0.1")
|
||||
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_REQUIRED" {
|
||||
t.Fatalf("expected AUTH_2FA_REQUIRED, got %v", err)
|
||||
}
|
||||
|
||||
loginCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode login: %v", err)
|
||||
}
|
||||
loginResult, err := svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: loginCode,
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login with 2FA: %v", err)
|
||||
}
|
||||
if loginResult.Token == "" {
|
||||
t.Fatalf("expected non-empty token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceDisableTwoFactor(t *testing.T) {
|
||||
svc, _ := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
|
||||
CurrentPassword: "password-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareTwoFactor: %v", err)
|
||||
}
|
||||
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode: %v", err)
|
||||
}
|
||||
if _, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}); err != nil {
|
||||
t.Fatalf("EnableTwoFactor: %v", err)
|
||||
}
|
||||
|
||||
disableCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode disable: %v", err)
|
||||
}
|
||||
user, err := svc.DisableTwoFactor(context.Background(), "1", DisableTwoFactorInput{
|
||||
CurrentPassword: "password-123",
|
||||
Code: disableCode,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("DisableTwoFactor: %v", err)
|
||||
}
|
||||
if user.TwoFactorEnabled {
|
||||
t.Fatalf("expected 2FA disabled")
|
||||
}
|
||||
|
||||
loginResult, err := svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login after disable: %v", err)
|
||||
}
|
||||
if loginResult.Token == "" {
|
||||
t.Fatalf("expected non-empty token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceRecoveryCodeLoginConsumesCode(t *testing.T) {
|
||||
svc, _ := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
|
||||
CurrentPassword: "password-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareTwoFactor: %v", err)
|
||||
}
|
||||
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode: %v", err)
|
||||
}
|
||||
enabled, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code})
|
||||
if err != nil {
|
||||
t.Fatalf("EnableTwoFactor: %v", err)
|
||||
}
|
||||
recoveryCode := enabled.RecoveryCodes[0]
|
||||
|
||||
loginResult, err := svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: recoveryCode,
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login with recovery code: %v", err)
|
||||
}
|
||||
if loginResult.User.TwoFactorRecoveryCodesRemaining != security.RecoveryCodeCount-1 {
|
||||
t.Fatalf("expected one recovery code consumed, got remaining=%d", loginResult.User.TwoFactorRecoveryCodesRemaining)
|
||||
}
|
||||
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: recoveryCode,
|
||||
}, "127.0.0.1")
|
||||
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
|
||||
t.Fatalf("expected consumed recovery code to fail, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceRegenerateRecoveryCodesInvalidatesOldCodes(t *testing.T) {
|
||||
svc, _ := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
|
||||
CurrentPassword: "password-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareTwoFactor: %v", err)
|
||||
}
|
||||
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode: %v", err)
|
||||
}
|
||||
enabled, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code})
|
||||
if err != nil {
|
||||
t.Fatalf("EnableTwoFactor: %v", err)
|
||||
}
|
||||
oldRecoveryCode := enabled.RecoveryCodes[0]
|
||||
|
||||
regenerateCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode regenerate: %v", err)
|
||||
}
|
||||
regenerated, err := svc.RegenerateRecoveryCodes(context.Background(), "1", RegenerateRecoveryCodesInput{
|
||||
CurrentPassword: "password-123",
|
||||
Code: regenerateCode,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RegenerateRecoveryCodes: %v", err)
|
||||
}
|
||||
if len(regenerated.RecoveryCodes) != security.RecoveryCodeCount {
|
||||
t.Fatalf("expected %d recovery codes, got %d", security.RecoveryCodeCount, len(regenerated.RecoveryCodes))
|
||||
}
|
||||
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: oldRecoveryCode,
|
||||
}, "127.0.0.1")
|
||||
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
|
||||
t.Fatalf("expected old recovery code to fail, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceTrustedDeviceSkipsMFA(t *testing.T) {
|
||||
svc, repo := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
setup, err := svc.PrepareTwoFactor(context.Background(), "1", TwoFactorSetupInput{
|
||||
CurrentPassword: "password-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareTwoFactor: %v", err)
|
||||
}
|
||||
code, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode: %v", err)
|
||||
}
|
||||
if _, err := svc.EnableTwoFactor(context.Background(), "1", EnableTwoFactorInput{Code: code}); err != nil {
|
||||
t.Fatalf("EnableTwoFactor: %v", err)
|
||||
}
|
||||
loginCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode login: %v", err)
|
||||
}
|
||||
firstLogin, err := svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: loginCode,
|
||||
RememberDevice: true, TrustedDeviceName: "test browser",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login with 2FA: %v", err)
|
||||
}
|
||||
if firstLogin.TrustedDeviceToken == "" || firstLogin.TrustedDevice == nil {
|
||||
t.Fatalf("expected trusted device token")
|
||||
}
|
||||
secondLogin, err := svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TrustedDeviceToken: firstLogin.TrustedDeviceToken,
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login with trusted device: %v", err)
|
||||
}
|
||||
if secondLogin.Token == "" {
|
||||
t.Fatalf("expected token")
|
||||
}
|
||||
disableCode, err := totp.GenerateCode(setup.Secret, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode disable: %v", err)
|
||||
}
|
||||
if _, err := svc.DisableTwoFactor(context.Background(), "1", DisableTwoFactorInput{
|
||||
CurrentPassword: "password-123",
|
||||
Code: disableCode,
|
||||
}); err != nil {
|
||||
t.Fatalf("DisableTwoFactor: %v", err)
|
||||
}
|
||||
if repo.users[0].TrustedDevices != "" {
|
||||
t.Fatalf("expected trusted devices cleared after disabling last MFA method")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceOutOfBandOTPLoginConsumesCode(t *testing.T) {
|
||||
svc, repo := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
user := repo.users[0]
|
||||
user.Email = "admin@example.com"
|
||||
user.EmailOTPEnabled = true
|
||||
hash, err := security.HashPassword("123456")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword: %v", err)
|
||||
}
|
||||
ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{
|
||||
Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptJSON: %v", err)
|
||||
}
|
||||
user.OutOfBandOTPCiphertext = ciphertext
|
||||
if err := repo.Update(context.Background(), user); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
loginResult, err := svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: "123456",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login with email OTP: %v", err)
|
||||
}
|
||||
if loginResult.Token == "" {
|
||||
t.Fatalf("expected token")
|
||||
}
|
||||
if repo.users[0].OutOfBandOTPCiphertext != "" {
|
||||
t.Fatalf("expected OTP to be consumed")
|
||||
}
|
||||
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: "123456",
|
||||
}, "127.0.0.1")
|
||||
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
|
||||
t.Fatalf("expected consumed OTP to fail, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceMFAStartIsRateLimited(t *testing.T) {
|
||||
svc, repo := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
repo.users[0].Email = "admin@example.com"
|
||||
repo.users[0].EmailOTPEnabled = true
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
_ = svc.SendLoginOTP(context.Background(), LoginOTPInput{
|
||||
Username: "admin", Password: "wrong-password", Channel: "email",
|
||||
}, "127.0.0.1")
|
||||
}
|
||||
err = svc.SendLoginOTP(context.Background(), LoginOTPInput{
|
||||
Username: "admin", Password: "wrong-password", Channel: "email",
|
||||
}, "127.0.0.1")
|
||||
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_RATE_LIMITED" {
|
||||
t.Fatalf("expected AUTH_RATE_LIMITED, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceDisabledOTPChannelCannotConsumePendingCode(t *testing.T) {
|
||||
svc, repo := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
user := repo.users[0]
|
||||
user.Email = "admin@example.com"
|
||||
user.EmailOTPEnabled = false
|
||||
user.SMSOTPEnabled = true
|
||||
hash, err := security.HashPassword("123456")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword: %v", err)
|
||||
}
|
||||
ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{
|
||||
Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptJSON: %v", err)
|
||||
}
|
||||
user.OutOfBandOTPCiphertext = ciphertext
|
||||
if err := repo.Update(context.Background(), user); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123", TwoFactorCode: "123456",
|
||||
}, "127.0.0.1")
|
||||
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_INVALID" {
|
||||
t.Fatalf("expected disabled OTP channel to fail, got %v", err)
|
||||
}
|
||||
if repo.users[0].OutOfBandOTPCiphertext != "" {
|
||||
t.Fatalf("expected disabled channel OTP to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceChangingOTPRecipientClearsPendingCode(t *testing.T) {
|
||||
svc, repo := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
user := repo.users[0]
|
||||
user.Email = "old@example.com"
|
||||
user.EmailOTPEnabled = true
|
||||
hash, err := security.HashPassword("123456")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword: %v", err)
|
||||
}
|
||||
ciphertext, err := svc.twoFactorCipher.EncryptJSON(pendingOutOfBandOTP{
|
||||
Channel: "email", CodeHash: hash, ExpiresAt: time.Now().UTC().Add(time.Minute),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptJSON: %v", err)
|
||||
}
|
||||
user.OutOfBandOTPCiphertext = ciphertext
|
||||
if err := repo.Update(context.Background(), user); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
updated, err := svc.ConfigureOutOfBandOTP(context.Background(), "1", OTPConfigInput{
|
||||
CurrentPassword: "password-123",
|
||||
Channel: "email",
|
||||
Enabled: true,
|
||||
Email: "new@example.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConfigureOutOfBandOTP: %v", err)
|
||||
}
|
||||
if updated.Email != "new@example.com" {
|
||||
t.Fatalf("expected email updated, got %q", updated.Email)
|
||||
}
|
||||
if repo.users[0].OutOfBandOTPCiphertext != "" {
|
||||
t.Fatalf("expected pending email OTP to be cleared after recipient change")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceCorruptWebAuthnCredentialsStillRequireMFA(t *testing.T) {
|
||||
svc, repo := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
repo.users[0].WebAuthnCredentials = "{invalid-json"
|
||||
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123",
|
||||
}, "127.0.0.1")
|
||||
if appErr, ok := err.(*apperror.AppError); !ok || appErr.Code != "AUTH_2FA_REQUIRED" {
|
||||
t.Fatalf("expected corrupt WebAuthn credentials to require MFA, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
221
server/internal/service/auth_trusted_device.go
Normal file
221
server/internal/service/auth_trusted_device.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/security"
|
||||
)
|
||||
|
||||
func (s *AuthService) ListTrustedDevices(ctx context.Context, subject string) ([]TrustedDeviceOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
devices, err := parseTrustedDevices(user.TrustedDevices)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
output := make([]TrustedDeviceOutput, 0, len(devices))
|
||||
for _, device := range devices {
|
||||
if device.ExpiresAt.Before(now) {
|
||||
continue
|
||||
}
|
||||
output = append(output, toTrustedDeviceOutput(device))
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
type TrustedDeviceRevokeInput struct {
|
||||
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
func (s *AuthService) RevokeTrustedDevice(ctx context.Context, subject string, id string, input TrustedDeviceRevokeInput) error {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
|
||||
return apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
|
||||
}
|
||||
devices, err := parseTrustedDevices(user.TrustedDevices)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
found := false
|
||||
filtered := make([]TrustedDeviceRecord, 0, len(devices))
|
||||
for _, device := range devices {
|
||||
if device.ID == strings.TrimSpace(id) {
|
||||
found = true
|
||||
} else {
|
||||
filtered = append(filtered, device)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return apperror.New(404, "AUTH_TRUSTED_DEVICE_NOT_FOUND", "可信设备不存在", nil)
|
||||
}
|
||||
encoded, err := encodeTrustedDevices(filtered)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
user.TrustedDevices = encoded
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return apperror.Internal("AUTH_TRUSTED_DEVICE_REVOKE_FAILED", "无法移除可信设备", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "trusted_device_revoke",
|
||||
TargetType: "trusted_device", TargetID: strings.TrimSpace(id),
|
||||
Detail: "移除可信设备",
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) verifyTrustedDevice(ctx context.Context, user *model.User, token string, clientKey string) (bool, error) {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return false, nil
|
||||
}
|
||||
devices, err := parseTrustedDevices(user.TrustedDevices)
|
||||
if err != nil {
|
||||
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
hash := trustedDeviceTokenHash(token)
|
||||
changed := false
|
||||
for i := range devices {
|
||||
device := &devices[i]
|
||||
if device.ExpiresAt.Before(now) {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(device.TokenHash), []byte(hash)) != 1 {
|
||||
continue
|
||||
}
|
||||
device.LastUsedAt = now
|
||||
device.LastIP = clientKey
|
||||
changed = true
|
||||
encoded, err := encodeTrustedDevices(filterActiveTrustedDevices(devices, now))
|
||||
if err != nil {
|
||||
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
user.TrustedDevices = encoded
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_UPDATE_FAILED", "无法更新可信设备", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "trusted_device_used",
|
||||
TargetType: "trusted_device", TargetID: device.ID, TargetName: device.Name,
|
||||
Detail: "使用可信设备跳过多因素验证", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
if changed {
|
||||
encoded, err := encodeTrustedDevices(filterActiveTrustedDevices(devices, now))
|
||||
if err != nil {
|
||||
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
user.TrustedDevices = encoded
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return false, apperror.Internal("AUTH_TRUSTED_DEVICE_UPDATE_FAILED", "无法更新可信设备", err)
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) issueTrustedDevice(ctx context.Context, user *model.User, name string, clientKey string) (string, *TrustedDeviceOutput, error) {
|
||||
token, err := randomURLToken(32)
|
||||
if err != nil {
|
||||
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法生成可信设备令牌", err)
|
||||
}
|
||||
id, err := randomURLToken(16)
|
||||
if err != nil {
|
||||
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法生成可信设备编号", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
deviceName := normalizeTrustedDeviceName(name)
|
||||
device := TrustedDeviceRecord{
|
||||
ID: id,
|
||||
Name: deviceName,
|
||||
TokenHash: trustedDeviceTokenHash(token),
|
||||
CreatedAt: now,
|
||||
LastUsedAt: now,
|
||||
ExpiresAt: now.Add(trustedDeviceTTL),
|
||||
LastIP: clientKey,
|
||||
}
|
||||
devices, err := parseTrustedDevices(user.TrustedDevices)
|
||||
if err != nil {
|
||||
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
devices = append(filterActiveTrustedDevices(devices, now), device)
|
||||
if len(devices) > maxTrustedDevices {
|
||||
devices = devices[len(devices)-maxTrustedDevices:]
|
||||
}
|
||||
encoded, err := encodeTrustedDevices(devices)
|
||||
if err != nil {
|
||||
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_INVALID", "可信设备配置异常", err)
|
||||
}
|
||||
user.TrustedDevices = encoded
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return "", nil, apperror.Internal("AUTH_TRUSTED_DEVICE_CREATE_FAILED", "无法保存可信设备", err)
|
||||
}
|
||||
output := toTrustedDeviceOutput(device)
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "trusted_device_create",
|
||||
TargetType: "trusted_device", TargetID: device.ID, TargetName: device.Name,
|
||||
Detail: fmt.Sprintf("添加可信设备,有效期至 %s", device.ExpiresAt.Format(time.RFC3339)), ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return token, &output, nil
|
||||
}
|
||||
|
||||
func filterActiveTrustedDevices(devices []TrustedDeviceRecord, now time.Time) []TrustedDeviceRecord {
|
||||
active := make([]TrustedDeviceRecord, 0, len(devices))
|
||||
for _, device := range devices {
|
||||
if device.ExpiresAt.After(now) {
|
||||
active = append(active, device)
|
||||
}
|
||||
}
|
||||
return active
|
||||
}
|
||||
|
||||
func trustedDeviceTokenHash(token string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(token)))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func randomURLToken(size int) (string, error) {
|
||||
buf := make([]byte, size)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func normalizeTrustedDeviceName(name string) string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return "当前设备"
|
||||
}
|
||||
if len([]rune(trimmed)) <= maxTrustedDeviceName {
|
||||
return trimmed
|
||||
}
|
||||
runes := []rune(trimmed)
|
||||
return string(runes[:maxTrustedDeviceName])
|
||||
}
|
||||
366
server/internal/service/auth_webauthn.go
Normal file
366
server/internal/service/auth_webauthn.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/security"
|
||||
)
|
||||
|
||||
type WebAuthnRequestContext struct {
|
||||
RPID string
|
||||
Origin string
|
||||
}
|
||||
|
||||
type WebAuthnRegistrationOptionsInput struct {
|
||||
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
type WebAuthnRegistrationFinishInput struct {
|
||||
Name string `json:"name" binding:"omitempty,max=128"`
|
||||
Credential security.WebAuthnRegistrationResponse `json:"credential" binding:"required"`
|
||||
}
|
||||
|
||||
type WebAuthnCredentialDeleteInput struct {
|
||||
CurrentPassword string `json:"currentPassword" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
type WebAuthnLoginOptionsInput struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
type webAuthnPublicKeyCredentialParam struct {
|
||||
Type string `json:"type"`
|
||||
Alg int `json:"alg"`
|
||||
}
|
||||
|
||||
type webAuthnRelyingParty struct {
|
||||
Name string `json:"name"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type webAuthnUserEntity struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
|
||||
type webAuthnCredentialDescriptor struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type webAuthnAuthenticatorSelection struct {
|
||||
UserVerification string `json:"userVerification"`
|
||||
}
|
||||
|
||||
type WebAuthnRegistrationOptions struct {
|
||||
Challenge string `json:"challenge"`
|
||||
RP webAuthnRelyingParty `json:"rp"`
|
||||
User webAuthnUserEntity `json:"user"`
|
||||
PubKeyCredParams []webAuthnPublicKeyCredentialParam `json:"pubKeyCredParams"`
|
||||
Timeout int `json:"timeout"`
|
||||
Attestation string `json:"attestation"`
|
||||
AuthenticatorSelection webAuthnAuthenticatorSelection `json:"authenticatorSelection"`
|
||||
ExcludeCredentials []webAuthnCredentialDescriptor `json:"excludeCredentials"`
|
||||
}
|
||||
|
||||
type WebAuthnLoginOptions struct {
|
||||
Challenge string `json:"challenge"`
|
||||
RPID string `json:"rpId"`
|
||||
Timeout int `json:"timeout"`
|
||||
UserVerification string `json:"userVerification"`
|
||||
AllowCredentials []webAuthnCredentialDescriptor `json:"allowCredentials"`
|
||||
}
|
||||
|
||||
func (s *AuthService) BeginWebAuthnRegistration(ctx context.Context, subject string, input WebAuthnRegistrationOptionsInput, request WebAuthnRequestContext) (*WebAuthnRegistrationOptions, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
|
||||
}
|
||||
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
|
||||
}
|
||||
challenge, err := security.GenerateWebAuthnChallenge()
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法生成通行密钥挑战", err)
|
||||
}
|
||||
state := webAuthnChallengeState{
|
||||
Type: "register",
|
||||
Challenge: challenge,
|
||||
RPID: request.RPID,
|
||||
Origin: request.Origin,
|
||||
ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL),
|
||||
}
|
||||
if err := s.saveWebAuthnChallenge(ctx, user, state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exclude := make([]webAuthnCredentialDescriptor, 0, len(credentials))
|
||||
for _, credential := range credentials {
|
||||
exclude = append(exclude, webAuthnCredentialDescriptor{Type: "public-key", ID: credential.CredentialID})
|
||||
}
|
||||
return &WebAuthnRegistrationOptions{
|
||||
Challenge: challenge,
|
||||
RP: webAuthnRelyingParty{Name: "BackupX", ID: request.RPID},
|
||||
User: webAuthnUserEntity{
|
||||
ID: security.EncodeBase64URL([]byte(fmt.Sprintf("%d", user.ID))),
|
||||
Name: user.Username,
|
||||
DisplayName: user.DisplayName,
|
||||
},
|
||||
PubKeyCredParams: []webAuthnPublicKeyCredentialParam{
|
||||
{Type: "public-key", Alg: -7},
|
||||
},
|
||||
Timeout: int(mfaChallengeTTL / time.Millisecond),
|
||||
Attestation: "none",
|
||||
AuthenticatorSelection: webAuthnAuthenticatorSelection{UserVerification: "preferred"},
|
||||
ExcludeCredentials: exclude,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) FinishWebAuthnRegistration(ctx context.Context, subject string, input WebAuthnRegistrationFinishInput) (*UserOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state, err := s.loadWebAuthnChallenge(user, "register")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := security.VerifyWebAuthnRegistration(input.Credential, state.Challenge, state.RPID, state.Origin)
|
||||
if err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_WEBAUTHN_VERIFY_FAILED", "通行密钥注册校验失败", err)
|
||||
}
|
||||
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
|
||||
}
|
||||
for _, credential := range credentials {
|
||||
if credential.CredentialID == parsed.CredentialID {
|
||||
return nil, apperror.Conflict("AUTH_WEBAUTHN_EXISTS", "该通行密钥已注册", nil)
|
||||
}
|
||||
}
|
||||
id, err := randomURLToken(16)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法生成通行密钥编号", err)
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
name := strings.TrimSpace(input.Name)
|
||||
if name == "" {
|
||||
name = "通行密钥"
|
||||
}
|
||||
credentials = append(credentials, WebAuthnCredentialRecord{
|
||||
ID: id,
|
||||
Name: normalizeTrustedDeviceName(name),
|
||||
CredentialID: parsed.CredentialID,
|
||||
PublicKeyX: parsed.PublicKeyX,
|
||||
PublicKeyY: parsed.PublicKeyY,
|
||||
SignCount: parsed.SignCount,
|
||||
CreatedAt: now,
|
||||
})
|
||||
encoded, err := encodeWebAuthnCredentials(credentials)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法保存通行密钥", err)
|
||||
}
|
||||
user.WebAuthnCredentials = encoded
|
||||
user.WebAuthnChallengeCiphertext = ""
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法保存通行密钥", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "webauthn_register",
|
||||
TargetType: "webauthn_credential", TargetID: id, TargetName: name,
|
||||
Detail: "注册通行密钥",
|
||||
})
|
||||
}
|
||||
return ToUserOutput(user), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) BeginWebAuthnLogin(ctx context.Context, input WebAuthnLoginOptionsInput, request WebAuthnRequestContext, clientKey string) (*WebAuthnLoginOptions, error) {
|
||||
user, err := s.verifyPasswordForMFAStart(ctx, input.Username, input.Password, clientKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
|
||||
}
|
||||
if len(credentials) == 0 {
|
||||
return nil, apperror.BadRequest("AUTH_WEBAUTHN_NOT_ENABLED", "当前账号未注册通行密钥", nil)
|
||||
}
|
||||
challenge, err := security.GenerateWebAuthnChallenge()
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法生成通行密钥挑战", err)
|
||||
}
|
||||
state := webAuthnChallengeState{
|
||||
Type: "login",
|
||||
Challenge: challenge,
|
||||
RPID: request.RPID,
|
||||
Origin: request.Origin,
|
||||
ExpiresAt: time.Now().UTC().Add(mfaChallengeTTL),
|
||||
}
|
||||
if err := s.saveWebAuthnChallenge(ctx, user, state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowed := make([]webAuthnCredentialDescriptor, 0, len(credentials))
|
||||
for _, credential := range credentials {
|
||||
allowed = append(allowed, webAuthnCredentialDescriptor{Type: "public-key", ID: credential.CredentialID})
|
||||
}
|
||||
return &WebAuthnLoginOptions{
|
||||
Challenge: challenge,
|
||||
RPID: request.RPID,
|
||||
Timeout: int(mfaChallengeTTL / time.Millisecond),
|
||||
UserVerification: "preferred",
|
||||
AllowCredentials: allowed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) VerifyWebAuthnLogin(ctx context.Context, user *model.User, assertion security.WebAuthnLoginAssertion, clientKey string) error {
|
||||
state, err := s.loadWebAuthnChallenge(user, "login")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
|
||||
}
|
||||
rawID := strings.TrimSpace(assertion.RawID)
|
||||
if rawID == "" {
|
||||
rawID = strings.TrimSpace(assertion.ID)
|
||||
}
|
||||
for i := range credentials {
|
||||
credential := &credentials[i]
|
||||
if credential.CredentialID != rawID {
|
||||
continue
|
||||
}
|
||||
nextSignCount, err := security.VerifyWebAuthnAssertion(assertion, state.Challenge, state.RPID, state.Origin, security.WebAuthnCredentialMaterial{
|
||||
CredentialID: credential.CredentialID,
|
||||
PublicKeyX: credential.PublicKeyX,
|
||||
PublicKeyY: credential.PublicKeyY,
|
||||
SignCount: credential.SignCount,
|
||||
})
|
||||
if err != nil {
|
||||
return apperror.Unauthorized("AUTH_WEBAUTHN_INVALID", "通行密钥校验失败", err)
|
||||
}
|
||||
credential.SignCount = nextSignCount
|
||||
credential.LastUsedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
encoded, err := encodeWebAuthnCredentials(credentials)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err)
|
||||
}
|
||||
user.WebAuthnCredentials = encoded
|
||||
user.WebAuthnChallengeCiphertext = ""
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "webauthn_used",
|
||||
TargetType: "webauthn_credential", TargetID: credential.ID, TargetName: credential.Name,
|
||||
Detail: "使用通行密钥完成多因素验证", ClientIP: clientKey,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return apperror.Unauthorized("AUTH_WEBAUTHN_INVALID", "通行密钥不存在", nil)
|
||||
}
|
||||
|
||||
func (s *AuthService) ListWebAuthnCredentials(ctx context.Context, subject string) ([]WebAuthnCredentialOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
|
||||
}
|
||||
output := make([]WebAuthnCredentialOutput, 0, len(credentials))
|
||||
for _, credential := range credentials {
|
||||
output = append(output, toWebAuthnCredentialOutput(credential))
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) DeleteWebAuthnCredential(ctx context.Context, subject string, id string, input WebAuthnCredentialDeleteInput) (*UserOutput, error) {
|
||||
user, err := s.userBySubject(ctx, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.CurrentPassword); err != nil {
|
||||
return nil, apperror.BadRequest("AUTH_WRONG_PASSWORD", "当前密码不正确", err)
|
||||
}
|
||||
credentials, err := parseWebAuthnCredentials(user.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_INVALID", "通行密钥配置异常", err)
|
||||
}
|
||||
found := false
|
||||
filtered := make([]WebAuthnCredentialRecord, 0, len(credentials))
|
||||
for _, credential := range credentials {
|
||||
if credential.ID == strings.TrimSpace(id) {
|
||||
found = true
|
||||
} else {
|
||||
filtered = append(filtered, credential)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, apperror.New(404, "AUTH_WEBAUTHN_NOT_FOUND", "通行密钥不存在", nil)
|
||||
}
|
||||
encoded, err := encodeWebAuthnCredentials(filtered)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_SAVE_FAILED", "无法更新通行密钥", err)
|
||||
}
|
||||
user.WebAuthnCredentials = encoded
|
||||
clearTrustedDevicesIfMFAOff(user)
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_DELETE_FAILED", "无法删除通行密钥", err)
|
||||
}
|
||||
if s.auditService != nil {
|
||||
s.auditService.Record(AuditEntry{
|
||||
UserID: user.ID, Username: user.Username,
|
||||
Category: "auth", Action: "webauthn_delete",
|
||||
TargetType: "webauthn_credential", TargetID: strings.TrimSpace(id),
|
||||
Detail: "删除通行密钥",
|
||||
})
|
||||
}
|
||||
return ToUserOutput(user), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) saveWebAuthnChallenge(ctx context.Context, user *model.User, state webAuthnChallengeState) error {
|
||||
ciphertext, err := s.twoFactorCipher.EncryptJSON(state)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法保存通行密钥挑战", err)
|
||||
}
|
||||
user.WebAuthnChallengeCiphertext = ciphertext
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_FAILED", "无法保存通行密钥挑战", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) loadWebAuthnChallenge(user *model.User, challengeType string) (*webAuthnChallengeState, error) {
|
||||
if strings.TrimSpace(user.WebAuthnChallengeCiphertext) == "" {
|
||||
return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_MISSING", "请先发起通行密钥验证", nil)
|
||||
}
|
||||
var state webAuthnChallengeState
|
||||
if err := s.twoFactorCipher.DecryptJSON(user.WebAuthnChallengeCiphertext, &state); err != nil {
|
||||
return nil, apperror.Internal("AUTH_WEBAUTHN_CHALLENGE_INVALID", "通行密钥挑战状态异常", err)
|
||||
}
|
||||
if state.Type != challengeType {
|
||||
return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_INVALID", "通行密钥挑战类型不匹配", nil)
|
||||
}
|
||||
if state.ExpiresAt.Before(time.Now().UTC()) {
|
||||
return nil, apperror.BadRequest("AUTH_WEBAUTHN_CHALLENGE_EXPIRED", "通行密钥挑战已过期", nil)
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
)
|
||||
|
||||
type NotificationUpsertInput struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=email webhook telegram"`
|
||||
Enabled bool `json:"enabled"`
|
||||
OnSuccess bool `json:"onSuccess"`
|
||||
OnFailure bool `json:"onFailure"`
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=email webhook telegram"`
|
||||
Enabled bool `json:"enabled"`
|
||||
OnSuccess bool `json:"onSuccess"`
|
||||
OnFailure bool `json:"onFailure"`
|
||||
// EventTypes 订阅的扩展事件列表。与 OnSuccess/OnFailure 并存:
|
||||
// - 两者均空时,订阅"备份成功/失败"对应原有语义(兼容)。
|
||||
// - EventTypes 显式指定时优先按清单匹配。
|
||||
@@ -186,8 +186,8 @@ func (s *NotificationService) NotifyBackupResult(ctx context.Context, event Back
|
||||
// - eventType 对应 model.NotificationEvent* 常量,用于订阅匹配
|
||||
//
|
||||
// 订阅匹配规则:
|
||||
// 1) notification.EventTypes 非空:必须包含 eventType
|
||||
// 2) notification.EventTypes 为空:沿用 OnSuccess/OnFailure 开关(仅 backup_* 事件)
|
||||
// 1. notification.EventTypes 非空:必须包含 eventType
|
||||
// 2. notification.EventTypes 为空:沿用 OnSuccess/OnFailure 开关(仅 backup_* 事件)
|
||||
func (s *NotificationService) DispatchEvent(ctx context.Context, eventType string, title string, body string, fields map[string]any) error {
|
||||
// 同步广播到 SSE 订阅者(前端 Dashboard 实时推送)。
|
||||
// 非阻塞:即便广播器未注入或订阅者已满也不影响 Notification 持久渠道。
|
||||
@@ -209,6 +209,49 @@ func (s *NotificationService) DispatchEvent(ctx context.Context, eventType strin
|
||||
return s.deliver(ctx, items, message)
|
||||
}
|
||||
|
||||
func (s *NotificationService) SendAuthEmailOTP(ctx context.Context, to string, code string) error {
|
||||
return s.sendFirstByType(ctx, "email", map[string]any{"to": strings.TrimSpace(to)}, notify.Message{
|
||||
Title: "BackupX 登录验证码",
|
||||
Body: fmt.Sprintf("您的 BackupX 登录验证码为:%s\n验证码 5 分钟内有效。若非本人操作,请立即检查账号安全。", code),
|
||||
Fields: map[string]any{
|
||||
"purpose": "login_otp",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *NotificationService) SendAuthSMSOTP(ctx context.Context, phone string, code string) error {
|
||||
return s.sendFirstByType(ctx, "webhook", nil, notify.Message{
|
||||
Title: "BackupX 登录验证码",
|
||||
Body: fmt.Sprintf("BackupX 登录验证码:%s,5 分钟内有效。", code),
|
||||
Fields: map[string]any{
|
||||
"phone": strings.TrimSpace(phone),
|
||||
"code": code,
|
||||
"purpose": "login_otp",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *NotificationService) sendFirstByType(ctx context.Context, notificationType string, override map[string]any, message notify.Message) error {
|
||||
items, err := s.notifications.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range items {
|
||||
if !item.Enabled || item.Type != notificationType {
|
||||
continue
|
||||
}
|
||||
configMap := map[string]any{}
|
||||
if err := s.cipher.DecryptJSON(item.ConfigCiphertext, &configMap); err != nil {
|
||||
return fmt.Errorf("decrypt notification %d config: %w", item.ID, err)
|
||||
}
|
||||
for key, value := range override {
|
||||
configMap[key] = value
|
||||
}
|
||||
return s.registry.Send(ctx, item.Type, configMap, message)
|
||||
}
|
||||
return fmt.Errorf("no enabled %s notification configured", notificationType)
|
||||
}
|
||||
|
||||
// collectSubscribers 按事件类型收集启用的订阅者。
|
||||
// 列出启用通知后按事件类型再过滤(避免引入新 repository 方法)。
|
||||
func (s *NotificationService) collectSubscribers(ctx context.Context, eventType string, fallbackSuccess bool) ([]model.Notification, error) {
|
||||
|
||||
@@ -22,13 +22,22 @@ func NewUserService(users repository.UserRepository) *UserService {
|
||||
|
||||
// UserSummary 用户列表项(不含密码哈希)。
|
||||
type UserSummary struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Disabled bool `json:"disabled"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
Role string `json:"role"`
|
||||
Disabled bool `json:"disabled"`
|
||||
MFAEnabled bool `json:"mfaEnabled"`
|
||||
TwoFactorEnabled bool `json:"twoFactorEnabled"`
|
||||
TwoFactorRecoveryCodesRemaining int `json:"twoFactorRecoveryCodesRemaining"`
|
||||
WebAuthnEnabled bool `json:"webAuthnEnabled"`
|
||||
WebAuthnCredentialCount int `json:"webAuthnCredentialCount"`
|
||||
TrustedDeviceCount int `json:"trustedDeviceCount"`
|
||||
EmailOTPEnabled bool `json:"emailOtpEnabled"`
|
||||
SMSOTPEnabled bool `json:"smsOtpEnabled"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}
|
||||
|
||||
// UserUpsertInput 创建/更新用户的输入。
|
||||
@@ -37,6 +46,7 @@ type UserUpsertInput struct {
|
||||
Password string `json:"password" binding:"omitempty,min=8,max=128"`
|
||||
DisplayName string `json:"displayName" binding:"required,min=1,max=128"`
|
||||
Email string `json:"email" binding:"omitempty,max=255"`
|
||||
Phone string `json:"phone" binding:"omitempty,max=64"`
|
||||
Role string `json:"role" binding:"required,oneof=admin operator viewer"`
|
||||
Disabled bool `json:"disabled"`
|
||||
}
|
||||
@@ -76,6 +86,7 @@ func (s *UserService) Create(ctx context.Context, input UserUpsertInput) (*UserS
|
||||
PasswordHash: hash,
|
||||
DisplayName: strings.TrimSpace(input.DisplayName),
|
||||
Email: strings.TrimSpace(input.Email),
|
||||
Phone: strings.TrimSpace(input.Phone),
|
||||
Role: input.Role,
|
||||
Disabled: input.Disabled,
|
||||
}
|
||||
@@ -107,18 +118,43 @@ func (s *UserService) Update(ctx context.Context, id uint, input UserUpsertInput
|
||||
return nil, apperror.Conflict("USER_USERNAME_EXISTS", "用户名已存在", nil)
|
||||
}
|
||||
}
|
||||
passwordChanged := strings.TrimSpace(input.Password) != ""
|
||||
disabledChanged := input.Disabled && !existing.Disabled
|
||||
emailChanged := strings.TrimSpace(input.Email) != strings.TrimSpace(existing.Email)
|
||||
phoneChanged := strings.TrimSpace(input.Phone) != strings.TrimSpace(existing.Phone)
|
||||
existing.Username = strings.TrimSpace(input.Username)
|
||||
existing.DisplayName = strings.TrimSpace(input.DisplayName)
|
||||
existing.Email = strings.TrimSpace(input.Email)
|
||||
existing.Phone = strings.TrimSpace(input.Phone)
|
||||
existing.Role = input.Role
|
||||
existing.Disabled = input.Disabled
|
||||
if strings.TrimSpace(input.Password) != "" {
|
||||
if passwordChanged {
|
||||
hash, err := security.HashPassword(input.Password)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("USER_HASH_FAILED", "无法处理密码", err)
|
||||
}
|
||||
existing.PasswordHash = hash
|
||||
existing.TrustedDevices = ""
|
||||
existing.OutOfBandOTPCiphertext = ""
|
||||
existing.WebAuthnChallengeCiphertext = ""
|
||||
}
|
||||
if strings.TrimSpace(existing.Email) == "" && existing.EmailOTPEnabled {
|
||||
existing.EmailOTPEnabled = false
|
||||
existing.OutOfBandOTPCiphertext = ""
|
||||
}
|
||||
if strings.TrimSpace(existing.Phone) == "" && existing.SMSOTPEnabled {
|
||||
existing.SMSOTPEnabled = false
|
||||
existing.OutOfBandOTPCiphertext = ""
|
||||
}
|
||||
if emailChanged || phoneChanged {
|
||||
existing.OutOfBandOTPCiphertext = ""
|
||||
}
|
||||
if disabledChanged {
|
||||
existing.TrustedDevices = ""
|
||||
existing.OutOfBandOTPCiphertext = ""
|
||||
existing.WebAuthnChallengeCiphertext = ""
|
||||
}
|
||||
clearTrustedDevicesIfMFAOff(existing)
|
||||
if err := s.users.Update(ctx, existing); err != nil {
|
||||
return nil, apperror.Internal("USER_UPDATE_FAILED", "无法更新用户", err)
|
||||
}
|
||||
@@ -147,14 +183,47 @@ func (s *UserService) Delete(ctx context.Context, id uint) error {
|
||||
return s.users.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *UserService) ResetTwoFactor(ctx context.Context, id uint) (*UserSummary, error) {
|
||||
existing, err := s.users.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("USER_GET_FAILED", "无法获取用户", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil, apperror.New(404, "USER_NOT_FOUND", "用户不存在", nil)
|
||||
}
|
||||
existing.TwoFactorEnabled = false
|
||||
existing.TwoFactorSecretCiphertext = ""
|
||||
existing.TwoFactorRecoveryCodeHashes = ""
|
||||
existing.WebAuthnCredentials = ""
|
||||
existing.WebAuthnChallengeCiphertext = ""
|
||||
existing.TrustedDevices = ""
|
||||
existing.EmailOTPEnabled = false
|
||||
existing.SMSOTPEnabled = false
|
||||
existing.OutOfBandOTPCiphertext = ""
|
||||
if err := s.users.Update(ctx, existing); err != nil {
|
||||
return nil, apperror.Internal("USER_2FA_RESET_FAILED", "无法重置 MFA", err)
|
||||
}
|
||||
summary := toUserSummary(existing)
|
||||
return &summary, nil
|
||||
}
|
||||
|
||||
func toUserSummary(u *model.User) UserSummary {
|
||||
return UserSummary{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
DisplayName: u.DisplayName,
|
||||
Email: u.Email,
|
||||
Role: u.Role,
|
||||
Disabled: u.Disabled,
|
||||
CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
DisplayName: u.DisplayName,
|
||||
Email: u.Email,
|
||||
Phone: u.Phone,
|
||||
Role: u.Role,
|
||||
Disabled: u.Disabled,
|
||||
MFAEnabled: userMFAEnabled(u),
|
||||
TwoFactorEnabled: u.TwoFactorEnabled,
|
||||
TwoFactorRecoveryCodesRemaining: recoveryCodeRemainingCount(u),
|
||||
WebAuthnEnabled: webAuthnCredentialCount(u) > 0,
|
||||
WebAuthnCredentialCount: webAuthnCredentialCount(u),
|
||||
TrustedDeviceCount: trustedDeviceCount(u),
|
||||
EmailOTPEnabled: u.EmailOTPEnabled,
|
||||
SMSOTPEnabled: u.SMSOTPEnabled,
|
||||
CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
}
|
||||
|
||||
124
server/internal/service/user_service_test.go
Normal file
124
server/internal/service/user_service_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/security"
|
||||
)
|
||||
|
||||
func TestUserServiceUpdatePasswordClearsTrustedDeviceState(t *testing.T) {
|
||||
hash, err := security.HashPassword("old-password")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword: %v", err)
|
||||
}
|
||||
repo := &fakeUserRepository{users: []*model.User{{
|
||||
ID: 1,
|
||||
Username: "admin",
|
||||
PasswordHash: hash,
|
||||
DisplayName: "Admin",
|
||||
Email: "admin@example.com",
|
||||
Role: model.UserRoleAdmin,
|
||||
TwoFactorEnabled: true,
|
||||
TrustedDevices: `[{"id":"device"}]`,
|
||||
OutOfBandOTPCiphertext: "pending",
|
||||
WebAuthnChallengeCiphertext: "challenge",
|
||||
}}}
|
||||
svc := NewUserService(repo)
|
||||
|
||||
if _, err := svc.Update(context.Background(), 1, UserUpsertInput{
|
||||
Username: "admin",
|
||||
Password: "new-password",
|
||||
DisplayName: "Admin",
|
||||
Email: "admin@example.com",
|
||||
Role: model.UserRoleAdmin,
|
||||
}); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
updated := repo.users[0]
|
||||
if security.ComparePassword(updated.PasswordHash, "new-password") != nil {
|
||||
t.Fatalf("expected password hash to be updated")
|
||||
}
|
||||
if updated.TrustedDevices != "" || updated.OutOfBandOTPCiphertext != "" || updated.WebAuthnChallengeCiphertext != "" {
|
||||
t.Fatalf("expected password update to clear trusted device state, got trusted=%q otp=%q challenge=%q", updated.TrustedDevices, updated.OutOfBandOTPCiphertext, updated.WebAuthnChallengeCiphertext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserServiceUpdateContactClearsUnavailableOTP(t *testing.T) {
|
||||
hash, err := security.HashPassword("password-123")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword: %v", err)
|
||||
}
|
||||
repo := &fakeUserRepository{users: []*model.User{{
|
||||
ID: 1,
|
||||
Username: "admin",
|
||||
PasswordHash: hash,
|
||||
DisplayName: "Admin",
|
||||
Email: "admin@example.com",
|
||||
Phone: "+15550000000",
|
||||
Role: model.UserRoleAdmin,
|
||||
EmailOTPEnabled: true,
|
||||
SMSOTPEnabled: true,
|
||||
TrustedDevices: `[{"id":"device"}]`,
|
||||
OutOfBandOTPCiphertext: "pending",
|
||||
}}}
|
||||
svc := NewUserService(repo)
|
||||
|
||||
summary, err := svc.Update(context.Background(), 1, UserUpsertInput{
|
||||
Username: "admin",
|
||||
DisplayName: "Admin",
|
||||
Role: model.UserRoleAdmin,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
updated := repo.users[0]
|
||||
if updated.EmailOTPEnabled || updated.SMSOTPEnabled || summary.MFAEnabled {
|
||||
t.Fatalf("expected unavailable OTP channels to be disabled")
|
||||
}
|
||||
if updated.TrustedDevices != "" || updated.OutOfBandOTPCiphertext != "" || updated.WebAuthnChallengeCiphertext != "" {
|
||||
t.Fatalf("expected last MFA removal to clear temporary state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserServiceUpdateContactChangeClearsPendingOTP(t *testing.T) {
|
||||
hash, err := security.HashPassword("password-123")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword: %v", err)
|
||||
}
|
||||
repo := &fakeUserRepository{users: []*model.User{{
|
||||
ID: 1,
|
||||
Username: "admin",
|
||||
PasswordHash: hash,
|
||||
DisplayName: "Admin",
|
||||
Email: "old@example.com",
|
||||
Role: model.UserRoleAdmin,
|
||||
EmailOTPEnabled: true,
|
||||
OutOfBandOTPCiphertext: "pending",
|
||||
}}}
|
||||
svc := NewUserService(repo)
|
||||
|
||||
summary, err := svc.Update(context.Background(), 1, UserUpsertInput{
|
||||
Username: "admin",
|
||||
DisplayName: "Admin",
|
||||
Email: "new@example.com",
|
||||
Role: model.UserRoleAdmin,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
updated := repo.users[0]
|
||||
if updated.Email != "new@example.com" || summary.Email != "new@example.com" {
|
||||
t.Fatalf("expected email to be updated")
|
||||
}
|
||||
if !updated.EmailOTPEnabled {
|
||||
t.Fatalf("expected email OTP to remain enabled")
|
||||
}
|
||||
if updated.OutOfBandOTPCiphertext != "" {
|
||||
t.Fatalf("expected contact change to clear pending OTP")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user