Skip to content

[Draft] Fea dynamicemb table fusion#313

Open
shijieliu wants to merge 11 commits intoNVIDIA:mainfrom
shijieliu:fea-dynamicemb_table_fusion
Open

[Draft] Fea dynamicemb table fusion#313
shijieliu wants to merge 11 commits intoNVIDIA:mainfrom
shijieliu:fea-dynamicemb_table_fusion

Conversation

@shijieliu
Copy link
Collaborator

@shijieliu shijieliu commented Feb 26, 2026

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

ci
CI after fixed

@shijieliu shijieliu changed the title Fea dynamicemb table fusion [Draft] Fea dynamicemb table fusion Feb 26, 2026
@greptile-apps
Copy link

greptile-apps bot commented Feb 26, 2026

Greptile Summary

This PR implements dynamic embedding table fusion — a major architectural refactor that replaces the previous per-table List[Storage] / List[Cache] design with a single fused DynamicEmbStorage / DynamicEmbCache / HybridStorage object that owns all tables simultaneously. The CUDA kernels, optimizer update paths, prefetch pipeline, and admission-counter logic are all updated accordingly.

Key changes:

  • DynamicEmbeddingTable monolith split into a DynamicEmbTableState dataclass plus three storage classes (DynamicEmbStorage, DynamicEmbCache, HybridStorage); flat multi-table load/store kernels replace the old combined dev/UVM layout.
  • DynamicEmbeddingFunction.forward now accepts a PrefetchState struct (produced by dynamicemb_prefetch) instead of raw indices, enabling a proper prefetch pipeline with an outstanding_keys_ref overflow guard.
  • Optimizer interface migrated from fused_update_with_index(dev_table, uvm_table) to fused_update_for_flat_table(table_ptrs, table_ids, …) plus a new update_for_padded_buffer path for the DEFAULT storage mode.
  • KVCounter simplified to a config holder; runtime logic moved to MultiTableKVCounter which wraps a single fused hash table for all admission counters.
  • Table lookup/insert CUDA kernels now route lookups via table_bucket_offsets per table-id and support compile-time overflow buckets (EnableOverflow).

Issues found:

  • Logic bug (batched_dynamicemb_function.py:225): In _apply_admission, missing_scores (typed Optional[torch.Tensor], None for LRU eviction) is unconditionally placed in the list passed to flagged_compact, which will raise at runtime when scores are absent and an admission strategy is active.
  • Dead code (batched_dynamicemb_function.py:892): unique_keys.numel() in _generic_forward_path is a standalone expression whose return value is silently discarded — almost certainly a leftover h_num_total = … assignment.
  • Unresolved TODO (batched_dynamicemb_function.py:661): # todo: double check on the frequency_counts_int64 conversion path feeding LFU eviction.
  • assert used for runtime validation (batched_dynamicemb_tables.py:668): Mixed admission-counter configurations are checked with assert, which is silently skipped under python -O; should be a ValueError.

Confidence Score: 2/5

  • Not safe to merge in current state — a confirmed logic bug can cause a runtime crash when LRU eviction and admission filtering are used together.
  • The None-in-flagged_compact issue in _apply_admission is a reproducible crash path (not a theoretical edge case), and the PR is still marked Draft with an open todo: double check comment. The overall architectural direction is sound and the CUDA kernel changes look correct, but the Python-layer bugs need resolution before this is production-ready.
  • corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py — contains the logic bug and dead code; corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.pyassert vs ValueError issue.

Important Files Changed

Filename Overview
corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py Core refactor: merges per-table prefetch/forward/backward into a unified fused path. Contains a discarded unique_keys.numel() return value, a potential None-in-flagged_compact crash when LRU is paired with an admission strategy, and an unresolved todo: double check on frequency conversion.
corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py Replaces per-table _storages/_caches lists with a single fused _storage/_cache. Uses assert instead of ValueError for mixed admission-counter validation, which silently passes under -O. Storage mode selection logic (HBM-only vs hybrid vs cache) is significantly restructured.
corelib/dynamicemb/dynamicemb/key_value_table.py Old DynamicEmbeddingTable monolith replaced by DynamicEmbTableState dataclass + DynamicEmbStorage/DynamicEmbCache/HybridStorage classes; flat-table load/store helpers added. Logic is well-structured; no critical issues found.
corelib/dynamicemb/dynamicemb/optimizer.py Optimizer interface migrated from fused_update_with_index(dev_table, uvm_table) to fused_update_for_flat_table(table_ptrs, table_ids, …) and update_for_padded_buffer. All four optimizer implementations updated consistently.
corelib/dynamicemb/dynamicemb/embedding_admission.py KVCounter simplified to a config-only class; new MultiTableKVCounter manages a single fused scored hash table for all tables. API change from per-table add(keys, frequencies, inplace) to multi-table add(keys, table_ids, frequencies) is clean and consistent.
corelib/dynamicemb/src/dynamic_emb_op.cu Load/store kernels migrated from split dev/UVM combined-table layout to flat multi-table layout with table_ptrs, table_ids, and per-table dims. Stride is now passed explicitly. Template parameter NumRegions cleanly distinguishes contiguous/emb-only/two-region copy modes.
corelib/dynamicemb/src/table_operation/kernels.cuh Table lookup kernel refactored to support multi-table routing via table_bucket_offsets and optional overflow bucket. ScorePolicyType promoted to a compile-time template parameter. Overflow path (EnableOverflow) added cleanly with __forceinline__ insert helpers.
corelib/dynamicemb/dynamicemb/types.py Added CopyMode enum and updated Storage/Cache abstract interfaces to expose find, insert, increment_counter, decrement_counter, and table-dimension queries. Breaking API change but contained within the library.
corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py Tests updated to use new single-storage API. Pipeline prefetch tests added. Coverage looks reasonable; no issues found.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant BDET as BatchedDynamicEmbeddingTablesV2
    participant Prefetch as dynamicemb_prefetch
    participant FwdFn as DynamicEmbeddingFunction.forward
    participant Storage as DynamicEmbStorage / DynamicEmbCache
    participant Optimizer as BaseDynamicEmbeddingOptimizer

    Caller->>BDET: prefetch(indices, offsets)
    BDET->>BDET: set cache/storage training=True, set_score
    BDET->>Prefetch: dynamicemb_prefetch(indices, offsets, cache, storage, ...)
    Prefetch->>Prefetch: segmented_unique → unique_keys
    alt StorageMode.CACHE
        Prefetch->>Storage: cache.lookup(unique_keys)
        Prefetch->>Storage: storage.find(miss_keys)
        Prefetch->>Storage: cache.insert_and_evict(admitted_keys)
        Prefetch->>Storage: storage.insert(evicted_keys)
    else StorageMode.HBM_DIRECT
        Prefetch->>Storage: _find_keys(state, unique_keys)
        Prefetch->>Storage: state.key_index_map.insert(admitted_keys)
    end
    Prefetch-->>BDET: PrefetchState
    BDET->>BDET: _prefetch_states.append(PrefetchState)
    BDET->>BDET: _update_score()

    Caller->>BDET: forward(indices, offsets)
    BDET->>BDET: _prefetch_states.popleft() → PrefetchState
    BDET->>FwdFn: DynamicEmbeddingFunction.apply(prefetch_state, ...)
    alt use_counter (CACHE or HBM_DIRECT)
        FwdFn->>Storage: load_from_flat(state, slot_indices)
    else DEFAULT
        FwdFn->>FwdFn: _generic_forward_path(storage, unique_keys)
    end
    FwdFn->>FwdFn: gather_embedding[_pooled](unique_embs → output_embs)
    FwdFn->>FwdFn: decrement outstanding_keys_ref
    FwdFn-->>BDET: output_embs
    BDET->>BDET: set cache/storage training=False

    Caller->>FwdFn: backward(grads)
    FwdFn->>FwdFn: reduce_grads → unique_grads
    FwdFn->>Optimizer: optimizer.step()
    alt use_counter
        FwdFn->>Optimizer: fused_update_for_flat_table(unique_grads, update_slot_indices)
        FwdFn->>Storage: decrement_counter(update_slot_indices)
    else DEFAULT
        FwdFn->>Optimizer: update_for_padded_buffer(unique_grads, unique_values)
        FwdFn->>Storage: storage.insert(unique_keys, unique_values, preserve_existing=True)
    end
Loading

