Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
c.emit(OpNot)

case "or", "||":
if c.config != nil && c.config.DisableSC {
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpOr)
break
}
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfTrue, placeholder)
Expand All @@ -455,6 +463,14 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
c.patchJump(end)

case "and", "&&":
if c.config != nil && c.config.DisableSC {
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpAnd)
break
}
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfFalse, placeholder)
Expand Down
2 changes: 2 additions & 0 deletions conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Config struct {
Builtins FunctionsTable
Disabled map[string]bool // disabled builtins
NtCache nature.Cache
DisableSC bool
}

// CreateNew creates new config with default values.
Expand All @@ -46,6 +47,7 @@ func CreateNew() *Config {
Functions: make(map[string]*builtin.Function),
Builtins: make(map[string]*builtin.Function),
Disabled: make(map[string]bool),
DisableSC: false,
}
for _, f := range builtin.Builtins {
c.Builtins[f.Name] = f
Expand Down
7 changes: 7 additions & 0 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ func Optimize(b bool) Option {
}
}

// DisableShortCircuit turns short circuit off.
func DisableShortCircuit() Option {
return func(c *conf.Config) {
c.DisableSC = true
}
}

// Patch adds visitor to list of visitors what will be applied before compiling AST to bytecode.
func Patch(visitor ast.Visitor) Option {
return func(c *conf.Config) {
Expand Down
52 changes: 52 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2883,3 +2883,55 @@ func TestIssue807(t *testing.T) {
t.Fatalf("expected 'in' operator to return false for unexported field")
}
}

func ExampleDisableShortCircuit() {
OR := func(a, b bool) bool {
return a || b
}

env := map[string]any{
"foo": func() bool {
fmt.Println("foo")
return false
},
"bar": func() bool {
fmt.Println("bar")
return false
},
"OR": OR,
}

program, _ := expr.Compile("true || foo() or bar()", expr.Env(env), expr.Operator("or", "OR"), expr.Operator("||", "OR"))
got, _ := expr.Run(program, env)
fmt.Println(got)

// Output:
// foo
// bar
// true
}

func TestDisableShortCircuit(t *testing.T) {
count := 0
exprStr := "foo() or bar()"
env := map[string]any{
"foo": func() bool {
count++
return true
},
"bar": func() bool {
count++
return true
},
}

program, _ := expr.Compile(exprStr, expr.DisableShortCircuit())
got, _ := expr.Run(program, env)
assert.Equal(t, 2, count)
assert.True(t, got.(bool))

program, _ = expr.Compile(exprStr)
got, _ = expr.Run(program, env)
assert.Equal(t, 3, count)
assert.True(t, got.(bool))
}
2 changes: 2 additions & 0 deletions vm/opcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,7 @@ const (
OpProfileStart
OpProfileEnd
OpBegin
OpAnd
OpOr
OpEnd // This opcode must be at the end of this list.
)
6 changes: 6 additions & 0 deletions vm/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,12 @@ func (program *Program) DisassembleWriter(w io.Writer) {
case OpBegin:
code("OpBegin")

case OpAnd:
code("OpAnd")

case OpOr:
code("OpOr")

case OpEnd:
code("OpEnd")

Expand Down
10 changes: 10 additions & 0 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,16 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
Len: array.Len(),
})

case OpAnd:
a := vm.pop()
b := vm.pop()
vm.push(a.(bool) && b.(bool))

case OpOr:
a := vm.pop()
b := vm.pop()
vm.push(a.(bool) || b.(bool))

case OpEnd:
vm.Scopes = vm.Scopes[:len(vm.Scopes)-1]

Expand Down
Loading