Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,13 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
}
}

var url string
var url, otp string
now := time.Now()
otp := crypto.GenerateOtp(config.Mailer.OtpLength)
if config.Mailer.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(config.Mailer.OtpLength)
} else {
otp = crypto.GenerateOtp(config.Mailer.OtpLength)
}

hashedToken := crypto.GenerateTokenHash(params.Email, otp)

Expand Down Expand Up @@ -326,7 +330,12 @@ func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *model
return err
}
oldToken := u.ConfirmationToken
otp := crypto.GenerateOtp(otpLength)
var otp string
if config.Mailer.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(otpLength)
} else {
otp = crypto.GenerateOtp(otpLength)
}

token := crypto.GenerateTokenHash(u.GetEmail(), otp)
u.ConfirmationToken = addFlowPrefixToToken(token, flowType)
Expand Down Expand Up @@ -361,7 +370,12 @@ func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User
otpLength := config.Mailer.OtpLength
var err error
oldToken := u.ConfirmationToken
otp := crypto.GenerateOtp(otpLength)
var otp string
if config.Mailer.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(otpLength)
} else {
otp = crypto.GenerateOtp(otpLength)
}

u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp)
now := time.Now()
Expand Down Expand Up @@ -403,7 +417,12 @@ func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *m
}

oldToken := u.RecoveryToken
otp := crypto.GenerateOtp(otpLength)
var otp string
if config.Mailer.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(otpLength)
} else {
otp = crypto.GenerateOtp(otpLength)
}

token := crypto.GenerateTokenHash(u.GetEmail(), otp)
u.RecoveryToken = addFlowPrefixToToken(token, flowType)
Expand Down Expand Up @@ -445,7 +464,12 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u
}

oldToken := u.ReauthenticationToken
otp := crypto.GenerateOtp(otpLength)
var otp string
if config.Mailer.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(otpLength)
} else {
otp = crypto.GenerateOtp(otpLength)
}

u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp)
now := time.Now()
Expand Down Expand Up @@ -488,7 +512,12 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U
}

oldToken := u.RecoveryToken
otp := crypto.GenerateOtp(otpLength)
var otp string
if config.Mailer.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(otpLength)
} else {
otp = crypto.GenerateOtp(otpLength)
}

token := crypto.GenerateTokenHash(u.GetEmail(), otp)
u.RecoveryToken = addFlowPrefixToToken(token, flowType)
Expand Down Expand Up @@ -528,7 +557,12 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models
return err
}

otpNew := crypto.GenerateOtp(otpLength)
var otpNew string
if config.Mailer.OtpAlphaNumeric {
otpNew = crypto.GenerateAlphanumericOtp(otpLength)
} else {
otpNew = crypto.GenerateOtp(otpLength)
}

u.EmailChange = email
token := crypto.GenerateTokenHash(u.EmailChange, otpNew)
Expand Down
8 changes: 6 additions & 2 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,12 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, "%s", generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency))
}
}

otp := crypto.GenerateOtp(config.MFA.Phone.OtpLength)
var otp string
if config.MFA.Phone.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(config.MFA.Phone.OtpLength)
} else {
otp = crypto.GenerateOtp(config.MFA.Phone.OtpLength)
}

challenge, err := factor.CreatePhoneChallenge(ipAddress, otp, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
return "", apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
}
otp = crypto.GenerateOtp(config.Sms.OtpLength)
if config.Sms.OtpAlphaNumeric {
otp = crypto.GenerateAlphanumericOtp(config.Sms.OtpLength)
} else {
otp = crypto.GenerateOtp(config.Sms.OtpLength)
}

