Skip to content

[PR1] Port output distribution classes to DynamicEmb#319

Closed
ShaobinChen-AH wants to merge 8 commits intoNVIDIA:mainfrom
ShaobinChen-AH:fix-issue-296
Closed

[PR1] Port output distribution classes to DynamicEmb#319
ShaobinChen-AH wants to merge 8 commits intoNVIDIA:mainfrom
ShaobinChen-AH:fix-issue-296

Conversation

@ShaobinChen-AH
Copy link
Contributor

Description

Checklist

@greptile-apps
Copy link

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR ports RwSequenceEmbeddingDist and RwPooledEmbeddingDist output distribution classes to DynamicEmb, overrides create_output_dist() in both RwSequenceDynamicEmbeddingSharding and RwPooledDynamicEmbeddingSharding, and exports the new classes via __init__.py. The design enables future optimization of the unbucketize_permute operation for non-contiguous (round-robin) distribution patterns.

Several issues flagged in the previous review round have been addressed:

  • Union and all required imports (BaseEmbeddingDist, EmbeddingShardingContext, SequenceShardingContext) are now present
  • Both new classes are added to __all__ in __init__.py
  • The previous assert sharding_ctx is not None guard has been replaced with an explicit ValueError
  • The _validate_sharding_ctx_consistency docstring is now in English
  • The _qcomm_codecs_registry dead attribute has been removed from RwPooledEmbeddingDist

The following items from earlier rounds remain unresolved:

  • Both forward methods are annotated -> torch.Tensor, but SequenceEmbeddingsAllToAll and PooledEmbeddingsReduceScatter both return Awaitable[torch.Tensor] — the annotation mismatches the abstract base class contract and the actual runtime return value
  • _validate_sharding_ctx_consistency still lacks parameter and return type annotations in a # pyre-strict file
  • RwPooledDynamicEmbeddingSharding.create_output_dist retains an inconsistent 12-space parameter indent with a trailing comma, diverging from every other method in the class

Confidence Score: 3/5

  • The PR is functionally sound but has a few unresolved issues from prior review cycles that should be addressed before merging.
  • The core logic is correct and the main blocking issues (missing imports, missing all entries, assert-based null guard) have been fixed. The remaining open items — incorrect forward return type annotations, missing type annotations on _validate_sharding_ctx_consistency, and formatting inconsistencies in rw_sharding.py — are not runtime bugs under typical usage but will cause Pyre type errors in a pyre-strict codebase and create incorrect API contracts for downstream consumers.
  • Pay attention to corelib/dynamicemb/dynamicemb/output_dist.py for the return type annotation mismatch on both forward methods, and to corelib/dynamicemb/dynamicemb/planner/rw_sharding.py for the formatting issues in RwPooledDynamicEmbeddingSharding.create_output_dist.

Important Files Changed

Filename Overview
corelib/dynamicemb/dynamicemb/output_dist.py New file adding RwSequenceEmbeddingDist and RwPooledEmbeddingDist. Key improvements over the initial PR: Union is properly imported, explicit ValueError replaces assert for null sharding_ctx, and _validate_sharding_ctx_consistency now uses English. However, both forward methods are still annotated as returning torch.Tensor instead of Awaitable[torch.Tensor], and _validate_sharding_ctx_consistency still lacks parameter/return type annotations in a pyre-strict file.
corelib/dynamicemb/dynamicemb/planner/rw_sharding.py Overrides create_output_dist() in both sharding classes. Required imports (BaseEmbeddingDist, EmbeddingShardingContext, SequenceShardingContext) have been correctly added. The RwPooledDynamicEmbeddingSharding.create_output_dist method still has inconsistent hanging-indent style and a trailing comma in its parameter list.
corelib/dynamicemb/dynamicemb/init.py Correctly imports RwSequenceEmbeddingDist and RwPooledEmbeddingDist and adds both to all, making them first-class public API symbols. No issues found.

Sequence Diagram

sequenceDiagram
    participant C as Caller
    participant RS as RwSequenceDynamicEmbeddingSharding
    participant RP as RwPooledDynamicEmbeddingSharding
    participant SD as RwSequenceEmbeddingDist
    participant PD as RwPooledEmbeddingDist
    participant A2A as SequenceEmbeddingsAllToAll
    participant RS2 as PooledEmbeddingsReduceScatter
    participant VB as VariableBatchPooledEmbeddingsReduceScatter

    C->>RS: create_output_dist(device)
    RS->>SD: RwSequenceEmbeddingDist(pg, num_features, device, codecs)
    SD->>A2A: SequenceEmbeddingsAllToAll(pg, features, device, codecs)
    RS-->>C: RwSequenceEmbeddingDist

    C->>RP: create_output_dist(device)
    RP->>PD: RwPooledEmbeddingDist(pg, embedding_dims, codecs)
    Note over PD: _dist=None (lazy init)
    RP-->>C: RwPooledEmbeddingDist

    C->>SD: forward(local_embs, sharding_ctx)
    SD->>A2A: forward(local_embs, lengths, splits, ...)
    A2A-->>C: Awaitable[Tensor]

    C->>PD: forward(local_embs, sharding_ctx)
    alt _dist is None
        PD->>PD: _create_output_dist_module(sharding_ctx)
        alt variable_batch_per_feature
            PD->>VB: create VariableBatchPooledEmbeddingsReduceScatter
        else
            PD->>RS2: create PooledEmbeddingsReduceScatter
        end
    end
    PD->>PD: _validate_sharding_ctx_consistency(sharding_ctx)
    alt sharding_ctx is None
        PD->>RS2: forward(local_embs)
    else variable_batch
        PD->>VB: forward(local_embs, batch_size_per_rank_per_feature, embedding_dims)
    else normal
        PD->>RS2: forward(local_embs, input_splits)
    end
    PD-->>C: Awaitable[Tensor]
Loading

Last reviewed commit: c5c82d5

Comment on lines +218 to +220
def _validate_sharding_ctx_consistency(self, sharding_ctx):
if self._dist_type is None:
return
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unreachable early-return branch — dead code

self._dist_type starts as None in __init__, but _validate_sharding_ctx_consistency is only ever called from forward after _create_output_dist_module has already set _dist_type to either "variable_batch" or "normal". The if self._dist_type is None: return guard can therefore never be reached in practice and silently hides future maintenance errors (e.g. if someone calls the validator before the module is initialized).

Either remove the guard and let the string comparisons below behave correctly for any future None state, or add an assertion to make the invariant explicit:

    def _validate_sharding_ctx_consistency(self, sharding_ctx: Optional[EmbeddingShardingContext]) -> None:
        assert self._dist_type is not None, (
            "_validate_sharding_ctx_consistency called before _create_output_dist_module"
        )

@z52527
Copy link
Collaborator

z52527 commented Mar 11, 2026

Hi Shaobin, we have addressed this issue at #297 .
The reason it wasn't merged is that we previously identified poor kernel performance in permute2D and sought to optimize it, but ultimately failed to reproduce the problem.
May I ask what you intended to follow up with regarding this PR?

@ShaobinChen-AH
Copy link
Contributor Author

Hi Shaobin, we have addressed this issue at #297 .
The reason it wasn't merged is that we previously identified poor kernel performance in permute2D and sought to optimize it, but ultimately failed to reproduce the problem.
May I ask what you intended to follow up with regarding this PR?

Thank you for the update. I wasn't aware that this issue had already been addressed. I'll go ahead and close this PR for now.

That said, I'd be very interested to see your ongoing efforts to optimize the kernel performance of permute2D. Please keep me posted if you make further progress — I'm happy to review or test any updates in the future.

Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants