Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
if err != nil {
return apierrors.NewInternalServerError("Database error creating anonymous user").WithInternalError(err)
}
if err := a.triggerAfterUserCreated(r, db, newUser); err != nil {
return err
}

metering.RecordLogin(metering.LoginTypeAnonymous, newUser.ID, nil)
return sendJSON(w, http.StatusOK, token)
Expand Down
13 changes: 13 additions & 0 deletions internal/api/apitask/apitask.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ type Task interface {
Run(context.Context) error
}

type taskFunc struct {
typ string
fn func(context.Context) error
}

func (o *taskFunc) Type() string { return o.typ }

func (o *taskFunc) Run(ctx context.Context) error { return o.fn(ctx) }

func Func(typ string, fn func(context.Context) error) Task {
return &taskFunc{typ: typ, fn: fn}
}

// Run will run a request-scoped background task in a separate goroutine
// immediately if the current context supports it. Otherwise it makes an
// immediate blocking call to task.Run(ctx).
Expand Down
23 changes: 5 additions & 18 deletions internal/api/apitask/apitask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@ import (
"github.com/stretchr/testify/require"
)

type taskFunc struct {
typ string
fn func(context.Context) error
}

func (o *taskFunc) Type() string { return o.typ }

func (o *taskFunc) Run(ctx context.Context) error { return o.fn(ctx) }

func taskFn(typ string, fn func(context.Context) error) Task {
return &taskFunc{typ: typ, fn: fn}
}

func TestContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
Expand Down Expand Up @@ -59,7 +46,7 @@ func TestRun(t *testing.T) {
expCalls := 0
for range 16 {
expCalls++
task := taskFn("test.run", func(ctx context.Context) error {
task := Func("test.run", func(ctx context.Context) error {
calls.Add(1)
return nil
})
Expand All @@ -85,7 +72,7 @@ func TestRun(t *testing.T) {
sentinel := errors.New("sentinel")
for range 16 {
expCalls++
task := taskFn("test.run", func(ctx context.Context) error {
task := Func("test.run", func(ctx context.Context) error {
calls.Add(1)
return sentinel
})
Expand All @@ -110,7 +97,7 @@ func TestRun(t *testing.T) {
sentinel := errors.New("sentinel")
for range 16 {
expCalls++
task := taskFn("test.run", func(ctx context.Context) error {
task := Func("test.run", func(ctx context.Context) error {
calls.Add(1)
return sentinel
})
Expand All @@ -137,7 +124,7 @@ func TestMiddleware(t *testing.T) {
hrFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for i := range 10 {
typ := fmt.Sprintf("test-task-%v", i)
task := taskFn(typ, func(ctx context.Context) error {
task := Func(typ, func(ctx context.Context) error {
return nil
})

Expand All @@ -164,7 +151,7 @@ func TestMiddleware(t *testing.T) {
hrFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for i := range 10 {
typ := fmt.Sprintf("test-task-%v", i)
task := taskFn(typ, func(ctx context.Context) error {
task := Func(typ, func(ctx context.Context) error {
return nil
})
err := Run(r.Context(), task)
Expand Down
50 changes: 50 additions & 0 deletions internal/api/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,43 @@ func runVerifyBeforeUserCreatedHook(
return latest
}

func runVerifyAfterUserCreatedHook(
t *testing.T,
inst *e2ehooks.Instance,
expUser *models.User,
) *models.User {
var latest *models.User
t.Run("VerifyAfterUserCreatedHook", func(t *testing.T) {
defer inst.HookRecorder.AfterUserCreated.ClearCalls()

calls := inst.HookRecorder.AfterUserCreated.GetCalls()
require.Equal(t, 1, len(calls))
call := calls[0]

hookReq := &v0hooks.AfterUserCreatedInput{}
err := call.Unmarshal(hookReq)
require.NoError(t, err)
require.Equal(t, v0hooks.AfterUserCreated, hookReq.Metadata.Name)

u := hookReq.User
require.Equal(t, expUser.ID, u.ID)
require.Equal(t, expUser.Aud, u.Aud)
require.Equal(t, expUser.Email, u.Email)
require.Equal(t, expUser.AppMetaData, u.AppMetaData)

require.False(t, u.CreatedAt.IsZero())
require.False(t, u.UpdatedAt.IsZero())

err = expUser.Confirm(inst.Conn)
require.NoError(t, err)

latest, err = models.FindUserByID(inst.Conn, expUser.ID)
require.NoError(t, err)
require.NotNil(t, latest)
})
return latest
}

func getAccessToken(
ctx context.Context,
t *testing.T,
Expand Down Expand Up @@ -208,6 +245,7 @@ func TestE2EHooks(t *testing.T) {
require.Equal(t, email, res.Email.String())

runVerifyBeforeUserCreatedHook(t, inst, res)
runVerifyAfterUserCreatedHook(t, inst, res)
})

t.Run("SignupPhone", func(t *testing.T) {
Expand All @@ -224,6 +262,8 @@ func TestE2EHooks(t *testing.T) {
require.Equal(t, phone, res.Phone.String())

runVerifyBeforeUserCreatedHook(t, inst, res)
runVerifyAfterUserCreatedHook(t, inst, res)

})

t.Run("SignupAnonymously", func(t *testing.T) {
Expand All @@ -235,6 +275,8 @@ func TestE2EHooks(t *testing.T) {
require.NoError(t, err)

runVerifyBeforeUserCreatedHook(t, inst, res.User)
runVerifyAfterUserCreatedHook(t, inst, res.User)

})

t.Run("ExternalCallback", func(t *testing.T) {
Expand All @@ -246,6 +288,8 @@ func TestE2EHooks(t *testing.T) {
require.NoError(t, err)

runVerifyBeforeUserCreatedHook(t, inst, res.User)
runVerifyAfterUserCreatedHook(t, inst, res.User)

})

t.Run("AdminEndpoints", func(t *testing.T) {
Expand Down Expand Up @@ -273,6 +317,8 @@ func TestE2EHooks(t *testing.T) {
require.NoError(t, err)

runVerifyBeforeUserCreatedHook(t, inst, res)
runVerifyAfterUserCreatedHook(t, inst, res)

})

t.Run("AdminGenerateLink", func(t *testing.T) {
Expand Down Expand Up @@ -304,6 +350,7 @@ func TestE2EHooks(t *testing.T) {
require.NoError(t, err)

runVerifyBeforeUserCreatedHook(t, inst, &res.User)
runVerifyAfterUserCreatedHook(t, inst, &res.User)
})

t.Run("InviteVerification", func(t *testing.T) {
Expand Down Expand Up @@ -332,6 +379,7 @@ func TestE2EHooks(t *testing.T) {
require.NoError(t, err)

runVerifyBeforeUserCreatedHook(t, inst, &res.User)
runVerifyAfterUserCreatedHook(t, inst, &res.User)
})
})
})
Expand Down Expand Up @@ -372,6 +420,7 @@ func TestE2EHooks(t *testing.T) {
require.Equal(t, email, mfaUser.Email.String())

mfaUser = runVerifyBeforeUserCreatedHook(t, inst, mfaUser)
runVerifyAfterUserCreatedHook(t, inst, mfaUser)
require.NotNil(t, mfaUser)
mfaUserAccessToken = getAccessToken(
ctx, t, inst, string(mfaUser.Email), defaultPassword)
Expand Down Expand Up @@ -562,6 +611,7 @@ func TestE2EHooks(t *testing.T) {
require.Equal(t, email, res.Email.String())

currentUser = runVerifyBeforeUserCreatedHook(t, inst, res)
runVerifyAfterUserCreatedHook(t, inst, res)
require.NotNil(t, currentUser)
inst.HookRecorder.CustomizeAccessToken.ClearCalls()
}
Expand Down
56 changes: 31 additions & 25 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
}
}

var createdUser bool
var user *models.User
var token *AccessTokenResponse
err = db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -231,7 +232,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
return terr
}
} else {
if user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType, emailOptional); terr != nil {
createdUser = true
if _, user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType, emailOptional); terr != nil {
return terr
}
}
Expand All @@ -253,10 +255,14 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
}
return nil
})

if err != nil {
return err
}
if createdUser {
if err := a.triggerAfterUserCreated(r, db, user); err != nil {
return err
}
}

// Record login for analytics - only when token is issued (not during pkce authorize)
if token != nil {
Expand Down Expand Up @@ -290,7 +296,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
return nil
}

func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string, emailOptional bool) (*models.User, error) {
func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string, emailOptional bool) (models.AccountLinkingDecision, *models.User, error) {
ctx := r.Context()
aud := a.requestAud(ctx, r)
config := a.config
Expand All @@ -304,28 +310,28 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.

decision, terr := models.DetermineAccountLinking(tx, config, userData.Emails, aud, providerType, userData.Metadata.Subject)
if terr != nil {
return nil, terr
return 0, nil, terr
}

switch decision.Decision {
case models.LinkAccount:
user = decision.User

if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil {
return nil, terr
return 0, nil, terr
}

if terr = user.UpdateUserMetaData(tx, identityData); terr != nil {
return nil, terr
return 0, nil, terr
}

if terr = user.UpdateAppMetaDataProviders(tx); terr != nil {
return nil, terr
return 0, nil, terr
}

case models.CreateAccount:
if config.DisableSignup {
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
return 0, nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
}

params := &SignupParams{
Expand All @@ -352,15 +358,15 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
// transaction
user, terr = params.ToUserModel(isSSOUser)
if terr != nil {
return nil, terr
return 0, nil, terr
}

if user, terr = a.signupNewUser(tx, user); terr != nil {
return nil, terr
return 0, nil, terr
}

if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil {
return nil, terr
return 0, nil, terr
}
user.Identities = append(user.Identities, *identity)

Expand All @@ -370,24 +376,24 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.

identity.IdentityData = identityData
if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil {
return nil, terr
return 0, nil, terr
}
if terr = user.UpdateUserMetaData(tx, identityData); terr != nil {
return nil, terr
return 0, nil, terr
}
if terr = user.UpdateAppMetaDataProviders(tx); terr != nil {
return nil, terr
return 0, nil, terr
}

case models.MultipleAccounts:
return nil, apierrors.NewInternalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)
return 0, nil, apierrors.NewInternalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)

default:
return nil, apierrors.NewInternalServerError("Unknown automatic linking decision: %v", decision.Decision)
return 0, nil, apierrors.NewInternalServerError("Unknown automatic linking decision: %v", decision.Decision)
}

if user.IsBanned() {
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned")
return 0, nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned")
}

hasEmails := providerType != "web3" && !(emailOptional && decision.CandidateEmail.Email == "")
Expand All @@ -398,44 +404,44 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
// need to be removed when a new oauth identity is being added
// to prevent pre-account takeover attacks from happening.
if terr = user.RemoveUnconfirmedIdentities(tx, identity); terr != nil {
return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
return 0, nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
}
if decision.CandidateEmail.Verified || config.Mailer.Autoconfirm {
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{
"provider": providerType,
}); terr != nil {
return nil, terr
return 0, nil, terr
}
// fall through to auto-confirm and issue token
if terr = user.Confirm(tx); terr != nil {
return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
return 0, nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
}
} else {
emailConfirmationSent := false
if decision.CandidateEmail.Email != "" {
if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil {
return nil, terr
return 0, nil, terr
}
emailConfirmationSent = true
}

if !config.Mailer.AllowUnverifiedEmailSignIns {
if emailConfirmationSent {
return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)))
return 0, nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)))
}

return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType)))
return 0, nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType)))
}
}
} else {
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.LoginAction, "", map[string]interface{}{
"provider": providerType,
}); terr != nil {
return nil, terr
return 0, nil, terr
}
}

return user, nil
return decision.Decision, user, nil
}

func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *provider.UserProvidedData, inviteToken, providerType string) (*models.User, error) {
Expand Down
Loading