Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2482,3 +2482,29 @@ func TestRaceCondition_variables(t *testing.T) {

wg.Wait()
}

func TestArrayComparison(t *testing.T) {
tests := []struct {
env any
code string
}{
{[]string{"A", "B"}, "foo == ['A', 'B']"},
{[]int{1, 2}, "foo == [1, 2]"},
{[]uint8{1, 2}, "foo == [1, 2]"},
{[]float64{1.1, 2.2}, "foo == [1.1, 2.2]"},
{[]any{"A", 1, 1.1, true}, "foo == ['A', 1, 1.1, true]"},
{[]string{"A", "B"}, "foo != [1, 2]"},
}

for _, tt := range tests {
t.Run(tt.code, func(t *testing.T) {
env := map[string]any{"foo": tt.env}
program, err := expr.Compile(tt.code, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, true, out)
})
}
}
46 changes: 46 additions & 0 deletions vm/runtime/helpers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func main() {
"cases_with_duration": func(op string) string {
return cases(op, uints, ints, floats, []string{"time.Duration"})
},
"array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) },
}).
Parse(helpers),
).Execute(&b, nil)
Expand Down Expand Up @@ -89,6 +90,45 @@ func cases(op string, xs ...[]string) string {
return strings.TrimRight(out, "\n")
}

func arrayEqualCases(xs ...[]string) string {
var types []string
for _, x := range xs {
types = append(types, x...)
}

_, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types)

var out string
echo := func(s string, xs ...any) {
out += fmt.Sprintf(s, xs...) + "\n"
}
echo(`case []any:`)
echo(`switch y := b.(type) {`)
for _, a := range append(types, "any") {
echo(`case []%v:`, a)
echo(`if len(x) != len(y) { return false }`)
echo(`for i := range x {`)
echo(`if !Equal(x[i], y[i]) { return false }`)
echo(`}`)
echo("return true")
}
echo(`}`)
for _, a := range types {
echo(`case []%v:`, a)
echo(`switch y := b.(type) {`)
echo(`case []any:`)
echo(`return Equal(y, x)`)
echo(`case []%v:`, a)
echo(`if len(x) != len(y) { return false }`)
echo(`for i := range x {`)
echo(`if x[i] != y[i] { return false }`)
echo(`}`)
echo("return true")
echo(`}`)
}
return strings.TrimRight(out, "\n")
}

func isFloat(t string) bool {
return strings.HasPrefix(t, "float")
}
Expand All @@ -110,6 +150,7 @@ import (
func Equal(a, b interface{}) bool {
switch x := a.(type) {
{{ cases "==" }}
{{ array_equal_cases }}
case string:
switch y := b.(type) {
case string:
Expand All @@ -125,6 +166,11 @@ func Equal(a, b interface{}) bool {
case time.Duration:
return x == y
}
case bool:
switch y := b.(type) {
case bool:
return x == y
}
}
if IsNil(a) && IsNil(b) {
return true
Expand Down
Loading