Skip to content

Commit 39101d9

Browse files
authored
Merge pull request #202 from st-tech/superset_matching
add n_comb options
2 parents 4de9f4a + 21daad9 commit 39101d9

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

benchmarks/set_matching_pytorch/test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_test_loader(
3131

3232
iqon_outfits = IQONOutfits(root=root, split=split)
3333

34-
test_examples = iqon_outfits.get_fitb_data(label_dir_name)
34+
test_examples = iqon_outfits.get_fitb_data(label_dir_name, n_comb=args.n_comb)
3535
feature_dir = iqon_outfits.feature_dir
3636
dataset = FINBsDataset(
3737
test_examples,
@@ -113,6 +113,7 @@ def main(args):
113113
parser.add_argument("--valid_year", type=int)
114114
parser.add_argument("--split", type=int, choices=[0, 1, 2])
115115
parser.add_argument("--model_dir", "-d", type=str)
116+
parser.add_argument("--n_comb", type=int, default=1)
116117
args = parser.parse_args()
117118

118119
main(args)

benchmarks/set_matching_pytorch/train_sm.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def get_train_val_loader(
2424
valid_year: Union[str, int],
2525
split: int,
2626
batch_size: int,
27+
n_comb: int,
2728
root: str = C.ROOT,
2829
num_workers: Optional[int] = None,
2930
) -> Tuple[Any, Any]:
@@ -33,8 +34,12 @@ def get_train_val_loader(
3334

3435
train, valid = iqon_outfits.get_trainval_data(label_dir_name)
3536
feature_dir = iqon_outfits.feature_dir
36-
train_dataset = MultisetSplitDataset(train, feature_dir, n_sets=1, n_drops=None)
37-
valid_dataset = MultisetSplitDataset(valid, feature_dir, n_sets=1, n_drops=None)
37+
train_dataset = MultisetSplitDataset(
38+
train, feature_dir, n_comb=n_comb, n_drops=None
39+
)
40+
valid_dataset = MultisetSplitDataset(
41+
valid, feature_dir, n_comb=n_comb, n_drops=None
42+
)
3843
return (
3944
get_loader(train_dataset, batch_size, num_workers=num_workers, is_train=True),
4045
get_loader(valid_dataset, batch_size, num_workers=num_workers, is_train=False),
@@ -70,7 +75,11 @@ def main(args):
7075

7176
# dataset
7277
train_loader, valid_loader = get_train_val_loader(
73-
args.train_year, args.valid_year, args.split, args.batchsize
78+
train_year=args.train_year,
79+
valid_year=args.valid_year,
80+
split=args.split,
81+
batch_size=args.batchsize,
82+
n_comb=args.n_comb,
7483
)
7584

7685
# logger
@@ -222,6 +231,7 @@ def eval_process(engine, batch):
222231
parser.add_argument("--valid_year", type=int)
223232
parser.add_argument("--split", type=int, choices=[0, 1, 2])
224233
parser.add_argument("--weight_path", "-w", type=str, default=None)
234+
parser.add_argument("--n_comb", type=int, default=1)
225235

226236
args = parser.parse_args()
227237

shift15m/datasets/outfitfeature.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ def __init__(
3636
self,
3737
sets: List,
3838
root: pathlib.Path,
39-
n_sets: int,
39+
n_comb: int,
4040
n_drops: Optional[int] = None,
4141
max_elementnum_per_set: Optional[int] = 8,
4242
):
4343
self.sets = sets
4444
self.feat_dir = root
45-
self.n_sets = n_sets
45+
self.n_comb = n_comb
4646
self.n_drops = n_drops
4747
if n_drops is None:
4848
n_drops = max_elementnum_per_set // 2
49-
setX_size = (max_elementnum_per_set - n_drops) * n_sets
50-
setY_size = n_drops * n_sets
49+
setX_size = (max_elementnum_per_set - n_drops) * n_comb
50+
setY_size = n_drops * n_comb
5151
self.transform_x = FeatureListTransform(
5252
max_set_size=setX_size, apply_shuffle=True, apply_padding=True
5353
)
@@ -59,9 +59,9 @@ def __len__(self):
5959
return len(self.sets)
6060

6161
def __getitem__(self, i):
62-
if self.n_sets > 1: # you can conduct "superset matching" by using n_sets > 1
62+
if self.n_comb > 1: # you can conduct "superset matching" by using n_comb > 1
6363
indices = np.delete(np.arange(len(self.sets)), i)
64-
indices = np.random.choice(indices, self.n_sets - 1, replace=False)
64+
indices = np.random.choice(indices, self.n_comb - 1, replace=False)
6565
indices = [i] + list(indices)
6666
else:
6767
indices = [i]
@@ -306,7 +306,7 @@ def _make_test_examples(
306306
n_cands: int = 8,
307307
seed: int = 0,
308308
):
309-
print("Make test dataset.")
309+
print("Making test dataset.")
310310
np.random.seed(seed)
311311

312312
test_sets = json.load(open(path / "test.json"))

0 commit comments

Comments
 (0)