Skip to content

Commit fdabdfb

Browse files
committed
Introduce interface for optimistic concurrency check.
1 parent 70cf75c commit fdabdfb

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

contracts.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ import (
66
"errors"
77
)
88

9-
var ErrParameterCountMismatch = errors.New("the number of parameters supplied does not match the statement")
9+
var (
10+
ErrParameterCountMismatch = errors.New("the number of parameters supplied does not match the statement")
11+
ErrOptimisticConcurrencyCheckFailed = errors.New("optimistic concurrency check failed")
12+
)
1013

1114
type (
1215
logger interface {
@@ -48,6 +51,14 @@ type (
4851
RowsAffected(uint64)
4952
}
5053

54+
// OptimisticConcurrencyCheck provides an (optional) hook for a type implementing
55+
// Script to verify whether the total count of all rows affected matches the returned
56+
// value. If not, an error wrapped with ErrOptimisticConcurrencyCheckFailed will be
57+
// returned by Handle.Execute().
58+
OptimisticConcurrencyCheck interface {
59+
ExpectedRowsAffected() uint64
60+
}
61+
5162
// Query represents a SQL statement that is expected to provide rows as a result.
5263
// Rows are provided to the Scan method.
5364
Query interface {

db.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ func (this *defaultHandle) Execute(ctx context.Context, scripts ...Script) (err
6969
if placeholderCount != len(parameters) {
7070
return fmt.Errorf("%w: Expected: %d, received %d", ErrParameterCountMismatch, placeholderCount, len(parameters))
7171
}
72+
var actualRowsAffectedCount uint64
7273
for statement, params := range interleaveParameters(statements, parameters...) {
7374
prepared, err := this.prepare(ctx, statement)
7475
if err != nil {
@@ -83,10 +84,22 @@ func (this *defaultHandle) Execute(ctx context.Context, scripts ...Script) (err
8384
if err != nil {
8485
return err
8586
}
87+
affected, err := result.RowsAffected()
88+
if err != nil {
89+
return err
90+
}
91+
rowCount := uint64(max(0, affected))
92+
actualRowsAffectedCount += rowCount
8693
if rows, ok := script.(RowsAffected); ok {
87-
if affected, err := result.RowsAffected(); err == nil {
88-
rows.RowsAffected(uint64(affected))
89-
}
94+
rows.RowsAffected(rowCount)
95+
}
96+
}
97+
if check, ok := script.(OptimisticConcurrencyCheck); ok {
98+
expectedRowsAffectedCount := check.ExpectedRowsAffected()
99+
if actualRowsAffectedCount != expectedRowsAffectedCount {
100+
return fmt.Errorf("%w: expected rows affected: %d (actual: %d)",
101+
ErrOptimisticConcurrencyCheckFailed, expectedRowsAffectedCount, actualRowsAffectedCount,
102+
)
90103
}
91104
}
92105
}

integration/db_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ func (this *Fixture) Teardown() {
4646
this.So(this.db.Close(), better.BeNil)
4747
}
4848

49+
func (this *Fixture) TestScript() {
50+
script := &InsertOne{name: "foo", expectedRowsAffected: 1}
51+
err := this.DB.Execute(this.Context(), script)
52+
this.So(err, better.BeNil)
53+
row := this.db.QueryRowContext(this.Context(), "select name from sqldb_integration_test where name = 'foo';")
54+
var foo string
55+
err = row.Scan(&foo)
56+
this.So(err, better.BeNil)
57+
this.So(foo, should.Equal, "foo")
58+
}
59+
func (this *Fixture) TestScript_OptimisticConcurrencyCheckFailure() {
60+
script := &InsertOne{name: "foo", expectedRowsAffected: 2}
61+
err := this.DB.Execute(this.Context(), script)
62+
this.So(err, should.WrapError, sqldb.ErrOptimisticConcurrencyCheckFailed)
63+
}
4964
func (this *Fixture) TestQuery() {
5065
for range 10 { // should transition to prepared statements
5166
query := &SelectAll{Result: make(map[int]string)}
@@ -74,6 +89,23 @@ func (this *Fixture) TestQueryQueryRow_NoResult() {
7489

7590
///////////////////////////////////////////////
7691

92+
type InsertOne struct {
93+
expectedRowsAffected uint64
94+
name string
95+
}
96+
97+
func (this *InsertOne) Statements() string {
98+
return "INSERT INTO sqldb_integration_test (name) VALUES (?);"
99+
}
100+
func (this *InsertOne) Parameters() []any {
101+
return []any{this.name}
102+
}
103+
func (this *InsertOne) ExpectedRowsAffected() uint64 {
104+
return this.expectedRowsAffected
105+
}
106+
107+
///////////////////////////////////////////////
108+
77109
type DDL struct {
78110
totalRows uint64
79111
}

0 commit comments

Comments
 (0)