Comments Outside Diff (4)

  1. corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 892 (link)

    Discarded return value — dead code

    unique_keys.numel() is called as a standalone statement and its return value is silently discarded. This is almost certainly a leftover from an earlier draft where the intent was something like h_num_total = unique_keys.numel(). As-is it triggers an unnecessary (but implicit) GPU→CPU sync in some backends and is misleading for readers.

  2. corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 225-228 (link)

    None passed to flagged_compact when scores are absent

    missing_scores is typed as Optional[torch.Tensor] and will be None when the eviction strategy does not use scores (e.g. LRU). At this call site admit_strategy is not None (the early-return guard on line 179 has already been passed), so missing_scores is unconditionally included in the list. flagged_compact expects tensor arguments and will raise at runtime when missing_scores is None.

    A minimal guard is needed:

    tensors_to_compact = [missing_keys, missing_indices, missing_table_ids]
    if missing_scores is not None:
        tensors_to_compact.append(missing_scores)
    
    count, _, compacted = flagged_compact(admit_mask, tensors_to_compact)
    
    if missing_scores is not None:
        keys_to_insert, positions_in_unique, table_ids_to_insert, scores_to_insert = compacted
    else:
        keys_to_insert, positions_in_unique, table_ids_to_insert = compacted
        scores_to_insert = None
  3. corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 661-663 (link)

    Unresolved todo: double check on frequency conversion

    The comment # todo: double check signals that this logic was not fully validated before the PR landed. Given that frequency_counters feeds directly into segmented_unique_cuda and influence LFU eviction decisions, an unverified conversion path can silently corrupt scores. Please resolve the TODO before merging.

  4. corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py, line 668-670 (link)

    assert instead of ValueError for runtime validation

    assert statements are silently removed when Python is run with the -O / -OO optimisation flags (common in production deployments). A user who configures some tables with admission counters and others without will get a confusing AssertionError at best, and a silent misbehaving model at worst. Use an explicit ValueError to make this a reliable check in all modes.

Last reviewed commit: 47cffb4

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

37 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@shijieliu shijieliu force-pushed the fea-dynamicemb_table_fusion branch from 2c3371a to a36d744 Compare February 27, 2026 02:29
training,
EvictStrategy(evict_strategy.value) if evict_strategy else None,
lfu_accumulated_frequency_per_table,
if prefetch_state is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen when prefetch depth > 1, will. there any protect to the prefetch_state

host_options: List[DynamicEmbTableOptions],
optimizer: BaseDynamicEmbeddingOptimizer,
):
self._hbm = create_table_state(hbm_options, optimizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there two duplicated table states for hybrid state?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the key_index_map

@shijieliu shijieliu force-pushed the fea-dynamicemb_table_fusion branch from a36d744 to c7590e7 Compare February 28, 2026 01:46
@greptile-apps
Copy link

greptile-apps bot commented Feb 28, 2026

Additional Comments (5)

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 787
Dead-code / no-op statements in dynamicemb_eval_forward

Three consecutive statements are bare expressions whose values are immediately discarded:

emb_dtype = storage.embedding_dtype()
storage.max_embedding_dim()   # return value is thrown away
cache is not None              # boolean is thrown away

These look like forgotten assignments. Based on how the rest of the function uses DynamicEmbeddingFunction.forward (which does emb_dim = storage.max_embedding_dim() and caching = cache is not None), the likely intent was:

        emb_dtype = storage.embedding_dtype()
        emb_dim = storage.max_embedding_dim()
        caching = cache is not None

Even if emb_dim and caching are not subsequently needed in this function's current form, leaving bare expression statements as dead code is confusing and will trip linters/reviewers.


corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 817
No-op boolean expression

dims is not None and max_D > min(dims) evaluates mixed_D but discards the result. In DynamicEmbeddingFunction.forward, the same expression is assigned to ctx.mixed_D and used for out_dim. Without the assignment here, gather_embedding_pooled always uses the full max_D-wide unique_embs regardless of whether the tables have uniform or mixed dimensions.

        mixed_D = dims is not None and max_D > min(dims)

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 232
flagged_compact called with None in the tensor list

When missing_scores is None (e.g. with LRU eviction strategy, which produces no score output), the call becomes:

flagged_compact(admit_mask, [missing_keys, missing_indices, missing_table_ids, None])

flagged_compact is a C++ extension that almost certainly expects every element of the second argument to be a torch.Tensor. Passing None would raise a TypeError at runtime.

This path is reachable any time:

  • admit_strategy is not None (admission control is active), and
  • missing_keys.numel() > 0, and
  • eviction strategy doesn't produce scores (e.g. KLru), making missing_scores=None

The list passed to flagged_compact should guard against None:

tensors_to_compact = [missing_keys, missing_indices, missing_table_ids]
if missing_scores is not None:
    tensors_to_compact.append(missing_scores)

_, _, compacted = flagged_compact(admit_mask, tensors_to_compact)

if missing_scores is not None:
    keys_to_insert, positions_in_unique, table_ids_to_insert, scores_to_insert = compacted
else:
    keys_to_insert, positions_in_unique, table_ids_to_insert = compacted
    scores_to_insert = None

corelib/dynamicemb/dynamicemb/key_value_table.py, line 1714
Dead-code attribute access in flush_cache

state.value_dim is evaluated but the result is immediately discarded. This is a no-op and appears to be a leftover from an earlier draft (perhaps intended as value_dim = state.value_dim for use later in the function, but never used).

Simply remove line 1714.


corelib/dynamicemb/dynamicemb/types.py, line 174
Abstract Storage.find return-type annotation is missing missing_table_ids

The concrete implementations (DynamicEmbStorage.find, HybridStorage.find) return an 8-tuple:

(num_missing, missing_keys, missing_indices, missing_table_ids,
 missing_scores, founds, output_scores, values)

The abstract signature declares only 7 elementsmissing_table_ids is absent after missing_indices. Any future Storage subclass that faithfully matches the abstract annotation will return 7 values and break every call-site that unpacks 8. The same issue exists for Cache.find at line 258–266.

# Storage.find – correct 8-element annotation
) -> Tuple[
    int,
    torch.Tensor,           # missing_keys
    torch.Tensor,           # missing_indices
    torch.Tensor,           # missing_table_ids  ← add this
    Optional[torch.Tensor], # missing_scores
    torch.Tensor,           # founds
    torch.Tensor,           # output_scores
    torch.Tensor,           # values
]:

Cache.find (lines 258-266) should similarly include missing_table_ids as the 4th element.

@shijieliu shijieliu force-pushed the fea-dynamicemb_table_fusion branch 2 times, most recently from a4097d9 to 4bc15aa Compare March 3, 2026 07:21
@greptile-apps
Copy link

greptile-apps bot commented Mar 3, 2026

Additional Comments (5)

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 241
flagged_compact called with None in tensor list

In _apply_admission, when an admission strategy is active but score tracking is disabled (e.g., LRU eviction mode), missing_scores will be None. The call:

flagged_compact(
    admit_mask,
    [missing_keys, missing_indices, missing_table_ids, missing_scores],
)

passes None as the fourth element in the tensor list. Since flagged_compact is a C++ pybind11 extension that expects torch.Tensor objects, this will raise a TypeError at runtime whenever admission filtering is combined with non-LFU eviction (e.g., LRU).

The fix is to guard the compact call and only include scores when they are available:

tensors_to_compact = [missing_keys, missing_indices, missing_table_ids]
if missing_scores is not None:
    tensors_to_compact.append(missing_scores)

_, _, compacted = flagged_compact(admit_mask, tensors_to_compact)
if missing_scores is not None:
    keys_to_insert, positions_in_unique, table_ids_to_insert, scores_to_insert = compacted
else:
    keys_to_insert, positions_in_unique, table_ids_to_insert = compacted
    scores_to_insert = None

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 897
Dead expression — result of numel() is discarded

unique_keys.numel() is called here but its return value is never used. This appears to be a leftover debug or planning artifact.

        device = unique_keys.device

corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py, line 1087
STEP-strategy prefetch score offset is no longer applied

The original _get_prefetch_score method correctly offset STEP-strategy scores by num_prefetch_ahead - 1 to account for the pipelined prefetch depth:

new_score = cur_score + self.num_prefetch_ahead - 1

The replacement _reduce_table_scores simply returns max(scores) with no adjustment. This means that when using DynamicEmbScoreStrategy.STEP with prefetch_pipeline=True, freshly prefetched entries may receive an incorrect (low) eviction score and get evicted prematurely before they are consumed in the corresponding forward pass. This behaviour change should be intentional and documented, or the offset should be preserved here using self._prefetch_outstanding_keys.


