From 717d7042158562c50686fea2b36cd89005bc10e0 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Fri, 8 Mar 2024 12:18:04 +0100 Subject: [PATCH] Process TFRecord reader binding classes only when it is enabled (#5360) Only parse the code and create the TFRecord reader API bindings/classes when the feature is enabled. Signed-off-by: Krzysztof Lecki --- .../nvidia/dali/ops/_operators/tfrecord.py | 117 +++++++++--------- 1 file changed, 59 insertions(+), 58 deletions(-) diff --git a/dali/python/nvidia/dali/ops/_operators/tfrecord.py b/dali/python/nvidia/dali/ops/_operators/tfrecord.py index 5001dfaa5f0..d2c05d1f47c 100644 --- a/dali/python/nvidia/dali/ops/_operators/tfrecord.py +++ b/dali/python/nvidia/dali/ops/_operators/tfrecord.py @@ -29,61 +29,62 @@ def tfrecord_enabled(): return False -def _get_impl(name, schema_name, internal_schema_name): - - class _TFRecordReaderImpl( - ops.python_op_factory(name, schema_name, internal_schema_name, generated=False) - ): - """custom wrappers around ops""" - - def __init__(self, path, index_path, features, **kwargs): - if isinstance(path, list): - self._path = path - else: - self._path = [path] - if isinstance(index_path, list): - self._index_path = index_path - else: - self._index_path = [index_path] - - kwargs.update({"path": self._path, "index_path": self._index_path}) - self._features = features - - super().__init__(**kwargs) - - def __call__(self, *inputs, **kwargs): - feature_names = [] - features = [] - for feature_name, feature in self._features.items(): - feature_names.append(feature_name) - features.append(feature) - if not isinstance(feature, _b.tfrecord.Feature): - raise TypeError( - "Expected `nvidia.dali.tfrecord.Feature` for the " - f'"{feature_name}", but got {type(feature)}. ' - "Use `nvidia.dali.tfrecord.FixedLenFeature` or " - "`nvidia.dali.tfrecord.VarLenFeature` to define the features to extract." - ) - - kwargs.update({"feature_names": feature_names, "features": features}) - - # We won't have MIS as this op doesn't have any inputs (Reader) - linear_outputs = super().__call__(*inputs, **kwargs) - # We may have single, flattened output - if not isinstance(linear_outputs, list): - linear_outputs = [linear_outputs] - outputs = {} - for feature_name, output in zip(feature_names, linear_outputs): - outputs[feature_name] = output - - return outputs - - return _TFRecordReaderImpl - - -class TFRecordReader(_get_impl("_TFRecordReader", "TFRecordReader", "_TFRecordReader")): - pass - - -class TFRecord(_get_impl("_TFRecord", "readers__TFRecord", "readers___TFRecord")): - pass +if tfrecord_enabled(): + + def _get_impl(name, schema_name, internal_schema_name): + + class _TFRecordReaderImpl( + ops.python_op_factory(name, schema_name, internal_schema_name, generated=False) + ): + """custom wrappers around ops""" + + def __init__(self, path, index_path, features, **kwargs): + if isinstance(path, list): + self._path = path + else: + self._path = [path] + if isinstance(index_path, list): + self._index_path = index_path + else: + self._index_path = [index_path] + + kwargs.update({"path": self._path, "index_path": self._index_path}) + self._features = features + + super().__init__(**kwargs) + + def __call__(self, *inputs, **kwargs): + feature_names = [] + features = [] + for feature_name, feature in self._features.items(): + feature_names.append(feature_name) + features.append(feature) + if not isinstance(feature, _b.tfrecord.Feature): + raise TypeError( + "Expected `nvidia.dali.tfrecord.Feature` for the " + f'"{feature_name}", but got {type(feature)}. ' + "Use `nvidia.dali.tfrecord.FixedLenFeature` or " + "`nvidia.dali.tfrecord.VarLenFeature` " + "to define the features to extract." + ) + + kwargs.update({"feature_names": feature_names, "features": features}) + + # We won't have MIS as this op doesn't have any inputs (Reader) + linear_outputs = super().__call__(*inputs, **kwargs) + # We may have single, flattened output + if not isinstance(linear_outputs, list): + linear_outputs = [linear_outputs] + outputs = {} + for feature_name, output in zip(feature_names, linear_outputs): + outputs[feature_name] = output + + return outputs + + return _TFRecordReaderImpl + + class TFRecordReader(_get_impl("_TFRecordReader", "TFRecordReader", "_TFRecordReader")): + pass + + class TFRecord(_get_impl("_TFRecord", "readers__TFRecord", "readers___TFRecord")): + pass