Skip to content

Commit

Permalink
DOBPlugin: Implementation and tests (#206)
Browse files Browse the repository at this point in the history
* created DOBplugin that performs pre processing on the transcripts that go into the DucklingPlugin. (does no preprocessing, a dummy plugin to check compatibility)

* added integration tests for DOBplugin

* modifying DOBPlugin to edit "input.transcripts" instead of "input.best_transcript"

* updated integration tests for DOBplugin

* handling error class5 in DOBPlugin

* handing class5 and class6 error in DOBPlugin

* added word2number in pyproject.toml

* added logs for when there is no response from DOBPlugin

* added try-except block in checker function in test script to handle failures that involve empty transcripts

* updated poetry.lock file

* added annotations for all functions, made class6 return the transcript if match not found

* added more function annotations

* added a filtering step (if v is not None) in sorted_result inside get_transcripts_from_utterances function to handle annotation type fail

* added annotations for result_dict and transcripts inside function get_transcripts_from_utterances

* modified annotations for result_dict inside function get_transcripts_from_utterances

* edited get_transcripts_from_utterances to fix type checker issues

* modified class4 function in DOBPlugin

* wrote unit test cases for DOBPlugin

* wrote more unit test functions for DOBPlugin

* handling more variations of class4 and class7_1

* annotated test cases that fall in error class7_1

* added class7_1 unit test cases

* modified to add class7_1 integration test cases

* wrote test function to cover the 'transform' function in DOBPlugin to attain full coverage

---------

Co-authored-by: vivitsaS <[email protected]>
Co-authored-by: Manas <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2024
1 parent a651397 commit e5b8e34
Show file tree
Hide file tree
Showing 10 changed files with 3,973 additions and 1,683 deletions.
2 changes: 2 additions & 0 deletions dialogy/plugins/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dialogy.plugins.text.classification.retain_intent import RetainOriginalIntentPlugin
from dialogy.plugins.text.classification.xlmr import XLMRMultiClass
from dialogy.plugins.text.combine_date_time import CombineDateTimeOverSlots
from dialogy.plugins.text.dob_plugin import DOBPlugin
from dialogy.plugins.text.duckling_plugin import DucklingPlugin
from dialogy.plugins.text.lb_plugin import DucklingPluginLB
from dialogy.plugins.text.list_entity_plugin import ListEntityPlugin
Expand All @@ -23,6 +24,7 @@
"CanonicalizationPlugin": CanonicalizationPlugin,
"MLPMultiClass": MLPMultiClass,
"MergeASROutputPlugin": MergeASROutputPlugin,
"DOBPlugin":DOBPlugin,
"DucklingPlugin": DucklingPlugin,
"ListEntityPlugin": ListEntityPlugin,
"XLMRMultiClass": XLMRMultiClass,
Expand Down
273 changes: 273 additions & 0 deletions dialogy/plugins/text/dob_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import json
import re
import traceback
from typing import Any, Callable, Dict, List, Optional, Union

import pandas as pd
from loguru import logger
from tqdm import tqdm
from word2number import w2n

import dialogy.constants as const
from dialogy.base import Guard, Input, Output, Plugin
from dialogy.types import Utterance
from dialogy.utils import normalize

def _post_processing_clean(transformed_transcript: str) -> str:
transformed_transcript = transformed_transcript.strip()
if " "in transformed_transcript:
transformed_transcript = transformed_transcript.replace(" ", " ")
return transformed_transcript

def _class_7_1(transcript: str) -> str:
"""
replace all instances that look like this:
<1 or 2 digit number><space><4 digit number> with
<1 or 2 digit number><space><2 digit number><2 digit number>
"""
try:
pattern = r'^\b(\d{1,2})\s(\d{2})(\d{2})\b$'
match = re.search(pattern, transcript)
if match:
# Access matching groups using group() or groups()
full_match = match.group(0)
print("groups = ",match.group(1), match.group(2), match.group(3))
initial = match.group(1)
two_digit_number_1 = match.group(2)
two_digit_number_2 = match.group(3)
split_four_into_2_two_digit_numbers = initial+" "+two_digit_number_1+" "+two_digit_number_2
return _post_processing_clean(split_four_into_2_two_digit_numbers)
else:
return transcript
except:
return transcript

def _class_4(input_string: str) -> str:
"""
cleaining
"""
# Remove all instances of "."
cleaned_string = input_string.replace(".", " ")
# Remove all instances of "it's" (case insensitive)
cleaned_string = cleaned_string.replace("it's", " ")
# replace all instances of "for" with 4 (case insensitive)
cleaned_string = cleaned_string.replace("for", "4")
# replace all instances of "-" with " " (case insensitive)
cleaned_string = cleaned_string.replace("-", " ")
# replace all instances of "st/nd/rd/th" with "" (case insensitive)
cleaned_string = re.sub(r'(\d+)(?:st|nd|rd|th)', r'\1', cleaned_string)


cleaned_string = _post_processing_clean(cleaned_string)
return cleaned_string


def _class_6(transcript: str) -> str:
"""
input: transcript that looks like- " X Y" where X is a string of numbers with or without space. Y is string- which is a
number written in words.
description: let X be as is. Only if X exist, look for Y. then convert Y into numbers- numeric digits
output: X (as is) Y(converted into numbers)
"""
"""
logic: WILL ONLY HANDLE YY IN WORDS!, NOT YYYY
1. regex pattern-
2 groups-
1st: numeric characters- with or without space;
2nd: everything that follows it - alphanumeric characters
start_idx = starting index of group 1
2. let group 1 be as is
3. transform group 2:
1. only retain the part that is a word representation of some number. end_idx = last index of this word representation of number
2. convert this word representation of number to numbers = words_converted_to_number
4. new_substring = group 1 + words_converted_to_number
5. replace the start_idx to last_idx part of the original string with new_substring
6. return transformed string
"""

# Regex pattern to capture numeric part and remaining words separately
pattern = re.compile(r"(\b\d+\s*\d*\b|\b\d+\b)(.*)", re.IGNORECASE)

match = re.search(pattern, transcript)

if match:
numeric_part = match.group(1)
remaining_words = match.group(2) or ""
try:
words_converted_to_number = str(w2n.word_to_num(remaining_words))
if len(words_converted_to_number) <= 2:

# Construct the transformed substring
transformed_transcript = numeric_part + " " + words_converted_to_number
return _post_processing_clean(transformed_transcript)
except:
pass
return transcript # Return original transcript if no match


def _class_5(transcript: str) -> str:
"""
input: transcript that looks like-"xx:yy "
description: replace ":" with " "; if yy = 00 or 0y: replace yy with " " or if 0y, replace 0 with ""
output: xx yy
"""

# Your regex pattern
pattern = re.compile(r"\b(\d{1,2}):(\d{2})\b")

match = pattern.search(transcript)

if match:
hours, minutes = match.groups()
start_idx, hours_end = match.start(1), match.end(1)
minutes_start, end_idx = match.start(2), match.end(2)

# Modifying the minutes based on the conditions
if minutes == "00":
minutes = ""
elif minutes.startswith("0"):
minutes = "" + minutes[1]

# Combining the modified hours and minutes
result = hours + " " + minutes
transformed_transcript = transcript[:start_idx] + result + transcript[end_idx:]
return _post_processing_clean(transformed_transcript)

else:
# Handle invalid input
return transcript


def _transform_invalid_date(transcript: str) -> str:
"""
input: transcripts that are responses for when the user is asked their
dob for authentication and are not recognised as dates by duckling
output: trasnformed transcript recognised by duckling as date (closest valid date)
description: handling class 5 error
"""
transcript = _class_4(transcript)
transcript = _class_5(transcript)
transcript = _class_6(transcript)
transcript = _class_7_1(transcript)
return transcript


def get_transcripts_from_utterances(
utterances: List[Utterance], func_transcript: Callable[[str], str]
) -> List[str]:
"""
input: utterances = [
[{'transcript': '102998', 'confidence': None},
{'transcript': '10 29 98', 'confidence': None},
{'transcript': '1029 niniety eight', 'confidence': None}]
]
description: access each transcript, confidence score pair, get
the result of <any func(transcript)>;
get a dictionary containing all results;
order this dictionary in decreasing order of confidence score
output:
best_transcript,
"""
result_dict: Dict[str, Any] = {}
transcripts: List[str] = []

for utterance_set in utterances:
for utterance in utterance_set:
transcript = utterance.get("transcript", "")
confidence = utterance.get("confidence", 0)

confidence = (
0 if confidence is None else confidence
) # Ensure confidence is not None
result = func_transcript(str(transcript))
if result == None:
result = ""
confidence = 0
if result in result_dict:
result_dict[result] += confidence
else:
result_dict[result] = confidence

# Sort the result_dict based on confidence in descending order
sorted_result = {
k: v
for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True)
if v is not None
}
transcripts = sorted(sorted_result, key=lambda x: sorted_result[x], reverse=True)

