Skip to content

Commit 73d3e62

Browse files
committed
Added more util functions + typo fixes + minor improvements.
1 parent 0341ada commit 73d3e62

File tree

4 files changed

+108
-27
lines changed

4 files changed

+108
-27
lines changed

pytorch_common/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def create_model(model_name, config):
1414
raise RuntimeError(f'Unknown model name {model_name}.')
1515
return model
1616

17-
def create_transformer_model(model_name, config):
17+
def create_transformer_model(model_name, config=None):
1818
'''
1919
Create a transformer model (e.g. BERT) either using the
2020
default pretrained model or using the provided config.
@@ -27,7 +27,7 @@ def create_transformer_model(model_name, config):
2727

2828
model_class, config_class = AutoModel, AutoConfig
2929

30-
if hasattr(config, 'output_dir'): # Load trained model from config
30+
if config is not None and hasattr(config, 'output_dir'): # Load trained model from config
3131
kwargs = {
3232
'pretrained_model_name_or_path': os.path.join(config.output_dir,
3333
config.model_name_or_path),

pytorch_common/models_dl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def __init__(self, config):
190190
self.in_dim = config.in_dim
191191
self.num_classes = config.num_classes
192192

193-
self.fc = nn.Linear(config.in_dim, self.num_classes)
193+
self.fc = nn.Linear(self.in_dim, self.num_classes)
194194

195195
self.initialize_model()
196196

pytorch_common/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def decouple_batch_train(batch):
334334
in the batch, and return them.
335335
Used commonly during training.
336336
337-
This is required becaue often other things
337+
This is required because often other things
338338
are also passed in the batch for debugging.
339339
'''
340340
# Assume first two elements of batch
@@ -349,7 +349,7 @@ def decouple_batch_test(batch):
349349
element in the batch.
350350
Used commonly to make predictions.
351351
352-
This is required becaue often other things
352+
This is required because often other things
353353
are also passed in the batch for debugging.
354354
'''
355355
# Only inputs are needed for making predictions

pytorch_common/utils.py

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ def create_dir_if_not_exists(dir_path):
4242
dir_path = os.path.join(parent_dir_path, dir_name)
4343
create_dir_if_not_exists(dir_path)
4444

45+
def remove_dir(dir_path):
46+
'''
47+
Remove an empty directory.
48+
Raises `OSError` if directory is not empty.
49+
'''
50+
if os.path.isdir(dir_path):
51+
os.rmdir(dir_path)
52+
4553
def human_time_interval(time_seconds):
4654
'''
4755
Converts a time interval in seconds to a human-friendly
@@ -96,22 +104,41 @@ def save_plot(config, fig, plot_name, model_name, config_info_dict, ext='png'):
96104

97105
def save_object(obj, primary_path, file_name=None, module='pickle'):
98106
'''
99-
This is a defensive way to write (pickle/dill).dump,
100-
allowing for very large files on all platforms.
107+
This is a generic function to save any given
108+
object using different `module`s, e.g. pickle,
109+
dill, and yaml.
101110
102111
Note: See `get_file_path()` for details on how
103112
how to set `primary_path` and `file_name`.
104113
'''
105114
file_path = get_file_path(primary_path, file_name)
106115
logging.info(f'Saving "{file_path}"...')
116+
if module == 'yaml':
117+
save_yaml(obj, file_path)
118+
else:
119+
save_pickle(obj, file_path, module)
120+
logging.info('Done.')
121+
122+
def save_pickle(obj, file_path, module='pickle'):
123+
'''
124+
This is a defensive way to write (pickle/dill).dump,
125+
allowing for very large files on all platforms.
126+
'''
107127
pickle_module = get_pickle_module(module)
108128
bytes_out = pickle_module.dumps(obj, protocol=pickle_module.HIGHEST_PROTOCOL)
109129
n_bytes = sys.getsizeof(bytes_out)
110130
MAX_BYTES = 2**31 - 1
111131
with open(file_path, 'wb') as f_out:
112132
for idx in range(0, n_bytes, MAX_BYTES):
113133
f_out.write(bytes_out[idx:idx+MAX_BYTES])
114-
logging.info('Done.')
134+
135+
def save_yaml(obj, file_path):
136+
'''
137+
Save a given yaml file.
138+
'''
139+
assert isinstance(obj, dict), 'Only `dict` objects can be stored as YAML files.'
140+
with open(file_path, 'w') as f_out:
141+
yaml.dump(obj, f_out)
115142

