Skip to content

Optimize balancer and setup debug logger.#308

Merged
shijieliu merged 22 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/opt_balancer
Mar 13, 2026
Merged

Optimize balancer and setup debug logger.#308
shijieliu merged 22 commits intoNVIDIA:mainfrom
JacoCheung:junzhang/opt_balancer

Conversation

@JacoCheung
Copy link
Collaborator

@JacoCheung JacoCheung commented Feb 11, 2026

Description

This PR optimizes the HSTU batch balancer pipeline to fully hide the KK (Karmarkar-Karp) algorithm overhead and reduce AllGather communication cost.

Core Changes

Async 2-phase KK pipeline

  • Split shuffle() into start_shuffle_async() (Phase 1) + finish_shuffle() (Phase 2).
  • Phase 1 runs AllGather workloads (NCCL) on main thread, then submits KK partitioning to a background ThreadPoolExecutor (pure CPU, no GPU access).
  • Phase 2 waits for KK result, then redistributes batch data (AllGather or All2All).
  • KK computation fully overlaps with forward/backward on the GPU.
  • ShuffleHandle provides stable per-batch tracking for the prefetch pipeline (multiple in-flight batches).

All2All shuffle path (batch_all2all.py)

  • New O(B) communication alternative to AllGather+index_select.
  • Toggled via SHUFFLE_WITH_ALL2ALL=1 env var.
  • Local dst_rank derivation (no extra communication when partitions are symmetric).

Fused multi-KJT AllGather (collective_ops.py)

  • keyed_jagged_tensor_list_allgather: gathers all KJTs with only 2 NCCL calls (1 for lengths, 1 for values), regardless of KJT count.
  • Batch size and dtype validation added for input KJTs.

Pipeline integration (train_pipeline.py)

  • Both JaggedMegatronTrainPipelineSparseDist and JaggedMegatronPrefetchTrainPipelineSparseDist updated with the 2-phase async API.
  • NCCL ordering guards (wait_stream) prevent deadlocks between shuffle and DDP/loss all-reduces.

Performance & Debug Utilities

  • Attention perf tracker (attn_perf_tracker.py, hooks.py): CUDA-event-based per-layer TFLOPS/MFU logging via module hooks (PRINT_HSTU_PERF=1).
  • GPU device spec & MFU (perf.py): DeviceSpec auto-discovery, peak TFLOPS calculation, MFU reporting.
  • Balancer logger: load-balance before/after stats, controlled by PRINT_LOAD_BALANCE=1 with start/stop window.
  • Dataset distribution support: UNIFORM, NORMAL, ZIPF, LOGNORMAL seqlen distributions for HSTURandom*Dataset.
  • Balanced attention benchmark: standalone benchmark for HSTU kernel SoL and balancer speedup analysis.

Bug Fixes (from review)

  • Removed accidental pdb.set_trace() left in test_dataset.py.
  • Fixed FLOP double-counting: removed inline perf tracking from FusedHSTULayerFunction, keeping only module-level hooks.
  • Converted _batch_counter from class variable to instance variable.
  • Fixed off-by-one in async path print window check (pass handle idx explicitly).
  • Removed unused linear_dim parameter from _compute_attn_fwd_flops.
  • Added actual_batch_size documentation note in pad_and_allgather_batch.

CI

@greptile-apps
Copy link

greptile-apps bot commented Feb 11, 2026

Greptile Summary

This PR delivers a significant refactor of the HSTU batch balancer pipeline: the Karmarkar-Karp (KK) algorithm is moved to a background ThreadPoolExecutor so its CPU-only cost overlaps with GPU forward/backward, a new O(B) All2All shuffle path avoids the AllGather+index_select overhead, and a fused multi-KJT AllGather reduces NCCL call count from O(num_keys) to 2. It also adds CUDA-event-based attention TFLOPS/MFU logging, configurable sequence-length / value distributions for synthetic datasets, and several bug fixes from the previous review round (instance-variable _batch_counter, off-by-one print window, super().__init__() calls).

