diff --git a/src/freestylo/AlliterationAnnotation.py b/src/freestylo/AlliterationAnnotation.py index ea72a46..2919ca4 100644 --- a/src/freestylo/AlliterationAnnotation.py +++ b/src/freestylo/AlliterationAnnotation.py @@ -28,7 +28,7 @@ class AlliterationAnnotation: It uses the TextObject class to store the text and its annotations. """ - def __init__(self, text : TextObject, max_skip = 2, min_length=3, skip_tokens=[".", ",", ":", ";", "!", "?", "…", "(", ")", "[", "]", "{", "}", "„", "“", "‚", "‘:", "‘", "’"]): + def __init__(self, text : TextObject, max_skip = 2, min_length=3, skip_tokens=[".", ",", ":", ";", "!", "?", "…", "(", ")", "[", "]", "{", "}", "„", "“", "‚", "‘:", "‘", "’"], ignore_tokens=None ): """ Parameters ---------- @@ -47,6 +47,11 @@ def __init__(self, text : TextObject, max_skip = 2, min_length=3, skip_tokens=[" self.max_skip = max_skip self.min_length = min_length self.skip_tokens = skip_tokens + if ignore_tokens is not None: + self.ignore_tokens = ignore_tokens + else: + self.ignore_tokens = [] + def find_candidates(self): @@ -65,12 +70,19 @@ def find_candidates(self): if not token_char.isalpha(): continue # if not, create a new one - if token_char not in open_candidates: + if token_char not in open_candidates and token_char not in self.skip_tokens and token_char not in self.ignore_tokens: open_candidates[token_char] = [AlliterationCandidate([i], token_char), 0] continue # if yes, add the current token to the candidate - candidate = open_candidates[token_char][0] - candidate.ids.append(i) + + try: + candidate = open_candidates[token_char][0] + except KeyError: + open_candidates[token_char] = [AlliterationCandidate([i], token_char), 0] + candidate = open_candidates[token_char][0] + + if token not in self.skip_tokens and token not in self.ignore_tokens: + candidate.ids.append(i) # close candidates keys_to_delete = [] diff --git a/src/freestylo/EpiphoraAnnotation.py b/src/freestylo/EpiphoraAnnotation.py index 344a0eb..8f41881 100644 --- a/src/freestylo/EpiphoraAnnotation.py +++ b/src/freestylo/EpiphoraAnnotation.py @@ -61,12 +61,16 @@ def split_in_phrases(self): phrases = [] current_start = 0 + punct_tokens = [".", "!", "?", ":", ";", ","] for i, token in tqdm(enumerate(self.text.tokens)): - if token in self.conj or self.text.pos[i] == self.punct_pos or self.text.pos[i] in ["CONJ", "CCONJ"]: + if token in self.conj or self.text.pos[i] == self.punct_pos or self.text.pos[i] in ["CONJ", "CCONJ"] or self.text.tokens[i].lower() in punct_tokens: if i-current_start > 2: - phrases.append([current_start, i]) - current_start = i+1 - phrases.append([current_start, len(self.text.tokens)]) + phrases.append([current_start, i-1]) + current_start = i + elif token in [".", "!", "?"]: + phrases.append([current_start, i-1]) + current_start = i + phrases.append([current_start, len(self.text.tokens)-1]) return phrases @@ -77,8 +81,11 @@ def find_candidates(self): candidates = [] current_candidate = EpiphoraCandidate([], "") phrases = self.split_in_phrases() + #for p in phrases: + # print("###") + # print(" ".join(self.text.tokens[p[0]:p[1]+1])) for phrase in tqdm(phrases): - word = self.text.tokens[phrase[1]-1] + word = self.text.tokens[phrase[1]] if word != current_candidate.word: if len(current_candidate.ids) >= self.min_length: candidates.append(current_candidate) diff --git a/src/freestylo/MetaphorAnnotation.py b/src/freestylo/MetaphorAnnotation.py index 994d626..15b13bc 100644 --- a/src/freestylo/MetaphorAnnotation.py +++ b/src/freestylo/MetaphorAnnotation.py @@ -83,7 +83,10 @@ def load_model(self, model_path): The path to the model. """ model_path = get_model_path(model_path) - self.model = SimilarityNN.SimilarityNN(300, 128, 1, 128, self.device) + input_size = 300 + if self.text.language == "mgh": + input_size = 100 + self.model = SimilarityNN.SimilarityNN(input_size, 128, 1, 128, self.device) self.model.load_state_dict(torch.load(model_path, weights_only=True, map_location=self.device)) self.model = self.model.to(self.device) self.model.eval() diff --git a/src/freestylo/freestylo_main.py b/src/freestylo/freestylo_main.py index e9d9937..b87dd52 100644 --- a/src/freestylo/freestylo_main.py +++ b/src/freestylo/freestylo_main.py @@ -159,7 +159,8 @@ def add_alliteration_annotation(text, config): alliteration = aa.AlliterationAnnotation( text = text, max_skip = config["max_skip"], - min_length = config["min_length"]) + min_length = config["min_length"], + ignore_tokens=config["ignore_tokens"]) print("Finding candidates") alliteration.find_candidates() print("Done")