116143
def load_object(primary_path, file_name=None, module='pickle'):
117144
'''
@@ -134,10 +161,11 @@ def load_object(primary_path, file_name=None, module='pickle'):
134161
else:
135162
raise FileNotFoundError(f'Could not find "{file_path}".')
136163

137-
def load_pickle(file_path, module):
164+
def load_pickle(file_path, module='pickle'):
138165
'''
139166
This is a defensive way to write (pickle/dill).load,
140167
allowing for very large files on all platforms.
168+
141169
This function is intended to be called inside
142170
`load_object()`, and assumes that the file
143171
already exists.
@@ -155,7 +183,9 @@ def load_pickle(file_path, module):
155183
def load_yaml(file_path):
156184
'''
157185
Load a given yaml file.
186+
158187
Return an empty dictionary if file is empty.
188+
159189
This function is intended to be called inside
160190
`load_object()`, and assumes that the file
161191
already exists.
@@ -225,11 +255,13 @@ def get_string_from_dict(config_info_dict=None):
225255
def get_unique_config_name(primary_name, config_info_dict=None):
226256
'''
227257
Returns a unique name for the current configuration.
258+
228259
The name will comprise the `primary_name` followed by a
229260
hash value uniquely generated from the `config_info_dict`.
230261
:param primary_name: Primary name of the object being stored.
231262
:param config_info_dict: An optional dict provided containing
232263
information about current config.
264+
233265
E.g.:
234266
`subcategory_classifier-3d02e8616cbeab37bc1bb972ecf02882`
235267
Each attribute in `config_info_dict` is in the "{name}_{value}"
@@ -309,6 +341,7 @@ def send_model_to_device(model, device, device_ids=None):
309341
def send_batch_to_device(batch, device):
310342
'''
311343
Send batch to given device.
344+
312345
Useful when the batch tuple is of variable lengths.
313346
Specifically,
314347
- In regular multiclass setting:
@@ -383,6 +416,54 @@ def convert_numpy_to_tensor(batch, device=None):
383416
logging.warning(f'Type "{type(batch)}" not understood. Returning variable as-is.')
384417
return batch
385418

419+
def compare_tensors_or_arrays(batch_a, batch_b):
420+
'''
421+
Compare the contents of two batches.
422+
Each batch may be of type `np.ndarray` or
423+
`torch.Tensor` or a list/tuple of them.
424+
425+
Will return True if the types of the two
426+
batches are different but contents are the same.
427+
'''
428+
if torch.is_tensor(batch_a):
429+
batch_a = convert_tensor_to_numpy(batch_a)
430+
if torch.is_tensor(batch_b):
431+
batch_b = convert_tensor_to_numpy(batch_b)
432+
433+
if isinstance(batch_a, np.ndarray) and isinstance(batch_b, np.ndarray):
434+
return np.all(batch_a == batch_b)
435+
elif isinstance(batch_a, (list, tuple)) and isinstance(batch_b, (list, tuple)):
436+
return all(compare_tensors_or_arrays(a, b) for a, b in zip(batch_a, batch_b))
437+
else: # Structure/type of batch unknown
438+
raise RuntimeError(f'Types of each batch "({type(batch_a)}, {type(batch_b)})" must '
439+
f'be `np.ndarray`, `torch.Tensor` or a list/tuple of them.')
440+
441+
def is_batch_on_gpu(batch):
442+
'''
443+
Check if a `batch` is on a GPU.
444+
445+
Similar to `send_batch_to_device()`, can take a
446+
torch.Tensor or a tuple/list of them as input.
447+
'''
448+
if torch.is_tensor(batch):
449+
return batch.is_cuda
450+
elif isinstance(batch, (list, tuple)):
451+
return all(is_batch_on_gpu(e) for e in batch)
452+
else: # Structure/type of batch unknown
453+
raise RuntimeError(f'Type "{type(batch)}" not understood.')
454+
455+
def is_model_on_gpu(model):
456+
'''
457+
Check if a `model` is on a GPU.
458+
'''
459+
return is_batch_on_gpu(next(model.parameters()))
460+
461+
def is_model_parallelized(model):
462+
'''
463+
Check if a `model` is parallelized on multiple GPUs.
464+
'''
465+
return is_model_on_gpu(model) and isinstance(model, DataParallel)
466+
386467
def get_total_grad_norm(parameters, norm_type=2):
387468
'''
388469
Get the total `norm_type` norm
@@ -404,6 +485,7 @@ def get_model_performance_trackers(config):
404485
class ModelTracker(object):
405486
'''
406487
Class for tracking model's progress.
488+
407489
Use this for keeping track of the loss and
408490
any evaluation metrics (accuracy, f1, etc.)
409491
at each epoch.
@@ -598,6 +680,7 @@ def _epochs_eval_metrics(self):
598680
def epochs(self):
599681
'''
600682
Returns the total list of epochs for which history is stored.
683+
601684
Assumes that history is stored for the same number of epochs
602685
for both loss and eval_metrics.
603686
'''
@@ -632,6 +715,7 @@ def _get_next_epoch(self, epoch, hist_type):
632715
class SequencePooler(nn.Module):
633716
'''
634717
Pool the sequence output for transformer-based models.
718+
635719
Class used instead of lambda functions to remain
636720
compatible with `torch.save()` and `torch.load()`.
637721
'''
@@ -648,41 +732,36 @@ def __init__(self, model_type='bert'):
648732
'roberta'
649733
'''
650734
super().__init__()
651-
self.model_type = model_type
652-
self._set_pooler()
735+
self._set_pooler(model_type)
653736

