In this part, we discuss the details of this generative-recommender, e.g.,
- data preprocess
- training process
- model architecture
Notably, this work inherits many ideas from Revisiting Neural Retrieval on Accelerators, and these work are from same authors.
Here we takes Amazon books as example.
-
format of sample of Amazon books:
index,user_id,sequence_item_ids,sequence_ratings,sequence_timestamps 428566,428566,"627204,513752,573980,594511,574513","5.0,5.0,5.0,5.0,5.0","1362355200,1366761600,1370390400,1380067200,1393977600" 606756,606756,"384084,399081,72516,336635,155917,72377,140831,347262","5.0,5.0,5.0,5.0,2.0,5.0,1.0,4.0","1339459200,1344384000,1345507200,1349395200,1349395200,1354924800,1358208000,1401062400"
The number of expected_num_unique_items is 695762.
-
preprocess procedure
- filter users and items with presence < 5
- categorize user id and item id to be memory-efficient
- sort and group items of same user by timestamp
- join user data(optional) and unify format
-
configuration
gin is applied to manage configurations. The configs(.gin files) are stored in the folder 'configs/'.
-
args parser
absl is used to manage command line. app takes charge of parsing command line args and flags takes charge of definition and management of args.
-
parallel programming
Authors use TCP initialization with torch.multiprocessing(MP) for torch.nn.parallel.DistributedDataParallel(DDP), which does not need torch.distributed.launch. Instead, the information to initial process group is required for init_process_group(). Here MP is used to simplify this procedure, and MP.spawn() is to create multiple process for manually launch them.
-
dataset
In dataset, historical item seqs is sorted in reverse chronological order. _sample_ratio decides the number kept in item seqs and sampling_kept_mask indicates the remaining items.
Main structure lies in get_sequential_encoder(), where exists 2 module type, i.e., 'SASRec' & 'HSTU'. At a high level, this model aims at generating user (behavioral seqs) embs.
- SASRec
# run a transformer block
# user (behavioral seqs) embs as input and output
def _run_one_layer(
self,
i: int,
user_embeddings: torch.Tensor,
valid_mask: torch.Tensor,
) -> torch.Tensor:
Q = F.layer_norm(
user_embeddings, normalized_shape=(self._embedding_dim,), eps=1e-8,
)
mha_outputs, _ = self.attention_layers[i](
query=Q,
key=user_embeddings,
value=user_embeddings,
attn_mask=self._attn_mask,
)
user_embeddings = self.forward_layers[i](
F.layer_norm(
Q + mha_outputs,
normalized_shape=(self._embedding_dim,),
eps=1e-8,
)
)
user_embeddings *= valid_mask
return user_embeddings
# run transformer blocks and get corresponding final user embs
def generate_user_embeddings(
self,
past_lengths: torch.Tensor,
past_ids: torch.Tensor,
past_embeddings: torch.Tensor,
past_payloads: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""
Args:
past_ids: (B, N,) x int
Returns:
(B, N, D,) x float
"""
past_lengths, user_embeddings, valid_mask = self._input_features_preproc(
past_lengths=past_lengths,
past_ids=past_ids,
past_embeddings=past_embeddings,
past_payloads=past_payloads,
)
for i in range(len(self.attention_layers)):
if self._activation_checkpoint:
user_embeddings = torch.utils.checkpoint.checkpoint(
self._run_one_layer, i, user_embeddings, valid_mask,
use_reentrant=False,
)
else:
user_embeddings = self._run_one_layer(i, user_embeddings, valid_mask)
return self._output_postproc(user_embeddings)
def forward(
self,
past_lengths: torch.Tensor,
past_ids: torch.Tensor,
past_embeddings: torch.Tensor,
past_payloads: Dict[str, torch.Tensor],
batch_id: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
past_ids: [B, N] x int64 where the latest engaged ids come first. In
particular, [:, 0] should correspond to the last engaged values.
past_ratings: [B, N] x int64.
past_timestamps: [B, N] x int64.
Returns:
encoded_embeddings of [B, N, D].
"""
encoded_embeddings = self.generate_user_embeddings(
past_lengths,
past_ids,
past_embeddings,
past_payloads,
)
return encoded_embeddings
Notice: in FFN, conv1d is applied to replace linear layer to be efficient.
class StandardAttentionFF(torch.nn.Module):
def __init__(
self,
embedding_dim: int,
hidden_dim: int,
activation_fn: str,
dropout_rate: float,
) -> None:
super().__init__()
assert activation_fn == "relu" or activation_fn == "gelu", \
f"Invalid activation_fn {activation_fn}"
self._conv1d = torch.nn.Sequential(
torch.nn.Conv1d(
in_channels=embedding_dim,
out_channels=hidden_dim,
kernel_size=1,
),
torch.nn.GELU() if activation_fn == "gelu" else torch.nn.ReLU(),
torch.nn.Dropout(p=dropout_rate),
torch.nn.Conv1d(
in_channels=hidden_dim,
out_channels=embedding_dim,
kernel_size=1,
),
torch.nn.Dropout(p=dropout_rate),
)
def forward(self, inputs) -> torch.Tensor:
# Conv1D requires (B, D, N)
return self._conv1d(inputs.transpose(-1, -2)).transpose(-1, -2) + inputs
- HSTU
Specifically, this main structure is composed by several important components, e.g.,
- embedding module
- interaction module
- input features preprocessor module
- output postprocessor module
The below is the explanation of each module.
-
embedding module (implemented by CategoricalEmbeddingModule & LocalEmbeddingModule)
The embs are initialized by truncated_normal(), which implements a truncated normal distribution to avoid extreme values leading to grad vanishing. Specifically, the interval to truncates is [-2,2].
def truncated_normal(x: torch.Tensor, mean: float, std: float) -> torch.Tensor: with torch.no_grad(): size = x.shape tmp = x.new_empty(size + (4,)).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] x.data.copy_(tmp.gather(-1, ind).squeeze(-1)) x.data.mul_(std).add_(mean) return x
The only different of categorical and local embs is that categorical embs need remap from item to categorical since pd.Categorical is applied at data preprocess step.
# local def get_item_embeddings(self, item_ids: torch.Tensor) -> torch.Tensor: return self._item_emb(item_ids) # categorical def get_item_embeddings(self, item_ids: torch.Tensor) -> torch.Tensor: item_ids = self._item_id_to_category_id[(item_ids - 1).clamp(min=0)] + 1 return self._item_emb(item_ids)
-
interaction module (implemented by DotProductSimilarity & create_mol_interaction_module)
The core module is to define how input embs interact with target items, possibly computing scores or similarities.
Its specific implementation is GeneralizedInteractionModule class, which computes similarity by input embs & item embs & item sideinfo. All models are this class.
Concretely, it has 2 kinds of similarities, e.g.,
-
DotProduct: applied by DotProductSimilarity class, which adjusts the shape of potential difference of batchsize & number of items.
if item_embeddings.size(0) == 1: # [B, D] x ([1, X, D] -> [D, X]) => [B, X] return torch.mm(input_embeddings, item_embeddings.squeeze(0).t()), {} # [B, X] elif input_embeddings.size(0) != item_embeddings.size(0): # (B * r, D) x (B, X, D). B, X, D = item_embeddings.size() return torch.bmm(input_embeddings.view(B, -1, D), item_embeddings.permute(0, 2, 1)).view(-1, X) else: # assert input_embeddings.size(0) == item_embeddings.size(0) # [B, X, D] x ([B, D] -> [B, D, 1]) => [B, X, 1] -> [B, X] return torch.bmm(item_embeddings, input_embeddings.unsqueeze(2)).squeeze(2)
-
MoL: applied by create_mol_interaction_module function, which implements MoL(Mixture-of-Logits)learned similarity in Revisiting Neural Retrieval on Accelerators.
DEFAULT SETTING is DotProduct.
-
-
input features preprocessor module (implemented by LearnablePositionalEmbeddingInputFeaturesPreprocessor & LearnablePositionalEmbeddingRatedInputFeaturesPreprocessor & CombinedItemAndRatingInputFeaturesPreprocessor)
This module aims to perform preprocess operations such as add positional embs. RatedPreprocessor handles additional features like rating by concating the user embs and rating embs. CombinedItemAndRatingPreprocessor handles items and corresponding rating as individual inputs, which doubles the seq length.
DEFAULT SETTING is LearnablePositionalEmbeddingInputFeaturesPreprocessor.
-
output postprocessor module (implemented by L2NormEmbeddingPostprocessor & LayerNormEmbeddingPostprocessor)
This module aims at apply norm to the output of transformer block.
DEFAULT SETTING is L2NormEmbeddingPostprocessor.
-
BCELoss
-
SampledSoftmaxLoss
DEFAULT SETTING is SampledSoftmaxLoss.
loss_module: str = "SampledSoftmaxLoss",
-
sampling_strategy == "in-batch"
InBatchNegativesSampler
-
sampling_strategy == "local"
LocalNegativesSampler
DEFAULT SETTING is in-batch.