@@ -42,6 +42,14 @@ def create_dir_if_not_exists(dir_path):
42
42
dir_path = os .path .join (parent_dir_path , dir_name )
43
43
create_dir_if_not_exists (dir_path )
44
44
45
+ def remove_dir (dir_path ):
46
+ '''
47
+ Remove an empty directory.
48
+ Raises `OSError` if directory is not empty.
49
+ '''
50
+ if os .path .isdir (dir_path ):
51
+ os .rmdir (dir_path )
52
+
45
53
def human_time_interval (time_seconds ):
46
54
'''
47
55
Converts a time interval in seconds to a human-friendly
@@ -96,22 +104,41 @@ def save_plot(config, fig, plot_name, model_name, config_info_dict, ext='png'):
96
104
97
105
def save_object (obj , primary_path , file_name = None , module = 'pickle' ):
98
106
'''
99
- This is a defensive way to write (pickle/dill).dump,
100
- allowing for very large files on all platforms.
107
+ This is a generic function to save any given
108
+ object using different `module`s, e.g. pickle,
109
+ dill, and yaml.
101
110
102
111
Note: See `get_file_path()` for details on how
103
112
how to set `primary_path` and `file_name`.
104
113
'''
105
114
file_path = get_file_path (primary_path , file_name )
106
115
logging .info (f'Saving "{ file_path } "...' )
116
+ if module == 'yaml' :
117
+ save_yaml (obj , file_path )
118
+ else :
119
+ save_pickle (obj , file_path , module )
120
+ logging .info ('Done.' )
121
+
122
+ def save_pickle (obj , file_path , module = 'pickle' ):
123
+ '''
124
+ This is a defensive way to write (pickle/dill).dump,
125
+ allowing for very large files on all platforms.
126
+ '''
107
127
pickle_module = get_pickle_module (module )
108
128
bytes_out = pickle_module .dumps (obj , protocol = pickle_module .HIGHEST_PROTOCOL )
109
129
n_bytes = sys .getsizeof (bytes_out )
110
130
MAX_BYTES = 2 ** 31 - 1
111
131
with open (file_path , 'wb' ) as f_out :
112
132
for idx in range (0 , n_bytes , MAX_BYTES ):
113
133
f_out .write (bytes_out [idx :idx + MAX_BYTES ])
114
- logging .info ('Done.' )
134
+
135
+ def save_yaml (obj , file_path ):
136
+ '''
137
+ Save a given yaml file.
138
+ '''
139
+ assert isinstance (obj , dict ), 'Only `dict` objects can be stored as YAML files.'
140
+ with open (file_path , 'w' ) as f_out :
141
+ yaml .dump (obj , f_out )
115
142
116
143
def load_object (primary_path , file_name = None , module = 'pickle' ):
117
144
'''
@@ -134,10 +161,11 @@ def load_object(primary_path, file_name=None, module='pickle'):
134
161
else :
135
162
raise FileNotFoundError (f'Could not find "{ file_path } ".' )
136
163
137
- def load_pickle (file_path , module ):
164
+ def load_pickle (file_path , module = 'pickle' ):
138
165
'''
139
166
This is a defensive way to write (pickle/dill).load,
140
167
allowing for very large files on all platforms.
168
+
141
169
This function is intended to be called inside
142
170
`load_object()`, and assumes that the file
143
171
already exists.
@@ -155,7 +183,9 @@ def load_pickle(file_path, module):
155
183
def load_yaml (file_path ):
156
184
'''
157
185
Load a given yaml file.
186
+
158
187
Return an empty dictionary if file is empty.
188
+
159
189
This function is intended to be called inside
160
190
`load_object()`, and assumes that the file
161
191
already exists.
@@ -225,11 +255,13 @@ def get_string_from_dict(config_info_dict=None):
225
255
def get_unique_config_name (primary_name , config_info_dict = None ):
226
256
'''
227
257
Returns a unique name for the current configuration.
258
+
228
259
The name will comprise the `primary_name` followed by a
229
260
hash value uniquely generated from the `config_info_dict`.
230
261
:param primary_name: Primary name of the object being stored.
231
262
:param config_info_dict: An optional dict provided containing
232
263
information about current config.
264
+
233
265
E.g.:
234
266
`subcategory_classifier-3d02e8616cbeab37bc1bb972ecf02882`
235
267
Each attribute in `config_info_dict` is in the "{name}_{value}"
@@ -309,6 +341,7 @@ def send_model_to_device(model, device, device_ids=None):
309
341
def send_batch_to_device (batch , device ):
310
342
'''
311
343
Send batch to given device.
344
+
312
345
Useful when the batch tuple is of variable lengths.
313
346
Specifically,
314
347
- In regular multiclass setting:
@@ -383,6 +416,54 @@ def convert_numpy_to_tensor(batch, device=None):
383
416
logging .warning (f'Type "{ type (batch )} " not understood. Returning variable as-is.' )
384
417
return batch
385
418
419
+ def compare_tensors_or_arrays (batch_a , batch_b ):
420
+ '''
421
+ Compare the contents of two batches.
422
+ Each batch may be of type `np.ndarray` or
423
+ `torch.Tensor` or a list/tuple of them.
424
+
425
+ Will return True if the types of the two
426
+ batches are different but contents are the same.
427
+ '''
428
+ if torch .is_tensor (batch_a ):
429
+ batch_a = convert_tensor_to_numpy (batch_a )
430
+ if torch .is_tensor (batch_b ):
431
+ batch_b = convert_tensor_to_numpy (batch_b )
432
+
433
+ if isinstance (batch_a , np .ndarray ) and isinstance (batch_b , np .ndarray ):
434
+ return np .all (batch_a == batch_b )
435
+ elif isinstance (batch_a , (list , tuple )) and isinstance (batch_b , (list , tuple )):
436
+ return all (compare_tensors_or_arrays (a , b ) for a , b in zip (batch_a , batch_b ))
437
+ else : # Structure/type of batch unknown
438
+ raise RuntimeError (f'Types of each batch "({ type (batch_a )} , { type (batch_b )} )" must '
439
+ f'be `np.ndarray`, `torch.Tensor` or a list/tuple of them.' )
440
+
441
+ def is_batch_on_gpu (batch ):
442
+ '''
443
+ Check if a `batch` is on a GPU.
444
+
445
+ Similar to `send_batch_to_device()`, can take a
446
+ torch.Tensor or a tuple/list of them as input.
447
+ '''
448
+ if torch .is_tensor (batch ):
449
+ return batch .is_cuda
450
+ elif isinstance (batch , (list , tuple )):
451
+ return all (is_batch_on_gpu (e ) for e in batch )
452
+ else : # Structure/type of batch unknown
453
+ raise RuntimeError (f'Type "{ type (batch )} " not understood.' )
454
+
455
+ def is_model_on_gpu (model ):
456
+ '''
457
+ Check if a `model` is on a GPU.
458
+ '''
459
+ return is_batch_on_gpu (next (model .parameters ()))
460
+
461
+ def is_model_parallelized (model ):
462
+ '''
463
+ Check if a `model` is parallelized on multiple GPUs.
464
+ '''
465
+ return is_model_on_gpu (model ) and isinstance (model , DataParallel )
466
+
386
467
def get_total_grad_norm (parameters , norm_type = 2 ):
387
468
'''
388
469
Get the total `norm_type` norm
@@ -404,6 +485,7 @@ def get_model_performance_trackers(config):
404
485
class ModelTracker (object ):
405
486
'''
406
487
Class for tracking model's progress.
488
+
407
489
Use this for keeping track of the loss and
408
490
any evaluation metrics (accuracy, f1, etc.)
409
491
at each epoch.
@@ -598,6 +680,7 @@ def _epochs_eval_metrics(self):
598
680
def epochs (self ):
599
681
'''
600
682
Returns the total list of epochs for which history is stored.
683
+
601
684
Assumes that history is stored for the same number of epochs
602
685
for both loss and eval_metrics.
603
686
'''
@@ -632,6 +715,7 @@ def _get_next_epoch(self, epoch, hist_type):
632
715
class SequencePooler (nn .Module ):
633
716
'''
634
717
Pool the sequence output for transformer-based models.
718
+
635
719
Class used instead of lambda functions to remain
636
720
compatible with `torch.save()` and `torch.load()`.
637
721
'''
@@ -648,41 +732,36 @@ def __init__(self, model_type='bert'):
648
732
'roberta'
649
733
'''
650
734
super ().__init__ ()
651
- self .model_type = model_type
652
- self ._set_pooler ()
735
+ self ._set_pooler (model_type )
653
736
654
737
def __repr__ (self ):
655
738
return f'{ self .__class__ .__name__ } (model_type={ self .model_type } )'
656
739
657
740
def forward (self , x ):
658
741
return self .pooler (x )
659
742
660
- def _set_pooler (self ):
743
+ def _set_pooler (self , model_type ):
661
744
'''
662
745
Set the appropriate pooler as per the `model_type`.
663
746
'''
664
- # Import here because it's an optional dependency
665
- from transformers .configuration_auto import CONFIG_MAPPING
666
-
667
- # Get a list of all supported model types ('bert', 'distilbert', etc.)
668
- self .supported_model_types = list (CONFIG_MAPPING .keys ())
669
-
670
- # Use default pooler if not supported
671
- if self .model_type not in self .supported_model_types :
672
- logging .warning (f'No supported sequence pooler was found for model of ' \
673
- f'type "{ self .model_type } ". Using the default one.' )
674
- self .model_type = self .DEFAULT_POOLER_TYPE
675
-
676
747
# Set the appropriate pooler as per `model_type`
677
748
self .POOLER_MAPPING = {
678
749
'bert' : self ._bert_pooler ,
679
750
'distilbert' : self ._distilbert_pooler ,
680
751
'albert' : self ._albert_pooler ,
681
752
'roberta' : self ._roberta_pooler ,
682
- 'electra' : self ._electra_pooler ,
683
- self .DEFAULT_POOLER_TYPE : self ._default_pooler
753
+ 'electra' : self ._electra_pooler
684
754
}
685
- self .pooler = self .POOLER_MAPPING [self .model_type ]
755
+
756
+ # Use default pooler if not supported
757
+ if model_type in self .POOLER_MAPPING .keys ():
758
+ self .model_type = model_type
759
+ self .pooler = self .POOLER_MAPPING [self .model_type ]
760
+ else :
761
+ logging .warning (f'No supported sequence pooler was found for model of ' \
762
+ f'type "{ self .model_type } ". Using the default one.' )
763
+ self .model_type = self .DEFAULT_POOLER_TYPE
764
+ self .pooler = self ._default_pooler
686
765
687
766
def _default_pooler (self , x ):
688
767
return x
@@ -718,6 +797,7 @@ def _electra_pooler(self, x):
718
797
class DataParallel (nn .DataParallel ):
719
798
'''
720
799
Custom DataParallel class inherited from nn.DataParallel.
800
+
721
801
Purpose is to allow direct access to model attributes and
722
802
methods when it is wrapped in a `module` attribute because
723
803
of nn.DataParallel.
@@ -729,6 +809,7 @@ def __getattr__(self, name):
729
809
'''
730
810
Return model's own attribute if available, otherwise
731
811
fallback to attribute of parent class.
812
+
732
813
Solves the issue that when nn.DataParallel is applied,
733
814
methods and attributes defined in BasePyTorchModel
734
815
like `predict()` can only be accessed with
0 commit comments