Conversation
Greptile SummaryThis PR implements dynamic embedding table fusion — a major architectural refactor that replaces the previous per-table Key changes:
Issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant BDET as BatchedDynamicEmbeddingTablesV2
participant Prefetch as dynamicemb_prefetch
participant FwdFn as DynamicEmbeddingFunction.forward
participant Storage as DynamicEmbStorage / DynamicEmbCache
participant Optimizer as BaseDynamicEmbeddingOptimizer
Caller->>BDET: prefetch(indices, offsets)
BDET->>BDET: set cache/storage training=True, set_score
BDET->>Prefetch: dynamicemb_prefetch(indices, offsets, cache, storage, ...)
Prefetch->>Prefetch: segmented_unique → unique_keys
alt StorageMode.CACHE
Prefetch->>Storage: cache.lookup(unique_keys)
Prefetch->>Storage: storage.find(miss_keys)
Prefetch->>Storage: cache.insert_and_evict(admitted_keys)
Prefetch->>Storage: storage.insert(evicted_keys)
else StorageMode.HBM_DIRECT
Prefetch->>Storage: _find_keys(state, unique_keys)
Prefetch->>Storage: state.key_index_map.insert(admitted_keys)
end
Prefetch-->>BDET: PrefetchState
BDET->>BDET: _prefetch_states.append(PrefetchState)
BDET->>BDET: _update_score()
Caller->>BDET: forward(indices, offsets)
BDET->>BDET: _prefetch_states.popleft() → PrefetchState
BDET->>FwdFn: DynamicEmbeddingFunction.apply(prefetch_state, ...)
alt use_counter (CACHE or HBM_DIRECT)
FwdFn->>Storage: load_from_flat(state, slot_indices)
else DEFAULT
FwdFn->>FwdFn: _generic_forward_path(storage, unique_keys)
end
FwdFn->>FwdFn: gather_embedding[_pooled](unique_embs → output_embs)
FwdFn->>FwdFn: decrement outstanding_keys_ref
FwdFn-->>BDET: output_embs
BDET->>BDET: set cache/storage training=False
Caller->>FwdFn: backward(grads)
FwdFn->>FwdFn: reduce_grads → unique_grads
FwdFn->>Optimizer: optimizer.step()
alt use_counter
FwdFn->>Optimizer: fused_update_for_flat_table(unique_grads, update_slot_indices)
FwdFn->>Storage: decrement_counter(update_slot_indices)
else DEFAULT
FwdFn->>Optimizer: update_for_padded_buffer(unique_grads, unique_values)
FwdFn->>Storage: storage.insert(unique_keys, unique_values, preserve_existing=True)
end
|
2c3371a to
a36d744
Compare
| training, | ||
| EvictStrategy(evict_strategy.value) if evict_strategy else None, | ||
| lfu_accumulated_frequency_per_table, | ||
| if prefetch_state is None: |
There was a problem hiding this comment.
What will happen when prefetch depth > 1, will. there any protect to the prefetch_state
| host_options: List[DynamicEmbTableOptions], | ||
| optimizer: BaseDynamicEmbeddingOptimizer, | ||
| ): | ||
| self._hbm = create_table_state(hbm_options, optimizer) |
There was a problem hiding this comment.
Will there two duplicated table states for hybrid state?
a36d744 to
c7590e7
Compare
Additional Comments (5)
Three consecutive statements are bare expressions whose values are immediately discarded: emb_dtype = storage.embedding_dtype()
storage.max_embedding_dim() # return value is thrown away
cache is not None # boolean is thrown awayThese look like forgotten assignments. Based on how the rest of the function uses Even if
When flagged_compact(admit_mask, [missing_keys, missing_indices, missing_table_ids, None])
This path is reachable any time:
The list passed to tensors_to_compact = [missing_keys, missing_indices, missing_table_ids]
if missing_scores is not None:
tensors_to_compact.append(missing_scores)
_, _, compacted = flagged_compact(admit_mask, tensors_to_compact)
if missing_scores is not None:
keys_to_insert, positions_in_unique, table_ids_to_insert, scores_to_insert = compacted
else:
keys_to_insert, positions_in_unique, table_ids_to_insert = compacted
scores_to_insert = None
Simply remove line 1714.
The concrete implementations ( The abstract signature declares only 7 elements – # Storage.find – correct 8-element annotation
) -> Tuple[
int,
torch.Tensor, # missing_keys
torch.Tensor, # missing_indices
torch.Tensor, # missing_table_ids ← add this
Optional[torch.Tensor], # missing_scores
torch.Tensor, # founds
torch.Tensor, # output_scores
torch.Tensor, # values
]:
|
a4097d9 to
4bc15aa
Compare
Additional Comments (5)
In flagged_compact(
admit_mask,
[missing_keys, missing_indices, missing_table_ids, missing_scores],
)passes The fix is to guard the compact call and only include scores when they are available: tensors_to_compact = [missing_keys, missing_indices, missing_table_ids]
if missing_scores is not None:
tensors_to_compact.append(missing_scores)
_, _, compacted = flagged_compact(admit_mask, tensors_to_compact)
if missing_scores is not None:
keys_to_insert, positions_in_unique, table_ids_to_insert, scores_to_insert = compacted
else:
keys_to_insert, positions_in_unique, table_ids_to_insert = compacted
scores_to_insert = None
The original new_score = cur_score + self.num_prefetch_ahead - 1The replacement
In int64_t table_id = NumRegions == 0 ? scalar_table_id : table_ids[emb_id];is not a int64_t table_id;
if constexpr (NumRegions == 0) {
table_id = scalar_table_id;
} else {
table_id = table_ids[emb_id];
}The same pattern appears in |
Additional Comments (7)
The early-return path on line 184–192 bypasses this only when scores_input = (
missing_scores
if missing_scores is not None
else torch.empty(0, dtype=torch.int64, device=missing_keys.device)
)
(
_,
_,
(
keys_to_insert,
positions_in_unique,
table_ids_to_insert,
scores_to_insert,
),
) = flagged_compact(
admit_mask,
[missing_keys, missing_indices, missing_table_ids, scores_input],
)
scores_to_insert = scores_to_insert if missing_scores is not None else None
The old per-table path used This issue also appears in
Consider adding validation: if len({kv.bucket_capacity for kv in kv_counters}) > 1:
raise ValueError(
"All KVCounter configs must share the same bucket_capacity in fused mode."
)
if len({kv.key_type for kv in kv_counters}) > 1:
raise ValueError(
"All KVCounter configs must share the same key_type in fused mode."
)
When creating external PS storage in caching mode, Add validation to ensure consistency: ext_storages = {opt.external_storage for opt in storage_options}
if len(ext_storages) > 1:
raise ValueError(
"All tables must share the same external_storage class in fused caching mode."
)
PS = storage_options[0].external_storage
This TODO should be resolved before the PR is merged. The comment is on a
|
Additional Comments (4)
|
|
* Avoid outstanding keys overflow: decrement in the end of fwd * Fix seq-emb'bw test;fix ref_counter in/decrement bug 1.Fix sequence embedding backward test: issue: there are two forward and one backward calls in one iteration, which will increment the ref_counter twice and decement it once. fix: switch to eval mode when only evaluate the model. other method: move ref_counter's decrement to the end of fwd, but may unlock the key early when there is an overlap of prefetch and backwad 2.Fix ref_counter increment/decrement bug issue: the arg slot_indices are begin from 0 for each table, but we need a flat index. besides, the flat_indices in one iteration are unique as two key can't share the same slot. fix: make increment/decrement the slot_indices for each table using table_ids. * Route to the correct ref_counter table in insert kernel * Update score in the end of prefetch;and only update it for STEP * Fix expected score in test as we update score in prefetch * Remove default value for table_ids in in/decrement_counter and make it not optional
2590bf7 to
8cc97ab
Compare
Description
Checklist
ci
CI after fixed