Skip to content
Open
6 changes: 3 additions & 3 deletions configs/inference/bert4rec_inference_config.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
{
"pred_prefix": "logits",
"label_prefix": "labels",
"experiment_name": "bert4rec_beauty_grid_0-5_0-1__",
"experiment_name": "bert4rec_all_rank_beauty",
"dataset": {
"type": "sequence",
"path_to_data_dir": "../data",
"path_to_data_dir": "../data/sasrec_in_batch",
"name": "Beauty",
"max_sequence_length": 50,
"samplers": {
Expand Down Expand Up @@ -34,7 +34,7 @@
}
},
"model": {
"type": "bert4rec",
"type": "bert4rec_all_rank",
"sequence_prefix": "item",
"labels_prefix": "labels",
"candidate_prefix": "candidates",
Expand Down
4 changes: 2 additions & 2 deletions configs/train/bert4rec_train_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"best_metric": "validation/ndcg@20",
"dataset": {
"type": "sequence",
"path_to_data_dir": "../data",
"path_to_data_dir": "../data/sasrec_in_batch",
"name": "Beauty",
"max_sequence_length": 50,
"samplers": {
Expand Down Expand Up @@ -41,7 +41,7 @@
"num_heads": 2,
"num_layers": 2,
"dim_feedforward": 256,
"dropout": 0.2,
"dropout": 0.3,
"activation": "gelu",
"layer_norm_eps": 1e-9,
"initializer_range": 0.02
Expand Down
169 changes: 169 additions & 0 deletions configs/train/bert4rec_train_config_all_rank.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
{
"experiment_name": "bert4rec_all_rank_beauty_bce_0-3_0-2",
"best_metric": "validation/ndcg@20",
"dataset": {
"type": "sequence",
"path_to_data_dir": "../data/sasrec_in_batch",
"name": "Beauty",
"max_sequence_length": 50,
"samplers": {
"mask_prob": 0.3,
"type": "masked_item_prediction",
"negative_sampler_type": "random"
}
},
"dataloader": {
"train": {
"type": "torch",
"batch_size": 256,
"batch_processor": {
"type": "basic"
},
"drop_last": true,
"shuffle": true
},
"validation": {
"type": "torch",
"batch_size": 256,
"batch_processor": {
"type": "basic"
},
"drop_last": false,
"shuffle": false
}
},
"model": {
"type": "bert4rec_all_rank",
"user_prefix": "user",
"sequence_prefix": "item",
"labels_prefix": "labels",
"candidate_prefix": "candidates",
"embedding_dim": 64,
"num_heads": 2,
"num_layers": 2,
"dim_feedforward": 256,
"dropout": 0.2,
"activation": "gelu",
"layer_norm_eps": 1e-9,
"initializer_range": 0.02
},
"optimizer": {
"type": "basic",
"optimizer": {
"type": "adam",
"lr": 0.001
},
"clip_grad_threshold": 5.0
},
"loss": {
"type": "composite",
"losses": [
{
"type": "bert4rec_sasrec",
"predictions_prefix": "logits",
"labels_prefix": "labels",
"output_prefix": "downstream_loss",
"weight": 1.0
}
],
"output_prefix": "loss"
},
"callback": {
"type": "composite",
"callbacks": [
{
"type": "metric",
"on_step": 1,
"loss_prefix": "loss"
},
{
"type": "validation",
"on_step": 64,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
},
{
"type": "eval",
"on_step": 256,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
}
]
}
}
169 changes: 169 additions & 0 deletions configs/train/bert4rec_train_config_in_batch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
{
"experiment_name": "bert4rec_in_batch_beauty_bce_0-3_0-5",
"best_metric": "validation/ndcg@20",
"dataset": {
"type": "sequence",
"path_to_data_dir": "../data/sasrec_in_batch",
"name": "Beauty",
"max_sequence_length": 50,
"samplers": {
"mask_prob": 0.3,
"type": "masked_item_prediction",
"negative_sampler_type": "random"
}
},
"dataloader": {
"train": {
"type": "torch",
"batch_size": 256,
"batch_processor": {
"type": "basic"
},
"drop_last": true,
"shuffle": true
},
"validation": {
"type": "torch",
"batch_size": 256,
"batch_processor": {
"type": "basic"
},
"drop_last": false,
"shuffle": false
}
},
"model": {
"type": "bert4rec_in_batch",
"user_prefix": "user",
"sequence_prefix": "item",
"labels_prefix": "labels",
"candidate_prefix": "candidates",
"embedding_dim": 64,
"num_heads": 2,
"num_layers": 2,
"dim_feedforward": 256,
"dropout": 0.5,
"activation": "gelu",
"layer_norm_eps": 1e-9,
"initializer_range": 0.02
},
"optimizer": {
"type": "basic",
"optimizer": {
"type": "adam",
"lr": 0.001
},
"clip_grad_threshold": 5.0
},
"loss": {
"type": "composite",
"losses": [
{
"type": "bert4rec_sasrec",
"predictions_prefix": "logits",
"labels_prefix": "labels",
"output_prefix": "downstream_loss",
"weight": 1.0
}
],
"output_prefix": "loss"
},
"callback": {
"type": "composite",
"callbacks": [
{
"type": "metric",
"on_step": 1,
"loss_prefix": "loss"
},
{
"type": "validation",
"on_step": 64,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
},
{
"type": "eval",
"on_step": 256,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
}
]
}
}
Loading