Files
BackupX/server/internal/security/webauthn.go
Wu Qing 63fde903d2 feat: add complete MFA support
Add complete MFA support with TOTP, recovery codes, WebAuthn, trusted-device cookie flow, and email/SMS OTP delivery via notification channels. Security follow-up: trusted device tokens are stored in HttpOnly cookies, and SMS OTP reuses the existing Webhook notifier to avoid introducing a new dynamic URL sink.
2026-04-25 22:14:50 +08:00

448 lines
12 KiB
Go

package security
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"math/big"
"strings"
)
const (
WebAuthnChallengeBytes = 32
)
type WebAuthnCredentialMaterial struct {
CredentialID string
PublicKeyX string
PublicKeyY string
SignCount uint32
}
type WebAuthnParsedCredential struct {
CredentialID string
PublicKeyX string
PublicKeyY string
SignCount uint32
}
type WebAuthnClientData struct {
Type string `json:"type"`
Challenge string `json:"challenge"`
Origin string `json:"origin"`
}
type WebAuthnAttestationResponse struct {
ClientDataJSON string `json:"clientDataJSON"`
AttestationObject string `json:"attestationObject"`
}
type WebAuthnRegistrationResponse struct {
ID string `json:"id"`
RawID string `json:"rawId"`
Type string `json:"type"`
Response WebAuthnAttestationResponse `json:"response"`
}
type WebAuthnAssertionResponse struct {
ClientDataJSON string `json:"clientDataJSON"`
AuthenticatorData string `json:"authenticatorData"`
Signature string `json:"signature"`
UserHandle string `json:"userHandle,omitempty"`
}
type WebAuthnLoginAssertion struct {
ID string `json:"id"`
RawID string `json:"rawId"`
Type string `json:"type"`
Response WebAuthnAssertionResponse `json:"response"`
}
func GenerateWebAuthnChallenge() (string, error) {
buf := make([]byte, WebAuthnChallengeBytes)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return EncodeBase64URL(buf), nil
}
func EncodeBase64URL(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
func DecodeBase64URL(value string) ([]byte, error) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return nil, errors.New("empty base64url value")
}
if decoded, err := base64.RawURLEncoding.DecodeString(trimmed); err == nil {
return decoded, nil
}
return base64.URLEncoding.DecodeString(trimmed)
}
func VerifyWebAuthnRegistration(input WebAuthnRegistrationResponse, challenge string, rpID string, expectedOrigin string) (*WebAuthnParsedCredential, error) {
if input.Type != "public-key" {
return nil, fmt.Errorf("unexpected credential type: %s", input.Type)
}
clientDataRaw, err := DecodeBase64URL(input.Response.ClientDataJSON)
if err != nil {
return nil, fmt.Errorf("decode client data: %w", err)
}
if err := validateWebAuthnClientData(clientDataRaw, "webauthn.create", challenge, expectedOrigin); err != nil {
return nil, err
}
attestationObject, err := DecodeBase64URL(input.Response.AttestationObject)
if err != nil {
return nil, fmt.Errorf("decode attestation object: %w", err)
}
parsed, err := parseCBORExact(attestationObject)
if err != nil {
return nil, fmt.Errorf("parse attestation object: %w", err)
}
attestationMap, ok := parsed.(map[any]any)
if !ok {
return nil, errors.New("attestation object is not a map")
}
authData, ok := attestationMap["authData"].([]byte)
if !ok {
return nil, errors.New("attestation authData is missing")
}
credential, err := parseAttestedCredentialData(authData, rpID)
if err != nil {
return nil, err
}
rawID := strings.TrimSpace(input.RawID)
if rawID == "" {
rawID = strings.TrimSpace(input.ID)
}
if rawID != "" && rawID != credential.CredentialID {
return nil, errors.New("credential raw id does not match attested credential id")
}
return credential, nil
}
func VerifyWebAuthnAssertion(input WebAuthnLoginAssertion, challenge string, rpID string, expectedOrigin string, credential WebAuthnCredentialMaterial) (uint32, error) {
if input.Type != "public-key" {
return 0, fmt.Errorf("unexpected credential type: %s", input.Type)
}
rawID := strings.TrimSpace(input.RawID)
if rawID == "" {
rawID = strings.TrimSpace(input.ID)
}
if rawID != credential.CredentialID {
return 0, errors.New("credential id does not match")
}
clientDataRaw, err := DecodeBase64URL(input.Response.ClientDataJSON)
if err != nil {
return 0, fmt.Errorf("decode client data: %w", err)
}
if err := validateWebAuthnClientData(clientDataRaw, "webauthn.get", challenge, expectedOrigin); err != nil {
return 0, err
}
authData, err := DecodeBase64URL(input.Response.AuthenticatorData)
if err != nil {
return 0, fmt.Errorf("decode authenticator data: %w", err)
}
signature, err := DecodeBase64URL(input.Response.Signature)
if err != nil {
return 0, fmt.Errorf("decode signature: %w", err)
}
signCount, err := parseAssertionAuthenticatorData(authData, rpID, credential.SignCount)
if err != nil {
return 0, err
}
xBytes, err := DecodeBase64URL(credential.PublicKeyX)
if err != nil {
return 0, fmt.Errorf("decode public key x: %w", err)
}
yBytes, err := DecodeBase64URL(credential.PublicKeyY)
if err != nil {
return 0, fmt.Errorf("decode public key y: %w", err)
}
publicKey := ecdsa.PublicKey{Curve: elliptic.P256(), X: new(big.Int).SetBytes(xBytes), Y: new(big.Int).SetBytes(yBytes)}
if !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
return 0, errors.New("webauthn public key is not on P-256 curve")
}
clientDataHash := sha256.Sum256(clientDataRaw)
verifyData := append(append([]byte{}, authData...), clientDataHash[:]...)
digest := sha256.Sum256(verifyData)
if !ecdsa.VerifyASN1(&publicKey, digest[:], signature) {
return 0, errors.New("invalid webauthn signature")
}
return signCount, nil
}
func validateWebAuthnClientData(raw []byte, expectedType string, challenge string, expectedOrigin string) error {
var clientData WebAuthnClientData
if err := json.Unmarshal(raw, &clientData); err != nil {
return fmt.Errorf("parse client data: %w", err)
}
if clientData.Type != expectedType {
return fmt.Errorf("unexpected webauthn client data type: %s", clientData.Type)
}
if clientData.Challenge != challenge {
return errors.New("webauthn challenge mismatch")
}
if expectedOrigin != "" && clientData.Origin != expectedOrigin {
return fmt.Errorf("webauthn origin mismatch: %s", clientData.Origin)
}
return nil
}
func parseAttestedCredentialData(authData []byte, rpID string) (*WebAuthnParsedCredential, error) {
signCount, credentialData, err := parseAuthenticatorDataHeader(authData, rpID, true, 0)
if err != nil {
return nil, err
}
if len(credentialData) < 18 {
return nil, errors.New("attested credential data is too short")
}
offset := 16
credentialIDLength := int(binary.BigEndian.Uint16(credentialData[offset : offset+2]))
offset += 2
if credentialIDLength <= 0 || len(credentialData) < offset+credentialIDLength {
return nil, errors.New("invalid credential id length")
}
credentialID := credentialData[offset : offset+credentialIDLength]
offset += credentialIDLength
publicKeyRaw := credentialData[offset:]
publicKey, err := parseCBOR(publicKeyRaw)
if err != nil {
return nil, fmt.Errorf("parse credential public key: %w", err)
}
publicKeyMap, ok := publicKey.(map[any]any)
if !ok {
return nil, errors.New("credential public key is not a map")
}
kty, err := coseInt(publicKeyMap, 1)
if err != nil {
return nil, err
}
alg, err := coseInt(publicKeyMap, 3)
if err != nil {
return nil, err
}
crv, err := coseInt(publicKeyMap, -1)
if err != nil {
return nil, err
}
if kty != 2 || alg != -7 || crv != 1 {
return nil, fmt.Errorf("unsupported COSE key: kty=%d alg=%d crv=%d", kty, alg, crv)
}
x, err := coseBytes(publicKeyMap, -2)
if err != nil {
return nil, err
}
y, err := coseBytes(publicKeyMap, -3)
if err != nil {
return nil, err
}
if !elliptic.P256().IsOnCurve(new(big.Int).SetBytes(x), new(big.Int).SetBytes(y)) {
return nil, errors.New("credential public key is not on P-256 curve")
}
return &WebAuthnParsedCredential{
CredentialID: EncodeBase64URL(credentialID),
PublicKeyX: EncodeBase64URL(x),
PublicKeyY: EncodeBase64URL(y),
SignCount: signCount,
}, nil
}
func parseAssertionAuthenticatorData(authData []byte, rpID string, previousSignCount uint32) (uint32, error) {
signCount, _, err := parseAuthenticatorDataHeader(authData, rpID, false, previousSignCount)
if err != nil {
return 0, err
}
return signCount, nil
}
func parseAuthenticatorDataHeader(authData []byte, rpID string, requireAttestedData bool, previousSignCount uint32) (uint32, []byte, error) {
if len(authData) < 37 {
return 0, nil, errors.New("authenticator data is too short")
}
expectedRPIDHash := sha256.Sum256([]byte(rpID))
if string(authData[:32]) != string(expectedRPIDHash[:]) {
return 0, nil, errors.New("rp id hash mismatch")
}
flags := authData[32]
if flags&0x01 == 0 {
return 0, nil, errors.New("user presence flag is missing")
}
signCount := binary.BigEndian.Uint32(authData[33:37])
if previousSignCount > 0 && signCount > 0 && signCount <= previousSignCount {
return 0, nil, errors.New("authenticator sign count did not increase")
}
if requireAttestedData && flags&0x40 == 0 {
return 0, nil, errors.New("attested credential data flag is missing")
}
return signCount, authData[37:], nil
}
func coseInt(m map[any]any, key int64) (int64, error) {
value, ok := m[key]
if !ok {
return 0, fmt.Errorf("missing COSE key %d", key)
}
intValue, ok := value.(int64)
if !ok {
return 0, fmt.Errorf("invalid COSE key %d", key)
}
return intValue, nil
}
func coseBytes(m map[any]any, key int64) ([]byte, error) {
value, ok := m[key]
if !ok {
return nil, fmt.Errorf("missing COSE key %d", key)
}
bytesValue, ok := value.([]byte)
if !ok || len(bytesValue) == 0 {
return nil, fmt.Errorf("invalid COSE key %d", key)
}
return bytesValue, nil
}
func parseCBOR(data []byte) (any, error) {
reader := cborReader{data: data}
value, err := reader.read()
if err != nil {
return nil, err
}
return value, nil
}
func parseCBORExact(data []byte) (any, error) {
reader := cborReader{data: data}
value, err := reader.read()
if err != nil {
return nil, err
}
if reader.pos != len(data) {
return nil, errors.New("trailing cbor data")
}
return value, nil
}
type cborReader struct {
data []byte
pos int
}
func (r *cborReader) read() (any, error) {
if r.pos >= len(r.data) {
return nil, errors.New("unexpected cbor eof")
}
initial := r.data[r.pos]
r.pos++
major := initial >> 5
additional := initial & 0x1f
length, err := r.readLength(additional)
if err != nil {
return nil, err
}
switch major {
case 0:
return int64(length), nil
case 1:
return -1 - int64(length), nil
case 2:
return r.readBytes(length)
case 3:
raw, err := r.readBytes(length)
if err != nil {
return nil, err
}
return string(raw), nil
case 4:
out := make([]any, 0, length)
for i := uint64(0); i < length; i++ {
item, err := r.read()
if err != nil {
return nil, err
}
out = append(out, item)
}
return out, nil
case 5:
out := make(map[any]any, length)
for i := uint64(0); i < length; i++ {
key, err := r.read()
if err != nil {
return nil, err
}
value, err := r.read()
if err != nil {
return nil, err
}
out[key] = value
}
return out, nil
case 7:
switch additional {
case 20:
return false, nil
case 21:
return true, nil
case 22, 23:
return nil, nil
default:
return nil, fmt.Errorf("unsupported cbor simple value: %d", additional)
}
default:
return nil, fmt.Errorf("unsupported cbor major type: %d", major)
}
}
func (r *cborReader) readLength(additional byte) (uint64, error) {
switch {
case additional < 24:
return uint64(additional), nil
case additional == 24:
if r.pos+1 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := r.data[r.pos]
r.pos++
return uint64(value), nil
case additional == 25:
if r.pos+2 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := binary.BigEndian.Uint16(r.data[r.pos : r.pos+2])
r.pos += 2
return uint64(value), nil
case additional == 26:
if r.pos+4 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := binary.BigEndian.Uint32(r.data[r.pos : r.pos+4])
r.pos += 4
return uint64(value), nil
case additional == 27:
if r.pos+8 > len(r.data) {
return 0, errors.New("unexpected cbor eof")
}
value := binary.BigEndian.Uint64(r.data[r.pos : r.pos+8])
r.pos += 8
return value, nil
default:
return 0, fmt.Errorf("unsupported cbor additional info: %d", additional)
}
}
func (r *cborReader) readBytes(length uint64) ([]byte, error) {
if length > uint64(len(r.data)-r.pos) {
return nil, errors.New("unexpected cbor eof")
}
out := r.data[r.pos : r.pos+int(length)]
r.pos += int(length)
return out, nil
}