[Draft] Feature: inference embedding with C++ export #324
[Draft] Feature: inference embedding with C++ export #324geoffreyQiu wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis draft PR introduces C++ export support for inference-only embedding lookup in Key changes:
Issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant App
participant InferenceEmbeddingTable
participant InferenceLinearBucketTable
participant INFERENCE_EMB_ops as torch.ops.INFERENCE_EMB
participant linear_mem_table as linear_mem_table (buffer)
App->>InferenceEmbeddingTable: forward(indices, offsets)
InferenceEmbeddingTable->>INFERENCE_EMB_ops: get_table_range(offsets, feature_offsets)
Note over INFERENCE_EMB_ops: returns table_ranges (1D, num_tables+1)<br/>⚠️ result is never used
InferenceEmbeddingTable->>INFERENCE_EMB_ops: expand_table_ids(offsets, table_offsets, ...)
INFERENCE_EMB_ops-->>InferenceEmbeddingTable: table_ids (num_elements,)
InferenceEmbeddingTable->>InferenceLinearBucketTable: lookup(keys, table_ids, score_policy)
InferenceLinearBucketTable->>INFERENCE_EMB_ops: table_lookup(table_storage, bucket_offsets, ...)
INFERENCE_EMB_ops-->>InferenceLinearBucketTable: score_out, founds, indices
Note over InferenceLinearBucketTable: indices[i] == -1 when key not found ⚠️
InferenceLinearBucketTable-->>InferenceEmbeddingTable: score_out, founds, table_indices
InferenceEmbeddingTable->>linear_mem_table: torch.index_select(table, 0, table_indices)
Note over linear_mem_table: ⚠️ CRASH if any table_indices[i] == -1
linear_mem_table-->>InferenceEmbeddingTable: embeddings (num_elements, emb_dim)
InferenceEmbeddingTable-->>App: embeddings
Last reviewed commit: f0fb78d |
|
|
||
| # Step 2: Expand table IDs from offsets | ||
| # expand_table_ids(offsets, table_offsets, num_tables, local_batch_size, num_elements) | ||
| # Returns (num_elements,) int64 table_ids indicating which table each element belongs to | ||
| num_features = offsets.shape[0] - 1 | ||
| num_elements = indices.shape[0] | ||
|
|
||
| # Prepare table_offsets_in_feature: where in feature space each table starts | ||
| table_offsets = torch.arange( | ||
| num_features + 1, dtype=torch.int64, device=self.device | ||
| ) |
There was a problem hiding this comment.
Invalid indices from unfound keys passed to torch.index_select
When the hash table lookup does not find a key, the CUDA kernel writes -1 into table_indices for that position (as confirmed by the kernel code in kernels.cuh which sets indices[i] = -1 on miss). The forward then calls:
embeddings = torch.index_select(self.linear_mem_table, 0, table_indices)torch.index_select does not support negative indices — passing -1 will raise a RuntimeError: index -1 is out of bounds for dimension 0 with size N at runtime whenever any key is missing from the table.
The torch.where block below this line was written exactly to handle this case (zeroing out embeddings for unfound items), but it is commented out. At minimum, the indices should be clamped before the gather, and then the where-mask applied:
safe_indices = table_indices.clamp(min=0)
embeddings = torch.index_select(self.linear_mem_table, 0, safe_indices)
embeddings = torch.where(
founds.unsqueeze(-1),
embeddings,
torch.zeros_like(embeddings),
)| ) | ||
| self.register_buffer("linear_mem_table", linear_mem_table) | ||
|
|
||
| def forward( | ||
| self, | ||
| indices: torch.Tensor, # (batch_size,) indices to lookup | ||
| offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Computed table_ranges result is never used
table_ranges is computed by calling torch.ops.INFERENCE_EMB.get_table_range(...), but is never referenced again in the forward() method — it's dead code. This is either an incomplete implementation (the range was intended to be used for per-feature routing) or the variable can be removed entirely. Leaving it in place will produce a graph node in the exported model that wastes computation.
| ) | |
| self.register_buffer("linear_mem_table", linear_mem_table) | |
| def forward( | |
| self, | |
| indices: torch.Tensor, # (batch_size,) indices to lookup | |
| offsets: torch.Tensor, # (num_features + 1,) batch offsets for pooling | |
| ) -> torch.Tensor: | |
| # Step 2: Expand table IDs from offsets |
| if os.path.exists(_path): | ||
| try: | ||
| torch.ops.load_library(_path) | ||
| print(f"[INFO] Loaded inference_emb_ops.so from {_path}") | ||
| _ops_loaded = True | ||
| except Exception as _e: | ||
| print(f"[WARN] Failed to load {_path}: {_e}") | ||
| break # stop after first found path, whether load succeeded or not | ||
|
|
||
| if not _ops_loaded: |
There was a problem hiding this comment.
Library loading stops at first found path even on failure
The break at line 55 exits the search loop after the first path where os.path.exists() returns True, regardless of whether torch.ops.load_library succeeds or fails. If the file exists but cannot be loaded (e.g., missing CUDA runtime symbols, wrong ABI), _ops_loaded remains False and none of the other fallback paths are tried.
Consider only break-ing when the load actually succeeds:
for _path in _SEARCH_PATHS:
if os.path.exists(_path):
try:
torch.ops.load_library(_path)
print(f"[INFO] Loaded inference_emb_ops.so from {_path}")
_ops_loaded = True
break # stop only on success
except Exception as _e:
print(f"[WARN] Failed to load {_path}: {_e}")| ) | ||
| self.register_buffer("linear_mem_table", linear_mem_table) | ||
|
|
||
| def forward( | ||
| self, | ||
| indices: torch.Tensor, # (batch_size,) indices to lookup |
There was a problem hiding this comment.
Misleading comment about get_table_range output shape
The inline comment says the output shape is (num_features, 2), but the actual C++ implementation (index_calculation.cu) returns at::empty_like(feature_offsets) — a 1D tensor of shape (num_tables + 1,), the same shape as feature_offsets. This is also confirmed by the fake kernel in index_range_meta.py which returns feature_offsets.new_empty(feature_offsets.shape).
| ) | |
| self.register_buffer("linear_mem_table", linear_mem_table) | |
| def forward( | |
| self, | |
| indices: torch.Tensor, # (batch_size,) indices to lookup | |
| table_ranges = torch.ops.INFERENCE_EMB.get_table_range( | |
| offsets, self.feature_offsets | |
| ) # (num_tables + 1,) – same shape as feature_offsets |
|
|
||
| def find_source_files(directory, extension_pattern, exclude_dirs=[]): | ||
| def find_source_files( | ||
| directory, | ||
| extension_pattern, | ||
| exclude_dirs=[], | ||
| exclude_files=[], |
There was a problem hiding this comment.
Mutable default argument introduced for exclude_files
exclude_files=[] is a mutable default argument, a well-known Python footgun: the same list object is shared across all calls that rely on the default. While harmless in this particular setup.py usage (called at install time, never mutated), adopting None with an explicit fallback is the conventional safe pattern:
| def find_source_files(directory, extension_pattern, exclude_dirs=[]): | |
| def find_source_files( | |
| directory, | |
| extension_pattern, | |
| exclude_dirs=[], | |
| exclude_files=[], | |
| def find_source_files( | |
| directory, | |
| extension_pattern, | |
| exclude_dirs=None, | |
| exclude_files=None, | |
| ): | |
| if exclude_dirs is None: | |
| exclude_dirs = [] | |
| if exclude_files is None: | |
| exclude_files = [] |
Add Cpp export for Inference embedding based on:
dynamicemb