From 3f3fa8791962bbcc901ab469ec37e7a07d961ac6 Mon Sep 17 00:00:00 2001 From: JGSweets Date: Fri, 28 Jan 2022 14:08:44 -0600 Subject: [PATCH] Add is_pred_labels to struct labeler and increment version (#435) * feat: add is_pred_labels to struct labeler * feat: increment version --- dataprofiler/labelers/data_processing.py | 30 ++++++++++++------- .../tests/labelers/test_data_processing.py | 20 +++++++++++++ dataprofiler/version.py | 2 +- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/dataprofiler/labelers/data_processing.py b/dataprofiler/labelers/data_processing.py index 79fd26b81..9019683b5 100644 --- a/dataprofiler/labelers/data_processing.py +++ b/dataprofiler/labelers/data_processing.py @@ -830,10 +830,10 @@ def _word_level_argmax(self, data, predictions, label_mapping, is_end = (idx == len(sample)-1 and start_idx > 0) if not is_separator: - label = entities_in_sample[idx] + label = entities_in_sample[idx] if label not in label_count: label_count[label] = 0 - label_count[label] += 1 + label_count[label] += 1 if is_separator or is_end: @@ -867,7 +867,7 @@ def _word_level_argmax(self, data, predictions, label_mapping, label_count = {background_label: 0} if char_pred[idx] == background_label and \ sample[idx] in separator_dict: - continue + continue word_level_predictions.append(entities_in_sample) return word_level_predictions @@ -1220,7 +1220,8 @@ class StructCharPostprocessor(BaseDataPostprocessor, metaclass=AutoSubRegistrationMeta): def __init__(self, default_label='UNKNOWN', pad_label='PAD', - flatten_separator="\x01"*5, random_state=None): + flatten_separator="\x01"*5, is_pred_labels=True, + random_state=None): """ Initialize the StructCharPostprocessor class @@ -1231,6 +1232,9 @@ def __init__(self, default_label='UNKNOWN', pad_label='PAD', :param flatten_separator: separator used to put between flattened samples. :type flatten_separator: str + :param is_pred_labels: (default: true) if true, will convert the model + indexes to the label strings given the label_mapping + :type is_pred_labels: bool :param random_state: random state setting to be used for randomly selecting a prediction when two labels have equal opportunity for a given sample. @@ -1256,6 +1260,7 @@ def __init__(self, default_label='UNKNOWN', pad_label='PAD', super().__init__(default_label=default_label, pad_label=pad_label, flatten_separator=flatten_separator, + is_pred_labels=is_pred_labels, random_state=random_state) def __eq__(self, other): @@ -1276,7 +1281,9 @@ def __eq__(self, other): or self._parameters["pad_label"] != \ other._parameters["pad_label"]\ or self._parameters["flatten_separator"] != \ - other._parameters["flatten_separator"]: + other._parameters["flatten_separator"] \ + or self._parameters["is_pred_labels"] != \ + other._parameters["is_pred_labels"]: return False return True @@ -1303,6 +1310,8 @@ def _validate_parameters(self, parameters): if param in ['default_label', 'pad_label', 'flatten_separator'] \ and not isinstance(value, str): errors.append("`{}` must be a string.".format(param)) + if param in ['is_pred_labels'] and not isinstance(value, bool): + errors.append("`{}` must be a boolean.".format(param)) if param == 'random_state' and not isinstance(value, random.Random): errors.append('`{}` must be a random.Random.'.format(param)) elif param not in allowed_parameters: @@ -1483,6 +1492,7 @@ def process(self, data, results, label_mapping): flatten_separator = self._parameters['flatten_separator'] default_label = self._parameters['default_label'] pad_label = self._parameters['pad_label'] + is_pred_labels = self._parameters['is_pred_labels'] # Format predictions # FORMER DEEPCOPY, SHALLOW AS ONLY INTERNAL @@ -1494,11 +1504,11 @@ def process(self, data, results, label_mapping): default_label=default_label, pad_label=pad_label) - reverse_label_mapping = {v: k for k, v in label_mapping.items()} - rev_label_map_vec_func = np.vectorize( - lambda x: reverse_label_mapping.get(x, None)) - - results['pred'] = rev_label_map_vec_func(results['pred']) + if is_pred_labels: + reverse_label_mapping = {v: k for k, v in label_mapping.items()} + rev_label_map_vec_func = np.vectorize( + lambda x: reverse_label_mapping.get(x, None)) + results['pred'] = rev_label_map_vec_func(results['pred']) return results def _save_processor(self, dirpath): diff --git a/dataprofiler/tests/labelers/test_data_processing.py b/dataprofiler/tests/labelers/test_data_processing.py index 9df9f0afb..664c5b630 100644 --- a/dataprofiler/tests/labelers/test_data_processing.py +++ b/dataprofiler/tests/labelers/test_data_processing.py @@ -1852,6 +1852,7 @@ def test_get_parameters(self): self.assertDictEqual(dict(default_label='UNKNOWN', pad_label='PAD', flatten_separator='\x01'*5, + is_pred_labels=True, random_state=random_state), processor.get_parameters()) @@ -1859,6 +1860,7 @@ def test_get_parameters(self): params = dict(default_label='test default', pad_label='test pad', flatten_separator='test', + is_pred_labels=False, random_state=random_state) processor = StructCharPostprocessor(**params) self.assertDictEqual(params, processor.get_parameters()) @@ -1867,6 +1869,7 @@ def test_get_parameters(self): params = dict(default_label='test default', pad_label='test pad', flatten_separator='test', + is_pred_labels=False, random_state=random_state) processor = StructCharPostprocessor(**params) self.assertDictEqual( @@ -1911,7 +1914,24 @@ def test_process(self): self.assertIn('pred', output) self.assertTrue((expected_output['pred'] == output['pred']).all()) + # test with is_pred_labels = False + processor = StructCharPostprocessor( + default_label='UNKNOWN', + pad_label='PAD', + is_pred_labels=False, + flatten_separator='\x01' * 5) + expected_output_ints = dict(pred=np.array([2, 3, 1, 3, 2])) + output = processor.process(data, results, label_mapping) + + self.assertIn('pred', output) + self.assertTrue((expected_output_ints['pred'] == output['pred']).all()) + # with confidences + processor = StructCharPostprocessor( + default_label='UNKNOWN', + pad_label='PAD', + is_pred_labels=True, + flatten_separator='\x01' * 5) confidences = [] for sample in results['pred']: confidences.append([]) diff --git a/dataprofiler/version.py b/dataprofiler/version.py index 68526a847..278261320 100644 --- a/dataprofiler/version.py +++ b/dataprofiler/version.py @@ -4,7 +4,7 @@ MAJOR = 0 MINOR = 7 -MICRO = 4 +MICRO = 5 VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO)