Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WarpSpec] add support for multiple channels sharing the same smem #9

Open
wants to merge 3 commits into
base: ws
Choose a base branch
from

Conversation

manman-ren
Copy link
Contributor

location

Summary: We already have channelsGroupedByProducers and channelsGroupedByConsumers. For one-producer-multi-consumer mode, a single buffer will be used, channelsGroupedByProducers is used for this. channelsGroupedByConsumers is to minimize the insertion of sync primitives, a single set of communication ops will be inserted.

For this patch, we want to share the same smem location for multiple channels that are live in different loop nests. We add allocation.shareGroup attributes to the local_allocs corresponding to channels that reuse the same smem location.

In order to reuse the same smem location, we update bufferIdx and phase through all the loop nests that share smem locations. We handle the following cases:
for # persistent loop
for # can be nested under if
for # can be nested under if
Or
for # can be nested under if
for # can be nested under if
Or
for # persistent loop
for # can be nested under if

The generated code will look like
for(accumLoopCount)
t1 = IfOp
forOp # loop A
tmpIdx = accumLoopCount + numStepsA
yield tmpIdx
else yield accumLoopCount
t2 = IfOp
forOp # loop B
tmpIdx = t1 + numStepsB
yield tmpIdx
else yield t1
yield t2 for accumLoopCount

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 14, 2024
@manman-ren
Copy link
Contributor Author

The implementation makes changes to appendBufferIdxArgs/createNewLoop to add an argument in outer loop for accumLoopCount or to add a constant for a place holder when there is no outer loop. It also changes specializeIfOp to create a result for the if to propagate the accumLoopCount.
We then use a helper function updateAccumLoopCount to correctly link up the values.

Phase 1:
ForOp with accumLoopCount as an argument
If
use accumLoopCount to set initialBufferIdx
ForOp
generate numSteps and create an add op for accumLoopCount + numSteps
Yield for ForOp with accumLoopCount (this will be updated later in updateAccumLoopCount)

@htyu
Copy link
Contributor

htyu commented Dec 18, 2024

This is great work, thanks!

BTW, can you include a lit test to help understand what this PR do exactly?

location

Summary: We already have channelsGroupedByProducers and
channelsGroupedByConsumers. For one-producer-multi-consumer mode,
a single buffer will be used, channelsGroupedByProducers is used
for this. channelsGroupedByConsumers is to minimize the insertion of
sync primitives, a single set of communication ops will be inserted.

For this patch, we want to share the same smem location for multiple
channels that are live in different loop nests. We add
allocation.shareGroup attributes to the local_allocs corresponding to
channels that reuse the same smem location.

In order to reuse the same smem location, we update bufferIdx and phase
through all the loop nests that share smem locations. We handle the
following cases:
for # persistent loop
  for # can be nested under if
  for # can be nested under if
Or
for # can be nested under if
for # can be nested under if
Or
for # persistent loop
  for # can be nested under if

The generated code will look like
for(accumLoopCount)
  t1 = IfOp
    forOp # loop A
    tmpIdx = accumLoopCount + numStepsA
    yield tmpIdx
    else yield accumLoopCount
  t2 = IfOp
    forOp # loop B
    tmpIdx = t1 + numStepsB
    yield tmpIdx
    else yield t1
  yield t2 for accumLoopCount

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
if (kv.second.size() <= 1)
continue;
bufferMap[kv.first].getDefiningOp()->setAttr(
"allocation.shareGroup",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question, why is this needed if same buffer is already used on the IR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants