Skip to content

Commit d01a4b3

Browse files
committed
Fix mask_embeddings_by_frequency edge case
1 parent dc864d4 commit d01a4b3

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

corelib/dynamicemb/src/dynamic_emb_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ if (tables[0]->evict_strategy() == EvictStrategy::kLfu && frequency_threshold >
567567
#ifdef DEBUG
568568

569569
if (frequency_threshold > 0 && mask_dims > 0) {
570-
printf("Masking enabled\n");
570+
// printf("Masking enabled\n");
571571
at::Tensor h_unique_embs = unique_embs.cpu();
572572
at::Tensor h_unique_embeddings_for_scatter = unique_embeddings_for_scatter.cpu();
573573
at::Tensor h_unique_output_scores = unique_output_scores.cpu();

corelib/dynamicemb/src/lookup_forward.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ void mask_embeddings_by_frequency(void *embeddings_ptr, void *scores_ptr,
275275
int frequency_threshold, int mask_dims,
276276
DataType emb_type, DataType score_type,
277277
int device_num_sms, cudaStream_t stream) {
278-
if (frequency_threshold <= 0 || mask_dims <= 0) {
279-
return; // No masking needed
278+
if (frequency_threshold <= 0 || mask_dims <= 0 || num_embeddings <= 0) {
279+
return; // No masking needed or no embeddings to process
280280
}
281281

282282
int num_warps_needed = num_embeddings; // each warp processes one embedding

0 commit comments

Comments
 (0)