Key changes and notable findings:

  • 2-phase async KK API (start_shuffle_async / finish_shuffle): well-designed; ShuffleHandle prevents id-reuse collisions; NCCL ordering is correctly guarded with wait_stream calls in both pipeline variants.
  • Division-by-zero risk in the new dense-tensor padding helper (batch_allgather.py line 65, batch_all2all.py line 446): when actual_batch_size == 0 (fully-empty tail batch at epoch end), the calculation t.numel() // actual_batch_size raises ZeroDivisionError.
  • Zero-workload real samples misclassified as padding (batch_shuffler.py lines 275–278): the has_padding heuristic checks only the last element of each rank's chunk; a legitimate user with workload == 0 (empty history) placed at the end of a chunk will silently trigger the padding-stripping path, corrupting actual_batch_size and dropping dense tensor rows.
  • Fragile positional argument access (hooks.py line 28): args[3] in the forward pre-hook assumes offsets is always the fourth positional argument of HSTUAttention.forward; an assertion or comment anchoring this assumption would prevent silent corruption on signature changes.
  • Silent early return in batched_dynamicemb_tables.py drops the logging.warning that previously informed users when prefetch was called with caching disabled.

Confidence Score: 3/5

  • The PR is functionally sound for the common case but contains two logic bugs (division-by-zero on empty tail batches; zero-workload samples treated as padding) that can cause crashes or silent data corruption under edge-case inputs.
  • The architecture and NCCL ordering guards are well thought-out, and several previous review issues have been fixed. However, the division-by-zero in the dense padding helpers is a crash-level bug for any training run that exhausts a rank's data mid-epoch, and the zero-workload padding heuristic can silently corrupt batch state for users with empty interaction histories — both of which are realistic production scenarios.
  • examples/commons/distributed/batch_allgather.py (line 65), examples/commons/distributed/batch_all2all.py (line 446), examples/commons/distributed/batch_shuffler.py (lines 275–278, 358–362)

Important Files Changed

Filename Overview
examples/commons/distributed/batch_shuffler.py Core async KK pipeline implementation; has_padding detection has an undocumented assumption (workload==0 → padding) that could misclassify zero-workload real samples, causing data loss in _strip_dense_padding.
examples/commons/distributed/batch_allgather.py Refactored to use fused KJT AllGather; dense-tensor padding computation at line 66 performs integer division by actual_batch_size, which raises ZeroDivisionError when that value is 0.
examples/commons/distributed/batch_all2all.py New All2All shuffle path; same division-by-zero risk in dense padding computation (line 446) when actual_batch_size == 0; otherwise well-structured with correct fused NCCL approach.
examples/commons/utils/hooks.py New CUDA-event-based perf hooks; _fwd_pre_hook accesses args[3] by positional index to obtain the offsets tensor — fragile if HSTUAttention.forward's parameter order ever changes.
examples/commons/ops/collective_ops.py New fused keyed_jagged_tensor_list_allgather reduces NCCL calls from O(num_keys) to 2; batch_size and dtype validation added per previous review; logic is sound.
examples/commons/pipeline/train_pipeline.py Both pipeline variants updated with 2-phase async KK; wait_sparse_data_dist correctly moved before Phase 1 to prevent concurrent NCCL on the DP communicator; NCCL ordering guards are well-documented.
corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py Changed prefetch_next_batch to silently return early when caching is disabled instead of logging a warning; loses a useful diagnostic but avoids spurious prefetch_stream initialization.

Sequence Diagram

sequenceDiagram
    participant DL as DataLoader
    participant Pipe as TrainPipeline (main thread)
    participant GPU as GPU (_memcpy_stream)
    participant KK as KK Thread (background)
    participant NCCL as NCCL (DP group)

    Note over Pipe,NCCL: Iteration N

    Pipe->>Pipe: start_sparse_data_dist(batch[N+1])
    Pipe->>GPU: H2D copy batch[N+2]
    Pipe->>Pipe: wait_sparse_data_dist(batch[N+1])

    Note over Pipe,NCCL: Phase 1 – on _memcpy_stream
    GPU->>NCCL: AllGather workloads (batch[N+2])
    NCCL-->>GPU: allgather_workloads_cpu (D2H)
    GPU->>KK: submit KK(allgather_workloads_cpu)

    Note over Pipe,GPU: Forward / Backward (overlaps KK)
    Pipe->>GPU: forward(batch[N])
    Pipe->>GPU: backward + optimizer step

    Note over Pipe,NCCL: Phase 2 – on _memcpy_stream
    KK-->>GPU: partitions_indices (future.result())
    GPU->>NCCL: AllGather batch[N+2] (or All2All)
    NCCL-->>GPU: redistributed batch[N+2]
    GPU->>GPU: index_select / _strip_dense_padding

    Pipe->>NCCL: wait_stream(_memcpy_stream) → loss AllReduce
    Pipe->>NCCL: DDP AllReduce gradients
