From 515d917bf58da4103a2bbf39c3716213c36aff03 Mon Sep 17 00:00:00 2001 From: SeqIO Team Date: Tue, 28 Nov 2023 01:15:28 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 585885441 --- seqio/utils.py | 8 +++++--- seqio/utils_test.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/seqio/utils.py b/seqio/utils.py index b37c3ecf..e1d84d79 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -716,10 +716,12 @@ def add_kwargs_to_transform(transform, **kwargs): is_dataclass = dataclasses.is_dataclass(transform) # Filter kwargs by attributes of the dataclass/arguments of the function. if is_dataclass: - avaialabe_arg_names = [f.name for f in dataclasses.fields(transform)] + available_arg_names = [f.name for f in dataclasses.fields(transform)] else: - avaialabe_arg_names = set(inspect.signature(transform).parameters.keys()) - kwargs = {k: v for k, v in kwargs.items() if k in avaialabe_arg_names} + available_arg_names = set(inspect.signature(transform).parameters.keys()) + if isinstance(transform, functools.partial): + available_arg_names -= set(transform.keywords.keys()) + kwargs = {k: v for k, v in kwargs.items() if k in available_arg_names} if not kwargs: return transform # Add attributes/arguments. diff --git a/seqio/utils_test.py b/seqio/utils_test.py index 6ab4af81..add15221 100644 --- a/seqio/utils_test.py +++ b/seqio/utils_test.py @@ -284,6 +284,19 @@ def __call__(self, x): fn = utils.add_kwargs_to_transform(fn, y=2, z=10) self.assertEqual(60, fn(3)) + def test_add_kwargs_to_transform_partial(self): + """Test add_kwargs_to_transform() with partial. + + Ensure not to overwrite the keyword argument once it has been predefined + by functools.partial. + """ + + def fn(x, y): + return x * y + + fn = utils.add_kwargs_to_transform(functools.partial(fn, x=1), x=2, y=3) + self.assertEqual(3, fn()) + class MapOverDatasetTest(parameterized.TestCase):