Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/api/oauthserver/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon

// Issue the refresh token and access token
var terr error
tokenResponse, terr = tokenService.IssueRefreshToken(r, tx, user, authMethod, grantParams)
tokenResponse, terr = tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -488,7 +488,7 @@ func (s *Server) handleRefreshTokenGrant(ctx context.Context, w http.ResponseWri
}

db := s.db.WithContext(ctx)
tokenResponse, err := tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{
tokenResponse, err := tokenService.RefreshTokenGrant(ctx, db, r, w.Header(), tokens.RefreshTokenGrantParams{
RefreshToken: params.RefreshToken,
ClientID: clientID,
})
Expand Down
6 changes: 3 additions & 3 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
}); terr != nil {
return terr
}
token, terr = a.tokenService.IssueRefreshToken(r, tx, user, models.PasswordGrant, grantParams)
token, terr = a.tokenService.IssueRefreshToken(r, w.Header(), tx, user, models.PasswordGrant, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -260,7 +260,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
}); terr != nil {
return terr
}
token, terr = a.tokenService.IssueRefreshToken(r, tx, user, authMethod, grantParams)
token, terr = a.tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams)
if terr != nil {
// error type is already handled in issueRefreshToken
return terr
Expand Down Expand Up @@ -295,7 +295,7 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user
}

func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*tokens.AccessTokenResponse, error) {
return a.tokenService.IssueRefreshToken(r, conn, user, authenticationMethod, grantParams)
return a.tokenService.IssueRefreshToken(r, make(http.Header), conn, user, authenticationMethod, grantParams)
}

func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*tokens.AccessTokenResponse, error) {
Expand Down
32 changes: 31 additions & 1 deletion internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package api
import (
"context"
"net/http"
"regexp"

"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/crypto"
"github.com/supabase/auth/internal/tokens"
)

Expand All @@ -12,15 +15,42 @@ type RefreshTokenGrantParams struct {
RefreshToken string `json:"refresh_token"`
}

var legacyRefreshTokenPattern = regexp.MustCompile("^[a-z0-9]{12}$")

func (p *RefreshTokenGrantParams) Validate() error {
if len(p.RefreshToken) < 12 {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid")
}

if len(p.RefreshToken) == 12 {
if !legacyRefreshTokenPattern.MatchString(p.RefreshToken) {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid")
}

return nil
}

_, err := crypto.ParseRefreshToken(p.RefreshToken)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid").WithInternalError(err)
}

return nil
}

// RefreshTokenGrant implements the refresh_token grant type flow
func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
params := &RefreshTokenGrantParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

if err := params.Validate(); err != nil {
return err
}

db := a.db.WithContext(ctx)
tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{
tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, w.Header(), tokens.RefreshTokenGrantParams{
RefreshToken: params.RefreshToken,
})
if err != nil {
Expand Down
31 changes: 29 additions & 2 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"
"github.com/supabase/auth/internal/models"
)