corelib/dynamicemb/dynamicemb/key_value_table.py, line 2816
Dead statement — state.value_dim result is unused

state.value_dim is a property access whose return value is immediately discarded. This is likely a remnant of a variable assignment that was deleted. It should either be removed, or the result should be assigned to a variable and used (e.g., as the batch_size argument to the export loop).

def flush_cache(cache: DynamicEmbCache, storage: Storage) -> None:
    state = cache._state
    batch_size = state.threads_in_wave

corelib/dynamicemb/src/dynamic_emb_op.cu, line 337
Potential null pointer dereference with NumRegions == 0 ternary

In load_from_flat_table_kernel_vec4 (and the scalar variant), when NumRegions == 0, table_ids is documented as unused (callers pass nullptr) and scalar_table_id is used directly. However the ternary:

int64_t table_id = NumRegions == 0 ? scalar_table_id : table_ids[emb_id];

is not a constexpr if, so NVCC may still emit the dead table_ids[emb_id] read in the generated PTX for the NumRegions == 0 instantiation, producing undefined behaviour. Replace with a constexpr if to guarantee the null pointer path is never reached:

int64_t table_id;
if constexpr (NumRegions == 0) {
    table_id = scalar_table_id;
} else {
    table_id = table_ids[emb_id];
}

The same pattern appears in load_from_flat_table_kernel (scalar version) and should be fixed there too.

@greptile-apps
Copy link

greptile-apps bot commented Mar 3, 2026

Additional Comments (7)

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 234
None passed to flagged_compact when missing_scores is absent

missing_scores is typed Optional[torch.Tensor] and will be None whenever the eviction strategy is LRU or CSTM (i.e., not LFU). When admit_strategy is not None and the code reaches this flagged_compact call, the list [missing_keys, missing_indices, missing_table_ids, None] is passed to the C++ extension, which will crash because it expects actual tensors.

The early-return path on line 184–192 bypasses this only when admit_strategy is None. Fix by guarding missing_scores before passing it:

scores_input = (
    missing_scores
    if missing_scores is not None
    else torch.empty(0, dtype=torch.int64, device=missing_keys.device)
)
(
    _,
    _,
    (
        keys_to_insert,
        positions_in_unique,
        table_ids_to_insert,
        scores_to_insert,
    ),
) = flagged_compact(
    admit_mask,
    [missing_keys, missing_indices, missing_table_ids, scores_input],
)
scores_to_insert = scores_to_insert if missing_scores is not None else None

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 828
Dead statement — result of numel() is unused

unique_keys.numel() computes the number of elements but discards the result. This is likely a leftover from a refactor where the variable was previously named h_num_total. Either remove the line or assign it for documentation purposes:

        h_num_total = unique_keys.numel()

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 662
Only initializers[0] used for all tables in fused path

_prefetch_cache_path, _prefetch_hbm_direct_path, _generic_forward_path, and DynamicEmbeddingFunction.forward all pass initializers[0] regardless of which table a key belongs to. In a multi-table setup where tables have different per-table initializers (e.g., different embedding dimensions or initialization schemes), keys from tables 1…N will be incorrectly initialized using table 0's initializer.

The old per-table path used initializers[i] within the table loop. The fused path needs to either (1) ensure a single shared initializer is always valid for all tables, or (2) dispatch to the correct per-table initializer based on unique_table_ids.

This issue also appears in dynamicemb_eval_forward (around lines 844 and 1120), so all code paths are affected.


corelib/dynamicemb/dynamicemb/embedding_admission.py, line 78
bucket_capacity and key_type from non-first counters are silently ignored

MultiTableKVCounter only uses kv_counters[0].bucket_capacity and kv_counters[0].key_type when constructing the underlying hash table, silently ignoring the same settings on counters 1…N. If a user specifies different bucket capacities or key types per table, their configuration will be disregarded without any warning.

Consider adding validation:

if len({kv.bucket_capacity for kv in kv_counters}) > 1:
    raise ValueError(
        "All KVCounter configs must share the same bucket_capacity in fused mode."
    )
if len({kv.key_type for kv in kv_counters}) > 1:
    raise ValueError(
        "All KVCounter configs must share the same key_type in fused mode."
    )

corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py, line 601
Only storage_options[0].external_storage is checked for all tables

When creating external PS storage in caching mode, storage_options[0].external_storage determines the storage class for the fused storage object. If tables have different external storage configurations (e.g., table 0 uses default storage while another has a custom PS), only the first table's setting is honored and others are silently ignored.

Add validation to ensure consistency:

ext_storages = {opt.external_storage for opt in storage_options}
if len(ext_storages) > 1:
    raise ValueError(
        "All tables must share the same external_storage class in fused caching mode."
    )
PS = storage_options[0].external_storage

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 603
Unresolved # todo: double check comment

This TODO should be resolved before the PR is merged. The comment is on a frequency_counters.long() conversion — it would be helpful to document whether the long conversion is definitively required or if the callers already guarantee int64 dtype, then remove the TODO.


examples/hstu/test_utils.py, line 578
Misleading comment: actual value is 8 MiB, not 4M

1024 * 1024 * 8 = 8,388,608 bytes = 8 MiB. The inline comment says # 4M HBM (maybe cached) which is incorrect.

                global_hbm_for_values=1024 * 1024 * 8,  # 8MiB HBM (maybe cached)

@greptile-apps
Copy link

greptile-apps bot commented Mar 3, 2026

Additional Comments (4)

corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py, line 828
unique_keys.numel() is called but its return value is never stored or used. This appears to be a leftover from refactoring. Remove this dead code:

        device = unique_keys.device

corelib/dynamicemb/dynamicemb/key_value_table.py, line 1523
state.value_dim is accessed but the result is never stored or used. This is a no-op, likely a leftover from earlier development. Remove this dead code:

    batch_size = state.threads_in_wave

corelib/dynamicemb/dynamicemb/types.py, line 158
The abstract Storage.find return-type annotation declares a 7-tuple, but all concrete implementations (DynamicEmbStorage.find, HybridStorage.find) return 8 values—the missing_table_ids element is present but missing from the abstract signature. All call-sites already correctly unpack 8 values (e.g., _prefetch_cache_path line 309-323). Update the annotation to match:

    ) -> Tuple[
        int,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        Optional[torch.Tensor],
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        num_missing: int
        missing_keys: torch.Tensor
        missing_indices: torch.Tensor
        missing_table_ids: torch.Tensor
        missing_scores: torch.Tensor
        founds: torch.Tensor
        output_scores: torch.Tensor
        values: torch.Tensor

corelib/dynamicemb/dynamicemb/key_value_table.py, line 736
When score_file_path is None, the message reads "Score file None does not exist." which implies a missing file rather than communicating that score loading is intentionally skipped. This can be confusing during debugging. Clarify the intent:

    if score_file_path is None:
        print("score_file_path is None. Scores will not be loaded.")

@shijieliu
Copy link
Collaborator Author

  1. resolve comments, fix expansion
  2. benchmark, add support for multiple table benchmark
  3. validate effectiveness of removing h2d from critical path, the kernel level change
  4. @JacoCheung verify perf impact on e2e

shijieliu and others added 9 commits March 11, 2026 01:24
* Avoid outstanding keys overflow: decrement in the end of fwd

* Fix seq-emb'bw test;fix ref_counter in/decrement bug

1.Fix sequence embedding backward test:
issue:
 there are two forward and one backward calls in one iteration,
 which will increment the ref_counter twice and decement it once.
fix:
 switch to eval mode when only evaluate the model.
other method:
 move ref_counter's decrement to the end of fwd, but may unlock the key
early when there is an overlap of prefetch and backwad

2.Fix  ref_counter increment/decrement bug
issue:
  the arg slot_indices are begin from 0 for each table, but we need a
flat index.
  besides, the flat_indices in one iteration are unique as two key can't
share the same slot.
fix:
  make increment/decrement the slot_indices for each table using
table_ids.

* Route to the correct ref_counter table in insert kernel

* Update score in the end of prefetch;and only update it for STEP

* Fix expected score in test as we update score in prefetch

* Remove default value for table_ids in in/decrement_counter and make it not optional
@shijieliu shijieliu force-pushed the fea-dynamicemb_table_fusion branch from 2590bf7 to 8cc97ab Compare March 11, 2026 08:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants