Skip to content

[Draft] Feature: inference embedding with C++ export #324

Open
geoffreyQiu wants to merge 8 commits intoNVIDIA:mainfrom
geoffreyQiu:fea-inference_emb_export
Open

[Draft] Feature: inference embedding with C++ export #324
geoffreyQiu wants to merge 8 commits intoNVIDIA:mainfrom
geoffreyQiu:fea-inference_emb_export

Conversation

@geoffreyQiu
Copy link
Collaborator

Add Cpp export for Inference embedding based on:

  • ScoredHashTable ops from dynamicemb
  • Exportable module based on LinearBucketTable.
  • LinearUVMEmbedding from NVEmbedding.

@greptile-apps
Copy link

greptile-apps bot commented Mar 12, 2026

Greptile Summary

This draft PR introduces C++ export support for inference-only embedding lookup in dynamicemb, built on top of three new INFERENCE_EMB torch-dispatch custom operators (get_table_range, expand_table_ids, table_lookup) alongside a new CMake build target (inference_emb_ops.so). It is paired with a large refactor of the core dynamicemb internals that unifies multi-table handling into single Storage/Cache objects, changes table_ids from int32 to int64 throughout the C++ codebase, and replaces the old pybind11-only operator bindings with macro-gated DEMB_USE_PYBIND11 guards so the same CUDA files can serve both the pybind11 extension and the new TORCH_LIBRARY_FRAGMENT shared library.

Key changes:

  • New: CMakeLists.txt builds inference_emb_ops.so from three new *_torch_binding.cu files
  • New: index_range_meta.py / lookup_meta.py register fake/meta kernels for torch.export support
  • New: examples/hstu/inference/test_export_demo.py demo for the full export-and-run flow
  • Refactor: ScoredHashTable / Storage / Cache / Counter interfaces updated to support multi-table table_ids; KVCounter split into plain config + MultiTableKVCounter
  • Refactor: segmented_unique / expand_table_ids_cuda now use int64 table IDs throughout
  • Fix: flagged_compact replaces the old select / select_index functions with a CUB-backed zipped-iterator compaction routine

Issues found:

  • Critical (runtime crash): InferenceEmbeddingTable.forward() passes raw table_indices (containing -1 for missing keys) directly to torch.index_select, which does not accept negative indices. The zeroing torch.where block below is commented out and must be re-enabled, with indices clamped first.
  • Logic bug: table_ranges is computed via get_table_range at the top of forward() but is never referenced — dead code that also introduces an unnecessary graph node in the exported model.
  • Logic bug: The .so search loop breaks on the first existing path whether or not loading succeeded, silently skipping fallback paths on partial failures.
  • Style: Newly added exclude_files=[] parameter in setup.py is a mutable default argument.

Confidence Score: 2/5

  • Not safe to merge — the demo forward pass will crash at runtime whenever any hash-table lookup misses, and dead code suggests the implementation is incomplete.
  • The PR is marked as [Draft] and contains a critical runtime bug: missing keys produce index -1 which is passed to torch.index_select causing an out-of-bounds error. Additionally, the table_ranges variable computed in forward() is entirely unused, indicating the embedding routing logic may be incomplete. The broader refactor of the dynamicemb core (int64 table_ids, multi-table storage) is extensive and well-structured, but the demo/integration layer that ties it all together has correctness gaps that need to be resolved before merging.
  • examples/hstu/inference/test_export_demo.py requires the most attention — it contains the critical index-select crash and the dead table_ranges variable.

Important Files Changed

