Skip to content

Commit 21daad9

Browse files
committed
black
1 parent e132cf4 commit 21daad9

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

benchmarks/set_matching_pytorch/train_sm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ def get_train_val_loader(
3434

3535
train, valid = iqon_outfits.get_trainval_data(label_dir_name)
3636
feature_dir = iqon_outfits.feature_dir
37-
train_dataset = MultisetSplitDataset(train, feature_dir, n_comb=n_comb, n_drops=None)
38-
valid_dataset = MultisetSplitDataset(valid, feature_dir, n_comb=n_comb, n_drops=None)
37+
train_dataset = MultisetSplitDataset(
38+
train, feature_dir, n_comb=n_comb, n_drops=None
39+
)
40+
valid_dataset = MultisetSplitDataset(
41+
valid, feature_dir, n_comb=n_comb, n_drops=None
42+
)
3943
return (
4044
get_loader(train_dataset, batch_size, num_workers=num_workers, is_train=True),
4145
get_loader(valid_dataset, batch_size, num_workers=num_workers, is_train=False),
@@ -71,10 +75,10 @@ def main(args):
7175

7276
# dataset
7377
train_loader, valid_loader = get_train_val_loader(
74-
train_year=args.train_year,
75-
valid_year=args.valid_year,
76-
split=args.split,
77-
batch_size=args.batchsize,
78+
train_year=args.train_year,
79+
valid_year=args.valid_year,
80+
split=args.split,
81+
batch_size=args.batchsize,
7882
n_comb=args.n_comb,
7983
)
8084

0 commit comments

Comments
 (0)