Expand Down Expand Up @@ -435,9 +436,11 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {

// ensure that the 4 refresh tokens are setup correctly
for i, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
require.NoError(ts.T(), err)

token := anyToken.(*models.RefreshToken)

if i == len(refreshTokens)-1 {
require.False(ts.T(), token.Revoked)
} else {
Expand Down Expand Up @@ -470,9 +473,10 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {

// ensure that the refresh tokens are marked as revoked in the database
for _, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
_, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false)
require.NoError(ts.T(), err)

token := anyToken.(*models.RefreshToken)
require.True(ts.T(), token.Revoked)
}

Expand Down Expand Up @@ -887,3 +891,26 @@ $$;`
})
}
}

func TestRefreshTokenGrantParamsValidate(t *testing.T) {
examples := []string{
"",
"01234567890",
"AAAAAAAAAAAA",
"------------",
"0000000000000",
}

p := &RefreshTokenGrantParams{}

for _, example := range examples {
p.RefreshToken = example
require.Error(t, p.Validate())
}

p.RefreshToken = "0123456abcde"
require.NoError(t, p.Validate())

p.RefreshToken = (&crypto.RefreshToken{}).Encode(make([]byte, 32))
require.NoError(t, p.Validate())
}
2 changes: 2 additions & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,10 @@ func (c *DatabaseEncryptionConfiguration) Validate() error {

type SecurityConfiguration struct {
Captcha CaptchaConfiguration `json:"captcha"`
RefreshTokenAlgorithmVersion int `json:"refresh_token_algorithm_version" split_words:"true"`
RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"`
RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"`
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`

Expand Down
1 change: 1 addition & 0 deletions internal/crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func GenerateOtp(digits int) string {

return otp
}

func GenerateTokenHash(emailOrPhone, otp string) string {
return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp)))
}
Expand Down
1 change: 1 addition & 0 deletions internal/crypto/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,5 @@ func TestEncryptedStringDecryptNegative(t *testing.T) {

func TestSecureToken(t *testing.T) {
assert.Equal(t, len(SecureAlphanumeric(22)), 22)
assert.Equal(t, len(SecureAlphanumeric(7)), 8)
}
155 changes: 155 additions & 0 deletions internal/crypto/refresh_tokens.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package crypto

import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"errors"
"math"

"github.com/gofrs/uuid"
)

func GenerateRefreshTokenHmacKey() []byte {
key := make([]byte, 32)
must(rand.Read(key))

return key
}

const refreshTokenChecksumLength = 4
const refreshTokenSignatureLength = 16
const minRefreshTokenLength = 1 + 16 + 1 + refreshTokenSignatureLength + refreshTokenChecksumLength
const maxRefreshTokenLength = minRefreshTokenLength + 8

// RefreshToken is an object that encodes a cryptographically authenticated
// (signed) message containing a version, session ID and monotonically
// increasing non-negative counter.
//
// The signature is a truncated (first 128 bits) of HMAC-SHA-256, which saves
// on encoded length without sacrificing security. The checksum of 4 bytes at
// the end is to lessen the load on the server with invalid strings (those that
// are not likely to be a proper refresh token).
type RefreshToken struct {
Raw []byte

Version byte
SessionID uuid.UUID
Counter int64
Signature []byte
}

func (RefreshToken) TableName() string {
panic("crypto.RefreshToken is not meant to be saved in the database")
}

func (r *RefreshToken) CheckSignature(hmacSha256Key []byte) bool {
bytes := r.Raw[:len(r.Raw)-refreshTokenSignatureLength-refreshTokenChecksumLength]

h := hmac.New(sha256.New, hmacSha256Key)
h.Write(bytes)
signature := h.Sum(nil)[:refreshTokenSignatureLength]

return hmac.Equal(signature, r.Signature)
}

func (r *RefreshToken) Encode(hmacSha256Key []byte) string {
result := make([]byte, 0, maxRefreshTokenLength)

result = append(result, 0)
result = append(result, r.SessionID.Bytes()...)
result = binary.AppendUvarint(result, safeUint64(r.Counter))

// Note on truncating the HMAC-SHA-256 output:
// This does not impact security as the brute-force space is 2^128 and
// the collision space is 2^64, both unattainable in practice.

h := hmac.New(sha256.New, hmacSha256Key)
h.Write(result)
signature := h.Sum(nil)[:refreshTokenSignatureLength]

result = append(result, signature...)

checksum := sha256.Sum256(result)
result = append(result, checksum[:refreshTokenChecksumLength]...)

r.Version = 0
r.Raw = result
r.Signature = signature

return base64.RawURLEncoding.EncodeToString(result)
}

var (
ErrRefreshTokenLength = errors.New("crypto: refresh token length is not valid")
ErrRefreshTokenUnknownVersion = errors.New("crypto: refresh token version is not 0")
ErrRefreshTokenChecksumInvalid = errors.New("crypto: refresh token checksum is not valid")
ErrRefreshTokenCounterInvalid = errors.New("crypto: refresh token's counter is not valid")
)

func safeInt64(v uint64) int64 {
if v > math.MaxInt64 {
return math.MaxInt64
}

return int64(v)
}

func safeUint64(v int64) uint64 {
if v < 0 {
return 0
}

return uint64(v)
}

func ParseRefreshToken(token string) (*RefreshToken, error) {
bytes, err := base64.RawURLEncoding.DecodeString(token)
if err != nil {
return nil, err
}

if len(bytes) < minRefreshTokenLength {
return nil, ErrRefreshTokenLength
}

if bytes[0] != 0 {
return nil, ErrRefreshTokenUnknownVersion
}

parseFrom := bytes[1 : len(bytes)-refreshTokenChecksumLength]

checksum256 := sha256.Sum256(bytes[:len(bytes)-refreshTokenChecksumLength])
if subtle.ConstantTimeCompare(checksum256[:refreshTokenChecksumLength], bytes[len(bytes)-refreshTokenChecksumLength:]) != 1 {
return nil, ErrRefreshTokenChecksumInvalid
}

sessionID := uuid.FromBytesOrNil(parseFrom[0:16])

parseFrom = parseFrom[16:]

counter, counterBytes := binary.Uvarint(parseFrom)
if counterBytes <= 0 {
return nil, ErrRefreshTokenCounterInvalid
}

parseFrom = parseFrom[counterBytes:]

if len(parseFrom) != 16 {
return nil, ErrRefreshTokenLength
}

signature := parseFrom

return &RefreshToken{
Raw: bytes,

Version: 0,
SessionID: sessionID,
Counter: safeInt64(counter),
Signature: signature,
}, nil
}
Loading