@@ -34,8 +34,12 @@ def get_train_val_loader(
34
34
35
35
train , valid = iqon_outfits .get_trainval_data (label_dir_name )
36
36
feature_dir = iqon_outfits .feature_dir
37
- train_dataset = MultisetSplitDataset (train , feature_dir , n_comb = n_comb , n_drops = None )
38
- valid_dataset = MultisetSplitDataset (valid , feature_dir , n_comb = n_comb , n_drops = None )
37
+ train_dataset = MultisetSplitDataset (
38
+ train , feature_dir , n_comb = n_comb , n_drops = None
39
+ )
40
+ valid_dataset = MultisetSplitDataset (
41
+ valid , feature_dir , n_comb = n_comb , n_drops = None
42
+ )
39
43
return (
40
44
get_loader (train_dataset , batch_size , num_workers = num_workers , is_train = True ),
41
45
get_loader (valid_dataset , batch_size , num_workers = num_workers , is_train = False ),
@@ -71,10 +75,10 @@ def main(args):
71
75
72
76
# dataset
73
77
train_loader , valid_loader = get_train_val_loader (
74
- train_year = args .train_year ,
75
- valid_year = args .valid_year ,
76
- split = args .split ,
77
- batch_size = args .batchsize ,
78
+ train_year = args .train_year ,
79
+ valid_year = args .valid_year ,
80
+ split = args .split ,
81
+ batch_size = args .batchsize ,
78
82
n_comb = args .n_comb ,
79
83
)
80
84
0 commit comments