Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cmd/serve_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ func serve(ctx context.Context) {
logrus.WithError(err).Fatal("unable to load config")
}

db, err := storage.Dial(config)
// Include serve ctx which carries cancelation signals so DialContext does
// not hang indefinitely at startup.
db, err := storage.DialContext(ctx, config)
if err != nil {
logrus.Fatalf("error opening database: %+v", err)
}
Expand All @@ -53,6 +55,10 @@ func serve(ctx context.Context) {
baseCtx, baseCancel := context.WithCancel(context.Background())
defer baseCancel()

// Add the base context to the db, this is so during the shutdown sequence
// the DB will be available while connections drain.
db = db.WithContext(ctx)

var wg sync.WaitGroup
defer wg.Wait() // Do not return to caller until this goroutine is done.

Expand All @@ -79,7 +85,7 @@ func serve(ctx context.Context) {
log := logrus.WithField("component", "api")

wrkLog := logrus.WithField("component", "apiworker")
wrk := apiworker.New(config, mrCache, wrkLog)
wrk := apiworker.New(config, mrCache, db, wrkLog)
wg.Add(1)
go func() {
defer wg.Done()
Expand Down
82 changes: 79 additions & 3 deletions internal/api/apiworker/apiworker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ import (
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/mailer/templatemailer"
"github.com/supabase/auth/internal/storage"
"golang.org/x/sync/errgroup"
)

// Worker is a simple background worker for async tasks.
type Worker struct {
le *logrus.Entry
tc *templatemailer.Cache
db *storage.Connection

// Notifies worker the cfg has been updated.
cfgCh chan struct{}
Expand All @@ -31,12 +34,14 @@ type Worker struct {
func New(
cfg *conf.GlobalConfiguration,
tc *templatemailer.Cache,
db *storage.Connection,
le *logrus.Entry,
) *Worker {
return &Worker{
le: le,
cfg: cfg,
tc: tc,
db: db,
cfgCh: make(chan struct{}, 1),
}
}
Expand All @@ -63,14 +68,85 @@ func (o *Worker) ReloadConfig(cfg *conf.GlobalConfiguration) {
}
}

// Work will periodically reload the templates in the background as long as the
// system remains active.
// Work will run background workers.
func (o *Worker) Work(ctx context.Context) error {
if ok := o.workMu.TryLock(); !ok {
return errors.New("apiworker: concurrent calls to Work are invalid")
}
defer o.workMu.Unlock()

var (
eg errgroup.Group
notifyTpl = make(chan struct{}, 1)
notifyDb = make(chan struct{}, 1)
)
eg.Go(func() error {
return o.configNotifier(ctx, notifyTpl, notifyDb)
})
eg.Go(func() error {
return o.templateWorker(ctx, notifyTpl)
})
eg.Go(func() error {
return o.dbWorker(ctx, notifyDb)
})
return eg.Wait()
}

func (o *Worker) configNotifier(
ctx context.Context,
notifyCh ...chan<- struct{},
) error {
le := o.le.WithFields(logrus.Fields{
"worker_type": "apiworker_config_notifier",
})
le.Info("apiworker: config notifier started")
defer le.Info("apiworker: config notifier exited")

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-o.cfgCh:

// When we get a config update, notify each worker to wake up
for _, ch := range notifyCh {
select {
case ch <- struct{}{}:
default:
}
}
}
}
}

func (o *Worker) dbWorker(ctx context.Context, cfgCh <-chan struct{}) error {
le := o.le.WithFields(logrus.Fields{
"worker_type": "apiworker_db_worker",
})
le.Info("apiworker: db worker started")
defer le.Info("apiworker: db worker exited")

if err := o.db.ApplyConfig(ctx, o.getConfig(), le); err != nil {
le.WithError(err).Error(
"failure applying config connection limits to db")
}

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-cfgCh:
if err := o.db.ApplyConfig(ctx, o.getConfig(), le); err != nil {
le.WithError(err).Error(
"failure applying config connection limits to db")
}
}
}
}

// templateWorker will periodically reload the templates in the background as
// long as the system remains active.
func (o *Worker) templateWorker(ctx context.Context, cfgCh <-chan struct{}) error {
le := o.le.WithFields(logrus.Fields{
"worker_type": "apiworker_template_cache",
})
Expand All @@ -91,7 +167,7 @@ func (o *Worker) Work(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-o.cfgCh:
case <-cfgCh:
tr.Reset(ival())
case <-tr.C:
}
Expand Down
6 changes: 6 additions & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ type DBConfiguration struct {
Driver string `json:"driver" required:"true"`
URL string `json:"url" envconfig:"DATABASE_URL" required:"true"`
Namespace string `json:"namespace" envconfig:"DB_NAMESPACE" default:"auth"`

// Percentage of DB conns the auth server may use in
// integer form i.e.: [1, 100] -> [1%, 100%]
ConnPercentage int `json:"conn_percentage" split_words:"true"`

// MaxPoolSize defaults to 0 (unlimited).
MaxPoolSize int `json:"max_pool_size" split_words:"true"`
MaxIdlePoolSize int `json:"max_idle_pool_size" split_words:"true"`
Expand All @@ -117,6 +122,7 @@ type DBConfiguration struct {
}

func (c *DBConfiguration) Validate() error {
c.ConnPercentage = min(max(c.ConnPercentage, 0), 100)
return nil
}

Expand Down
24 changes: 24 additions & 0 deletions internal/conf/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,30 @@ func TestGlobal(t *testing.T) {
err := populateGlobal(cfg)
require.NoError(t, err)
}

// ConnPercentage
{
tests := []struct {
from int
exp int
}{
{-2, 0},
{-1, 0},
{0, 0},
{1, 1},
{25, 25},
{99, 99},
{100, 100},
{101, 100},
{102, 100},
}
for _, test := range tests {
cfg := &DBConfiguration{ConnPercentage: test.from}
err := cfg.Validate()
require.NoError(t, err)
require.Equal(t, test.exp, cfg.ConnPercentage)
}
}
}

func TestPasswordRequiredCharactersDecode(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions internal/observability/request-logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ type logEntry struct {
Entry *logrus.Entry
}

// NewLogEntry returns a new chimiddleware.LogEntry from a *logrus.Entry.
func NewLogEntry(le *logrus.Entry) chimiddleware.LogEntry {
return &logEntry{le}
}

func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) {
fields := logrus.Fields{
"status": status,
Expand Down
Loading