Optimize balancer and setup debug logger.#308
Conversation
Greptile SummaryThis PR delivers a significant refactor of the HSTU batch balancer pipeline: the Karmarkar-Karp (KK) algorithm is moved to a background Key changes and notable findings:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant DL as DataLoader
participant Pipe as TrainPipeline (main thread)
participant GPU as GPU (_memcpy_stream)
participant KK as KK Thread (background)
participant NCCL as NCCL (DP group)
Note over Pipe,NCCL: Iteration N
Pipe->>Pipe: start_sparse_data_dist(batch[N+1])
Pipe->>GPU: H2D copy batch[N+2]
Pipe->>Pipe: wait_sparse_data_dist(batch[N+1])
Note over Pipe,NCCL: Phase 1 – on _memcpy_stream
GPU->>NCCL: AllGather workloads (batch[N+2])
NCCL-->>GPU: allgather_workloads_cpu (D2H)
GPU->>KK: submit KK(allgather_workloads_cpu)
Note over Pipe,GPU: Forward / Backward (overlaps KK)
Pipe->>GPU: forward(batch[N])
Pipe->>GPU: backward + optimizer step
Note over Pipe,NCCL: Phase 2 – on _memcpy_stream
KK-->>GPU: partitions_indices (future.result())
GPU->>NCCL: AllGather batch[N+2] (or All2All)
NCCL-->>GPU: redistributed batch[N+2]
GPU->>GPU: index_select / _strip_dense_padding
Pipe->>NCCL: wait_stream(_memcpy_stream) → loss AllReduce
Pipe->>NCCL: DDP AllReduce gradients
|
8b68ea4 to
2e3a262
Compare
6d3e1a6 to
5a0eff0
Compare
90abbe0 to
a0019e0
Compare
examples/hstu/training/benchmark/run_balanced_pipeline_comparison.sh
Outdated
Show resolved
Hide resolved
709bb20 to
bf3ee02
Compare
…ecv_ids.numel() all2all_batch was incorrectly overwriting batch_size with recv_ids.numel() and setting actual_batch_size = batch_size. Since _apply_to_tensors_or_kjt already copies batch_size from the source batch, no reassignment is needed. actual_batch_size is now set to recv_ids.numel() directly. Also adds test_all2all_vs_allgather_incomplete_batch with per-rank random actual_batch_size to verify data correctness under heterogeneous incomplete batches. Made-with: Cursor
…l path Read SHUFFLE_WITH_ALL2ALL from environment at module load. When set to "1", finish_shuffle and shuffle use the All2All communication path regardless of the use_all2all parameter (OR logic). Made-with: Cursor
…dst_rank
- pad_and_allgather_batch: zero-pad dense tensors to batch_size before
AllGather; optionally return is_padding mask via return_padding_flag
- pad_and_all2all_batch: zero-pad dense tensors before All2All; make
dst_rank and recv_counts required params; remove _build_dst_rank
(was the only consumer of an extra all_to_all_single call)
- _strip_dense_padding: new module-level utility to truncate dense
tensors to actual_batch_size rows while keeping KJTs intact
- finish_shuffle / shuffle: apply _strip_dense_padding when padding
is present so output batch satisfies the 4 invariants:
batch_size = KJT dim (includes padding)
actual_batch_size = dense dim (real samples only)
Dense tensor dim-0 = actual_batch_size
KJT lengths count = batch_size * num_keys
- BaseBatch: update docstring and field comments to document invariants;
index_select_dense_tensor uses actual_batch_size for reshape
- Tests: add _strip_dense_padding verification to allgather and all2all
tests; merge incomplete batch test into test_all2all_vs_allgather;
use _build_dst_rank_local to pre-compute dst_rank/recv_counts
Made-with: Cursor
… after shuffle After balanced shuffler, _strip_dense_padding trims dense tensors to actual_batch_size while KJTs retain batch_size entries (padding samples have zero-length sequences). This causes a dimension mismatch in _get_high_inds when max_num_candidates > 0. Re-pad num_candidates with zeros to align with the KJT-derived sequence_embeddings_lengths. Made-with: Cursor
Made-with: Cursor
- Remove duplicate FLOP counting in FusedHSTULayerFunction (keep module hooks only) - Remove accidental pdb.set_trace() in test_dataset.py - Convert _batch_counter from class variable to instance variable - Fix off-by-one in async path print window check by passing handle idx - Remove unused linear_dim parameter from _compute_attn_fwd_flops - Add batch_size and values dtype validation in KJT list allgather - Add NOTE comment for actual_batch_size in pad_and_allgather_batch Made-with: Cursor
When world_size==1 and actual_batch_size < batch_size (incomplete batch), pad_and_all2all_batch previously returned the original batch without padding dense tensors. This caused shape mismatches with pad_and_allgather_batch which always pads. Unify both functions to follow the same world_size==1 contract: - Complete batch: return as-is (no-op) - Incomplete batch: pad dense tensors, update actual_batch_size, skip all collective communication - KJT fields returned unchanged in both cases Document the fast-path behaviour in both docstrings. Made-with: Cursor
bf3ee02 to
f92e576
Compare
Description
This PR optimizes the HSTU batch balancer pipeline to fully hide the KK (Karmarkar-Karp) algorithm overhead and reduce AllGather communication cost.
Core Changes
Async 2-phase KK pipeline
shuffle()intostart_shuffle_async()(Phase 1) +finish_shuffle()(Phase 2).ThreadPoolExecutor(pure CPU, no GPU access).ShuffleHandleprovides stable per-batch tracking for the prefetch pipeline (multiple in-flight batches).All2All shuffle path (
batch_all2all.py)SHUFFLE_WITH_ALL2ALL=1env var.dst_rankderivation (no extra communication when partitions are symmetric).Fused multi-KJT AllGather (
collective_ops.py)keyed_jagged_tensor_list_allgather: gathers all KJTs with only 2 NCCL calls (1 for lengths, 1 for values), regardless of KJT count.Pipeline integration (
train_pipeline.py)JaggedMegatronTrainPipelineSparseDistandJaggedMegatronPrefetchTrainPipelineSparseDistupdated with the 2-phase async API.wait_stream) prevent deadlocks between shuffle and DDP/loss all-reduces.Performance & Debug Utilities
attn_perf_tracker.py,hooks.py): CUDA-event-based per-layer TFLOPS/MFU logging via module hooks (PRINT_HSTU_PERF=1).perf.py):DeviceSpecauto-discovery, peak TFLOPS calculation, MFU reporting.PRINT_LOAD_BALANCE=1with start/stop window.UNIFORM,NORMAL,ZIPF,LOGNORMALseqlen distributions forHSTURandom*Dataset.Bug Fixes (from review)
pdb.set_trace()left intest_dataset.py.FusedHSTULayerFunction, keeping only module-level hooks._batch_counterfrom class variable to instance variable.linear_dimparameter from_compute_attn_fwd_flops.actual_batch_sizedocumentation note inpad_and_allgather_batch.CI