Skip to content

all_reduce_one_shot / all_reduce_two_shot use hardcoded lock value 1 #465

@aamarnat

Description

@aamarnat

Bug

In iris/x/all_reduce.py, all_reduce_one_shot and all_reduce_two_shot use a hardcoded value of 1 to signal "tile ready":

  • Producers write: tl.atomic_xchg(lock_ptr, 1, sem="release")
  • Consumers spin: while iris.atomic_add(lock_ptr, 0, ...) != 1: pass

Between calls, the lock array must be zeroed back to 0 via a collective shmem.zeros + barrier, adding overhead to every kernel invocation. If the lock array is not properly zeroed (e.g., due to workspace reuse or error), consumers see lock == 1 from a previous call and read stale data.

Impact

  • Per-call overhead from mandatory lock zeroing + barrier between invocations
  • Fragile: skipping the zeroing step silently produces wrong results
  • Prevents efficient workspace reuse across calls

Fix

Add a call_number parameter to both functions:

  • Producers signal with: tl.atomic_xchg(lock_ptr, call_number, sem="release", scope="sys")
  • Consumers spin on: while iris.atomic_add(lock_ptr, 0, ...) != call_number: pass

Add a monotonically increasing call_counter field to FusedWorkspace, incremented on every matmul_all_reduce call. Each call uses a new version number, so stale locks from previous calls are automatically ignored without zeroing.

Component

iris/x/all_reduce.py, iris/ops/workspace.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingirisIris project issue

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions