Skip to content
20 changes: 15 additions & 5 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DictionaryComprehension,
Expression,
GeneratorExpr,
ListComprehension,
ListExpr,
Lvalue,
MemberExpr,
Expand Down Expand Up @@ -1220,13 +1221,22 @@ def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
and isinstance(expr.node, Var)
and expr.node.is_final
and isinstance(expr.node.final_value, str)
and expr.node.has_explicit_value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized when looking that this line is unnecessary if the above checks pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#19930 renders this irrelevant

):
return len(expr.node.final_value)
# TODO: extend this, passing length of listcomp and genexp should have worthwhile
# performance boost and can be (sometimes) figured out pretty easily. set and dict
# comps *can* be done as well but will need special logic to consider the possibility
# of key conflicts. Range, enumerate, zip are all simple logic.
elif isinstance(expr, ListComprehension):
return get_expr_length(builder, expr.generator)
elif isinstance(expr, GeneratorExpr) and not expr.condlists:
sequence_lengths = [get_expr_length(builder, seq) for seq in expr.sequences]
if None not in sequence_lengths:
if len(sequence_lengths) == 1:
return sequence_lengths[0]
product = sequence_lengths[0]
for l in sequence_lengths[1:]:
product *= l # type: ignore [operator]
return product
# TODO: extend this, set and dict comps can be done as well but will
# need special logic to consider the possibility of key conflicts.
# Range, enumerate, zip are all simple logic.

# we might still be able to get the length directly from the type
rtype = builder.node_type(expr)
Expand Down
116 changes: 116 additions & 0 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,122 @@ L4:
a = r6
return 1

[case testTupleBuiltFromListComprehension]
def f(val: int) -> bool:
return val % 2 == 0

def test() -> None:
a = tuple(f(x) for x in [a * b for a in [1, 2, 3] for b in [1, 2, 3]])
[out]
def f(val):
val, r0 :: int
r1 :: bit
L0:
r0 = CPyTagged_Remainder(val, 4)
r1 = int_eq r0, 0
return r1
def test():
r0, r1 :: list
r2, r3, r4 :: object
r5 :: ptr
r6, r7 :: native_int
r8 :: bit
r9 :: object
r10, a :: int
r11 :: list
r12, r13, r14 :: object
r15 :: ptr
r16, r17 :: native_int
r18 :: bit
r19 :: object
r20, b, r21 :: int
r22 :: object
r23 :: i32
r24 :: bit
r25, r26, r27 :: native_int
r28 :: tuple
r29, r30 :: native_int
r31 :: bit
r32 :: object
r33, x :: int
r34 :: bool
r35 :: object
r36 :: native_int
a_2 :: tuple
L0:
r0 = PyList_New(0)
r1 = PyList_New(3)
r2 = object 1
r3 = object 2
r4 = object 3
r5 = list_items r1
buf_init_item r5, 0, r2
buf_init_item r5, 1, r3
buf_init_item r5, 2, r4
keep_alive r1
r6 = 0
L1:
r7 = var_object_size r1
r8 = r6 < r7 :: signed
if r8 goto L2 else goto L8 :: bool
L2:
r9 = list_get_item_unsafe r1, r6
r10 = unbox(int, r9)
a = r10
r11 = PyList_New(3)
r12 = object 1
r13 = object 2
r14 = object 3
r15 = list_items r11
buf_init_item r15, 0, r12
buf_init_item r15, 1, r13
buf_init_item r15, 2, r14
keep_alive r11
r16 = 0
L3:
r17 = var_object_size r11
r18 = r16 < r17 :: signed
if r18 goto L4 else goto L6 :: bool
L4:
r19 = list_get_item_unsafe r11, r16
r20 = unbox(int, r19)
b = r20
r21 = CPyTagged_Multiply(a, b)
r22 = box(int, r21)
r23 = PyList_Append(r0, r22)
r24 = r23 >= 0 :: signed
L5:
r25 = r16 + 1
r16 = r25
goto L3
L6:
L7:
r26 = r6 + 1
r6 = r26
goto L1
L8:
r27 = var_object_size r0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right, should be 9

r28 = PyTuple_New(r27)
r29 = 0
L9:
r30 = var_object_size r0
r31 = r29 < r30 :: signed
if r31 goto L10 else goto L12 :: bool
L10:
r32 = list_get_item_unsafe r0, r29
r33 = unbox(int, r32)
x = r33
r34 = f(x)
r35 = box(bool, r34)
CPySequenceTuple_SetItemUnsafe(r28, r29, r35)
L11:
r36 = r29 + 1
r29 = r36
goto L9
L12:
a_2 = r28
return 1

[case testTupleBuiltFromStr]
def f2(val: str) -> str:
return val + "f2"
Expand Down