Skip to content

Commit 87d1a7a

Browse files
committed
Variadic methods.
1 parent b8b68ca commit 87d1a7a

File tree

2 files changed

+66
-55
lines changed

2 files changed

+66
-55
lines changed

contracts.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ type (
2525
// Handle is a high level approach to common database operations, where each operation implements either
2626
// the Query or Script interface.
2727
Handle interface {
28-
Execute(ctx context.Context, script Script) error
29-
Populate(ctx context.Context, query Query) error
30-
PopulateRow(ctx context.Context, query Query) error
28+
Execute(context.Context, ...Script) error
29+
Populate(context.Context, ...Query) error
30+
PopulateRow(context.Context, ...Query) error
3131
}
3232

3333
// Script represents SQL statements that aren't expected to provide rows as a result.

db.go

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -60,81 +60,92 @@ func checksum(x []byte) (hash uint64) {
6060
return h.Sum64()
6161
}
6262

63-
func (this *defaultHandle) Execute(ctx context.Context, script Script) (err error) {
63+
func (this *defaultHandle) Execute(ctx context.Context, scripts ...Script) (err error) {
6464
defer func() { err = normalizeErr(err) }()
65-
statements := script.Statements()
66-
parameters := script.Parameters()
67-
placeholderCount := strings.Count(statements, "?")
68-
if placeholderCount != len(parameters) {
69-
return fmt.Errorf("%w: Expected: %d, received %d", ErrParameterCountMismatch, placeholderCount, len(parameters))
65+
for _, script := range scripts {
66+
statements := script.Statements()
67+
parameters := script.Parameters()
68+
placeholderCount := strings.Count(statements, "?")
69+
if placeholderCount != len(parameters) {
70+
return fmt.Errorf("%w: Expected: %d, received %d", ErrParameterCountMismatch, placeholderCount, len(parameters))
71+
}
72+
for statement, params := range interleaveParameters(statements, parameters...) {
73+
prepared, err := this.prepare(ctx, statement)
74+
if err != nil {
75+
return err
76+
}
77+
var result sql.Result
78+
if prepared != nil {
79+
result, err = prepared.ExecContext(ctx, params...)
80+
} else {
81+
result, err = this.pool.ExecContext(ctx, statement, params...)
82+
}
83+
if err != nil {
84+
return err
85+
}
86+
if rows, ok := script.(RowsAffected); ok {
87+
if affected, err := result.RowsAffected(); err == nil {
88+
rows.RowsAffected(uint64(affected))
89+
}
90+
}
91+
}
7092
}
71-
for statement, params := range interleaveParameters(statements, parameters...) {
93+
return nil
94+
}
95+
func (this *defaultHandle) Populate(ctx context.Context, queries ...Query) (err error) {
96+
defer func() { err = normalizeErr(err) }()
97+
for _, query := range queries {
98+
statement := query.Statement()
7299
prepared, err := this.prepare(ctx, statement)
73100
if err != nil {
74101
return err
75102
}
76-
var result sql.Result
103+
parameters := query.Parameters()
104+
var rows *sql.Rows
77105
if prepared != nil {
78-
result, err = prepared.ExecContext(ctx, params...)
106+
rows, err = prepared.QueryContext(ctx, parameters...)
79107
} else {
80-
result, err = this.pool.ExecContext(ctx, statement, params...)
108+
rows, err = this.pool.QueryContext(ctx, statement, parameters...)
81109
}
82110
if err != nil {
83111
return err
84112
}
85-
if rows, ok := script.(RowsAffected); ok {
86-
if affected, err := result.RowsAffected(); err == nil {
87-
rows.RowsAffected(uint64(affected))
113+
for rows.Next() {
114+
err = query.Scan(rows)
115+
if err != nil {
116+
_ = rows.Close()
117+
return err
88118
}
89119
}
120+
_ = rows.Close()
90121
}
91122
return nil
92123
}
93-
func (this *defaultHandle) Populate(ctx context.Context, query Query) (err error) {
124+
func (this *defaultHandle) PopulateRow(ctx context.Context, queries ...Query) (err error) {
94125
defer func() { err = normalizeErr(err) }()
95-
statement := query.Statement()
96-
prepared, err := this.prepare(ctx, statement)
97-
if err != nil {
98-
return err
99-
}
100-
parameters := query.Parameters()
101-
var rows *sql.Rows
102-
if prepared != nil {
103-
rows, err = prepared.QueryContext(ctx, parameters...)
104-
} else {
105-
rows, err = this.pool.QueryContext(ctx, statement, parameters...)
106-
}
107-
if err != nil {
108-
return err
109-
}
110-
defer func() { _ = rows.Close() }()
111-
for rows.Next() {
112-
err = query.Scan(rows)
126+
for _, query := range queries {
127+
statement := query.Statement()
128+
prepared, err := this.prepare(ctx, statement)
113129
if err != nil {
114130
return err
115131
}
116-
}
117-
return rows.Err()
118-
}
119-
func (this *defaultHandle) PopulateRow(ctx context.Context, query Query) (err error) {
120-
defer func() { err = normalizeErr(err) }()
121-
statement := query.Statement()
122-
prepared, err := this.prepare(ctx, statement)
123-
if err != nil {
132+
parameters := query.Parameters()
133+
var row *sql.Row
134+
if prepared != nil {
135+
row = prepared.QueryRowContext(ctx, parameters...)
136+
} else {
137+
row = this.pool.QueryRowContext(ctx, statement, parameters...)
138+
}
139+
err = query.Scan(row)
140+
if err == nil {
141+
continue
142+
}
143+
if errors.Is(err, sql.ErrNoRows) {
144+
continue
145+
}
124146
return err
125147
}
126-
parameters := query.Parameters()
127-
var row *sql.Row
128-
if prepared != nil {
129-
row = prepared.QueryRowContext(ctx, parameters...)
130-
} else {
131-
row = this.pool.QueryRowContext(ctx, statement, parameters...)
132-
}
133-
err = query.Scan(row)
134-
if errors.Is(err, sql.ErrNoRows) {
135-
return nil
136-
}
137-
return err
148+
return nil
138149
}
139150

140151
// interleaveParameters splits the statements (on ';') and yields each with its corresponding parameters.

0 commit comments

Comments
 (0)