Loading

Comments Outside Diff (5)

  1. examples/commons/distributed/batch_allgather.py, line 65-68 (link)

    Division by zero when actual_batch_size == 0

    t.numel() // orig_actual_bs raises ZeroDivisionError when orig_actual_bs == 0. This can occur at the tail of a training epoch when a rank exhausts its data and produces a batch where every sample is "padding" (i.e. actual_batch_size is set to 0). The condition batch.actual_batch_size < batch.batch_size would be satisfied, so execution reaches the division without any guard.

    The same pattern is present in batch_all2all.py at the equivalent line.

  2. examples/commons/distributed/batch_all2all.py, line 444-449 (link)

    Division by zero when actual_batch_size == 0

    Same issue as in batch_allgather.py: t.numel() // batch.actual_batch_size raises ZeroDivisionError when batch.actual_batch_size == 0. The condition pad_dense = batch.actual_batch_size < local_batch_size is True for a fully-empty batch (actual_batch_size == 0), so this path is reached before any guard.

  3. examples/commons/distributed/batch_shuffler.py, line 275-278 (link)

    Zero-workload real samples silently treated as padding

    The check allgather_workloads_cpu[(i + 1) * local_batch_size - 1] == 0 only inspects the last element of each rank's chunk to decide whether any padding is present. If a legitimate user sample (real, non-padding) happens to have workload == 0 — for example, a user with an empty interaction history — and is placed last in a rank's chunk, has_padding will be set to True.

    This then cascades into the actual_bs computation at line 359:

    actual_bs = (allgather_wl[indices_this_rank] > 0).sum().item()

    …which would exclude all zero-workload real samples from actual_bs, causing _strip_dense_padding to silently drop their dense-tensor rows.

    At minimum this assumption should be documented in the function docstring and guarded in get_workloads implementations. Consider using a dedicated sentinel value (e.g. negative workload or a separate boolean mask) to distinguish true padding from zero-cost real samples.

  4. examples/commons/utils/hooks.py, line 27-30 (link)

    Fragile positional argument access in forward pre-hook

    args[3] hard-codes the position of the offsets parameter in HSTUAttention.forward(self, tq, tk, tv, offsets, ...). If a new parameter is ever inserted before offsets (e.g. scaling or a mask), this line will silently pick up the wrong tensor — the type is still torch.Tensor so there's no type-error at hook registration time, only silent numerical corruption at runtime.

    Consider binding by name instead, which is safe regardless of order changes:

    def _fwd_pre_hook(module, args, kwargs):
        # args: (tq, tk, tv, offsets, max_seqlen, ...)
        # Use the named parameter directly from the bound call's args mapping
        # or document that args[3] == offsets and add a guard.
        assert len(args) >= 4, "Expected at least 4 positional args (tq, tk, tv, offsets)"
        offsets = args[3]  # HSTUAttention.forward arg order: tq, tk, tv, offsets
  5. corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py, line 897-898 (link)

    Silent early return drops useful misconfiguration warning

    The previous code emitted logging.warning("Caching is not enabled, prefetch will do nothing.") before falling through, giving users a clear signal that they've enabled prefetch without caching (a likely misconfiguration). The new return silently swallows the call with only a comment, making it harder to diagnose the mistake at runtime.

    Consider preserving the warning:

    if not self._caching:
        # caching is not enabled, nothing to prefetch
        warnings.warn(
            "prefetch_next_batch called but caching is not enabled — skipping.",
            stacklevel=2,
        )
        return

    (warnings is already imported in this module.)

Last reviewed commit: f92e576

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

