Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 585885441
  • Loading branch information
SeqIO Team authored and SeqIO committed Nov 28, 2023
1 parent 64c56b8 commit 515d917
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
8 changes: 5 additions & 3 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,10 +716,12 @@ def add_kwargs_to_transform(transform, **kwargs):
is_dataclass = dataclasses.is_dataclass(transform)
# Filter kwargs by attributes of the dataclass/arguments of the function.
if is_dataclass:
avaialabe_arg_names = [f.name for f in dataclasses.fields(transform)]
available_arg_names = [f.name for f in dataclasses.fields(transform)]
else:
avaialabe_arg_names = set(inspect.signature(transform).parameters.keys())
kwargs = {k: v for k, v in kwargs.items() if k in avaialabe_arg_names}
available_arg_names = set(inspect.signature(transform).parameters.keys())
if isinstance(transform, functools.partial):
available_arg_names -= set(transform.keywords.keys())
kwargs = {k: v for k, v in kwargs.items() if k in available_arg_names}
if not kwargs:
return transform
# Add attributes/arguments.
Expand Down
13 changes: 13 additions & 0 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,19 @@ def __call__(self, x):
fn = utils.add_kwargs_to_transform(fn, y=2, z=10)
self.assertEqual(60, fn(3))

def test_add_kwargs_to_transform_partial(self):
"""Test add_kwargs_to_transform() with partial.
Ensure not to overwrite the keyword argument once it has been predefined
by functools.partial.
"""

def fn(x, y):
return x * y

fn = utils.add_kwargs_to_transform(functools.partial(fn, x=1), x=2, y=3)
self.assertEqual(3, fn())


class MapOverDatasetTest(parameterized.TestCase):

Expand Down

0 comments on commit 515d917

Please sign in to comment.