Skip to content

Commit 45a92ba

Browse files
Matthias RabeGreg Weber
authored andcommitted
UnsafeLogged added to report missing fields
Signed-off-by: Matthias Rabe <[email protected]>
1 parent 28212d4 commit 45a92ba

File tree

1 file changed

+96
-17
lines changed

1 file changed

+96
-17
lines changed

sqlx.go

Lines changed: 96 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"fmt"
88

9+
"io"
910
"io/ioutil"
1011
"path/filepath"
1112
"reflect"
@@ -144,6 +145,43 @@ func isUnsafe(i interface{}) bool {
144145
}
145146
}
146147

148+
func logFor(i interface{}) io.Writer {
149+
switch v := i.(type) {
150+
case Row:
151+
return v.log
152+
case *Row:
153+
return v.log
154+
case Rows:
155+
return v.log
156+
case *Rows:
157+
return v.log
158+
case NamedStmt:
159+
return v.Stmt.log
160+
case *NamedStmt:
161+
return v.Stmt.log
162+
case Stmt:
163+
return v.log
164+
case *Stmt:
165+
return v.log
166+
case qStmt:
167+
return v.log
168+
case *qStmt:
169+
return v.log
170+
case DB:
171+
return v.log
172+
case *DB:
173+
return v.log
174+
case Tx:
175+
return v.log
176+
case *Tx:
177+
return v.log
178+
case sql.Rows, *sql.Rows:
179+
return nil
180+
default:
181+
return nil
182+
}
183+
}
184+
147185
func mapperFor(i interface{}) *reflectx.Mapper {
148186
switch i := i.(type) {
149187
case DB:
@@ -167,6 +205,7 @@ var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
167205
type Row struct {
168206
err error
169207
unsafe bool
208+
log io.Writer
170209
rows *sql.Rows
171210
Mapper *reflectx.Mapper
172211
}
@@ -243,6 +282,7 @@ type DB struct {
243282
*sql.DB
244283
driverName string
245284
unsafe bool
285+
log io.Writer
246286
Mapper *reflectx.Mapper
247287
}
248288

@@ -291,7 +331,16 @@ func (db *DB) Rebind(query string) string {
291331
// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its
292332
// safety behavior.
293333
func (db *DB) Unsafe() *DB {
294-
return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper}
334+
return db.UnsafeLogged(nil)
335+
}
336+
337+
// Like Unsafe, UnsafeLogged returns a version of DB which will succeed to scan
338+
// when columns in the SQL result have no fields in the destination struct.
339+
// But unlike Unsafe(), this will write a short log, if it does.
340+
// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its
341+
// safety behavior.
342+
func (db *DB) UnsafeLogged(log io.Writer) *DB {
343+
return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, log: log, Mapper: db.Mapper}
295344
}
296345

297346
// BindNamed binds a query using the DB driver's bindvar type.
@@ -340,7 +389,7 @@ func (db *DB) Beginx() (*Tx, error) {
340389
if err != nil {
341390
return nil, err
342391
}
343-
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
392+
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, log: db.log, Mapper: db.Mapper}, err
344393
}
345394

346395
// Queryx queries the database and returns an *sqlx.Rows.
@@ -350,14 +399,14 @@ func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
350399
if err != nil {
351400
return nil, err
352401
}
353-
return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
402+
return &Rows{Rows: r, unsafe: db.unsafe, log: db.log, Mapper: db.Mapper}, err
354403
}
355404

356405
// QueryRowx queries the database and returns an *sqlx.Row.
357406
// Any placeholder parameters are replaced with supplied args.
358407
func (db *DB) QueryRowx(query string, args ...interface{}) *Row {
359408
rows, err := db.DB.Query(query, args...)
360-
return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
409+
return &Row{rows: rows, err: err, unsafe: db.unsafe, log: db.log, Mapper: db.Mapper}
361410
}
362411

363412
// MustExec (panic) runs MustExec using this database.
@@ -381,6 +430,7 @@ type Conn struct {
381430
*sql.Conn
382431
driverName string
383432
unsafe bool
433+
log io.Writer
384434
Mapper *reflectx.Mapper
385435
}
386436

@@ -389,6 +439,7 @@ type Tx struct {
389439
*sql.Tx
390440
driverName string
391441
unsafe bool
442+
log io.Writer
392443
Mapper *reflectx.Mapper
393444
}
394445

@@ -405,7 +456,14 @@ func (tx *Tx) Rebind(query string) string {
405456
// Unsafe returns a version of Tx which will silently succeed to scan when
406457
// columns in the SQL result have no fields in the destination struct.
407458
func (tx *Tx) Unsafe() *Tx {
408-
return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper}
459+
return tx.UnsafeLogged(nil)
460+
}
461+
462+
// Like Unsafe, UnsafeLogged returns a version of Tx which will succeed to
463+
// scan when columns in the SQL result have no fields in the destination struct.
464+
// But unlike Unsafe(), this will write a short log, if it does.
465+
func (tx *Tx) UnsafeLogged(log io.Writer) *Tx {
466+
return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, log: log, Mapper: tx.Mapper}
409467
}
410468