34 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@JacoCheung JacoCheung force-pushed the junzhang/opt_balancer branch from 8b68ea4 to 2e3a262 Compare February 12, 2026 12:02
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@JacoCheung JacoCheung requested a review from shijieliu February 25, 2026 07:02
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

30 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

31 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@JacoCheung JacoCheung force-pushed the junzhang/opt_balancer branch from 6d3e1a6 to 5a0eff0 Compare February 25, 2026 09:10
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

31 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@JacoCheung JacoCheung force-pushed the junzhang/opt_balancer branch from 90abbe0 to a0019e0 Compare March 10, 2026 10:26
@JacoCheung JacoCheung force-pushed the junzhang/opt_balancer branch 3 times, most recently from 709bb20 to bf3ee02 Compare March 12, 2026 08:13
…ecv_ids.numel()

all2all_batch was incorrectly overwriting batch_size with recv_ids.numel()
and setting actual_batch_size = batch_size. Since _apply_to_tensors_or_kjt
already copies batch_size from the source batch, no reassignment is needed.
actual_batch_size is now set to recv_ids.numel() directly.

Also adds test_all2all_vs_allgather_incomplete_batch with per-rank random
actual_batch_size to verify data correctness under heterogeneous incomplete
batches.

Made-with: Cursor
…l path

Read SHUFFLE_WITH_ALL2ALL from environment at module load. When set to "1",
finish_shuffle and shuffle use the All2All communication path regardless of
the use_all2all parameter (OR logic).

Made-with: Cursor
…dst_rank

- pad_and_allgather_batch: zero-pad dense tensors to batch_size before
  AllGather; optionally return is_padding mask via return_padding_flag
- pad_and_all2all_batch: zero-pad dense tensors before All2All; make
  dst_rank and recv_counts required params; remove _build_dst_rank
  (was the only consumer of an extra all_to_all_single call)
- _strip_dense_padding: new module-level utility to truncate dense
  tensors to actual_batch_size rows while keeping KJTs intact
- finish_shuffle / shuffle: apply _strip_dense_padding when padding
  is present so output batch satisfies the 4 invariants:
    batch_size = KJT dim (includes padding)
    actual_batch_size = dense dim (real samples only)
    Dense tensor dim-0 = actual_batch_size
    KJT lengths count = batch_size * num_keys
- BaseBatch: update docstring and field comments to document invariants;
  index_select_dense_tensor uses actual_batch_size for reshape
- Tests: add _strip_dense_padding verification to allgather and all2all
  tests; merge incomplete batch test into test_all2all_vs_allgather;
  use _build_dst_rank_local to pre-compute dst_rank/recv_counts

Made-with: Cursor
… after shuffle

After balanced shuffler, _strip_dense_padding trims dense tensors to
actual_batch_size while KJTs retain batch_size entries (padding samples
have zero-length sequences). This causes a dimension mismatch in
_get_high_inds when max_num_candidates > 0. Re-pad num_candidates with
zeros to align with the KJT-derived sequence_embeddings_lengths.

Made-with: Cursor
- Remove duplicate FLOP counting in FusedHSTULayerFunction (keep module hooks only)
- Remove accidental pdb.set_trace() in test_dataset.py
- Convert _batch_counter from class variable to instance variable
- Fix off-by-one in async path print window check by passing handle idx
- Remove unused linear_dim parameter from _compute_attn_fwd_flops
- Add batch_size and values dtype validation in KJT list allgather
- Add NOTE comment for actual_batch_size in pad_and_allgather_batch

Made-with: Cursor
When world_size==1 and actual_batch_size < batch_size (incomplete
batch), pad_and_all2all_batch previously returned the original batch
without padding dense tensors. This caused shape mismatches with
pad_and_allgather_batch which always pads.

Unify both functions to follow the same world_size==1 contract:
- Complete batch: return as-is (no-op)
- Incomplete batch: pad dense tensors, update actual_batch_size, skip
  all collective communication
- KJT fields returned unchanged in both cases

Document the fast-path behaviour in both docstrings.

Made-with: Cursor
@JacoCheung JacoCheung force-pushed the junzhang/opt_balancer branch from bf3ee02 to f92e576 Compare March 13, 2026 02:10
@shijieliu shijieliu merged commit 7cf4575 into NVIDIA:main Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants