diff --git a/bench_test.go b/bench_test.go index 601be77a2..ebeed5d18 100644 --- a/bench_test.go +++ b/bench_test.go @@ -486,7 +486,7 @@ func Benchmark_sortBy(b *testing.B) { env["arr"].([]Foo)[i] = Foo{Value: v.(int)} } - program, err := expr.Compile(`sortBy(arr, "Value")`, expr.Env(env)) + program, err := expr.Compile(`sortBy(arr, .Value)`, expr.Env(env)) require.NoError(b, err) var out any diff --git a/builtin/builtin.go b/builtin/builtin.go index f7b7bdeb1..70a8c09c7 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -87,6 +87,11 @@ var Builtins = []*Function{ Predicate: true, Types: types(new(func([]any, func(any) any) map[any][]any)), }, + { + Name: "sortBy", + Predicate: true, + Types: types(new(func([]any, func(any) bool, string) []any)), + }, { Name: "reduce", Predicate: true, @@ -905,109 +910,65 @@ var Builtins = []*Function{ }, { Name: "sort", - Func: func(args ...any) (any, error) { + Safe: func(args ...any) (any, uint, error) { if len(args) != 1 && len(args) != 2 { - return nil, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args)) + return nil, 0, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args)) } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sort %s", v.Kind()) - } + var array []any - orderBy := OrderBy{} - if len(args) == 2 { - dir, err := ascOrDesc(args[1]) - if err != nil { - return nil, err + switch in := args[0].(type) { + case []any: + array = make([]any, len(in)) + copy(array, in) + case []int: + array = make([]any, len(in)) + for i, v := range in { + array[i] = v + } + case []float64: + array = make([]any, len(in)) + for i, v := range in { + array[i] = v + } + case []string: + array = make([]any, len(in)) + for i, v := range in { + array[i] = v } - orderBy.Desc = dir } - sortable, err := copyArray(v, orderBy) - if err != nil { - return nil, err - } - sort.Sort(sortable) - return sortable.Array, nil - }, - Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 && len(args) != 2 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sort %s", args[0]) - } + var desc bool if len(args) == 2 { - switch kind(args[1]) { - case reflect.String, reflect.Interface: + switch args[1].(string) { + case "asc": + desc = false + case "desc": + desc = true default: - return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1]) + return nil, 0, fmt.Errorf("invalid order %s, expected asc or desc", args[1]) } } - return arrayType, nil - }, - }, - { - Name: "sortBy", - Func: func(args ...any) (any, error) { - if len(args) != 2 && len(args) != 3 { - return nil, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sort %s", v.Kind()) - } - - orderBy := OrderBy{} - - field, ok := args[1].(string) - if !ok { - return nil, fmt.Errorf("invalid argument for sort (expected string, got %s)", reflect.TypeOf(args[1])) - } - orderBy.Field = field - - if len(args) == 3 { - dir, err := ascOrDesc(args[2]) - if err != nil { - return nil, err - } - orderBy.Desc = dir - } - - sortable, err := copyArray(v, orderBy) - if err != nil { - return nil, err + sortable := &runtime.Sort{ + Desc: desc, + Array: array, } sort.Sort(sortable) - return sortable.Array, nil - }, - Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 2 && len(args) != 3 { - return anyType, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sort %s", args[0]) - } - switch kind(args[1]) { - case reflect.String, reflect.Interface: - default: - return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1]) - } - if len(args) == 3 { - switch kind(args[2]) { - case reflect.String, reflect.Interface: - default: - return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1]) - } - } - return arrayType, nil + + return sortable.Array, uint(len(array)), nil }, + Types: types( + new(func([]any, string) []any), + new(func([]int, string) []any), + new(func([]float64, string) []any), + new(func([]string, string) []any), + + new(func([]any) []any), + new(func([]float64) []any), + new(func([]string) []any), + new(func([]int) []any), + ), }, bitFunc("bitand", func(x, y int) (any, error) { return x & y, nil diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index d6b967d13..87e2a7c9d 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -530,8 +530,8 @@ func TestBuiltin_sort(t *testing.T) { {`sort(ArrayOfInt)`, []any{1, 2, 3}}, {`sort(ArrayOfFloat)`, []any{1.0, 2.0, 3.0}}, {`sort(ArrayOfInt, 'desc')`, []any{3, 2, 1}}, - {`sortBy(ArrayOfFoo, 'Value')`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}}, - {`sortBy([{id: "a"}, {id: "b"}], "id", "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}}, + {`sortBy(ArrayOfFoo, .Value)`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}}, + {`sortBy([{id: "a"}, {id: "b"}], .id, "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}}, } for _, test := range tests { @@ -546,6 +546,20 @@ func TestBuiltin_sort(t *testing.T) { } } +func TestBuiltin_sort_i64(t *testing.T) { + env := map[string]any{ + "array": []int{1, 2, 3}, + "i64": int64(1), + } + + program, err := expr.Compile(`sort(map(array, i64))`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, []any{int64(1), int64(1), int64(1)}, out) +} + func TestBuiltin_bitOpsFunc(t *testing.T) { tests := []struct { input string diff --git a/builtin/sort.go b/builtin/sort.go deleted file mode 100644 index 9b9ddc165..000000000 --- a/builtin/sort.go +++ /dev/null @@ -1,96 +0,0 @@ -package builtin - -import ( - "fmt" - "reflect" -) - -type Sortable struct { - Array []any - Values []reflect.Value - OrderBy -} - -type OrderBy struct { - Field string - Desc bool -} - -func (s *Sortable) Len() int { - return len(s.Array) -} - -func (s *Sortable) Swap(i, j int) { - s.Array[i], s.Array[j] = s.Array[j], s.Array[i] - s.Values[i], s.Values[j] = s.Values[j], s.Values[i] -} - -func (s *Sortable) Less(i, j int) bool { - a, b := s.Values[i], s.Values[j] - switch a.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if s.Desc { - return a.Int() > b.Int() - } - return a.Int() < b.Int() - case reflect.Float64, reflect.Float32: - if s.Desc { - return a.Float() > b.Float() - } - return a.Float() < b.Float() - case reflect.String: - if s.Desc { - return a.String() > b.String() - } - return a.String() < b.String() - default: - panic(fmt.Sprintf("sort: unsupported type %s", a.Kind())) - } -} - -func copyArray(v reflect.Value, orderBy OrderBy) (*Sortable, error) { - s := &Sortable{ - Array: make([]any, v.Len()), - Values: make([]reflect.Value, v.Len()), - OrderBy: orderBy, - } - var prev reflect.Value - for i := 0; i < s.Len(); i++ { - elem := deref(v.Index(i)) - var value reflect.Value - switch elem.Kind() { - case reflect.Struct: - value = elem.FieldByName(s.Field) - case reflect.Map: - value = elem.MapIndex(reflect.ValueOf(s.Field)) - default: - value = elem - } - value = deref(value) - - s.Array[i] = elem.Interface() - s.Values[i] = value - - if i == 0 { - prev = value - } else if value.Type() != prev.Type() { - return nil, fmt.Errorf("cannot sort array of different types (%s and %s)", value.Type(), prev.Type()) - } - } - return s, nil -} - -func ascOrDesc(arg any) (bool, error) { - dir, ok := arg.(string) - if !ok { - return false, fmt.Errorf("invalid argument for sort (expected string, got %s)", reflect.TypeOf(arg)) - } - switch dir { - case "desc": - return true, nil - case "asc": - return false, nil - default: - return false, fmt.Errorf(`invalid argument for sort (expected "asc" or "desc", got %q)`, dir) - } -} diff --git a/checker/checker.go b/checker/checker.go index 3dc4e95ad..11e4eee3c 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -633,7 +633,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { if isAny(collection) { return arrayType, info{} } - return reflect.SliceOf(collection.Elem()), info{} + return arrayType, info{} } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -651,7 +651,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { closure.NumOut() == 1 && closure.NumIn() == 1 && isAny(closure.In(0)) { - return reflect.SliceOf(closure.Out(0)), info{} + return arrayType, info{} } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -739,6 +739,28 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { } return v.error(node.Arguments[1], "predicate should has one input and one output param") + case "sortBy": + collection, _ := v.visit(node.Arguments[0]) + if !isArray(collection) && !isAny(collection) { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + } + + v.begin(collection) + closure, _ := v.visit(node.Arguments[1]) + v.end() + + if len(node.Arguments) == 3 { + _, _ = v.visit(node.Arguments[2]) + } + + if isFunc(closure) && + closure.NumOut() == 1 && + closure.NumIn() == 1 && isAny(closure.In(0)) { + + return reflect.TypeOf([]any{}), info{} + } + return v.error(node.Arguments[1], "predicate should has one input and one output param") + case "reduce": collection, _ := v.visit(node.Arguments[0]) if !isArray(collection) && !isAny(collection) { diff --git a/checker/checker_test.go b/checker/checker_test.go index 2bf5ec864..bab9a0a67 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -400,11 +400,6 @@ invalid operation: < (mismatched types mock.Bar and int) (1:29) | all(ArrayOfFoo, {#.Method() < 0}) | ............................^ -map(Any, {0})[0] + "str" -invalid operation: + (mismatched types int and string) (1:18) - | map(Any, {0})[0] + "str" - | .................^ - Variadic() not enough arguments to call Variadic (1:1) | Variadic() @@ -445,11 +440,6 @@ builtin map takes only array (got int) (1:5) | map(1, {2}) | ....^ -map(filter(ArrayOfFoo, {true}), {.Not}) -type mock.Foo has no field Not (1:35) - | map(filter(ArrayOfFoo, {true}), {.Not}) - | ..................................^ - ArrayOfFoo[Foo] array elements can only be selected using an integer (got mock.Foo) (1:12) | ArrayOfFoo[Foo] diff --git a/compiler/compiler.go b/compiler/compiler.go index 252699859..a4f189e6b 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -873,11 +873,31 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "groupBy": c.compile(node.Arguments[0]) c.emit(OpBegin) + c.emit(OpCreate, 1) + c.emit(OpSetAcc) c.emitLoop(func() { c.compile(node.Arguments[1]) c.emit(OpGroupBy) }) - c.emit(OpGetGroupBy) + c.emit(OpGetAcc) + c.emit(OpEnd) + return + + case "sortBy": + c.compile(node.Arguments[0]) + c.emit(OpBegin) + if len(node.Arguments) == 3 { + c.compile(node.Arguments[2]) + } else { + c.emit(OpPush, c.addConstant("asc")) + } + c.emit(OpCreate, 2) + c.emit(OpSetAcc) + c.emitLoop(func() { + c.compile(node.Arguments[1]) + c.emit(OpSortBy) + }) + c.emit(OpSort) c.emit(OpEnd) return diff --git a/debug/debugger.go b/debug/debugger.go index 6f341e30e..5676fc1b9 100644 --- a/debug/debugger.go +++ b/debug/debugger.go @@ -134,9 +134,6 @@ func StartDebugger(program *Program, env any) { keys = append(keys, pair{"Index", s.Index}) keys = append(keys, pair{"Len", s.Len}) keys = append(keys, pair{"Count", s.Count}) - if s.GroupBy != nil { - keys = append(keys, pair{"GroupBy", s.GroupBy}) - } if s.Acc != nil { keys = append(keys, pair{"Acc", s.Acc}) } diff --git a/parser/parser.go b/parser/parser.go index bc620ac68..1eabdebe2 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -39,6 +39,7 @@ var predicates = map[string]struct { "findLast": {[]arg{expr, closure}}, "findLastIndex": {[]arg{expr, closure}}, "groupBy": {[]arg{expr, closure}}, + "sortBy": {[]arg{expr, closure, expr | optional}}, "reduce": {[]arg{expr, closure, expr | optional}}, } diff --git a/vm/opcodes.go b/vm/opcodes.go index dc3c97d0c..0417dab61 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -70,15 +70,17 @@ const ( OpDecrementIndex OpIncrementCount OpGetIndex - OpSetIndex OpGetCount OpGetLen - OpGetGroupBy OpGetAcc + OpSetAcc + OpSetIndex OpPointer OpThrow + OpCreate OpGroupBy - OpSetAcc + OpSortBy + OpSort OpBegin OpEnd // This opcode must be at the end of this list. ) diff --git a/vm/program.go b/vm/program.go index fa9fefd8f..4a878267b 100644 --- a/vm/program.go +++ b/vm/program.go @@ -327,32 +327,38 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpGetIndex: code("OpGetIndex") - case OpSetIndex: - code("OpSetIndex") - case OpGetCount: code("OpGetCount") case OpGetLen: code("OpGetLen") - case OpGetGroupBy: - code("OpGetGroupBy") - case OpGetAcc: code("OpGetAcc") + case OpSetAcc: + code("OpSetAcc") + + case OpSetIndex: + code("OpSetIndex") + case OpPointer: code("OpPointer") case OpThrow: code("OpThrow") + case OpCreate: + argument("OpCreate") + case OpGroupBy: code("OpGroupBy") - case OpSetAcc: - code("OpSetAcc") + case OpSortBy: + code("OpSortBy") + + case OpSort: + code("OpSort") case OpBegin: code("OpBegin") diff --git a/vm/runtime/sort.go b/vm/runtime/sort.go new file mode 100644 index 000000000..fb1f340d7 --- /dev/null +++ b/vm/runtime/sort.go @@ -0,0 +1,45 @@ +package runtime + +type SortBy struct { + Desc bool + Array []any + Values []any +} + +func (s *SortBy) Len() int { + return len(s.Array) +} + +func (s *SortBy) Swap(i, j int) { + s.Array[i], s.Array[j] = s.Array[j], s.Array[i] + s.Values[i], s.Values[j] = s.Values[j], s.Values[i] +} + +func (s *SortBy) Less(i, j int) bool { + a, b := s.Values[i], s.Values[j] + if s.Desc { + return Less(b, a) + } + return Less(a, b) +} + +type Sort struct { + Desc bool + Array []any +} + +func (s *Sort) Len() int { + return len(s.Array) +} + +func (s *Sort) Swap(i, j int) { + s.Array[i], s.Array[j] = s.Array[j], s.Array[i] +} + +func (s *Sort) Less(i, j int) bool { + a, b := s.Array[i], s.Array[j] + if s.Desc { + return Less(b, a) + } + return Less(a, b) +} diff --git a/vm/utils.go b/vm/utils.go index dc7f56944..d7db2a52a 100644 --- a/vm/utils.go +++ b/vm/utils.go @@ -9,7 +9,19 @@ type ( SafeFunction = func(params ...any) (any, uint, error) ) -// MemoryBudget represents an upper limit of memory usage. -var MemoryBudget uint = 1e6 +var ( + // MemoryBudget represents an upper limit of memory usage. + MemoryBudget uint = 1e6 -var errorType = reflect.TypeOf((*error)(nil)).Elem() + errorType = reflect.TypeOf((*error)(nil)).Elem() +) + +type Scope struct { + Array reflect.Value + Index int + Len int + Count int + Acc any +} + +type groupBy = map[any][]any diff --git a/vm/vm.go b/vm/vm.go index f3dc4ab15..56d5fc2ee 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "regexp" + "sort" "strings" "github.com/expr-lang/expr/builtin" @@ -43,15 +44,6 @@ type VM struct { curr chan int } -type Scope struct { - Array reflect.Value - Index int - Len int - Count int - GroupBy map[any][]any - Acc any -} - func (vm *VM) Run(program *Program, env any) (_ any, err error) { defer func() { if r := recover(); r != nil { @@ -460,10 +452,6 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case OpGetIndex: vm.push(vm.scope().Index) - case OpSetIndex: - scope := vm.scope() - scope.Index = vm.pop().(int) - case OpGetCount: scope := vm.scope() vm.push(scope.Count) @@ -472,15 +460,16 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { scope := vm.scope() vm.push(scope.Len) - case OpGetGroupBy: - vm.push(vm.scope().GroupBy) - case OpGetAcc: vm.push(vm.scope().Acc) case OpSetAcc: vm.scope().Acc = vm.pop() + case OpSetIndex: + scope := vm.scope() + scope.Index = vm.pop().(int) + case OpPointer: scope := vm.scope() vm.push(scope.Array.Index(scope.Index).Interface()) @@ -488,14 +477,50 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case OpThrow: panic(vm.pop().(error)) + case OpCreate: + switch arg { + case 1: + vm.push(make(groupBy)) + case 2: + scope := vm.scope() + var desc bool + switch vm.pop().(string) { + case "asc": + desc = false + case "desc": + desc = true + default: + panic("unknown order, use asc or desc") + } + vm.push(&runtime.SortBy{ + Desc: desc, + Array: make([]any, 0, scope.Len), + Values: make([]any, 0, scope.Len), + }) + default: + panic(fmt.Sprintf("unknown OpCreate argument %v", arg)) + } + case OpGroupBy: scope := vm.scope() - if scope.GroupBy == nil { - scope.GroupBy = make(map[any][]any) - } - it := scope.Array.Index(scope.Index).Interface() key := vm.pop() - scope.GroupBy[key] = append(scope.GroupBy[key], it) + item := scope.Array.Index(scope.Index).Interface() + scope.Acc.(groupBy)[key] = append(scope.Acc.(groupBy)[key], item) + + case OpSortBy: + scope := vm.scope() + value := vm.pop() + item := scope.Array.Index(scope.Index).Interface() + sortable := scope.Acc.(*runtime.SortBy) + sortable.Array = append(sortable.Array, item) + sortable.Values = append(sortable.Values, value) + + case OpSort: + scope := vm.scope() + sortable := scope.Acc.(*runtime.SortBy) + sort.Sort(sortable) + vm.memGrow(uint(scope.Len)) + vm.push(sortable.Array) case OpBegin: a := vm.pop()