diff --git a/named.go b/named.go index dd899d35..a886558f 100644 --- a/named.go +++ b/named.go @@ -12,10 +12,12 @@ package sqlx // * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist // import ( + "bytes" "database/sql" "errors" "fmt" "reflect" + "regexp" "strconv" "unicode" @@ -206,6 +208,50 @@ func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) return bound, arglist, nil } +var ( + EndBracketsReg = regexp.MustCompile(`\([^()]*\)\s*$`) +) + +func fixBound(bound string, loop int) string { + endBrackets := EndBracketsReg.FindString(bound) + if endBrackets == "" { + return bound + } + var buffer bytes.Buffer + buffer.WriteString(bound) + for i := 0; i < loop-1; i++ { + buffer.WriteString(",") + buffer.WriteString(endBrackets) + } + return buffer.String() +} + +// bindArray binds a named parameter query with fields from an array or slice of +// structs argument. +func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { + bound, names, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return "", []interface{}{}, err + } + arrayValue := reflect.ValueOf(arg) + arrayLen := arrayValue.Len() + if arrayLen == 0 { + return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg) + } + var arglist []interface{} + for i := 0; i < arrayLen; i++ { + elemArglist, err := bindArgs(names, arrayValue.Index(i).Interface(), m) + if err != nil { + return "", []interface{}{}, err + } + arglist = append(arglist, elemArglist...) + } + if arrayLen > 1 { + bound = fixBound(bound, arrayLen) + } + return bound, arglist, nil +} + // bindMap binds a named parameter query with a map of arguments. func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { bound, names, err := compileNamedQuery([]byte(query), bindType) @@ -318,7 +364,12 @@ func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Ma if maparg, ok := arg.(map[string]interface{}); ok { return bindMap(bindType, query, maparg) } - return bindStruct(bindType, query, arg, m) + switch reflect.TypeOf(arg).Kind() { + case reflect.Array, reflect.Slice: + return bindArray(bindType, query, arg, m) + default: + return bindStruct(bindType, query, arg, m) + } } // NamedQuery binds a named query and then runs Query on the result using the