return transcripts


def get_dob(utterances: List[Utterance]) -> List[str]:
try:
# print("UTTERS:", utterances)
transcripts = normalize(utterances)
invalid_transcript = len(transcripts) == 1 and any(
token.lower() in transcripts for token in const.INVALID_TOKENS
)
if invalid_transcript or not transcripts:
return []
else:
# best_transcript = _format_date(utterances)
# print("transcripts = ", transcripts)
transcripts = get_transcripts_from_utterances(
utterances=utterances, func_transcript=_transform_invalid_date
)
# print("dob output:",transcripts)
return transcripts
except TypeError as type_error:
raise TypeError("`transcript` is expected in the ASR output.") from type_error


class DOBPlugin(Plugin):
def __init__(
self,
input_column: str = const.ALTERNATIVES,
output_column: Optional[str] = None,
use_transform: bool = False,
dest: Optional[str] = None,
guards: Optional[List[Guard]] = None,
debug: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
dest=dest,
guards=guards,
input_column=input_column,
output_column=output_column,
use_transform=use_transform,
debug=debug,
**kwargs,
)

async def utility(self, input: Input, _: Output) -> Any:
# return input.best_transcript
return get_dob(input.utterances)

async def transform(self, training_data: pd.DataFrame) -> pd.DataFrame:
if not self.use_transform:
return training_data

training_data["use"] = True
logger.debug(f"Transforming dataset via {self.__class__.__name__}")
for i, row in tqdm(training_data.iterrows(), total=len(training_data)):
asr_output = None
try:
asr_output = json.loads(row[self.input_column])
if asr_output and (dob := get_dob(asr_output)):
training_data.loc[i, self.output_column] = dob[0]
else:
training_data.loc[i, "use"] = False
except Exception as error: # pylint: disable=broad-except
training_data.loc[i, "use"] = False
logger.error(f"{error} -- {asr_output}\n{traceback.format_exc()}")

training_data_ = training_data[training_data.use].copy()
training_data_.drop("use", axis=1, inplace=True)
discarded_data = len(training_data) - len(training_data_)
if discarded_data:
logger.debug(
f"Discarding {discarded_data} samples because the alternatives couldn't be parsed."
)
return training_data_
Loading

0 comments on commit e5b8e34

Please sign in to comment.