Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug that when ignore_error is True, corresponding assign is not skipped, thus causing out of order assign. #586

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion ml_metrics/_src/chainables/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_sharded_sequence_data_source_resume(self):
self.assertEqual([(1 + 2) / 2], it.agg_result)
self.assertEqual([(1 + 2) / 2], it_new.agg_result)

def test_sequence_data_source_ignore_error(self):
def test_sequence_data_source_apply_ignore_error(self):
def foo(x):
if x == 2:
raise ValueError('foo')
Expand All @@ -220,6 +220,25 @@ def foo(x):
self.assertNotEmpty(it._thread_pool._threads)
self.assertTrue(all(not t.is_alive() for t in it._thread_pool._threads))

def test_sequence_data_source_assign_ignore_error(self):
def foo(x):
if x == 2:
raise ValueError('foo')
return x

num_threads = 1
p = (
transform.TreeTransform(num_threads=num_threads)
.apply(fn=lambda x: x, output_keys='a')
.assign('b', fn=foo, input_keys='a')
)
expected = [{'a': 0, 'b': 0}, {'a': 1, 'b': 1}, {'a': 3, 'b': 3}]
it = p.make().iterate(range(4), ignore_error=True)
self.assertEqual(expected, list(it))
assert it._thread_pool is not None
self.assertNotEmpty(it._thread_pool._threads)
self.assertTrue(all(not t.is_alive() for t in it._thread_pool._threads))

def test_mock_generator_bool_operator(self):
ds = test_utils.NoLenIter(range(3))
with self.assertRaisesRegex(ValueError, 'Cannot call len()'):
Expand Down
15 changes: 11 additions & 4 deletions ml_metrics/_src/chainables/tree_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def __call__(
return mit.first(self.iterate([inputs]))

def _iterate(
self, input_iterator: Iterator[tree.MapLikeTree[ValueT] | None]
self,
input_iterator: Iterator[tree.MapLikeTree[ValueT] | None],
ignore_error: bool = False,
) -> Iterator[tree.MapLikeTree[ValueT] | None]:
fn_inputs = map(self.get_inputs, input_iterator)
if self.fn_batch_size:
Expand All @@ -267,7 +269,7 @@ def _iterate(
num_columns=self.num_inputs,
)
# Only ignore function call error.
map_ = iter_utils.map_ignore_error if self.ignore_error else map
map_ = iter_utils.map_ignore_error if ignore_error else map
fn_outputs = map_(self._maybe_call_fn, fn_inputs)
if self.batch_size:
fn_outputs = iter_utils.rebatched_args(
Expand All @@ -280,7 +282,10 @@ def _iterate(
def iterate(
self, input_iterator: Iterable[tree.MapLikeTree[ValueT] | None]
) -> Iterator[tree.MapLikeTree[ValueT] | None]:
return map(self._get_outputs, self._iterate(iter(input_iterator)))
return map(
self._get_outputs,
self._iterate(iter(input_iterator), ignore_error=self.ignore_error),
)

def __getstate__(self):
state = self.__dict__.copy()
Expand All @@ -306,7 +311,9 @@ def iterate(
) -> Iterator[tree.MapLikeTree[ValueT] | None]:
return it.starmap(
self._get_outputs,
iter_utils.processed_with_inputs(self._iterate, iter(input_iterator)),
iter_utils.processed_with_inputs(
self._iterate, iter(input_iterator), ignore_error=self.ignore_error
),
)


Expand Down
18 changes: 15 additions & 3 deletions ml_metrics/_src/utils/iter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def empty(self) -> bool:


STOP_ITERATION = StopIteration()
_SKIP = '_SKIP'


def iter_ignore_error(it):
def iter_ignore_error(it, error_return=None):
"""Yields the next element from an iterator, ignoring errors.

Be careful when using this function, it can cause infinite loop if the
Expand All @@ -69,6 +70,7 @@ def iter_ignore_error(it):

Args:
it: The iterator to ignore errors from.
error_return: The value to return when an error is encountered.

Yilds:
The next element from the iterator, ignoring errors.
Expand All @@ -79,6 +81,8 @@ def iter_ignore_error(it):
except (StopIteration, StopAsyncIteration) as e:
return e.value
except _IGNORE_ERROR_TYPES:
if error_return is not None:
yield error_return
continue


Expand Down Expand Up @@ -1079,12 +1083,20 @@ def processed_with_inputs(
input_iterator: Iterator[_InputT],
*,
max_buffer_size: int = 0,
ignore_error: bool = False,
) -> Iterator[tuple[_ValueT, _InputT]]:
"""Zips the processed outputs with its inputs."""
iter_input = _TeeIterator(input_iterator, buffer_size=max_buffer_size)
iter_output = process_fn(iter_input)
# Note that recital iterator has to be put after the input iterator so that
# there are values to be recited.
if ignore_error:
iter_output = iter_ignore_error(iter_output, error_return=_SKIP)
# Note that recital iterator has to be put after the input iterator so that
# there are values to be recited.
return (
(output, input)
for output, input in zip(iter_output, iter_input.tee())
if output is not _SKIP
)
return zip(iter_output, iter_input.tee())


Expand Down