Skip to content
4 changes: 3 additions & 1 deletion mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,8 +990,10 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
elif isinstance(target_type, TypeVarLikeType):
return self.get_sequence_type_from_type(target_type.upper_bound)
elif isinstance(target_type, TupleType):
items = target_type.items
assert items, "This function does not support empty tuples"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary but would have been marginally helpful

# Tuple might have elements of different types.
rtypes = {self.mapper.type_to_rtype(item) for item in target_type.items}
rtypes = set(map(self.mapper.type_to_rtype, items))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this looks more readable than the brackets when you're scanning quickly

if len(rtypes) == 1:
return rtypes.pop()
else:
Expand Down
42 changes: 31 additions & 11 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

from collections.abc import Callable
from typing import ClassVar
from typing import ClassVar, cast

from mypy.nodes import (
ARG_POS,
Expand Down Expand Up @@ -242,25 +242,45 @@ def sequence_from_generator_preallocate_helper(
rtype = builder.node_type(sequence_expr)
if not (is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple)):
return None
sequence = builder.accept(sequence_expr)
length = get_expr_length_value(builder, sequence_expr, sequence, line, use_pyssize_t=True)

if isinstance(rtype, RTuple):
# If input is RTuple, box it to tuple_rprimitive for generic iteration
# TODO: this can be optimized a bit better with an unrolled ForRTuple helper
proper_type = get_proper_type(builder.types[sequence_expr])
assert isinstance(proper_type, TupleType), proper_type

get_item_ops = [
(
LoadLiteral(typ.value, object_rprimitive)
if isinstance(typ, LiteralType)
else TupleGet(sequence, i, line)
)
for i, typ in enumerate(get_proper_types(proper_type.items))
]
# the for_loop_helper_with_index crashes for empty tuples, bail out
if not proper_type.items:
return None

proper_types = get_proper_types(proper_type.items)

get_item_ops: list[LoadLiteral | TupleGet]
if all(isinstance(typ, LiteralType) for typ in proper_types):
get_item_ops = [
LoadLiteral(cast(LiteralType, typ).value, object_rprimitive)
for typ in proper_types
]

else:
sequence = builder.accept(sequence_expr)
get_item_ops = [
(
LoadLiteral(typ.value, object_rprimitive)
if isinstance(typ, LiteralType)
else TupleGet(sequence, i, line)
)
for i, typ in enumerate(proper_types)
]

items = list(map(builder.add, get_item_ops))
sequence = builder.new_tuple(items, line)

else:
sequence = builder.accept(sequence_expr)

length = get_expr_length_value(builder, sequence_expr, sequence, line, use_pyssize_t=True)

target_op = empty_op_llbuilder(length, line)

def set_item(item_index: Value) -> None:
Expand Down
8 changes: 8 additions & 0 deletions mypyc/test-data/run-generators.test
Original file line number Diff line number Diff line change
Expand Up @@ -936,3 +936,11 @@ def test_generator_override() -> None:
assert base1_foo(Base1()) == [1]
assert base1_foo(Derived1()) == [2, 3]
assert derived1_foo(Derived1()) == [2, 3]

[case testGeneratorEmptyTuple]
from collections.abc import Generator
from typing import Optional, Union

def test_compiledGeneratorEmptyTuple() -> None:
jobs: Generator[Optional[str], None, None] = (_ for _ in ())
assert list(jobs) == []
9 changes: 7 additions & 2 deletions mypyc/test-data/run-loops.test
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Test cases for "range" objects, "for" and "while" loops (compile and run)

[case testFor]
from typing import List, Tuple
from typing import Any, List, Tuple
def count(n: int) -> None:
for i in range(n):
print(i)
Expand All @@ -21,6 +21,10 @@ def list_iter(l: List[int]) -> None:
def tuple_iter(l: Tuple[int, ...]) -> None:
for i in l:
print(i)
def empty_tuple_iter(l: Tuple[()]) -> None:
i: Any
for i in l:
print(i)
def str_iter(l: str) -> None:
for i in l:
print(i)
Expand All @@ -39,7 +43,7 @@ def count_down_short() -> None:
[file driver.py]
from native import (
count, list_iter, list_rev_iter, list_rev_iter_lol, count_between, count_down, count_double,
count_down_short, tuple_iter, str_iter,
count_down_short, tuple_iter, empty_tuple_iter, str_iter,
)
count(5)
list_iter(list(reversed(range(5))))
Expand All @@ -52,6 +56,7 @@ count_down_short()
print('==')
list_rev_iter_lol(list(reversed(range(5))))
tuple_iter((1, 2, 3))
empty_tuple_iter(())
str_iter("abc")
[out]
0
Expand Down