From 857e539ea5893ebd60f1a1e93aea82447f14b80a Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Tue, 26 Dec 2023 14:30:40 +0530 Subject: [PATCH 1/8] Add validation for tagging type --- skit_labels/cli.py | 9 +++++++-- skit_labels/constants.py | 6 +++++- skit_labels/utils.py | 36 +++++++++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/skit_labels/cli.py b/skit_labels/cli.py index 9dba2df..4d6572a 100644 --- a/skit_labels/cli.py +++ b/skit_labels/cli.py @@ -319,12 +319,17 @@ def build_cli(): return parser -def upload_dataset(input_file, url, token, job_id, data_source, data_label = None): +def upload_dataset(input_file, url, token, job_id, data_source, data_label = None, tagging_type=None): input_file = utils.add_data_label(input_file, data_label) if data_source == const.SOURCE__DB: fn = commands.upload_dataset_to_db elif data_source == const.SOURCE__LABELSTUDIO: - fn = commands.upload_dataset_to_labelstudio + if tagging_type: + is_valid, error = utils.validate_input_data(tagging_type, input_file) + if not is_valid: + return error, None + + fn = commands.upload_dataset_to_labelstudio errors, df_size = asyncio.run( fn( input_file, diff --git a/skit_labels/constants.py b/skit_labels/constants.py index ae80e08..b7fe6ef 100644 --- a/skit_labels/constants.py +++ b/skit_labels/constants.py @@ -120,4 +120,8 @@ FROM_NAME_INTENT = "tag" CHOICES = "choices" TAXONOMY = "taxonomy" -VALUE = "value" \ No newline at end of file +VALUE = "value" + +EXPECTED_COLUMNS_MAPPING = { + "conversation_tagging": ['situation_id', 'situation_str', 'call'] +} \ No newline at end of file diff --git a/skit_labels/utils.py b/skit_labels/utils.py index 7979eac..793d8d9 100644 --- a/skit_labels/utils.py +++ b/skit_labels/utils.py @@ -10,7 +10,7 @@ from datetime import datetime import pandas as pd from typing import Union - +from constants import EXPECTED_COLUMNS_MAPPING LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] @@ -110,3 +110,37 @@ def add_data_label(input_file: str, data_label: Optional[str] = None) -> str: df = df.assign(data_label=data_label) df.to_csv(input_file, index=False) return input_file + + +def validate_headers(input_file, tagging_type): + expected_columns_mapping = EXPECTED_COLUMNS_MAPPING + expected_headers = expected_columns_mapping.get(tagging_type) + + df = pd.read_csv(input_file) + column_headers = df.columns.to_list() + column_headers = [header.lower() for header in column_headers] + column_headers = sorted(column_headers) + expected_headers = sorted(expected_headers) + logger.info(f"column_headers: {column_headers}") + logger.info(f"expected_headers: {expected_headers}") + + is_match = column_headers == expected_headers + mismatch_headers = [] + logger.info(f"Is match: {is_match}") + + if not is_match: + mismatch_headers_set =set(column_headers).symmetric_difference(set(expected_headers)) + mismatch_headers = list(mismatch_headers_set) + return is_match, mismatch_headers + + +def validate_input_data(tagging_type, input_file): + is_valid = True + error = '' + if tagging_type == 'conversation_tagging': + is_match, mismatch_headers = validate_headers(input_file, tagging_type) + if not is_match: + error = f'Headers in the input file does not match the expected fields. Mismatched fields = {mismatch_headers}' + is_valid = False + + return is_valid, error \ No newline at end of file From 3b6d0916df15aacfe4b7564deb7db45183207542 Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Tue, 26 Dec 2023 14:57:45 +0530 Subject: [PATCH 2/8] Fix import --- skit_labels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skit_labels/utils.py b/skit_labels/utils.py index 793d8d9..dea4773 100644 --- a/skit_labels/utils.py +++ b/skit_labels/utils.py @@ -10,7 +10,7 @@ from datetime import datetime import pandas as pd from typing import Union -from constants import EXPECTED_COLUMNS_MAPPING +from skit_labels import constants as const LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] @@ -113,7 +113,7 @@ def add_data_label(input_file: str, data_label: Optional[str] = None) -> str: def validate_headers(input_file, tagging_type): - expected_columns_mapping = EXPECTED_COLUMNS_MAPPING + expected_columns_mapping = const.EXPECTED_COLUMNS_MAPPING expected_headers = expected_columns_mapping.get(tagging_type) df = pd.read_csv(input_file) From a0f15665c5ed37c00f41a1a008456b37981b8f69 Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Tue, 26 Dec 2023 18:43:50 +0530 Subject: [PATCH 3/8] Update the expected fields --- skit_labels/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skit_labels/constants.py b/skit_labels/constants.py index b7fe6ef..c90e4ed 100644 --- a/skit_labels/constants.py +++ b/skit_labels/constants.py @@ -123,5 +123,5 @@ VALUE = "value" EXPECTED_COLUMNS_MAPPING = { - "conversation_tagging": ['situation_id', 'situation_str', 'call'] + "conversation_tagging": ['scenario', 'scenario_category', 'situation_str', 'call'] } \ No newline at end of file From 29efa8b294d65fda794774b03f860f490a7ae5f2 Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Wed, 27 Dec 2023 14:28:12 +0530 Subject: [PATCH 4/8] Move values to constants --- skit_labels/constants.py | 4 +++- skit_labels/utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/skit_labels/constants.py b/skit_labels/constants.py index c90e4ed..4455608 100644 --- a/skit_labels/constants.py +++ b/skit_labels/constants.py @@ -124,4 +124,6 @@ EXPECTED_COLUMNS_MAPPING = { "conversation_tagging": ['scenario', 'scenario_category', 'situation_str', 'call'] -} \ No newline at end of file +} + +CONVERSATION_TAGGING = 'conversation_tagging' \ No newline at end of file diff --git a/skit_labels/utils.py b/skit_labels/utils.py index dea4773..2982e56 100644 --- a/skit_labels/utils.py +++ b/skit_labels/utils.py @@ -137,7 +137,7 @@ def validate_headers(input_file, tagging_type): def validate_input_data(tagging_type, input_file): is_valid = True error = '' - if tagging_type == 'conversation_tagging': + if tagging_type == const.CONVERSATION_TAGGING: is_match, mismatch_headers = validate_headers(input_file, tagging_type) if not is_match: error = f'Headers in the input file does not match the expected fields. Mismatched fields = {mismatch_headers}' From 6f9dba2fafeaa2e8a38e54b4c458c7030b56317c Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Wed, 27 Dec 2023 15:06:05 +0530 Subject: [PATCH 5/8] Update the expected fields --- skit_labels/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skit_labels/constants.py b/skit_labels/constants.py index 4455608..c612d9b 100644 --- a/skit_labels/constants.py +++ b/skit_labels/constants.py @@ -123,7 +123,7 @@ VALUE = "value" EXPECTED_COLUMNS_MAPPING = { - "conversation_tagging": ['scenario', 'scenario_category', 'situation_str', 'call'] + "conversation_tagging": ['scenario', 'scenario_category', 'situation_str', 'call', 'data_label'] } CONVERSATION_TAGGING = 'conversation_tagging' \ No newline at end of file From f53e63f77b52fe5ba30ed7ab3c082c7914f55b9a Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Wed, 27 Dec 2023 15:35:58 +0530 Subject: [PATCH 6/8] Add support for tagging type arg in cli --- skit_labels/cli.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/skit_labels/cli.py b/skit_labels/cli.py index 4d6572a..9d805fe 100644 --- a/skit_labels/cli.py +++ b/skit_labels/cli.py @@ -227,6 +227,13 @@ def upload_dataset_to_labelstudio_command( required=True, help="The data label implying the source of data", ) + + parser.add_argument( + "--tagging-type", + type=str, + help="The tagging type for the calls being uploaded", + ) + return parser @@ -391,7 +398,7 @@ def cmd_to_str(args: argparse.Namespace) -> str: arg_id = args.job_id _ = is_valid_data_label(args.data_label) - errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label) + errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label, args.tagging_type) if errors: return ( From dd8d2ab18de5348dbcfcbe40f87986942e73eef3 Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Wed, 27 Dec 2023 15:39:02 +0530 Subject: [PATCH 7/8] Update the version --- CHANGELOG.md | 2 ++ pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7fbc8c..8fd7b56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,6 @@ # CHANGELOG +## 0.3.35 +- [x] add: validations for the input file for conversation tagging ## 0.3.34 - [x] PL-61: Add retry mechanism for uploading data to Label studio diff --git a/pyproject.toml b/pyproject.toml index db35829..810a948 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "skit-labels" -version = "0.3.34" +version = "0.3.35" description = "Command line tool for interacting with labelled datasets at skit.ai." authors = [] license = "MIT" From 3cbc67e42e65d20debebd178e71dc9af8f1d2f32 Mon Sep 17 00:00:00 2001 From: "dhanashree.s" Date: Wed, 27 Dec 2023 16:12:43 +0530 Subject: [PATCH 8/8] Refactor the column check function --- skit_labels/utils.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/skit_labels/utils.py b/skit_labels/utils.py index 2982e56..b35acb9 100644 --- a/skit_labels/utils.py +++ b/skit_labels/utils.py @@ -117,30 +117,38 @@ def validate_headers(input_file, tagging_type): expected_headers = expected_columns_mapping.get(tagging_type) df = pd.read_csv(input_file) + column_headers = df.columns.to_list() column_headers = [header.lower() for header in column_headers] column_headers = sorted(column_headers) expected_headers = sorted(expected_headers) + logger.info(f"column_headers: {column_headers}") logger.info(f"expected_headers: {expected_headers}") is_match = column_headers == expected_headers - mismatch_headers = [] logger.info(f"Is match: {is_match}") if not is_match: - mismatch_headers_set =set(column_headers).symmetric_difference(set(expected_headers)) - mismatch_headers = list(mismatch_headers_set) - return is_match, mismatch_headers + missing_headers = set(expected_headers).difference(set(column_headers)) + additional_headers = set(column_headers).difference(set(expected_headers)) + if missing_headers: + return missing_headers + elif additional_headers: + df.drop(additional_headers, axis=1, inplace=True) + df.to_csv(input_file, index=False) + is_match = True + logger.info(f"Following additional headers have been removed from the csv: {additional_headers}") + return [] def validate_input_data(tagging_type, input_file): is_valid = True error = '' if tagging_type == const.CONVERSATION_TAGGING: - is_match, mismatch_headers = validate_headers(input_file, tagging_type) - if not is_match: - error = f'Headers in the input file does not match the expected fields. Mismatched fields = {mismatch_headers}' + missing_headers = validate_headers(input_file, tagging_type) + if missing_headers: + error = f'Headers in the input file does not match the expected fields. Missing fields = {missing_headers}' is_valid = False return is_valid, error \ No newline at end of file