654737
def __repr__(self):
655738
return f'{self.__class__.__name__}(model_type={self.model_type})'
656739

657740
def forward(self, x):
658741
return self.pooler(x)
659742

660-
def _set_pooler(self):
743+
def _set_pooler(self, model_type):
661744
'''
662745
Set the appropriate pooler as per the `model_type`.
663746
'''
664-
# Import here because it's an optional dependency
665-
from transformers.configuration_auto import CONFIG_MAPPING
666-
667-
# Get a list of all supported model types ('bert', 'distilbert', etc.)
668-
self.supported_model_types = list(CONFIG_MAPPING.keys())
669-
670-
# Use default pooler if not supported
671-
if self.model_type not in self.supported_model_types:
672-
logging.warning(f'No supported sequence pooler was found for model of '\
673-
f'type "{self.model_type}". Using the default one.')
674-
self.model_type = self.DEFAULT_POOLER_TYPE
675-
676747
# Set the appropriate pooler as per `model_type`
677748
self.POOLER_MAPPING = {
678749
'bert': self._bert_pooler,
679750
'distilbert': self._distilbert_pooler,
680751
'albert': self._albert_pooler,
681752
'roberta': self._roberta_pooler,
682-
'electra': self._electra_pooler,
683-
self.DEFAULT_POOLER_TYPE: self._default_pooler
753+
'electra': self._electra_pooler
684754
}
685-
self.pooler = self.POOLER_MAPPING[self.model_type]
755+
756+
# Use default pooler if not supported
757+
if model_type in self.POOLER_MAPPING.keys():
758+
self.model_type = model_type
759+
self.pooler = self.POOLER_MAPPING[self.model_type]
760+
else:
761+
logging.warning(f'No supported sequence pooler was found for model of '\
762+
f'type "{self.model_type}". Using the default one.')
763+
self.model_type = self.DEFAULT_POOLER_TYPE
764+
self.pooler = self._default_pooler
686765

687766
def _default_pooler(self, x):
688767
return x
@@ -718,6 +797,7 @@ def _electra_pooler(self, x):
718797
class DataParallel(nn.DataParallel):
719798
'''
720799
Custom DataParallel class inherited from nn.DataParallel.
800+
721801
Purpose is to allow direct access to model attributes and
722802
methods when it is wrapped in a `module` attribute because
723803
of nn.DataParallel.
@@ -729,6 +809,7 @@ def __getattr__(self, name):
729809
'''
730810
Return model's own attribute if available, otherwise
731811
fallback to attribute of parent class.
812+
732813
Solves the issue that when nn.DataParallel is applied,
733814
methods and attributes defined in BasePyTorchModel
734815
like `predict()` can only be accessed with

0 commit comments

Comments
 (0)