if config.Hook.SendSMS.Enabled {
input := v0hooks.NewSendSMSInput(
Expand Down
15 changes: 9 additions & 6 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,11 @@ type TOTPFactorTypeConfiguration struct {
type PhoneFactorTypeConfiguration struct {
// Default to false in order to ensure Phone MFA is opt-in
MFAFactorTypeConfiguration
OtpLength int `json:"otp_length" split_words:"true"`
SMSTemplate *template.Template `json:"-"`
MaxFrequency time.Duration `json:"max_frequency" split_words:"true"`
Template string `json:"template"`
OtpLength int `json:"otp_length" split_words:"true"`
OtpAlphaNumeric bool `json:"otp_alpha_numeric" split_words:"true"`
SMSTemplate *template.Template `json:"-"`
MaxFrequency time.Duration `json:"max_frequency" split_words:"true"`
Template string `json:"template"`
}

// MFAConfiguration holds all the MFA related Configuration
Expand Down Expand Up @@ -572,8 +573,9 @@ type MailerConfiguration struct {

SecureEmailChangeEnabled bool `json:"secure_email_change_enabled" split_words:"true" default:"true"`

OtpExp uint `json:"otp_exp" split_words:"true"`
OtpLength int `json:"otp_length" split_words:"true"`
OtpExp uint `json:"otp_exp" split_words:"true"`
OtpLength int `json:"otp_length" split_words:"true"`
OtpAlphaNumeric bool `json:"otp_alpha_numeric" split_words:"true"`

ExternalHosts []string `json:"external_hosts" split_words:"true"`

Expand Down Expand Up @@ -657,6 +659,7 @@ type SmsProviderConfiguration struct {
MaxFrequency time.Duration `json:"max_frequency" split_words:"true"`
OtpExp uint `json:"otp_exp" split_words:"true"`
OtpLength int `json:"otp_length" split_words:"true"`
OtpAlphaNumeric bool `json:"otp_alpha_numeric" split_words:"true"`
Provider string `json:"provider"`
Template string `json:"template"`
TestOTP map[string]string `json:"test_otp" split_words:"true"`
Expand Down
17 changes: 17 additions & 0 deletions internal/crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ func GenerateOtp(digits int) string {
return generateOtp(rand.Reader, digits)
}

// GenerateAlphanumericOtp generates a random n digit otp with extended charset
func GenerateAlphanumericOtp(digits int) string {
return generateAlphanumericOtp(rand.Reader, digits)
}

func generateOtp(r io.Reader, digits int) string {
// TODO(cstockton): Change the code to be below and propagate errors so we
// can have non-panicing bounds checks. This is just a defensive change so
Expand All @@ -42,6 +47,18 @@ func generateOtp(r io.Reader, digits int) string {
return otp
}

func generateAlphanumericOtp(r io.Reader, length int) string {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
charsetLen := big.NewInt(int64(len(charset)))

otp := make([]byte, length)
for i := range otp {
val := must(rand.Int(r, charsetLen))
otp[i] = charset[val.Int64()]
}
return string(otp)
}

func GenerateTokenHash(emailOrPhone, otp string) string {
return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp)))
}
Expand Down
79 changes: 79 additions & 0 deletions internal/crypto/crypto_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package crypto

import (
"crypto/rand"
"math"
"strings"
"testing"

mrand "math/rand"
mathrand "math/rand/v2"

"github.com/gofrs/uuid"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -77,6 +81,81 @@ func TestGenerateOtp(t *testing.T) {
}
}

func TestGenerateAlphanumericOtp(t *testing.T) {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
charsetSet := make(map[rune]bool)
for _, c := range charset {
charsetSet[c] = true
}

t.Run("correct length", func(t *testing.T) {
for _, length := range []int{1, 4, 6, 8, 12} {
otp := generateAlphanumericOtp(rand.Reader, length)
if len(otp) != length {
t.Errorf("length=%d: got OTP of length %d: %q", length, len(otp), otp)
}
}
})

t.Run("only valid characters", func(t *testing.T) {
for range 100 {
otp := generateAlphanumericOtp(rand.Reader, 10)
for _, c := range otp {
if !charsetSet[c] {
t.Errorf("invalid character %q in OTP %q", c, otp)
}
}
}
})

t.Run("uppercase only", func(t *testing.T) {
for range 100 {
otp := generateAlphanumericOtp(rand.Reader, 10)
if otp != strings.ToUpper(otp) {
t.Errorf("OTP contains lowercase characters: %q", otp)
}
}
})

t.Run("deterministic with fixed reader", func(t *testing.T) {
seed := [32]byte{}
r1 := mathrand.NewChaCha8(seed)
r2 := mathrand.NewChaCha8(seed)
otp1 := generateAlphanumericOtp(r1, 8)
otp2 := generateAlphanumericOtp(r2, 8)
if otp1 != otp2 {
t.Errorf("same seed produced different OTPs: %q vs %q", otp1, otp2)
}
})

t.Run("different seeds produce different OTPs", func(t *testing.T) {
r1 := mathrand.NewChaCha8([32]byte{0})
r2 := mathrand.NewChaCha8([32]byte{1})
otp1 := generateAlphanumericOtp(r1, 16)
otp2 := generateAlphanumericOtp(r2, 16)
if otp1 == otp2 {
t.Errorf("different seeds produced the same OTP: %q", otp1)
}
})

t.Run("character distribution is roughly uniform", func(t *testing.T) {
counts := make(map[rune]int)
iterations := 36 * 1000
for range iterations {
otp := generateAlphanumericOtp(rand.Reader, 1)
counts[rune(otp[0])]++
}
expected := float64(iterations) / float64(len(charset))
tolerance := expected * 0.15
for _, c := range charset {
diff := math.Abs(float64(counts[c]) - expected)
if diff > tolerance {
t.Errorf("character %q count %d deviates too far from expected %.0f", c, counts[c], expected)
}
}
})
}

func TestEncryptedStringPositive(t *testing.T) {
id := uuid.Must(uuid.NewV4()).String()

Expand Down
Loading