diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index 542ddb1ae..e9d87574c 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -44,7 +44,9 @@ func serve(ctx context.Context) { logrus.WithError(err).Fatal("unable to load config") } - db, err := storage.Dial(config) + // Include serve ctx which carries cancelation signals so DialContext does + // not hang indefinitely at startup. + db, err := storage.DialContext(ctx, config) if err != nil { logrus.Fatalf("error opening database: %+v", err) } @@ -53,6 +55,10 @@ func serve(ctx context.Context) { baseCtx, baseCancel := context.WithCancel(context.Background()) defer baseCancel() + // Add the base context to the db, this is so during the shutdown sequence + // the DB will be available while connections drain. + db = db.WithContext(ctx) + var wg sync.WaitGroup defer wg.Wait() // Do not return to caller until this goroutine is done. @@ -79,7 +85,7 @@ func serve(ctx context.Context) { log := logrus.WithField("component", "api") wrkLog := logrus.WithField("component", "apiworker") - wrk := apiworker.New(config, mrCache, wrkLog) + wrk := apiworker.New(config, mrCache, db, wrkLog) wg.Add(1) go func() { defer wg.Done() diff --git a/internal/api/apiworker/apiworker.go b/internal/api/apiworker/apiworker.go index e6627c9f4..d4d2d5cd1 100644 --- a/internal/api/apiworker/apiworker.go +++ b/internal/api/apiworker/apiworker.go @@ -9,12 +9,15 @@ import ( "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/mailer/templatemailer" + "github.com/supabase/auth/internal/storage" + "golang.org/x/sync/errgroup" ) // Worker is a simple background worker for async tasks. type Worker struct { le *logrus.Entry tc *templatemailer.Cache + db *storage.Connection // Notifies worker the cfg has been updated. cfgCh chan struct{} @@ -31,12 +34,14 @@ type Worker struct { func New( cfg *conf.GlobalConfiguration, tc *templatemailer.Cache, + db *storage.Connection, le *logrus.Entry, ) *Worker { return &Worker{ le: le, cfg: cfg, tc: tc, + db: db, cfgCh: make(chan struct{}, 1), } } @@ -63,14 +68,85 @@ func (o *Worker) ReloadConfig(cfg *conf.GlobalConfiguration) { } } -// Work will periodically reload the templates in the background as long as the -// system remains active. +// Work will run background workers. func (o *Worker) Work(ctx context.Context) error { if ok := o.workMu.TryLock(); !ok { return errors.New("apiworker: concurrent calls to Work are invalid") } defer o.workMu.Unlock() + var ( + eg errgroup.Group + notifyTpl = make(chan struct{}, 1) + notifyDb = make(chan struct{}, 1) + ) + eg.Go(func() error { + return o.configNotifier(ctx, notifyTpl, notifyDb) + }) + eg.Go(func() error { + return o.templateWorker(ctx, notifyTpl) + }) + eg.Go(func() error { + return o.dbWorker(ctx, notifyDb) + }) + return eg.Wait() +} + +func (o *Worker) configNotifier( + ctx context.Context, + notifyCh ...chan<- struct{}, +) error { + le := o.le.WithFields(logrus.Fields{ + "worker_type": "apiworker_config_notifier", + }) + le.Info("apiworker: config notifier started") + defer le.Info("apiworker: config notifier exited") + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-o.cfgCh: + + // When we get a config update, notify each worker to wake up + for _, ch := range notifyCh { + select { + case ch <- struct{}{}: + default: + } + } + } + } +} + +func (o *Worker) dbWorker(ctx context.Context, cfgCh <-chan struct{}) error { + le := o.le.WithFields(logrus.Fields{ + "worker_type": "apiworker_db_worker", + }) + le.Info("apiworker: db worker started") + defer le.Info("apiworker: db worker exited") + + if err := o.db.ApplyConfig(ctx, o.getConfig(), le); err != nil { + le.WithError(err).Error( + "failure applying config connection limits to db") + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-cfgCh: + if err := o.db.ApplyConfig(ctx, o.getConfig(), le); err != nil { + le.WithError(err).Error( + "failure applying config connection limits to db") + } + } + } +} + +// templateWorker will periodically reload the templates in the background as +// long as the system remains active. +func (o *Worker) templateWorker(ctx context.Context, cfgCh <-chan struct{}) error { le := o.le.WithFields(logrus.Fields{ "worker_type": "apiworker_template_cache", }) @@ -91,7 +167,7 @@ func (o *Worker) Work(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case <-o.cfgCh: + case <-cfgCh: tr.Reset(ival()) case <-tr.C: } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 224c4f9eb..72b39b99d 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -104,6 +104,11 @@ type DBConfiguration struct { Driver string `json:"driver" required:"true"` URL string `json:"url" envconfig:"DATABASE_URL" required:"true"` Namespace string `json:"namespace" envconfig:"DB_NAMESPACE" default:"auth"` + + // Percentage of DB conns the auth server may use in + // integer form i.e.: [1, 100] -> [1%, 100%] + ConnPercentage int `json:"conn_percentage" split_words:"true"` + // MaxPoolSize defaults to 0 (unlimited). MaxPoolSize int `json:"max_pool_size" split_words:"true"` MaxIdlePoolSize int `json:"max_idle_pool_size" split_words:"true"` @@ -117,6 +122,7 @@ type DBConfiguration struct { } func (c *DBConfiguration) Validate() error { + c.ConnPercentage = min(max(c.ConnPercentage, 0), 100) return nil } diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index 82982fb26..ada5c0759 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -262,6 +262,30 @@ func TestGlobal(t *testing.T) { err := populateGlobal(cfg) require.NoError(t, err) } + + // ConnPercentage + { + tests := []struct { + from int + exp int + }{ + {-2, 0}, + {-1, 0}, + {0, 0}, + {1, 1}, + {25, 25}, + {99, 99}, + {100, 100}, + {101, 100}, + {102, 100}, + } + for _, test := range tests { + cfg := &DBConfiguration{ConnPercentage: test.from} + err := cfg.Validate() + require.NoError(t, err) + require.Equal(t, test.exp, cfg.ConnPercentage) + } + } } func TestPasswordRequiredCharactersDecode(t *testing.T) { diff --git a/internal/observability/request-logger.go b/internal/observability/request-logger.go index ca1f6b704..c180bdc6d 100644 --- a/internal/observability/request-logger.go +++ b/internal/observability/request-logger.go @@ -73,6 +73,11 @@ type logEntry struct { Entry *logrus.Entry } +// NewLogEntry returns a new chimiddleware.LogEntry from a *logrus.Entry. +func NewLogEntry(le *logrus.Entry) chimiddleware.LogEntry { + return &logEntry{le} +} + func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { fields := logrus.Fields{ "status": status, diff --git a/internal/storage/dial.go b/internal/storage/dial.go index 6181d3c70..03d7205e8 100644 --- a/internal/storage/dial.go +++ b/internal/storage/dial.go @@ -3,6 +3,7 @@ package storage import ( "context" "database/sql" + "fmt" "net/url" "reflect" "time" @@ -16,17 +17,92 @@ import ( "github.com/supabase/auth/internal/conf" ) -// Connection is the interface a storage provider must implement. +// Connection is the interface a storage provider must implement. Do not copy +// a storage connection type Connection struct { *pop.Connection + sqldb *sql.DB } // Dial will connect to that storage engine func Dial(config *conf.GlobalConfiguration) (*Connection, error) { + return DialContext(context.TODO(), config) +} + +func DialContext( + ctx context.Context, + config *conf.GlobalConfiguration, +) (*Connection, error) { + cd, err := newConnectionDetails(config) + if err != nil { + return nil, err + } + + db, err := pop.NewConnection(cd) + if err != nil { + return nil, errors.Wrap(err, "opening database connection") + } + if err := db.Open(); err != nil { + return nil, errors.Wrap(err, "checking database connection") + } + + sqldb, ok := popConnToStd(db) + if ok && config.Metrics.Enabled { + registerOpenTelemetryDatabaseStats(config, sqldb) + } + + conn := &Connection{ + Connection: db, + sqldb: sqldb, + } + return conn, nil +} + +// // GetSqlDB returns the underlying *sql.DB and true or nil if no db could be obtained. +// func (c *Connection) GetSqlDB() (*sql.DB, bool) { return c.sqldb, c.sqldb != nil } + +// Copy will return a copy of this connection. It must be instead of using a +// struct literal from external packages. +func (c *Connection) Copy() *Connection { + cpy := *c + return &cpy +} + +func newConnectionDetails( + config *conf.GlobalConfiguration, +) (*pop.ConnectionDetails, error) { + cd := &pop.ConnectionDetails{ + Dialect: config.DB.Driver, + URL: config.DB.URL, + Pool: config.DB.MaxPoolSize, + IdlePool: config.DB.MaxIdlePoolSize, + ConnMaxLifetime: config.DB.ConnMaxLifetime, + ConnMaxIdleTime: config.DB.ConnMaxIdleTime, + Options: make(map[string]string), + } + if err := applyDBDriver(config, cd); err != nil { + return nil, err + } + if config.DB.HealthCheckPeriod != time.Duration(0) { + cd.Options["pool_health_check_period"] = config.DB.HealthCheckPeriod.String() + } + if config.DB.ConnMaxIdleTime != time.Duration(0) { + cd.Options["pool_max_conn_idle_time"] = config.DB.ConnMaxIdleTime.String() + } + return cd, nil +} + +// TODO(cstockton): I'm preserving the Mutation here for now because I'm not +// sure what side effects changing this could have. But it should probably go +// inside the Validate() function in conf package or somewhere else. +func applyDBDriver( + config *conf.GlobalConfiguration, + cd *pop.ConnectionDetails, +) error { if config.DB.Driver == "" && config.DB.URL != "" { u, err := url.Parse(config.DB.URL) if err != nil { - return nil, errors.Wrap(err, "parsing db connection url") + return errors.Wrap(err, "parsing db connection url") } config.DB.Driver = u.Scheme } @@ -55,57 +131,203 @@ func Dial(config *conf.GlobalConfiguration) (*Connection, error) { } } - options := make(map[string]string) + cd.Driver = driver + return nil +} - if config.DB.HealthCheckPeriod != time.Duration(0) { - options["pool_health_check_period"] = config.DB.HealthCheckPeriod.String() +// NOTE: I couldn't find any way to obtain the store when wrapped with context +// due to the private store field in pop.contextStore. +func popConnToStd(db *pop.Connection) (sqldb *sql.DB, ok bool) { + defer func() { + if rec := recover(); rec != nil { + sqldb, ok = nil, false + } + }() + + // Get element stored in the pop.store interface within field db.Store. + dbval := reflect.ValueOf(db.Store).Elem() // *pop.dB + + // dbval should contain a pointer to struct with layout of pop.dB: + // + // type dB struct { + // *sqlx.DB + // } + // + dbval = dbval.Field(0) // *sqlx.DB + + // dbval should now be a pointer to a struct with layout like sqlx.DB: + // + // type DB struct { + // *sql.DB + // } + // + dbval = dbval.Elem().Field(0) // *sql.DB + + // dbval should now be (*sql.DB) get an iface and try to cast. + sqldb, ok = dbval.Interface().(*sql.DB) + return +} + +// ApplyConfig will apply the given config to this *Connection, potentially +// adjusting the underlying *sql.DB's current settings. +// +// When config.DB.ConnPercentage is set to a non-zero value ApplyConfig attempts +// to set the MaxOpenConns and MaxIdleConns to a percentage based value. It does +// this by opening a connection to the server and calling +// `SHOW max_connections;` to determine the connection limits. If this operation +// fails it applies no configuration changes at all and returns an error. +func (c *Connection) ApplyConfig( + ctx context.Context, + config *conf.GlobalConfiguration, + le *logrus.Entry, +) error { + sqldb := c.sqldb + if sqldb == nil { + return errors.New("storage: ApplyConfig: unable to access underying *sql.DB") } - if config.DB.ConnMaxIdleTime != time.Duration(0) { - options["pool_max_conn_idle_time"] = config.DB.ConnMaxIdleTime.String() + cl, err := c.getConnLimits(ctx, &config.DB) + if err != nil { + return fmt.Errorf("storage: ApplyConfig: %w", err) } - db, err := pop.NewConnection(&pop.ConnectionDetails{ - Dialect: config.DB.Driver, - Driver: driver, - URL: config.DB.URL, - Pool: config.DB.MaxPoolSize, - IdlePool: config.DB.MaxIdlePoolSize, - ConnMaxLifetime: config.DB.ConnMaxLifetime, - ConnMaxIdleTime: config.DB.ConnMaxIdleTime, - Options: options, - }) + le.WithFields(logrus.Fields{ + // Config values + "config_max_pool_size": config.DB.MaxPoolSize, + "config_max_idle_pool_size": config.DB.MaxIdlePoolSize, + "config_conn_max_lifetime": config.DB.ConnMaxLifetime.String(), + "config_conn_max_idle_time": config.DB.ConnMaxIdleTime.String(), + "config_conn_percentage": config.DB.ConnPercentage, + + // Server values + "server_max_conns": cl.ServerMaxConns, + + // Limit values + "limit_max_open_conns": cl.MaxOpenConns, + "limit_max_idle_conns": cl.MaxIdleConns, + "limit_conn_max_lifetime": cl.ConnMaxLifetime.String(), + "limit_conn_max_idle_time": cl.ConnMaxIdleTime.String(), + "limit_strategy": cl.Strategy, + }).Infof("applying connection limits to db using the %q strategy", cl.Strategy) + + sqldb.SetMaxOpenConns(cl.MaxOpenConns) + sqldb.SetMaxIdleConns(cl.MaxIdleConns) + sqldb.SetConnMaxLifetime(cl.ConnMaxLifetime) + sqldb.SetConnMaxIdleTime(cl.ConnMaxIdleTime) + return nil +} + +func (c *Connection) getConnLimits( + ctx context.Context, + dbCfg *conf.DBConfiguration, +) (*ConnLimits, error) { + // Set the connection limits to the fixed values in config + cl := newConnLimitsFromConfig(dbCfg) + + // Always fetch max conns because it is useful for logging. + maxConns, err := c.showMaxConns(ctx) if err != nil { - return nil, errors.Wrap(err, "opening database connection") + return nil, err } - if err := db.Open(); err != nil { - return nil, errors.Wrap(err, "checking database connection") + cl.ServerMaxConns = maxConns + + if dbCfg.ConnPercentage == 0 { + // pct based conn limits are disabled + cl.Strategy = connLimitsFixedStrategy + return cl, nil + } + + // pct conn limits are enabled, try to determine what they should be + if err := c.applyPercentageLimits(dbCfg, maxConns, cl); err != nil { + return nil, err + } + + return cl, nil +} + +func (c *Connection) applyPercentageLimits( + dbCfg *conf.DBConfiguration, + maxConns int, + cl *ConnLimits, +) error { + cl.ServerMaxConns = maxConns // set this here too for unit tests + + if dbCfg.ConnPercentage == 0 { + // pct based conn limits are disabled + cl.Strategy = connLimitsFixedStrategy + return nil + } + + if maxConns <= 0 { + // If maxConns is 0 it means our role or db is not allowing conns right + // now and we do nothing. + return errors.New("db reported a maximum of 0 connections") } - if config.Metrics.Enabled { - registerOpenTelemetryDatabaseStats(db, config) + // Ensure the conn pct isn't OOB + if dbCfg.ConnPercentage <= 0 || dbCfg.ConnPercentage > 100 { + return errors.New("db conn percentage must be between 1 and 100") } - return &Connection{db}, nil + // maxConns > 0 so we may calculate the percentage. + pct := float64(dbCfg.ConnPercentage) + cl.MaxOpenConns = int(max(1, (pct/100)*float64(maxConns))) + + // We set max idle conns to the max open conns. + cl.MaxIdleConns = cl.MaxOpenConns + + // return the percentage based conn limits + cl.Strategy = connLimitsPercentageStrategy + return nil +} + +// showMaxConns retrieves the max_connections from the db. +func (c *Connection) showMaxConns(ctx context.Context) (int, error) { + db := c.WithContext(ctx) + + var maxConns int + err := db.Transaction(func(tx *Connection) error { + return tx.RawQuery("SHOW max_connections;").First(&maxConns) + }) + if err != nil { + return 0, err + } + return maxConns, nil } -func registerOpenTelemetryDatabaseStats(db *pop.Connection, config *conf.GlobalConfiguration) { +const ( + connLimitsErrorStrategy = "error" + connLimitsFixedStrategy = "fixed" + connLimitsPercentageStrategy = "percentage" +) + +// ConnLimits represents the connection limits for the underlying *sql.DB. +type ConnLimits struct { + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration + ConnMaxIdleTime time.Duration + ServerMaxConns int + Strategy string +} + +func newConnLimitsFromConfig(dbCfg *conf.DBConfiguration) *ConnLimits { + return &ConnLimits{ + MaxOpenConns: dbCfg.MaxPoolSize, + MaxIdleConns: dbCfg.MaxIdlePoolSize, + ConnMaxLifetime: dbCfg.ConnMaxLifetime, + ConnMaxIdleTime: dbCfg.ConnMaxIdleTime, + Strategy: connLimitsErrorStrategy, + } +} + +func registerOpenTelemetryDatabaseStats(config *conf.GlobalConfiguration, sqldb *sql.DB) { defer func() { if rec := recover(); rec != nil { logrus.WithField("error", rec).Error("registerOpenTelemetryDatabaseStats is not able to determine database object with reflection -- panicked") } }() - dbval := reflect.Indirect(reflect.ValueOf(db.Store)) - dbfield := dbval.Field(0) - sqldbfield := reflect.Indirect(dbfield).Field(0) - - sqldb, ok := sqldbfield.Interface().(*sql.DB) - if !ok || sqldb == nil { - logrus.Error("registerOpenTelemetryDatabaseStats is not able to determine database object with reflection") - return - } - if err := otelsql.RegisterDBStatsMetrics(sqldb); err != nil { logrus.WithError(err).Error("unable to register OpenTelemetry stats metrics for databse") } else { @@ -154,7 +376,10 @@ func (c *Connection) Transaction(fn func(*Connection) error) error { if c.TX == nil { var returnErr error if terr := c.Connection.Transaction(func(tx *pop.Connection) error { - err := fn(&Connection{tx}) + conn := c.Copy() + conn.Connection = tx + + err := fn(conn) switch err.(type) { case *CommitWithError: returnErr = err @@ -179,7 +404,9 @@ func (c *Connection) Transaction(fn func(*Connection) error) error { // WithContext returns a new connection with an updated context. This is // typically used for tracing as the context contains trace span information. func (c *Connection) WithContext(ctx context.Context) *Connection { - return &Connection{c.Connection.WithContext(ctx)} + cpy := c.Copy() + cpy.Connection = cpy.Connection.WithContext(ctx) + return cpy } func getExcludedColumns(model interface{}, includeColumns ...string) ([]string, error) { diff --git a/internal/storage/dial_test.go b/internal/storage/dial_test.go index 078b6d57a..fac2f4d72 100644 --- a/internal/storage/dial_test.go +++ b/internal/storage/dial_test.go @@ -1,12 +1,18 @@ package storage import ( + "context" "errors" + "fmt" "testing" + "time" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/stretchr/testify/require" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" + "golang.org/x/sync/errgroup" ) type TestUser struct { @@ -58,3 +64,440 @@ func TestTransaction(t *testing.T) { require.NoError(t, err) require.Empty(t, data) } + +func TestPopConnToStd(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + apiTestConfig := "../../hack/test.env" + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + cd, err := newConnectionDetails(config) + require.NoError(t, err) + + conn, err := pop.NewConnection(cd) + require.NoError(t, err) + require.NoError(t, conn.Open()) + + t.Run("connToDB", func(t *testing.T) { + sdb, ok := popConnToStd(conn) + require.NotNil(t, sdb) + require.True(t, ok) + require.Equal(t, "*sql.DB", fmt.Sprintf("%T", sdb)) + }) + + // Could not find a way to do this (without unsafe) due to struct layout + // of pop.contextStore (contextStore { store: store }). + t.Run("connWithContextToDB", func(t *testing.T) { + sdb, ok := popConnToStd(conn.WithContext(ctx)) + require.Nil(t, sdb) + require.False(t, ok) + }) +} + +func TestConnection(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + t.Run("DialContext", func(t *testing.T) { + config := mustConfig(t) + dbCfg := config.DB + + cd, err := newConnectionDetails(config) + require.NoError(t, err) + require.Equal(t, dbCfg.Driver, cd.Dialect) + require.Equal(t, dbCfg.URL, cd.URL) + require.Equal(t, dbCfg.MaxPoolSize, cd.Pool) + require.Equal(t, dbCfg.MaxIdlePoolSize, cd.IdlePool) + require.Equal(t, dbCfg.ConnMaxLifetime, cd.ConnMaxLifetime) + require.Equal(t, dbCfg.ConnMaxIdleTime, cd.ConnMaxIdleTime) + + db, err := DialContext(ctx, config) + require.NoError(t, err) + require.NotNil(t, db) + defer db.Close() + }) + + t.Run("DialContextInvalidDriver", func(t *testing.T) { + config := mustConfig(t) + + // set invalid db url + config.DB.URL = string([]byte("\x00")) + config.DB.Driver = "" + + const errStr = "invalid control character in URL" + db, err := DialContext(ctx, config) + require.Nil(t, db) + require.Error(t, err) + require.Contains(t, err.Error(), errStr) + }) +} + +func TestConnLimits(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + t.Run("ApplyConfig", func(t *testing.T) { + config := mustConfig(t) + config.DB.MaxPoolSize = 50 + + db, err := DialContext(ctx, config) + require.NoError(t, err) + require.NotNil(t, db) + defer db.Close() + + maxConns, err := db.showMaxConns(ctx) + require.NoError(t, err) + + // 100 is current default in local testing, we rely on it for unit tests + require.True(t, maxConns == 100) + + openConns := func(n int) { + var eg errgroup.Group + for range n { + eg.Go(func() error { + return db.Transaction(func(tx *Connection) error { + return tx.RawQuery("SELECT pg_sleep(0.1);").Exec() + }) + }) + } + require.NoError(t, eg.Wait()) + } + + // baseline: should have at most 1 conn + { + stats := db.sqldb.Stats() + require.Equal(t, 1, stats.OpenConnections, + "expected a single open connection") + } + + // stats should show max poolsize since we have no conn pct + { + // open max pool size * 2 + openConns(config.DB.MaxPoolSize * 2) + + // after blocking calls we should have the max pool size if applying worked. + stats := db.sqldb.Stats() + require.Equal(t, config.DB.MaxPoolSize, stats.OpenConnections, + "expected a single open connection") + require.NoError(t, err) + } + + // apply percentage based now + { + newConfig := mustConfig(t) + newConfig.DB.ConnPercentage = 30 + + le := observability.GetLogEntryFromContext(ctx).Entry + err := db.ApplyConfig(ctx, newConfig, le) + require.NoError(t, err) + } + + // stats should show 30 since we allocated 30% of conns of our 100 avail + { + // open max pool size * 2 + openConns(config.DB.MaxPoolSize * 2) + + // after blocking calls we should have the max pool size if applying worked. + stats := db.sqldb.Stats() + + require.Equal(t, 30, stats.OpenConnections, + "expected a single open connection") + require.NoError(t, err) + } + + // exp error when sqldb nil + { + db.sqldb = nil + newConfig := mustConfig(t) + newConfig.DB.ConnPercentage = 50 + + le := observability.GetLogEntryFromContext(ctx).Entry + err := db.ApplyConfig(ctx, config, le) + require.Error(t, err) + require.Contains(t, err.Error(), "unable to access underying *sql.DB") + } + }) + + t.Run("getConnLimits", func(t *testing.T) { + config := mustConfig(t) + config.DB.MaxPoolSize = 50 + + db, err := DialContext(ctx, config) + require.NoError(t, err) + require.NotNil(t, db) + defer db.Close() + + const maxConns = 100 + { + serverMaxConns, err := db.showMaxConns(ctx) + require.NoError(t, err) + + // 100 is current default in local testing, we rely on it for unit tests + require.True(t, serverMaxConns == maxConns) + } + + t.Run("PercentageEnabled", func(t *testing.T) { + dbCfg := conf.DBConfiguration{ + ConnPercentage: 10, + MaxPoolSize: 50, + MaxIdlePoolSize: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + } + exp := ConnLimits{ + MaxOpenConns: 10, + MaxIdleConns: 10, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + Strategy: connLimitsPercentageStrategy, + } + + cl, err := db.getConnLimits(ctx, &dbCfg) + require.NoError(t, err) + require.NotNil(t, cl) + + require.Equal(t, exp.MaxOpenConns, cl.MaxOpenConns) + require.Equal(t, exp.MaxIdleConns, cl.MaxIdleConns) + require.Equal(t, exp.ConnMaxLifetime, cl.ConnMaxLifetime) + require.Equal(t, exp.ConnMaxIdleTime, cl.ConnMaxIdleTime) + require.Equal(t, exp.Strategy, cl.Strategy) + require.Equal(t, maxConns, cl.ServerMaxConns) + }) + + t.Run("PercentageDisabled", func(t *testing.T) { + dbCfg := conf.DBConfiguration{ + ConnPercentage: 0, + MaxPoolSize: 50, + MaxIdlePoolSize: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + } + exp := ConnLimits{ + MaxOpenConns: 50, + MaxIdleConns: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + Strategy: connLimitsFixedStrategy, + } + + cl, err := db.getConnLimits(ctx, &dbCfg) + require.NoError(t, err) + require.NotNil(t, cl) + + require.Equal(t, exp.MaxOpenConns, cl.MaxOpenConns) + require.Equal(t, exp.MaxIdleConns, cl.MaxIdleConns) + require.Equal(t, exp.ConnMaxLifetime, cl.ConnMaxLifetime) + require.Equal(t, exp.ConnMaxIdleTime, cl.ConnMaxIdleTime) + require.Equal(t, exp.Strategy, cl.Strategy) + require.Equal(t, maxConns, cl.ServerMaxConns) + }) + }) + + t.Run("applyPercentageLimits", func(t *testing.T) { + + config := mustConfig(t) + config.DB.MaxPoolSize = 50 + + db, err := DialContext(ctx, config) + require.NoError(t, err) + require.NotNil(t, db) + defer db.Close() + + const maxConns = 100 + { + serverMaxConns, err := db.showMaxConns(ctx) + require.NoError(t, err) + + // 100 is current default in local testing, we rely on it for unit tests + require.True(t, serverMaxConns == maxConns) + } + + type testCase struct { + desc string + maxConns int + cfg conf.DBConfiguration + exp ConnLimits + err string + } + tests := []testCase{ + + { + desc: "exp fallback to maxpool size", + maxConns: maxConns, + cfg: conf.DBConfiguration{ + ConnPercentage: 0, + MaxPoolSize: 50, + MaxIdlePoolSize: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + }, + exp: ConnLimits{ + MaxOpenConns: 50, + MaxIdleConns: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + ServerMaxConns: maxConns, + Strategy: connLimitsFixedStrategy, + }, + }, + + { + desc: "exp conn pct to take precedence over max pool size", + maxConns: maxConns, + cfg: conf.DBConfiguration{ + ConnPercentage: 30, + MaxPoolSize: 50, + MaxIdlePoolSize: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + }, + exp: ConnLimits{ + MaxOpenConns: 30, + MaxIdleConns: 30, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + ServerMaxConns: maxConns, + Strategy: connLimitsPercentageStrategy, + }, + }, + + { + desc: "exp conn pct to ignore fixed values", + maxConns: maxConns, + cfg: conf.DBConfiguration{ + ConnPercentage: 30, + MaxPoolSize: 0, + MaxIdlePoolSize: 0, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + }, + exp: ConnLimits{ + MaxOpenConns: 30, + MaxIdleConns: 30, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + ServerMaxConns: maxConns, + Strategy: connLimitsPercentageStrategy, + }, + }, + + { + desc: "exp conn pct to not be set lower than 1 for small max conns", + maxConns: 4, + cfg: conf.DBConfiguration{ + ConnPercentage: 10, + MaxPoolSize: 50, + MaxIdlePoolSize: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + }, + exp: ConnLimits{ + MaxOpenConns: 1, + MaxIdleConns: 1, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + ServerMaxConns: 4, + Strategy: connLimitsPercentageStrategy, + }, + }, + + { + desc: "exp error", + err: "percentage must be between 1 and 100", + maxConns: maxConns, + cfg: conf.DBConfiguration{ + ConnPercentage: -1, + MaxPoolSize: 50, + MaxIdlePoolSize: 25, + }, + exp: ConnLimits{ + MaxOpenConns: 50, + MaxIdleConns: 50, + ServerMaxConns: maxConns, + Strategy: connLimitsErrorStrategy, + }, + }, + + { + desc: "exp error", + err: "db reported a maximum of 0 connections", + maxConns: 0, + cfg: conf.DBConfiguration{ + ConnPercentage: 30, + MaxPoolSize: 50, + MaxIdlePoolSize: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + }, + exp: ConnLimits{ + MaxOpenConns: 50, + MaxIdleConns: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + ServerMaxConns: 0, + Strategy: connLimitsErrorStrategy, + }, + }, + } + + tcStr := func(tc testCase) string { + str := fmt.Sprintf("%v when server maxConns is %d", tc.desc, tc.maxConns) + str += fmt.Sprintf(" and cfg(pct: %v max: %v)", + tc.cfg.ConnPercentage, tc.cfg.MaxPoolSize) + str += fmt.Sprintf(" exp %v", tc.exp) + return str + } + + for idx, tc := range tests { + t.Logf("test #%v - %v", idx, tcStr(tc)) + + dbCfg := &tc.cfg + cl := newConnLimitsFromConfig(dbCfg) + + err := db.applyPercentageLimits(dbCfg, tc.maxConns, cl) + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + continue + } + require.NoError(t, err) + require.NotNil(t, cl) + + require.Equal(t, tc.exp.MaxOpenConns, cl.MaxOpenConns) + require.Equal(t, tc.exp.MaxIdleConns, cl.MaxIdleConns) + require.Equal(t, tc.exp.ConnMaxLifetime, cl.ConnMaxLifetime) + require.Equal(t, tc.exp.ConnMaxIdleTime, cl.ConnMaxIdleTime) + require.Equal(t, tc.exp.Strategy, cl.Strategy) + require.Equal(t, tc.exp.ServerMaxConns, cl.ServerMaxConns) + } + }) +} + +func mustConfig(t *testing.T) *conf.GlobalConfiguration { + apiTestConfig := "../../hack/test.env" + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + config.Tracing.Enabled = true + config.Metrics.Enabled = true + + dbCfg := &conf.DBConfiguration{ + Driver: config.DB.Driver, + URL: config.DB.URL, + Namespace: config.DB.Namespace, + + ConnPercentage: 0, + MaxPoolSize: 50, + MaxIdlePoolSize: 50, + ConnMaxIdleTime: time.Second * 60, + ConnMaxLifetime: 0, + + HealthCheckPeriod: 0, + CleanupEnabled: false, + MigrationsPath: config.DB.MigrationsPath, + } + config.DB = *dbCfg + return config +}