@@ -47,7 +47,7 @@ export const defaultDataLoadingCode: defaultDataLoadingCode_t = {
4747 else:
4848 return {'image': image, 'label': label}
4949
50- def get_image_classification_loader (root, batch_size=32, split='train', shuffle=True, transforms=None, **kwargs):
50+ def custom_loader (root, batch_size=32, split='train', shuffle=True, transforms=None, **kwargs):
5151 dataset = ImageClassificationDataset(root=root, split=split, transforms=transforms, **kwargs)
5252 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
5353 'generation' : `class ImageGenerationDataset(Dataset):
@@ -94,7 +94,7 @@ def get_image_classification_loader(root, batch_size=32, split='train', shuffle=
9494 else:
9595 return item
9696
97- def get_image_generation_loader (root, batch_size=16, split='train', shuffle=True, transforms=None, **kwargs):
97+ def custom_loader (root, batch_size=16, split='train', shuffle=True, transforms=None, **kwargs):
9898 dataset = ImageGenerationDataset(root=root, split=split, transforms=transforms, **kwargs)
9999 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
100100 'image-segmentation' : `class SegmentationDataset(Dataset):
@@ -132,7 +132,7 @@ def get_image_generation_loader(root, batch_size=16, split='train', shuffle=True
132132 else:
133133 return {'image': image, 'mask': mask}
134134
135- def get_segmentation_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
135+ def custom_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
136136 dataset = SegmentationDataset(root=root, split=split, transforms=transforms, **kwargs)
137137 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
138138 'object-detection' : `class ObjectDetectionDataset(Dataset):
@@ -178,7 +178,7 @@ def get_segmentation_loader(root, batch_size=8, split='train', shuffle=True, tra
178178 else:
179179 return {'image': image, **target}
180180
181- def get_object_detection_loader (root, batch_size=4, split='train', shuffle=True, transforms=None, **kwargs):
181+ def custom_loader (root, batch_size=4, split='train', shuffle=True, transforms=None, **kwargs):
182182 dataset = ObjectDetectionDataset(root=root, split=split, transforms=transforms, **kwargs)
183183 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
184184 collate_fn=lambda x: tuple(zip(*x)))` ,
@@ -218,7 +218,7 @@ def get_object_detection_loader(root, batch_size=4, split='train', shuffle=True,
218218 else:
219219 return item
220220
221- def get_text_classification_loader (root, batch_size=32, split='train', shuffle=True, transforms=None, **kwargs):
221+ def custom_loader (root, batch_size=32, split='train', shuffle=True, transforms=None, **kwargs):
222222 dataset = TextClassificationDataset(root, split, transforms=transforms, **kwargs)
223223 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
224224 'generation' : `class TextGenerationDataset(Dataset):
@@ -254,7 +254,7 @@ def get_text_classification_loader(root, batch_size=32, split='train', shuffle=T
254254 }
255255 return format_handlers.get(self.return_format, lambda x:x)(item)
256256
257- def get_text_generation_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
257+ def custom_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
258258 dataset = TextGenerationDataset(root, split, transforms=transforms, **kwargs)
259259 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
260260 'summarization' : `class TextPairDataset(Dataset):
@@ -285,7 +285,7 @@ def get_text_generation_loader(root, batch_size=8, split='train', shuffle=True,
285285 }
286286 return format_handlers.get(self.return_format, lambda x:x)(item)
287287
288- def get_text_pair_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
288+ def custom_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
289289 dataset = TextPairDataset(root, split, transforms=transforms, **kwargs)
290290 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
291291 'translation' : `class TextPairDataset(Dataset):
@@ -316,7 +316,7 @@ def get_text_pair_loader(root, batch_size=8, split='train', shuffle=True, transf
316316 }
317317 return format_handlers.get(self.return_format, lambda x:x)(item)
318318
319- def get_text_pair_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
319+ def custom_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
320320 dataset = TextPairDataset(root, split, transforms=transforms, **kwargs)
321321 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)`
322322 } ,
@@ -379,7 +379,7 @@ def get_text_pair_loader(root, batch_size=8, split='train', shuffle=True, transf
379379 else:
380380 return {'audio': waveform, 'label': label}
381381
382- def get_audio_classification_loader (root, batch_size=16, split='train', shuffle=True, transforms=None, **kwargs):
382+ def custom_loader (root, batch_size=16, split='train', shuffle=True, transforms=None, **kwargs):
383383 dataset = AudioClassificationDataset(root=root, split=split, transforms=transforms, **kwargs)
384384 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
385385 'conversion' : `class AudioConversionDataset(Dataset):
@@ -427,7 +427,7 @@ def get_audio_classification_loader(root, batch_size=16, split='train', shuffle=
427427 else:
428428 return waveform[:, :self.clip_samples]
429429
430- def get_audio_conversion_loader (root, batch_size=4, split='train', shuffle=True, transforms=None, **kwargs):
430+ def custom_loader (root, batch_size=4, split='train', shuffle=True, transforms=None, **kwargs):
431431 dataset = AudioConversionDataset(root=root, split=split, transforms=transforms, **kwargs)
432432 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
433433 'generation' : `class AudioGenerationDataset(Dataset):
@@ -513,7 +513,7 @@ def get_audio_conversion_loader(root, batch_size=4, split='train', shuffle=True,
513513 else:
514514 return item
515515
516- def get_audio_generation_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
516+ def custom_loader (root, batch_size=8, split='train', shuffle=True, transforms=None, **kwargs):
517517 dataset = AudioGenerationDataset(root=root, split=split, transforms=transforms, **kwargs)
518518 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
519519 'recognition' : `class AudioClassificationDataset(Dataset):
@@ -574,7 +574,7 @@ def get_audio_generation_loader(root, batch_size=8, split='train', shuffle=True,
574574 else:
575575 return {'audio': waveform, 'label': label}
576576
577- def get_audio_classification_loader (root, batch_size=16, split='train', shuffle=True, transforms=None, **kwargs):
577+ def custom_loader (root, batch_size=16, split='train', shuffle=True, transforms=None, **kwargs):
578578 dataset = AudioClassificationDataset(root=root, split=split, transforms=transforms, **kwargs)
579579 return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)` ,
580580 }
0 commit comments