From 2730d349aaf0e93623a9203950aadab1944ed9a1 Mon Sep 17 00:00:00 2001 From: JGSweets Date: Fri, 15 Oct 2021 22:33:40 -0500 Subject: [PATCH] fix: bug in confidences on output (#419) --- dataprofiler/labelers/character_level_cnn_model.py | 2 +- dataprofiler/tests/labelers/test_data_processing.py | 1 + .../test_integration_unstructured_data_labeler.py | 10 +++++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/dataprofiler/labelers/character_level_cnn_model.py b/dataprofiler/labelers/character_level_cnn_model.py index 7b22f76fb..922682c0c 100644 --- a/dataprofiler/labelers/character_level_cnn_model.py +++ b/dataprofiler/labelers/character_level_cnn_model.py @@ -966,7 +966,7 @@ def predict(self, data, batch_size=32, show_confidences=False, in enumerate(sentence_lengths[:allocation_index]): predictions_list[index] = list(predictions[index][:sentence_length]) if show_confidences: - confidences_list = list(confidences[index][:sentence_length]) + confidences_list[index] = list(confidences[index][:sentence_length]) if show_confidences: return {'pred': predictions_list, 'conf': confidences_list} diff --git a/dataprofiler/tests/labelers/test_data_processing.py b/dataprofiler/tests/labelers/test_data_processing.py index 101ccd8c3..f04767b1a 100644 --- a/dataprofiler/tests/labelers/test_data_processing.py +++ b/dataprofiler/tests/labelers/test_data_processing.py @@ -1209,6 +1209,7 @@ def test_match_sentence_lengths(self): inplace=True) self.assertEqual(results, post_process_results) + class TestPreandPostCharacterProcessorConnection(unittest.TestCase): def test_flatten_convert(self): diff --git a/dataprofiler/tests/labelers/test_integration_unstructured_data_labeler.py b/dataprofiler/tests/labelers/test_integration_unstructured_data_labeler.py index 3ecc08a37..ad99ef41b 100644 --- a/dataprofiler/tests/labelers/test_integration_unstructured_data_labeler.py +++ b/dataprofiler/tests/labelers/test_integration_unstructured_data_labeler.py @@ -115,8 +115,16 @@ def test_default_confidences(self): model_predictions_char_level, model_confidences_char_level = \ results["pred"], results["conf"] - # for now just checking that it's not empty + # for now just checking that it's not empty and appropriate size + num_labels = max(default.label_mapping.values()) + 1 + len_text = len(sample[0]) self.assertIsNotNone(model_confidences_char_level) + self.assertEqual((len_text, num_labels), + model_confidences_char_level[0].shape) + + len_text = len(sample[1]) + self.assertEqual((len_text, num_labels), + model_confidences_char_level[1].shape) def test_default_edge_cases(self): """more complicated test for edge cases for the default model"""