Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/larktools/ebnf_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,24 @@

assignment: VARNAME "=" arith_expr


logic_expr: logic_state | logic_operation | logic_comparison # removing logic_comparison here will make other logic tests succeed.

logic_state: BOOLEAN

logic_operation: logic_and | logic_or | logic_not
logic_and: logic_expr "and" logic_state
logic_or: logic_expr "or" logic_state
logic_not: "not" logic_expr

logic_comparison: logic_greater_than | logic_greater_equal | logic_equal | logic_smaller_equal | logic_smaller_than | logic_unequal
logic_greater_than: arith_expr ">" arith_expr
logic_greater_equal: arith_expr ">=" arith_expr
logic_equal: arith_expr "==" arith_expr
logic_smaller_equal: arith_expr "<=" arith_expr
logic_smaller_than: arith_expr "<" arith_expr
logic_unequal: arith_expr "!=" arith_expr


arith_expr: sum
sum: product | addition | subtraction
addition: sum "+" product
Expand All @@ -53,6 +70,7 @@
SIGNED_INT: ["+"|"-"] INT
DECIMAL: INT "." INT? | "." INT


_EXP: ("e"|"E") SIGNED_INT
FLOAT: INT _EXP | DECIMAL _EXP?
SIGNED_FLOAT: ["+"|"-"] FLOAT
Expand All @@ -63,6 +81,8 @@
LETTER: UCASE_LETTER | LCASE_LETTER
WORD: LETTER+

BOOLEAN: "True" | "False"

// Whitespace characters are filtered out before parsing.
// However, linebreaks are preserved.

Expand Down
29 changes: 22 additions & 7 deletions src/larktools/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,12 @@ def __call__(self, env):
class NumberNode:
def __init__(self, lark_node):
node_name = get_name(lark_node)
self._value = {
"SIGNED_FLOAT": float, "INT": int, "INDEX": int,
}[node_name](get_value(lark_node))
op_map = {
"SIGNED_FLOAT": float,
"INT": int,
"INDEX": int,
"BOOLEAN": lambda x: True if x == "True" else (False if x == "False" else None)}
self._value = op_map[node_name](get_value(lark_node))

def __call__(self, env):
return self._value
Expand All @@ -102,7 +105,8 @@ class UnaryOperatorNode(MappedOperatorNode):
def __init__(self, lark_node):
super().__init__(
lark_node,
op_map={"neg_atom": lambda x: -x[0]}
op_map={"neg_atom": lambda x: -x[0],
"logic_not": lambda x: not x[0]}
)


Expand All @@ -115,17 +119,28 @@ def __init__(self, lark_node):
"subtraction": lambda x: x[0] - x[1],
"multiplication": lambda x: x[0] * x[1],
"division": lambda x: x[0] / x[1],
"logic_and": lambda x: bool(x[0]) and bool(x[1]),
"logic_or": lambda x: x[0] or x[1],
"logic_greater_than": lambda x: x[0] > x[1],
"logic_greater_equal": lambda x: x[0] >= x[1],
"logic_equal": lambda x: x[0] == x[1],
"logic_smaller_equal": lambda x: x[0] <= x[1],
"logic_smaller_than": lambda x: x[0] < x[1],
"logic_unequal": lambda x: x[0] != x[1]
}
)


NODE_MAP = {
RootNode: ("multi_line_block",),
AssignNode: ("assignment",),
UnaryOperatorNode: ("neg_atom",),
BinaryOperatorNode: ("addition", "subtraction", "multiplication", "division"),
UnaryOperatorNode: ("neg_atom", "logic_not"),
BinaryOperatorNode: ("addition", "subtraction", "multiplication", "division",
"logic_and", "logic_or",
"logic_greater_than","logic_greater_equal","logic_equal",
"logic_smaller_equal","logic_smaller_than","logic_unequal"),
VariableNode: ("variable", "varname"),
NumberNode: ("INT", "SIGNED_INT", "FLOAT", "SIGNED_FLOAT", "INDEX"),
NumberNode: ("INT", "SIGNED_INT", "FLOAT", "SIGNED_FLOAT", "INDEX", "BOOLEAN"),
}

INV_NODE_MAP = {k: v for v in NODE_MAP for k in NODE_MAP[v]}
67 changes: 67 additions & 0 deletions tests/test_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
from typing import Optional, Union

from lark import Lark

from larktools.ebnf_grammar import grammar
from larktools.evaluation import instantiate_eval_tree


class LogicParser:
def __init__(self):
self.parser = Lark(grammar, parser="lalr", start="logic_expr")

def parse_and_eval(self, expression: str, env: Optional[dict] = None) -> Union[int, float]:
tree = self.parser.parse(expression)
eval_tree = instantiate_eval_tree(tree)
res = eval_tree({} if env is None else env)
return res


def _parse_and_assert(expression: str, expected: Union[int, float], env: Optional[dict] = None) -> None:
parser = LogicParser()
res = parser.parse_and_eval(expression, env)
assert expected == res

def _parse_and_assert_collection(tests: list[str, Union[int, float]]) -> None:
for ipt, expected in tests:
_parse_and_assert(ipt, expected)


def test_comparison():
_parse_and_assert("3 > 5", False)
_parse_and_assert("3 >= 5", False)
_parse_and_assert("3 >= 3", True)
_parse_and_assert("3 >= 3", True)
_parse_and_assert("5 == 3 + 2", True)
_parse_and_assert("2 + 3 == 3 + 2", True)
_parse_and_assert("5 == 3", False)
_parse_and_assert("3 <= 5", True)
_parse_and_assert("3 <= 3", True)
_parse_and_assert("3 <= 2", False)
_parse_and_assert("3 < 5", True)
_parse_and_assert("3 != 5", True)
_parse_and_assert("5 != 5", False)


def test_logic_states():
_parse_and_assert("True", True)
_parse_and_assert("False", False)

def test_logic_operations():
_parse_and_assert("False or True", True)
_parse_and_assert("False or False", False)
_parse_and_assert("True or True", True)
_parse_and_assert("True or False", True)

_parse_and_assert("False and True", False)
_parse_and_assert("False and False", False)
_parse_and_assert("True and True", True)
_parse_and_assert("True and False", False)

def test_logic_negation():
_parse_and_assert("not True", False)
_parse_and_assert("not False", True)



Loading