Skip to content

Commit

Permalink
remove chunking logic to have simple sentence labeler. fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
MattGPT-ai committed Nov 23, 2024
1 parent 24cd023 commit a711cb6
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 137 deletions.
2 changes: 1 addition & 1 deletion flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def __init__(
head_id: Optional[int] = None,
whitespace_after: int = 1,
start_position: int = 0,
sentence=None,
sentence: Optional["Sentence"] = None,
) -> None:
super().__init__(sentence=sentence)

Expand Down
2 changes: 1 addition & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def _get_state_dict(self):
state["locked_dropout"] = self.locked_dropout.dropout_rate
state["multi_label"] = self.multi_label
state["multi_label_threshold"] = self.multi_label_threshold
state["loss_weights"] = self.loss_weights
state["loss_weights"] = self.weight_dict
state["train_on_gold_pairs_only"] = self.train_on_gold_pairs_only
state["inverse_model"] = self.inverse_model
if self._custom_decoder:
Expand Down
82 changes: 24 additions & 58 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import reduce
from math import inf
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Union
from typing import Literal, NamedTuple, Optional, Union

from numpy import ndarray
from scipy.stats import pearsonr, spearmanr
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
self.number_of_weights = number_of_weights

def extract_weights(self, state_dict: Dict, iteration: int) -> None:
def extract_weights(self, state_dict: dict, iteration: int) -> None:
for key in state_dict:
vec = state_dict[key]
try:
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.optimizer = optimizer

self.min_lrs: List[float]
self.min_lrs: list[float]
if isinstance(min_lr, (list, tuple)):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
Expand Down Expand Up @@ -315,10 +315,10 @@ def _init_is_better(self, mode: MinMax) -> None:

self.mode = mode

def state_dict(self) -> Dict:
def state_dict(self) -> dict:
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}

def load_state_dict(self, state_dict: Dict) -> None:
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
self._init_is_better(mode=self.mode)

Expand Down Expand Up @@ -369,7 +369,7 @@ def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.
def store_embeddings(
data_points: Union[list[DT], Dataset],
storage_mode: EmbeddingStorageMode,
dynamic_embeddings: Optional[List[str]] = None,
dynamic_embeddings: Optional[list[str]] = None,
) -> None:
"""Stores embeddings of data points in memory or on disk.
Expand Down Expand Up @@ -401,7 +401,7 @@ def store_embeddings(
data_point.to("cpu", pin_memory=pin_memory)


def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]:
def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]:
dynamic_embeddings = []
all_embeddings = []
for data_point in data_points:
Expand Down Expand Up @@ -444,7 +444,7 @@ class CharEntity(NamedTuple):


def create_labeled_sentence_from_tokens(
tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner"
tokens: Union[list[Token]], token_entities: list[TokenEntity], type_name: str = "ner"
) -> Sentence:
"""Creates a new Sentence object from a list of tokens or strings and applies entity labels.
Expand All @@ -457,20 +457,18 @@ def create_labeled_sentence_from_tokens(
Returns:
A labeled Sentence object
"""
tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence
sentence = Sentence(tokens, use_tokenizer=True)
tokens_ = [token.text for token in tokens] # create new tokens that do not already belong to a sentence
sentence = Sentence(tokens_, use_tokenizer=True)
for entity in token_entities:
sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score)
return sentence


def create_sentence_chunks(
def create_labeled_sentence(
text: str,
entities: List[CharEntity],
token_limit: int = 512,
use_context: bool = True,
overlap: int = 0, # TODO: implement overlap
) -> List[Sentence]:
entities: list[CharEntity],
token_limit: float = inf,
) -> Sentence:
"""Chunks and labels a text from a list of entity annotations.
The function explicitly tokenizes the text and labels separately, ensuring entity labels are
Expand All @@ -481,48 +479,25 @@ def create_sentence_chunks(
entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
format (start_char_index, end_char_index, entity_class, entity_text).
token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking
use_context: whether to add context to the sentence
overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context
Returns:
A list of labeled Sentence objects representing the chunks of the original text
"""
chunks = []

tokens: List[Token] = []
tokens: list[Token] = []
current_index = 0
token_entities: List[TokenEntity] = []
end_token_idx = 0
token_entities: list[TokenEntity] = []

for entity in entities:

if entity.start_char_idx > current_index: # add non-entity text
non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens
while end_token_idx + len(non_entity_tokens) > token_limit:
num_tokens = token_limit - len(tokens)
tokens.extend(non_entity_tokens[:num_tokens])
non_entity_tokens = non_entity_tokens[num_tokens:]
# skip any fully negative samples, they cause fine_tune to fail with
# `torch.cat(): expected a non-empty list of Tensors`
if len(token_entities) > 0:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
tokens, token_entities = [], []
end_token_idx = 0
tokens.extend(non_entity_tokens)
if current_index < entity.start_char_idx:
# add tokens before the entity
sentence = Sentence(text[current_index : entity.start_char_idx])
tokens.extend(sentence)

# add new entity tokens
start_token_idx = len(tokens)
entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx])
if len(entity_sentence) > token_limit:
logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}")
end_token_idx = start_token_idx + len(entity_sentence)

if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))

tokens, token_entities = [], []
start_token_idx, end_token_idx = 0, len(entity_sentence)

token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score)
token_entities.append(token_entity)
tokens.extend(entity_sentence)
Expand All @@ -532,19 +507,10 @@ def create_sentence_chunks(
# add any remaining tokens to a new chunk
if current_index < len(text):
remaining_sentence = Sentence(text[current_index:])
if end_token_idx + len(remaining_sentence) > token_limit:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
tokens, token_entities = [], []
tokens.extend(remaining_sentence)

if tokens:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))

for chunk in chunks:
if len(chunk) > token_limit:
logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}")

if use_context:
Sentence.set_context_for_sentences(chunks)
if isinstance(token_limit, int) and token_limit < len(tokens):
tokens = tokens[:token_limit]
token_entities = [entity for entity in token_entities if entity.end_token_idx <= token_limit]

return chunks
return create_labeled_sentence_from_tokens(tokens, token_entities)
Loading

0 comments on commit a711cb6

Please sign in to comment.