411469
// BindNamed binds a query within a transaction's bindvar type.
@@ -438,14 +496,14 @@ func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
438496
if err != nil {
439497
return nil, err
440498
}
441-
return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
499+
return &Rows{Rows: r, unsafe: tx.unsafe, log: tx.log, Mapper: tx.Mapper}, err
442500
}
443501

444502
// QueryRowx within a transaction.
445503
// Any placeholder parameters are replaced with supplied args.
446504
func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {
447505
rows, err := tx.Tx.Query(query, args...)
448-
return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
506+
return &Row{rows: rows, err: err, unsafe: tx.unsafe, log: tx.log, Mapper: tx.Mapper}
449507
}
450508

451509
// Get within a transaction.
@@ -501,13 +559,21 @@ func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) {
501559
type Stmt struct {
502560
*sql.Stmt
503561
unsafe bool
562+
log io.Writer
504563
Mapper *reflectx.Mapper
505564
}
506565

507566
// Unsafe returns a version of Stmt which will silently succeed to scan when
508567
// columns in the SQL result have no fields in the destination struct.
509568
func (s *Stmt) Unsafe() *Stmt {
510-
return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper}
569+
return s.UnsafeLogged(nil)
570+
}
571+
572+
// Like Unsafe, UnsafeLogged returns a version of Stmt which will succeed to
573+
// scan when columns in the SQL result have no fields in the destination struct.
574+
// But unlike Unsafe(), this will write a short log, if it does.
575+
func (s *Stmt) UnsafeLogged(log io.Writer) *Stmt {
576+
return &Stmt{Stmt: s.Stmt, unsafe: true, log: log, Mapper: s.Mapper}
511577
}
512578

513579
// Select using the prepared statement.
@@ -557,12 +623,12 @@ func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) {
557623
if err != nil {
558624
return nil, err
559625
}
560-
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
626+
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, log: q.Stmt.log, Mapper: q.Stmt.Mapper}, err
561627
}
562628

563629
func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row {
564630
rows, err := q.Stmt.Query(args...)
565-
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}
631+
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, log: q.Stmt.log, Mapper: q.Stmt.Mapper}
566632
}
567633

568634
func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) {
@@ -574,6 +640,7 @@ func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) {
574640
type Rows struct {
575641
*sql.Rows
576642
unsafe bool
643+
log io.Writer
577644
Mapper *reflectx.Mapper
578645
// these fields cache memory use for a rows during iteration w/ structScan
579646
started bool
@@ -614,8 +681,12 @@ func (r *Rows) StructScan(dest interface{}) error {
614681

615682
r.fields = m.TraversalsByName(v.Type(), columns)
616683
// if we are not unsafe and are missing fields, return an error
617-
if f, err := missingFields(r.fields); err != nil && !r.unsafe {
618-
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
684+
if f, err := missingFields(r.fields); err != nil {
685+
if !r.unsafe {
686+
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
687+
} else if r.log != nil {
688+
fmt.Fprintf(r.log, "missing destination name %s in %T\n", columns[f], dest)
689+
}
619690
}
620691
r.values = make([]interface{}, len(columns))
621692
r.started = true
@@ -662,7 +733,7 @@ func Preparex(p Preparer, query string) (*Stmt, error) {
662733
if err != nil {
663734
return nil, err
664735
}
665-
return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err
736+
return &Stmt{Stmt: s, unsafe: isUnsafe(p), log: logFor(p), Mapper: mapperFor(p)}, err
666737
}
667738

668739
// Select executes a query using the provided Queryer, and StructScans each row
@@ -776,8 +847,12 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
776847

777848
fields := m.TraversalsByName(v.Type(), columns)
778849
// if we are not unsafe and are missing fields, return an error
779-
if f, err := missingFields(fields); err != nil && !r.unsafe {
780-
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
850+
if f, err := missingFields(fields); err != nil {
851+
if !r.unsafe {
852+
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
853+
} else if r.log != nil {
854+
fmt.Fprintf(r.log, "missing destination name %s in %T\n", columns[f], dest)
855+
}
781856
}
782857
values := make([]interface{}, len(columns))
783858

@@ -944,8 +1019,12 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
9441019

9451020
fields := m.TraversalsByName(base, columns)
9461021
// if we are not unsafe and are missing fields, return an error
947-
if f, err := missingFields(fields); err != nil && !isUnsafe(rows) {
948-
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
1022+
if f, err := missingFields(fields); err != nil {
1023+
if !isUnsafe(rows) {
1024+
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
1025+
} else if log := logFor(rows); log != nil {
1026+
fmt.Fprintf(log, "missing destination name %s in %T\n", columns[f], dest)
1027+
}
9491028
}
9501029
values = make([]interface{}, len(columns))
9511030

0 commit comments

Comments
 (0)