| 
 | 1 | +package sqldb  | 
 | 2 | + | 
 | 3 | +import (  | 
 | 4 | +	"context"  | 
 | 5 | +	"database/sql"  | 
 | 6 | +	"errors"  | 
 | 7 | +	"fmt"  | 
 | 8 | +	"runtime/debug"  | 
 | 9 | +	"strings"  | 
 | 10 | +)  | 
 | 11 | + | 
 | 12 | +var ErrArgumentCountMismatch = errors.New("the number of arguments supplied does not match the statement")  | 
 | 13 | + | 
 | 14 | +type Binder func(Scanner) error  | 
 | 15 | + | 
 | 16 | +type Scanner interface {  | 
 | 17 | +	Scan(...any) error  | 
 | 18 | +}  | 
 | 19 | + | 
 | 20 | +// DBTx is either a *sql.DB or a *sql.Tx  | 
 | 21 | +type DBTx interface {  | 
 | 22 | +	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)  | 
 | 23 | +	ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)  | 
 | 24 | +	QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)  | 
 | 25 | +	QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row  | 
 | 26 | +}  | 
 | 27 | + | 
 | 28 | +// BindAll receives the *sql.Rows + error from the QueryContext method of either  | 
 | 29 | +// a *sql.DB, a *sql.Tx, or a *sql.Stmt, as well as a binder callback, to be called  | 
 | 30 | +// for each record, which gives the caller the opportunity to scan and aggregate values.  | 
 | 31 | +func BindAll(rows *sql.Rows, err error, binder Binder) error {  | 
 | 32 | +	if err != nil {  | 
 | 33 | +		return normalize(err)  | 
 | 34 | +	}  | 
 | 35 | +	defer func() { _ = rows.Close() }()  | 
 | 36 | +	for rows.Next() {  | 
 | 37 | +		err = binder(rows)  | 
 | 38 | +		if err != nil {  | 
 | 39 | +			return normalize(err)  | 
 | 40 | +		}  | 
 | 41 | +	}  | 
 | 42 | +	return nil  | 
 | 43 | +}  | 
 | 44 | + | 
 | 45 | +// ExecuteStatements receives a *sql.DB or *sql.Tx as well as one or more SQL statements (separated by ';')  | 
 | 46 | +// and executes each one with the arguments corresponding to that statement.  | 
 | 47 | +func ExecuteStatements(ctx context.Context, db DBTx, statements string, args ...any) (uint64, error) {  | 
 | 48 | +	placeholderCount := strings.Count(statements, "?")  | 
 | 49 | +	if placeholderCount != len(args) {  | 
 | 50 | +		return 0, fmt.Errorf("%w: Expected: %d, received %d", ErrArgumentCountMismatch, placeholderCount, len(args))  | 
 | 51 | +	}  | 
 | 52 | +	var count uint64  | 
 | 53 | +	index := 0  | 
 | 54 | +	for _, statement := range strings.Split(statements, ";") {  | 
 | 55 | +		if len(strings.TrimSpace(statement)) == 0 {  | 
 | 56 | +			continue  | 
 | 57 | +		}  | 
 | 58 | +		statement += ";" // terminate the statement  | 
 | 59 | +		indexOffset := strings.Count(statement, "?")  | 
 | 60 | +		result, err := db.ExecContext(ctx, statement, args[index:index+indexOffset]...)  | 
 | 61 | +		rows, err := RowsAffected(result, err)  | 
 | 62 | +		if err != nil {  | 
 | 63 | +			return 0, err // already normalized  | 
 | 64 | +		}  | 
 | 65 | +		count += rows  | 
 | 66 | +		index += indexOffset  | 
 | 67 | +	}  | 
 | 68 | +	return count, nil  | 
 | 69 | +}  | 
 | 70 | + | 
 | 71 | +// RowsAffected returns the rows affected from a sql.Result. This is generally only needed  | 
 | 72 | +// by external callers when dealing with the result of a prepared statement.  | 
 | 73 | +func RowsAffected(result sql.Result, err error) (uint64, error) {  | 
 | 74 | +	if err != nil {  | 
 | 75 | +		return 0, normalize(err)  | 
 | 76 | +	}  | 
 | 77 | +	rows, err := result.RowsAffected()  | 
 | 78 | +	if err != nil {  | 
 | 79 | +		return 0, normalize(err)  | 
 | 80 | +	}  | 
 | 81 | +	return uint64(rows), nil  | 
 | 82 | +}  | 
 | 83 | +func normalize(err error) error {  | 
 | 84 | +	if err == nil {  | 
 | 85 | +		return nil  | 
 | 86 | +	}  | 
 | 87 | +	err = fmt.Errorf("%w\nStack Trace:\n%s", err, string(debug.Stack()))  | 
 | 88 | +	if strings.Contains(err.Error(), "operation was canceled") {  | 
 | 89 | +		return fmt.Errorf("%w: %w", context.Canceled, err)  | 
 | 90 | +	}  | 
 | 91 | +	return err  | 
 | 92 | +}  | 
0 commit comments