Skip to content

Commit fdae4ef

Browse files
committed
remove chunking logic to have simple sentence labeler. fix tests.
1 parent 24cd023 commit fdae4ef

File tree

3 files changed

+112
-136
lines changed

3 files changed

+112
-136
lines changed

flair/nn/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ def _get_state_dict(self):
982982
state["locked_dropout"] = self.locked_dropout.dropout_rate
983983
state["multi_label"] = self.multi_label
984984
state["multi_label_threshold"] = self.multi_label_threshold
985-
state["loss_weights"] = self.loss_weights
985+
state["loss_weights"] = self.weight_dict
986986
state["train_on_gold_pairs_only"] = self.train_on_gold_pairs_only
987987
state["inverse_model"] = self.inverse_model
988988
if self._custom_decoder:

flair/training_utils.py

Lines changed: 20 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from functools import reduce
77
from math import inf
88
from pathlib import Path
9-
from typing import Dict, List, Literal, NamedTuple, Optional, Union
9+
from typing import Literal, NamedTuple, Optional, Union
1010

1111
from numpy import ndarray
1212
from scipy.stats import pearsonr, spearmanr
@@ -102,7 +102,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
102102
self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
103103
self.number_of_weights = number_of_weights
104104

105-
def extract_weights(self, state_dict: Dict, iteration: int) -> None:
105+
def extract_weights(self, state_dict: dict, iteration: int) -> None:
106106
for key in state_dict:
107107
vec = state_dict[key]
108108
try:
@@ -215,7 +215,7 @@ def __init__(
215215
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
216216
self.optimizer = optimizer
217217

218-
self.min_lrs: List[float]
218+
self.min_lrs: list[float]
219219
if isinstance(min_lr, (list, tuple)):
220220
if len(min_lr) != len(optimizer.param_groups):
221221
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
@@ -315,10 +315,10 @@ def _init_is_better(self, mode: MinMax) -> None:
315315

316316
self.mode = mode
317317

318-
def state_dict(self) -> Dict:
318+
def state_dict(self) -> dict:
319319
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
320320

321-
def load_state_dict(self, state_dict: Dict) -> None:
321+
def load_state_dict(self, state_dict: dict) -> None:
322322
self.__dict__.update(state_dict)
323323
self._init_is_better(mode=self.mode)
324324

@@ -369,7 +369,7 @@ def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.
369369
def store_embeddings(
370370
data_points: Union[list[DT], Dataset],
371371
storage_mode: EmbeddingStorageMode,
372-
dynamic_embeddings: Optional[List[str]] = None,
372+
dynamic_embeddings: Optional[list[str]] = None,
373373
) -> None:
374374
"""Stores embeddings of data points in memory or on disk.
375375
@@ -401,7 +401,7 @@ def store_embeddings(
401401
data_point.to("cpu", pin_memory=pin_memory)
402402

403403

404-
def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]:
404+
def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]:
405405
dynamic_embeddings = []
406406
all_embeddings = []
407407
for data_point in data_points:
@@ -444,7 +444,7 @@ class CharEntity(NamedTuple):
444444

445445

446446
def create_labeled_sentence_from_tokens(
447-
tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner"
447+
tokens: Union[list[Token]], token_entities: list[TokenEntity], type_name: str = "ner"
448448
) -> Sentence:
449449
"""Creates a new Sentence object from a list of tokens or strings and applies entity labels.
450450
@@ -457,20 +457,18 @@ def create_labeled_sentence_from_tokens(
457457
Returns:
458458
A labeled Sentence object
459459
"""
460-
tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence
460+
tokens = [token.text for token in tokens] # create new tokens that do not already belong to a sentence
461461
sentence = Sentence(tokens, use_tokenizer=True)
462462
for entity in token_entities:
463463
sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score)
464464
return sentence
465465

466466

467-
def create_sentence_chunks(
467+
def create_labeled_sentence(
468468
text: str,
469-
entities: List[CharEntity],
470-
token_limit: int = 512,
471-
use_context: bool = True,
472-
overlap: int = 0, # TODO: implement overlap
473-
) -> List[Sentence]:
469+
entities: list[CharEntity],
470+
token_limit: float = inf,
471+
) -> Sentence:
474472
"""Chunks and labels a text from a list of entity annotations.
475473
476474
The function explicitly tokenizes the text and labels separately, ensuring entity labels are
@@ -481,48 +479,25 @@ def create_sentence_chunks(
481479
entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
482480
format (start_char_index, end_char_index, entity_class, entity_text).
483481
token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking
484-
use_context: whether to add context to the sentence
485-
overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context
486482
487483
Returns:
488484
A list of labeled Sentence objects representing the chunks of the original text
489485
"""
490-
chunks = []
491-
492-
tokens: List[Token] = []
486+
tokens: list[Token] = []
493487
current_index = 0
494-
token_entities: List[TokenEntity] = []
495-
end_token_idx = 0
488+
token_entities: list[TokenEntity] = []
496489

497490
for entity in entities:
498-
499-
if entity.start_char_idx > current_index: # add non-entity text
500-
non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens
501-
while end_token_idx + len(non_entity_tokens) > token_limit:
502-
num_tokens = token_limit - len(tokens)
503-
tokens.extend(non_entity_tokens[:num_tokens])
504-
non_entity_tokens = non_entity_tokens[num_tokens:]
505-
# skip any fully negative samples, they cause fine_tune to fail with
506-
# `torch.cat(): expected a non-empty list of Tensors`
507-
if len(token_entities) > 0:
508-
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
509-
tokens, token_entities = [], []
510-
end_token_idx = 0
511-
tokens.extend(non_entity_tokens)
491+
if current_index < entity.start_char_idx:
492+
# add tokens before the entity
493+
sentence = Sentence(text[current_index : entity.start_char_idx])
494+
tokens.extend(sentence)
512495

513496
# add new entity tokens
514497
start_token_idx = len(tokens)
515498
entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx])
516-
if len(entity_sentence) > token_limit:
517-
logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}")
518499
end_token_idx = start_token_idx + len(entity_sentence)
519500

520-
if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk
521-
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
522-
523-
tokens, token_entities = [], []
524-
start_token_idx, end_token_idx = 0, len(entity_sentence)
525-
526501
token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score)
527502
token_entities.append(token_entity)
528503
tokens.extend(entity_sentence)
@@ -532,19 +507,6 @@ def create_sentence_chunks(
532507
# add any remaining tokens to a new chunk
533508
if current_index < len(text):
534509
remaining_sentence = Sentence(text[current_index:])
535-
if end_token_idx + len(remaining_sentence) > token_limit:
536-
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
537-
tokens, token_entities = [], []
538510
tokens.extend(remaining_sentence)
539511

540-
if tokens:
541-
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
542-
543-
for chunk in chunks:
544-
if len(chunk) > token_limit:
545-
logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}")
546-
547-
if use_context:
548-
Sentence.set_context_for_sentences(chunks)
549-
550-
return chunks
512+
return create_labeled_sentence_from_tokens(tokens, token_entities)

0 commit comments

Comments
 (0)