6
6
from functools import reduce
7
7
from math import inf
8
8
from pathlib import Path
9
- from typing import Dict , List , Literal , NamedTuple , Optional , Union
9
+ from typing import Literal , NamedTuple , Optional , Union
10
10
11
11
from numpy import ndarray
12
12
from scipy .stats import pearsonr , spearmanr
@@ -102,7 +102,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
102
102
self .weights_dict : dict [str , dict [int , list [float ]]] = defaultdict (lambda : defaultdict (list ))
103
103
self .number_of_weights = number_of_weights
104
104
105
- def extract_weights (self , state_dict : Dict , iteration : int ) -> None :
105
+ def extract_weights (self , state_dict : dict , iteration : int ) -> None :
106
106
for key in state_dict :
107
107
vec = state_dict [key ]
108
108
try :
@@ -215,7 +215,7 @@ def __init__(
215
215
raise TypeError (f"{ type (optimizer ).__name__ } is not an Optimizer" )
216
216
self .optimizer = optimizer
217
217
218
- self .min_lrs : List [float ]
218
+ self .min_lrs : list [float ]
219
219
if isinstance (min_lr , (list , tuple )):
220
220
if len (min_lr ) != len (optimizer .param_groups ):
221
221
raise ValueError (f"expected { len (optimizer .param_groups )} min_lrs, got { len (min_lr )} " )
@@ -315,10 +315,10 @@ def _init_is_better(self, mode: MinMax) -> None:
315
315
316
316
self .mode = mode
317
317
318
- def state_dict (self ) -> Dict :
318
+ def state_dict (self ) -> dict :
319
319
return {key : value for key , value in self .__dict__ .items () if key != "optimizer" }
320
320
321
- def load_state_dict (self , state_dict : Dict ) -> None :
321
+ def load_state_dict (self , state_dict : dict ) -> None :
322
322
self .__dict__ .update (state_dict )
323
323
self ._init_is_better (mode = self .mode )
324
324
@@ -369,7 +369,7 @@ def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.
369
369
def store_embeddings (
370
370
data_points : Union [list [DT ], Dataset ],
371
371
storage_mode : EmbeddingStorageMode ,
372
- dynamic_embeddings : Optional [List [str ]] = None ,
372
+ dynamic_embeddings : Optional [list [str ]] = None ,
373
373
) -> None :
374
374
"""Stores embeddings of data points in memory or on disk.
375
375
@@ -401,7 +401,7 @@ def store_embeddings(
401
401
data_point .to ("cpu" , pin_memory = pin_memory )
402
402
403
403
404
- def identify_dynamic_embeddings (data_points : List [DT ]) -> Optional [List [str ]]:
404
+ def identify_dynamic_embeddings (data_points : list [DT ]) -> Optional [list [str ]]:
405
405
dynamic_embeddings = []
406
406
all_embeddings = []
407
407
for data_point in data_points :
@@ -444,7 +444,7 @@ class CharEntity(NamedTuple):
444
444
445
445
446
446
def create_labeled_sentence_from_tokens (
447
- tokens : Union [List [Token ]], token_entities : List [TokenEntity ], type_name : str = "ner"
447
+ tokens : Union [list [Token ]], token_entities : list [TokenEntity ], type_name : str = "ner"
448
448
) -> Sentence :
449
449
"""Creates a new Sentence object from a list of tokens or strings and applies entity labels.
450
450
@@ -457,20 +457,18 @@ def create_labeled_sentence_from_tokens(
457
457
Returns:
458
458
A labeled Sentence object
459
459
"""
460
- tokens = [Token ( token .text ) for token in tokens ] # create new tokens that do not already belong to a sentence
460
+ tokens = [token .text for token in tokens ] # create new tokens that do not already belong to a sentence
461
461
sentence = Sentence (tokens , use_tokenizer = True )
462
462
for entity in token_entities :
463
463
sentence [entity .start_token_idx : entity .end_token_idx ].add_label (type_name , entity .label , score = entity .score )
464
464
return sentence
465
465
466
466
467
- def create_sentence_chunks (
467
+ def create_labeled_sentence (
468
468
text : str ,
469
- entities : List [CharEntity ],
470
- token_limit : int = 512 ,
471
- use_context : bool = True ,
472
- overlap : int = 0 , # TODO: implement overlap
473
- ) -> List [Sentence ]:
469
+ entities : list [CharEntity ],
470
+ token_limit : float = inf ,
471
+ ) -> Sentence :
474
472
"""Chunks and labels a text from a list of entity annotations.
475
473
476
474
The function explicitly tokenizes the text and labels separately, ensuring entity labels are
@@ -481,48 +479,25 @@ def create_sentence_chunks(
481
479
entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
482
480
format (start_char_index, end_char_index, entity_class, entity_text).
483
481
token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking
484
- use_context: whether to add context to the sentence
485
- overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context
486
482
487
483
Returns:
488
484
A list of labeled Sentence objects representing the chunks of the original text
489
485
"""
490
- chunks = []
491
-
492
- tokens : List [Token ] = []
486
+ tokens : list [Token ] = []
493
487
current_index = 0
494
- token_entities : List [TokenEntity ] = []
495
- end_token_idx = 0
488
+ token_entities : list [TokenEntity ] = []
496
489
497
490
for entity in entities :
498
-
499
- if entity .start_char_idx > current_index : # add non-entity text
500
- non_entity_tokens = Sentence (text [current_index : entity .start_char_idx ]).tokens
501
- while end_token_idx + len (non_entity_tokens ) > token_limit :
502
- num_tokens = token_limit - len (tokens )
503
- tokens .extend (non_entity_tokens [:num_tokens ])
504
- non_entity_tokens = non_entity_tokens [num_tokens :]
505
- # skip any fully negative samples, they cause fine_tune to fail with
506
- # `torch.cat(): expected a non-empty list of Tensors`
507
- if len (token_entities ) > 0 :
508
- chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
509
- tokens , token_entities = [], []
510
- end_token_idx = 0
511
- tokens .extend (non_entity_tokens )
491
+ if current_index < entity .start_char_idx :
492
+ # add tokens before the entity
493
+ sentence = Sentence (text [current_index : entity .start_char_idx ])
494
+ tokens .extend (sentence )
512
495
513
496
# add new entity tokens
514
497
start_token_idx = len (tokens )
515
498
entity_sentence = Sentence (text [entity .start_char_idx : entity .end_char_idx ])
516
- if len (entity_sentence ) > token_limit :
517
- logger .warning (f"Entity length is greater than token limit! { len (entity_sentence )} > { token_limit } " )
518
499
end_token_idx = start_token_idx + len (entity_sentence )
519
500
520
- if end_token_idx >= token_limit : # create chunk from existing and add this entity to next chunk
521
- chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
522
-
523
- tokens , token_entities = [], []
524
- start_token_idx , end_token_idx = 0 , len (entity_sentence )
525
-
526
501
token_entity = TokenEntity (start_token_idx , end_token_idx , entity .label , entity .value , entity .score )
527
502
token_entities .append (token_entity )
528
503
tokens .extend (entity_sentence )
@@ -532,19 +507,6 @@ def create_sentence_chunks(
532
507
# add any remaining tokens to a new chunk
533
508
if current_index < len (text ):
534
509
remaining_sentence = Sentence (text [current_index :])
535
- if end_token_idx + len (remaining_sentence ) > token_limit :
536
- chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
537
- tokens , token_entities = [], []
538
510
tokens .extend (remaining_sentence )
539
511
540
- if tokens :
541
- chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
542
-
543
- for chunk in chunks :
544
- if len (chunk ) > token_limit :
545
- logger .warning (f"Chunk size is longer than token limit: { len (chunk )} > { token_limit } " )
546
-
547
- if use_context :
548
- Sentence .set_context_for_sentences (chunks )
549
-
550
- return chunks
512
+ return create_labeled_sentence_from_tokens (tokens , token_entities )
0 commit comments