diff --git a/builtin/builtin.go b/builtin/builtin.go index 70a8c09c7..4aad6aa9c 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/expr-lang/expr/internal/deref" "github.com/expr-lang/expr/vm/runtime" ) @@ -440,7 +441,7 @@ var Builtins = []*Function{ sum := int64(0) i := 0 for ; i < v.Len(); i++ { - it := deref(v.Index(i)) + it := deref.Value(v.Index(i)) if it.CanInt() { sum += it.Int() } else if it.CanFloat() { @@ -453,7 +454,7 @@ var Builtins = []*Function{ float: fSum := float64(sum) for ; i < v.Len(); i++ { - it := deref(v.Index(i)) + it := deref.Value(v.Index(i)) if it.CanInt() { fSum += float64(it.Int()) } else if it.CanFloat() { @@ -492,7 +493,7 @@ var Builtins = []*Function{ sum := float64(0) i := 0 for ; i < v.Len(); i++ { - it := deref(v.Index(i)) + it := deref.Value(v.Index(i)) if it.CanInt() { sum += float64(it.Int()) } else if it.CanFloat() { @@ -530,7 +531,7 @@ var Builtins = []*Function{ } s := make([]float64, v.Len()) for i := 0; i < v.Len(); i++ { - it := deref(v.Index(i)) + it := deref.Value(v.Index(i)) if it.CanInt() { s[i] = float64(it.Int()) } else if it.CanFloat() { @@ -850,7 +851,7 @@ var Builtins = []*Function{ } out := reflect.MakeMap(mapType) for i := 0; i < v.Len(); i++ { - pair := deref(v.Index(i)) + pair := deref.Value(v.Index(i)) if pair.Kind() != reflect.Array && pair.Kind() != reflect.Slice { return nil, fmt.Errorf("invalid pair %v", pair) } @@ -908,6 +909,49 @@ var Builtins = []*Function{ } }, }, + { + Name: "concat", + Safe: func(args ...any) (any, uint, error) { + if len(args) == 0 { + return nil, 0, fmt.Errorf("invalid number of arguments (expected at least 1, got 0)") + } + + var size uint + var arr []any + + for _, arg := range args { + v := reflect.ValueOf(deref.Deref(arg)) + + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, 0, fmt.Errorf("cannot concat %s", v.Kind()) + } + + size += uint(v.Len()) + + for i := 0; i < v.Len(); i++ { + item := v.Index(i) + arr = append(arr, item.Interface()) + } + } + + return arr, size, nil + }, + Validate: func(args []reflect.Type) (reflect.Type, error) { + if len(args) == 0 { + return anyType, fmt.Errorf("invalid number of arguments (expected at least 1, got 0)") + } + + for _, arg := range args { + switch kind(deref.Type(arg)) { + case reflect.Interface, reflect.Slice, reflect.Array: + default: + return anyType, fmt.Errorf("cannot concat %s", arg) + } + } + + return arrayType, nil + }, + }, { Name: "sort", Safe: func(args ...any) (any, uint, error) { diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 87e2a7c9d..3a2850071 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -19,11 +19,13 @@ import ( ) func TestBuiltin(t *testing.T) { + ArrayWithNil := []any{42} env := map[string]any{ - "ArrayOfString": []string{"foo", "bar", "baz"}, - "ArrayOfInt": []int{1, 2, 3}, - "ArrayOfAny": []any{1, "2", true}, - "ArrayOfFoo": []mock.Foo{{Value: "a"}, {Value: "b"}, {Value: "c"}}, + "ArrayOfString": []string{"foo", "bar", "baz"}, + "ArrayOfInt": []int{1, 2, 3}, + "ArrayOfAny": []any{1, "2", true}, + "ArrayOfFoo": []mock.Foo{{Value: "a"}, {Value: "b"}, {Value: "c"}}, + "PtrArrayWithNil": &ArrayWithNil, } var tests = []struct { @@ -130,6 +132,8 @@ func TestBuiltin(t *testing.T) { {`reduce(1..9, # + #acc)`, 45}, {`reduce([.5, 1.5, 2.5], # + #acc, 0)`, 4.5}, {`reduce([], 5, 0)`, 0}, + {`concat(ArrayOfString, ArrayOfInt)`, []any{"foo", "bar", "baz", 1, 2, 3}}, + {`concat(PtrArrayWithNil, [nil])`, []any{42, nil}}, } for _, test := range tests { diff --git a/builtin/utils.go b/builtin/utils.go index 70b612b04..7d3b6ee8e 100644 --- a/builtin/utils.go +++ b/builtin/utils.go @@ -35,35 +35,6 @@ func types(types ...any) []reflect.Type { return ts } -func deref(v reflect.Value) reflect.Value { - if v.Kind() == reflect.Interface { - if v.IsNil() { - return v - } - v = v.Elem() - } - -loop: - for v.Kind() == reflect.Ptr { - if v.IsNil() { - return v - } - indirect := reflect.Indirect(v) - switch indirect.Kind() { - case reflect.Struct, reflect.Map, reflect.Array, reflect.Slice: - break loop - default: - v = v.Elem() - } - } - - if v.IsValid() { - return v - } - - panic(fmt.Sprintf("cannot deref %s", v)) -} - func toInt(val any) (int, error) { switch v := val.(type) { case int: diff --git a/checker/checker.go b/checker/checker.go index 11e4eee3c..c845dd78a 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -9,6 +9,7 @@ import ( "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" + "github.com/expr-lang/expr/internal/deref" "github.com/expr-lang/expr/parser" ) @@ -203,8 +204,7 @@ func (v *checker) ConstantNode(node *ast.ConstantNode) (reflect.Type, info) { func (v *checker) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) { t, _ := v.visit(node.Node) - - t = deref(t) + t = deref.Type(t) switch node.Operator { @@ -235,8 +235,8 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { l, _ := v.visit(node.Left) r, ri := v.visit(node.Right) - l = deref(l) - r = deref(r) + l = deref.Type(l) + r = deref.Type(r) // check operator overloading if fns, ok := v.config.Operators[node.Operator]; ok { diff --git a/checker/types.go b/checker/types.go index 8c0805049..d10736a77 100644 --- a/checker/types.go +++ b/checker/types.go @@ -205,25 +205,6 @@ func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { return reflect.StructField{}, false } -func deref(t reflect.Type) reflect.Type { - if t == nil { - return nil - } - if t.Kind() == reflect.Interface { - return t - } - for t != nil && t.Kind() == reflect.Ptr { - e := t.Elem() - switch e.Kind() { - case reflect.Struct, reflect.Map, reflect.Array, reflect.Slice: - return t - default: - t = e - } - } - return t -} - func kind(t reflect.Type) reflect.Kind { if t == nil { return reflect.Invalid diff --git a/conf/types_table.go b/conf/types_table.go index 8ebb76c35..738eee840 100644 --- a/conf/types_table.go +++ b/conf/types_table.go @@ -2,6 +2,8 @@ package conf import ( "reflect" + + "github.com/expr-lang/expr/internal/deref" ) type Tag struct { @@ -77,7 +79,7 @@ func CreateTypesTable(i any) TypesTable { func FieldsFromStruct(t reflect.Type) TypesTable { types := make(TypesTable) - t = dereference(t) + t = deref.Type(t) if t == nil { return types } @@ -111,23 +113,6 @@ func FieldsFromStruct(t reflect.Type) TypesTable { return types } -func dereference(t reflect.Type) reflect.Type { - if t == nil { - return nil - } - if t.Kind() == reflect.Ptr { - t = dereference(t.Elem()) - } - return t -} - -func kind(t reflect.Type) reflect.Kind { - if t == nil { - return reflect.Invalid - } - return t.Kind() -} - func FieldName(field reflect.StructField) string { if taggedName := field.Tag.Get("expr"); taggedName != "" { return taggedName diff --git a/docgen/docgen.go b/docgen/docgen.go index e9a542b8c..aed0f48f0 100644 --- a/docgen/docgen.go +++ b/docgen/docgen.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/expr-lang/expr/conf" + "github.com/expr-lang/expr/internal/deref" ) // Kind can be any of array, map, struct, func, string, int, float, bool or any. @@ -80,7 +81,7 @@ func CreateDoc(i any) *Context { c := &Context{ Variables: make(map[Identifier]*Type), Types: make(map[TypeName]*Type), - PkgPath: dereference(reflect.TypeOf(i)).PkgPath(), + PkgPath: deref.Type(reflect.TypeOf(i)).PkgPath(), } for name, t := range conf.CreateTypesTable(i) { @@ -134,7 +135,7 @@ func (c *Context) use(t reflect.Type, ops ...option) *Type { methods = append(methods, m) } - t = dereference(t) + t = deref.Type(t) // Only named types will have methods defined on them. // It maybe not even struct, but we gonna call then @@ -253,13 +254,3 @@ func isPrivate(s string) bool { func isProtobuf(s string) bool { return strings.HasPrefix(s, "XXX_") } - -func dereference(t reflect.Type) reflect.Type { - if t == nil { - return nil - } - if t.Kind() == reflect.Ptr { - t = dereference(t.Elem()) - } - return t -} diff --git a/expr_test.go b/expr_test.go index 7c20e4899..ed08cae53 100644 --- a/expr_test.go +++ b/expr_test.go @@ -316,9 +316,9 @@ func ExampleOperator_Decimal() { code := `A + B - C` type Env struct { - A, B, C Decimal - Sub func(a, b Decimal) Decimal - Add func(a, b Decimal) Decimal + A, B, C Decimal + Sub func(a, b Decimal) Decimal + Add func(a, b Decimal) Decimal } options := []expr.Option{ @@ -334,11 +334,11 @@ func ExampleOperator_Decimal() { } env := Env{ - A: Decimal{3}, - B: Decimal{2}, - C: Decimal{1}, + A: Decimal{3}, + B: Decimal{2}, + C: Decimal{1}, Sub: func(a, b Decimal) Decimal { return Decimal{a.N - b.N} }, - Add: func(a, b Decimal) Decimal { return Decimal{a.N + b.N} }, + Add: func(a, b Decimal) Decimal { return Decimal{a.N + b.N} }, } output, err := expr.Run(program, env) @@ -1358,21 +1358,6 @@ func TestExpr_fetch_from_func(t *testing.T) { assert.Contains(t, err.Error(), "cannot fetch Value from func()") } -func TestExpr_fetch_from_interface(t *testing.T) { - type FooBar struct { - Value string - } - foobar := &FooBar{"waldo"} - var foobarAny any = foobar - var foobarPtrAny any = &foobarAny - - res, err := expr.Eval("foo.Value", map[string]any{ - "foo": foobarPtrAny, - }) - assert.NoError(t, err) - assert.Equal(t, "waldo", res) -} - func TestExpr_map_default_values(t *testing.T) { env := map[string]any{ "foo": map[string]string{}, diff --git a/internal/deref/deref.go b/internal/deref/deref.go new file mode 100644 index 000000000..acdc89811 --- /dev/null +++ b/internal/deref/deref.go @@ -0,0 +1,47 @@ +package deref + +import ( + "fmt" + "reflect" +) + +func Deref(p any) any { + if p == nil { + return nil + } + + v := reflect.ValueOf(p) + + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + if v.IsNil() { + return nil + } + v = v.Elem() + } + + if v.IsValid() { + return v.Interface() + } + + panic(fmt.Sprintf("cannot dereference %v", p)) +} + +func Type(t reflect.Type) reflect.Type { + if t == nil { + return nil + } + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func Value(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + if v.IsNil() { + return v + } + v = v.Elem() + } + return v +} diff --git a/internal/deref/deref_test.go b/internal/deref/deref_test.go new file mode 100644 index 000000000..554455f6d --- /dev/null +++ b/internal/deref/deref_test.go @@ -0,0 +1,111 @@ +package deref_test + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/expr-lang/expr/internal/deref" +) + +func TestDeref(t *testing.T) { + a := uint(42) + b := &a + c := &b + d := &c + + got := deref.Deref(d) + assert.Equal(t, uint(42), got) +} + +func TestDeref_mix_ptr_with_interface(t *testing.T) { + a := uint(42) + var b any = &a + var c any = &b + d := &c + + got := deref.Deref(d) + assert.Equal(t, uint(42), got) +} + +func TestDeref_nil(t *testing.T) { + var a *int + assert.Nil(t, deref.Deref(a)) + assert.Nil(t, deref.Deref(nil)) +} + +func TestType(t *testing.T) { + a := uint(42) + b := &a + c := &b + d := &c + + dt := deref.Type(reflect.TypeOf(d)) + assert.Equal(t, reflect.Uint, dt.Kind()) +} + +func TestType_two_ptr_with_interface(t *testing.T) { + a := uint(42) + var b any = &a + + dt := deref.Type(reflect.TypeOf(b)) + assert.Equal(t, reflect.Uint, dt.Kind()) + +} + +func TestType_three_ptr_with_interface(t *testing.T) { + a := uint(42) + var b any = &a + var c any = &b + + dt := deref.Type(reflect.TypeOf(c)) + assert.Equal(t, reflect.Interface, dt.Kind()) +} + +func TestType_nil(t *testing.T) { + assert.Nil(t, deref.Type(nil)) +} + +func TestValue(t *testing.T) { + a := uint(42) + b := &a + c := &b + d := &c + + got := deref.Value(reflect.ValueOf(d)) + assert.Equal(t, uint(42), got.Interface()) +} + +func TestValue_two_ptr_with_interface(t *testing.T) { + a := uint(42) + var b any = &a + + got := deref.Value(reflect.ValueOf(b)) + assert.Equal(t, uint(42), got.Interface()) +} + +func TestValue_three_ptr_with_interface(t *testing.T) { + a := uint(42) + var b any = &a + c := &b + + got := deref.Value(reflect.ValueOf(c)) + assert.Equal(t, uint(42), got.Interface()) +} + +func TestValue_nil(t *testing.T) { + got := deref.Value(reflect.ValueOf(nil)) + assert.False(t, got.IsValid()) +} + +func TestValue_nil_in_chain(t *testing.T) { + var a any = nil + var b any = &a + c := &b + + got := deref.Value(reflect.ValueOf(c)) + assert.True(t, got.IsValid()) + assert.True(t, got.IsNil()) + assert.Nil(t, got.Interface()) +} diff --git a/test/deref/deref_test.go b/test/deref/deref_test.go index 684794a01..2911e8e26 100644 --- a/test/deref/deref_test.go +++ b/test/deref/deref_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/expr-lang/expr" @@ -237,3 +238,18 @@ func TestDeref_сommutative(t *testing.T) { }) } } + +func TestDeref_fetch_from_interface_mix_pointer(t *testing.T) { + type FooBar struct { + Value string + } + foobar := &FooBar{"waldo"} + var foobarAny any = foobar + var foobarPtrAny any = &foobarAny + + res, err := expr.Eval("foo.Value", map[string]any{ + "foo": foobarPtrAny, + }) + assert.NoError(t, err) + assert.Equal(t, "waldo", res) +} diff --git a/vm/runtime/runtime.go b/vm/runtime/runtime.go index 09c58bdf7..d54152117 100644 --- a/vm/runtime/runtime.go +++ b/vm/runtime/runtime.go @@ -6,20 +6,13 @@ import ( "fmt" "math" "reflect" -) -func deref(kind reflect.Kind, value reflect.Value) (reflect.Kind, reflect.Value) { - for kind == reflect.Ptr || kind == reflect.Interface { - value = value.Elem() - kind = value.Kind() - } - return kind, value -} + "github.com/expr-lang/expr/internal/deref" +) func Fetch(from, i any) any { v := reflect.ValueOf(from) - kind := v.Kind() - if kind == reflect.Invalid { + if v.Kind() == reflect.Invalid { panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) } @@ -37,11 +30,9 @@ func Fetch(from, i any) any { // a value, when they are accessed through a pointer we don't want to // copy them to a value. // De-reference everything if necessary (interface and pointers) - kind, v = deref(kind, v) + v = deref.Value(v) - // TODO: We can create separate opcodes for each of the cases below to make - // the little bit faster. - switch kind { + switch v.Kind() { case reflect.Array, reflect.Slice, reflect.String: index := ToInt(i) if index < 0 { @@ -144,27 +135,6 @@ func FetchMethod(from any, method *Method) any { panic(fmt.Sprintf("cannot fetch %v from %T", method.Name, from)) } -func Deref(i any) any { - if i == nil { - return nil - } - - v := reflect.ValueOf(i) - - for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { - if v.IsNil() { - return nil - } - v = v.Elem() - } - - if v.IsValid() { - return v.Interface() - } - - panic(fmt.Sprintf("cannot dereference %v", i)) -} - func Slice(array, from, to any) any { v := reflect.ValueOf(array) diff --git a/vm/vm.go b/vm/vm.go index 56d5fc2ee..1e85893b0 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -11,6 +11,7 @@ import ( "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/file" + "github.com/expr-lang/expr/internal/deref" "github.com/expr-lang/expr/vm/runtime" ) @@ -436,7 +437,7 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case OpDeref: a := vm.pop() - vm.push(runtime.Deref(a)) + vm.push(deref.Deref(a)) case OpIncrementIndex: vm.scope().Index++