Skip to content

Commit 7295811

Browse files
committed
v3
Implement the spirit of previous versions in a single file with just a few simple functions. Adds inherent support for prepared statements, as well as QueryRowContext. It is very plausible that this module, which is now less than 100 lines of code, is a good candidate for copy-paste installation.
1 parent c42503d commit 7295811

File tree

3 files changed

+172
-1
lines changed

3 files changed

+172
-1
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
module github.com/smarty/sqldb/v2
1+
module github.com/smarty/sqldb/v3
22

33
go 1.18

sqldb.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package sqldb
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"fmt"
8+
"runtime/debug"
9+
"strings"
10+
)
11+
12+
var ErrArgumentCountMismatch = errors.New("the number of arguments supplied does not match the statement")
13+
14+
type Binder func(Scanner) error
15+
16+
type Scanner interface {
17+
Scan(...any) error
18+
}
19+
20+
// DBTx is either a *sql.DB or a *sql.Tx
21+
type DBTx interface {
22+
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
23+
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
24+
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
25+
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
26+
}
27+
28+
// BindAll receives the *sql.Rows + error from the QueryContext method of either
29+
// a *sql.DB, a *sql.Tx, or a *sql.Stmt, as well as a binder callback, to be called
30+
// for each record, which gives the caller the opportunity to scan and aggregate values.
31+
func BindAll(rows *sql.Rows, err error, binder Binder) error {
32+
if err != nil {
33+
return normalize(err)
34+
}
35+
defer func() { _ = rows.Close() }()
36+
for rows.Next() {
37+
err = binder(rows)
38+
if err != nil {
39+
return normalize(err)
40+
}
41+
}
42+
return nil
43+
}
44+
45+
// ExecuteStatements receives a *sql.DB or *sql.Tx as well as one or more SQL statements (separated by ';')
46+
// and executes each one with the arguments corresponding to that statement.
47+
func ExecuteStatements(ctx context.Context, db DBTx, statements string, args ...any) (uint64, error) {
48+
placeholderCount := strings.Count(statements, "?")
49+
if placeholderCount != len(args) {
50+
return 0, fmt.Errorf("%w: Expected: %d, received %d", ErrArgumentCountMismatch, placeholderCount, len(args))
51+
}
52+
var count uint64
53+
index := 0
54+
for _, statement := range strings.Split(statements, ";") {
55+
if len(strings.TrimSpace(statement)) == 0 {
56+
continue
57+
}
58+
statement += ";" // terminate the statement
59+
indexOffset := strings.Count(statement, "?")
60+
result, err := db.ExecContext(ctx, statement, args[index:index+indexOffset]...)
61+
rows, err := RowsAffected(result, err)
62+
if err != nil {
63+
return 0, err // already normalized
64+
}
65+
count += rows
66+
index += indexOffset
67+
}
68+
return count, nil
69+
}
70+
71+
// RowsAffected returns the rows affected from a sql.Result. This is generally only needed
72+
// by external callers when dealing with the result of a prepared statement.
73+
func RowsAffected(result sql.Result, err error) (uint64, error) {
74+
if err != nil {
75+
return 0, normalize(err)
76+
}
77+
rows, err := result.RowsAffected()
78+
if err != nil {
79+
return 0, normalize(err)
80+
}
81+
return uint64(rows), nil
82+
}
83+
func normalize(err error) error {
84+
if err == nil {
85+
return nil
86+
}
87+
err = fmt.Errorf("%w\nStack Trace:\n%s", err, string(debug.Stack()))
88+
if strings.Contains(err.Error(), "operation was canceled") {
89+
return fmt.Errorf("%w: %w", context.Canceled, err)
90+
}
91+
return err
92+
}

sqldb_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package sqldb_test
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/smarty/sqldb/v3"
8+
)
9+
10+
// Business Event:
11+
type FooEstablished struct {
12+
Timestamp time.Time
13+
FooID uint64
14+
}
15+
16+
// Storage Operation:
17+
type LoadFooName struct {
18+
FooID uint64
19+
Result struct {
20+
FooID uint64
21+
FooName string
22+
}
23+
}
24+
25+
type Mapper struct {
26+
db sqldb.DBTx
27+
}
28+
29+
func (this Mapper) fooEstablished(ctx context.Context, operation FooEstablished) (uint64, error) {
30+
return sqldb.ExecuteStatements(ctx, this.db, `
31+
INSERT
32+
INTO Foos
33+
( foo_id, created, foo_name )
34+
VALUES ( ?, ?, '' )
35+
ON DUPLICATE KEY
36+
UPDATE created = created;`,
37+
operation.FooID, operation.Timestamp,
38+
)
39+
}
40+
func (this Mapper) fooEstablished_Prepared(ctx context.Context, operation FooEstablished) (uint64, error) {
41+
statement, err := this.db.PrepareContext(ctx, `
42+
INSERT
43+
INTO Foos
44+
( foo_id, created, foo_name )
45+
VALUES ( ?, ?, '' )
46+
ON DUPLICATE KEY
47+
UPDATE created = created;`,
48+
)
49+
if err != nil {
50+
return 0, err
51+
}
52+
return sqldb.RowsAffected(statement.ExecContext(ctx, operation.FooID, operation.Timestamp))
53+
}
54+
55+
func (this Mapper) loadFooName(ctx context.Context, operation *LoadFooName) error {
56+
rows, err := this.db.QueryContext(ctx, `
57+
SELECT foo_id, foo_name
58+
FROM Foos
59+
WHERE foo_id = ?;`,
60+
operation.FooID,
61+
)
62+
return sqldb.BindAll(rows, err, func(source sqldb.Scanner) error {
63+
return source.Scan(&operation.Result.FooID, &operation.Result.FooName)
64+
})
65+
}
66+
func (this Mapper) loadFooName_Prepared(ctx context.Context, operation *LoadFooName) error {
67+
statement, err := this.db.PrepareContext(ctx, `
68+
SELECT foo_id, foo_name
69+
FROM Foos
70+
WHERE foo_id = ?;`,
71+
)
72+
if err != nil {
73+
return err
74+
}
75+
rows, err := statement.QueryContext(ctx, operation.FooID)
76+
return sqldb.BindAll(rows, err, func(source sqldb.Scanner) error {
77+
return source.Scan(&operation.Result.FooID, &operation.Result.FooName)
78+
})
79+
}

0 commit comments

Comments
 (0)