Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion named.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ package sqlx
// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
//
import (
"bytes"
"database/sql"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"unicode"

Expand Down Expand Up @@ -206,6 +208,50 @@ func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper)
return bound, arglist, nil
}

var (
EndBracketsReg = regexp.MustCompile(`\([^()]*\)\s*$`)
)

func fixBound(bound string, loop int) string {
endBrackets := EndBracketsReg.FindString(bound)
if endBrackets == "" {
return bound
}
var buffer bytes.Buffer
buffer.WriteString(bound)
for i := 0; i < loop-1; i++ {
buffer.WriteString(",")
buffer.WriteString(endBrackets)
}
return buffer.String()
}

// bindArray binds a named parameter query with fields from an array or slice of
// structs argument.
func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arrayValue := reflect.ValueOf(arg)
arrayLen := arrayValue.Len()
if arrayLen == 0 {
return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg)
}
var arglist []interface{}
for i := 0; i < arrayLen; i++ {
elemArglist, err := bindArgs(names, arrayValue.Index(i).Interface(), m)
if err != nil {
return "", []interface{}{}, err
}
arglist = append(arglist, elemArglist...)
}
if arrayLen > 1 {
bound = fixBound(bound, arrayLen)
}
return bound, arglist, nil
}

// bindMap binds a named parameter query with a map of arguments.
func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
Expand Down Expand Up @@ -318,7 +364,12 @@ func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Ma
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMap(bindType, query, maparg)
}
return bindStruct(bindType, query, arg, m)
switch reflect.TypeOf(arg).Kind() {
case reflect.Array, reflect.Slice:
return bindArray(bindType, query, arg, m)
default:
return bindStruct(bindType, query, arg, m)
}
}

// NamedQuery binds a named query and then runs Query on the result using the
Expand Down