[PR1] Port output distribution classes to DynamicEmb#319
[PR1] Port output distribution classes to DynamicEmb#319ShaobinChen-AH wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR ports Several issues flagged in the previous review round have been addressed:
The following items from earlier rounds remain unresolved:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant C as Caller
participant RS as RwSequenceDynamicEmbeddingSharding
participant RP as RwPooledDynamicEmbeddingSharding
participant SD as RwSequenceEmbeddingDist
participant PD as RwPooledEmbeddingDist
participant A2A as SequenceEmbeddingsAllToAll
participant RS2 as PooledEmbeddingsReduceScatter
participant VB as VariableBatchPooledEmbeddingsReduceScatter
C->>RS: create_output_dist(device)
RS->>SD: RwSequenceEmbeddingDist(pg, num_features, device, codecs)
SD->>A2A: SequenceEmbeddingsAllToAll(pg, features, device, codecs)
RS-->>C: RwSequenceEmbeddingDist
C->>RP: create_output_dist(device)
RP->>PD: RwPooledEmbeddingDist(pg, embedding_dims, codecs)
Note over PD: _dist=None (lazy init)
RP-->>C: RwPooledEmbeddingDist
C->>SD: forward(local_embs, sharding_ctx)
SD->>A2A: forward(local_embs, lengths, splits, ...)
A2A-->>C: Awaitable[Tensor]
C->>PD: forward(local_embs, sharding_ctx)
alt _dist is None
PD->>PD: _create_output_dist_module(sharding_ctx)
alt variable_batch_per_feature
PD->>VB: create VariableBatchPooledEmbeddingsReduceScatter
else
PD->>RS2: create PooledEmbeddingsReduceScatter
end
end
PD->>PD: _validate_sharding_ctx_consistency(sharding_ctx)
alt sharding_ctx is None
PD->>RS2: forward(local_embs)
else variable_batch
PD->>VB: forward(local_embs, batch_size_per_rank_per_feature, embedding_dims)
else normal
PD->>RS2: forward(local_embs, input_splits)
end
PD-->>C: Awaitable[Tensor]
Last reviewed commit: c5c82d5 |
d3e469b to
df835b9
Compare
| def _validate_sharding_ctx_consistency(self, sharding_ctx): | ||
| if self._dist_type is None: | ||
| return |
There was a problem hiding this comment.
Unreachable early-return branch — dead code
self._dist_type starts as None in __init__, but _validate_sharding_ctx_consistency is only ever called from forward after _create_output_dist_module has already set _dist_type to either "variable_batch" or "normal". The if self._dist_type is None: return guard can therefore never be reached in practice and silently hides future maintenance errors (e.g. if someone calls the validator before the module is initialized).
Either remove the guard and let the string comparisons below behave correctly for any future None state, or add an assertion to make the invariant explicit:
def _validate_sharding_ctx_consistency(self, sharding_ctx: Optional[EmbeddingShardingContext]) -> None:
assert self._dist_type is not None, (
"_validate_sharding_ctx_consistency called before _create_output_dist_module"
)|
Hi Shaobin, we have addressed this issue at #297 . |
Thank you for the update. I wasn't aware that this issue had already been addressed. I'll go ahead and close this PR for now. That said, I'd be very interested to see your ongoing efforts to optimize the kernel performance of permute2D. Please keep me posted if you make further progress — I'm happy to review or test any updates in the future. Thanks again! |
Description
Checklist
I am familiar with the Contributing Guidelines.
New or existing tests cover these changes.
The documentation is up to date with these changes.
Add RwSequenceEmbeddingDist and RwPooledEmbeddingDist classes in output_dist.py- Override create_output_dist() in RwSequenceDynamicEmbeddingSharding and RwPooledDynamicEmbeddingSharding- Enable future optimization for unbucketize_permute operation in row-wise sharding
Verified with test_sequence_embedding_fw.py (100 iterations passed)
Related: [FEA] Slow unbucketize permute operation in SequenceEmbeddingsAllToAll for row-wise sharding #296"