Filename Overview
examples/hstu/inference/test_export_demo.py New demo module for export-compatible inference embedding using INFERENCE_EMB custom ops. Contains three critical issues: (1) unfound hash-table keys produce -1 indices that are passed directly to torch.index_select, causing a runtime error; (2) the computed table_ranges variable is never used (dead code / possible incomplete impl); (3) the .so search loop stops at the first existing path even if loading fails, silently skipping fallback paths.
corelib/dynamicemb/dynamicemb/index_range_meta.py New file registering fake/meta kernels for INFERENCE_EMB::get_table_range and INFERENCE_EMB::expand_table_ids to support torch.export. Fake shapes correctly mirror the C++ implementation. Registration errors are surfaced as RuntimeWarnings rather than hard failures, which is appropriate for optional .so loading.
corelib/dynamicemb/dynamicemb/lookup_meta.py New file registering a fake/meta kernel for INFERENCE_EMB::table_lookup. Output shapes (score_out, founds, indices each of length n) correctly match the C++ kernel. Validation checks for mismatched key/table_ids lengths and score_input length are sound.
corelib/dynamicemb/CMakeLists.txt New CMake build file for inference_emb_ops.so shared library. Correctly auto-detects Torch cmake prefix, supports configurable CUDA architectures via SM variable, and uses TORCH_LIBRARY_FRAGMENT-based bindings (not pybind11). The referenced src/utils.cpp exists in the repository.
corelib/dynamicemb/src/table_operation/lookup_torch_binding.cu New C++ torch-dispatch binding for INFERENCE_EMB::table_lookup using TORCH_LIBRARY_FRAGMENT. Properly validates CUDA tensor requirements and provides a TORCH_CHECK-failing CPU fallback. Uses std::optional for optional overflow storage parameters.
corelib/dynamicemb/dynamicemb/types.py Refactored Storage/Cache/Counter abstract interfaces to add multi-table support via table_id parameters. Adds new CopyMode enum and export_keys_values iterator to Storage. Cache.lookup and insert_and_evict signatures updated; flush() removed from Cache. Changes look structurally sound.
corelib/dynamicemb/dynamicemb/scored_hashtable.py Major refactor: removes GroupedScoredHashTable and ScoreArg.is_return, updates ScoredHashTable interface to accept multi-table table_ids, and adds get_scored_table factory. LinearBucketTable now handles batched multi-table logic. GroupedScoredHashTable is removed in favour of the unified multi-table approach.
corelib/dynamicemb/src/index_calculation.cu Replaces separate select/select_index ops with the new flagged_compact function that performs CUB stream compaction on up to 6 zipped 8-byte tensors simultaneously. Contains an intentional GPU→CPU sync (num_selected.cpu().item()) to obtain the count for output slicing. Pybind11 bindings guarded behind DEMB_USE_PYBIND11 macro.
corelib/dynamicemb/src/unique_op.cu Changes table_ids dtype throughout from int32 to int64 (kernel arguments, output allocation, TORCH_CHECK assertion). Also adds pybind11 guard macros. The sign-extension fix in pack_table_val (cast to int32_t before sign-extending to int64_t) is correct.
corelib/dynamicemb/setup.py Adds exclude_files parameter to find_source_files to omit the three new TORCH_LIBRARY_FRAGMENT binding files from the pybind11 extension build. Adds -DDEMB_USE_PYBIND11 compile flag. Newly added exclude_files=[] uses a mutable default argument (minor Python anti-pattern).
corelib/dynamicemb/dynamicemb/key_value_table.py Large refactor replacing flat-table APIs with multi-table APIs. Adds DynamicEmbStorage, DynamicEmbCache, HybridStorage classes. Replaces load_from_combined_table/store_to_combined_table with flat-table variants. Introduces flush_cache helper and load_from_flat_single_table utility.
corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py Major refactor to multi-table unified architecture. Replaces per-table _storages/_caches lists with single _storage/_cache objects. Adds HybridStorage fallback when data doesn't fit in HBM. Adds PrefetchState deque for pipeline management. Removes direct wildcard imports in favour of explicit ones.
corelib/dynamicemb/dynamicemb/embedding_admission.py KVCounter is now a plain data-holding config class (no Counter ABC). A new MultiTableKVCounter wraps a list of KVCounter configs into a single fused ScoredHashTable, providing add/erase/dump/load with table_id routing.

Sequence Diagram

sequenceDiagram
    participant App
    participant InferenceEmbeddingTable
    participant InferenceLinearBucketTable
    participant INFERENCE_EMB_ops as torch.ops.INFERENCE_EMB
    participant linear_mem_table as linear_mem_table (buffer)

    App->>InferenceEmbeddingTable: forward(indices, offsets)
    InferenceEmbeddingTable->>INFERENCE_EMB_ops: get_table_range(offsets, feature_offsets)
    Note over INFERENCE_EMB_ops: returns table_ranges (1D, num_tables+1)<br/>⚠️ result is never used
    InferenceEmbeddingTable->>INFERENCE_EMB_ops: expand_table_ids(offsets, table_offsets, ...)
    INFERENCE_EMB_ops-->>InferenceEmbeddingTable: table_ids (num_elements,)
    InferenceEmbeddingTable->>InferenceLinearBucketTable: lookup(keys, table_ids, score_policy)
    InferenceLinearBucketTable->>INFERENCE_EMB_ops: table_lookup(table_storage, bucket_offsets, ...)
    INFERENCE_EMB_ops-->>InferenceLinearBucketTable: score_out, founds, indices
    Note over InferenceLinearBucketTable: indices[i] == -1 when key not found ⚠️
    InferenceLinearBucketTable-->>InferenceEmbeddingTable: score_out, founds, table_indices
    InferenceEmbeddingTable->>linear_mem_table: torch.index_select(table, 0, table_indices)
    Note over linear_mem_table: ⚠️ CRASH if any table_indices[i] == -1
    linear_mem_table-->>InferenceEmbeddingTable: embeddings (num_elements, emb_dim)
    InferenceEmbeddingTable-->>App: embeddings
Loading

Last reviewed commit: f0fb78d

Comment on lines +280 to +290

# Step 2: Expand table IDs from offsets
# expand_table_ids(offsets, table_offsets, num_tables, local_batch_size, num_elements)
# Returns (num_elements,) int64 table_ids indicating which table each element belongs to
num_features = offsets.shape[0] - 1
num_elements = indices.shape[0]

# Prepare table_offsets_in_feature: where in feature space each table starts
table_offsets = torch.arange(
num_features + 1, dtype=torch.int64, device=self.device
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid indices from unfound keys passed to torch.index_select

When the hash table lookup does not find a key, the CUDA kernel writes -1 into table_indices for that position (as confirmed by the kernel code in kernels.cuh which sets indices[i] = -1 on miss). The forward then calls:

embeddings = torch.index_select(self.linear_mem_table, 0, table_indices)

torch.index_select does not support negative indices — passing -1 will raise a RuntimeError: index -1 is out of bounds for dimension 0 with size N at runtime whenever any key is missing from the table.

The torch.where block below this line was written exactly to handle this case (zeroing out embeddings for unfound items), but it is commented out. At minimum, the indices should be clamped before the gather, and then the where-mask applied:

safe_indices = table_indices.clamp(min=0)
embeddings = torch.index_select(self.linear_mem_table, 0, safe_indices)
embeddings = torch.where(
    founds.unsqueeze(-1),
    embeddings,
    torch.zeros_like(embeddings),
)

Comment on lines +255 to +262
)
self.register_buffer("linear_mem_table", linear_mem_table)

def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling
) -> torch.Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computed table_ranges result is never used

table_ranges is computed by calling torch.ops.INFERENCE_EMB.get_table_range(...), but is never referenced again in the forward() method — it's dead code. This is either an incomplete implementation (the range was intended to be used for per-feature routing) or the variable can be removed entirely. Leaving it in place will produce a graph node in the exported model that wastes computation.

Suggested change
)
self.register_buffer("linear_mem_table", linear_mem_table)
def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling
) -> torch.Tensor:
# Step 2: Expand table IDs from offsets

Comment on lines +46 to +55
if os.path.exists(_path):
try:
torch.ops.load_library(_path)
print(f"[INFO] Loaded inference_emb_ops.so from {_path}")
_ops_loaded = True
except Exception as _e:
print(f"[WARN] Failed to load {_path}: {_e}")
break # stop after first found path, whether load succeeded or not

if not _ops_loaded:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Library loading stops at first found path even on failure

The break at line 55 exits the search loop after the first path where os.path.exists() returns True, regardless of whether torch.ops.load_library succeeds or fails. If the file exists but cannot be loaded (e.g., missing CUDA runtime symbols, wrong ABI), _ops_loaded remains False and none of the other fallback paths are tried.

Consider only break-ing when the load actually succeeds:

for _path in _SEARCH_PATHS:
    if os.path.exists(_path):
        try:
            torch.ops.load_library(_path)
            print(f"[INFO] Loaded inference_emb_ops.so from {_path}")
            _ops_loaded = True
            break  # stop only on success
        except Exception as _e:
            print(f"[WARN] Failed to load {_path}: {_e}")

Comment on lines +255 to +260
)
self.register_buffer("linear_mem_table", linear_mem_table)

def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading comment about get_table_range output shape

The inline comment says the output shape is (num_features, 2), but the actual C++ implementation (index_calculation.cu) returns at::empty_like(feature_offsets) — a 1D tensor of shape (num_tables + 1,), the same shape as feature_offsets. This is also confirmed by the fake kernel in index_range_meta.py which returns feature_offsets.new_empty(feature_offsets.shape).

Suggested change
)
self.register_buffer("linear_mem_table", linear_mem_table)
def forward(
self,
indices: torch.Tensor, # (batch_size,) indices to lookup
table_ranges = torch.ops.INFERENCE_EMB.get_table_range(
offsets, self.feature_offsets
) # (num_tables + 1,) – same shape as feature_offsets

Comment on lines 57 to +62

def find_source_files(directory, extension_pattern, exclude_dirs=[]):
def find_source_files(
directory,
extension_pattern,
exclude_dirs=[],
exclude_files=[],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutable default argument introduced for exclude_files

exclude_files=[] is a mutable default argument, a well-known Python footgun: the same list object is shared across all calls that rely on the default. While harmless in this particular setup.py usage (called at install time, never mutated), adopting None with an explicit fallback is the conventional safe pattern:

Suggested change
def find_source_files(directory, extension_pattern, exclude_dirs=[]):
def find_source_files(
directory,
extension_pattern,
exclude_dirs=[],
exclude_files=[],
def find_source_files(
directory,
extension_pattern,
exclude_dirs=None,
exclude_files=None,
):
if exclude_dirs is None:
exclude_dirs = []
if exclude_files is None:
exclude_files = []

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants