-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WarpSpec] add support for multiple channels sharing the same smem #9
base: ws
Are you sure you want to change the base?
Conversation
The implementation makes changes to appendBufferIdxArgs/createNewLoop to add an argument in outer loop for accumLoopCount or to add a constant for a place holder when there is no outer loop. It also changes specializeIfOp to create a result for the if to propagate the accumLoopCount. Phase 1: |
This is great work, thanks! BTW, can you include a lit test to help understand what this PR do exactly? |
location Summary: We already have channelsGroupedByProducers and channelsGroupedByConsumers. For one-producer-multi-consumer mode, a single buffer will be used, channelsGroupedByProducers is used for this. channelsGroupedByConsumers is to minimize the insertion of sync primitives, a single set of communication ops will be inserted. For this patch, we want to share the same smem location for multiple channels that are live in different loop nests. We add allocation.shareGroup attributes to the local_allocs corresponding to channels that reuse the same smem location. In order to reuse the same smem location, we update bufferIdx and phase through all the loop nests that share smem locations. We handle the following cases: for # persistent loop for # can be nested under if for # can be nested under if Or for # can be nested under if for # can be nested under if Or for # persistent loop for # can be nested under if The generated code will look like for(accumLoopCount) t1 = IfOp forOp # loop A tmpIdx = accumLoopCount + numStepsA yield tmpIdx else yield accumLoopCount t2 = IfOp forOp # loop B tmpIdx = t1 + numStepsB yield tmpIdx else yield t1 yield t2 for accumLoopCount Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1015fdc
to
492969a
Compare
if (kv.second.size() <= 1) | ||
continue; | ||
bufferMap[kv.first].getDefiningOp()->setAttr( | ||
"allocation.shareGroup", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dumb question, why is this needed if same buffer is already used on the IR?
location
Summary: We already have channelsGroupedByProducers and channelsGroupedByConsumers. For one-producer-multi-consumer mode, a single buffer will be used, channelsGroupedByProducers is used for this. channelsGroupedByConsumers is to minimize the insertion of sync primitives, a single set of communication ops will be inserted.
For this patch, we want to share the same smem location for multiple channels that are live in different loop nests. We add allocation.shareGroup attributes to the local_allocs corresponding to channels that reuse the same smem location.
In order to reuse the same smem location, we update bufferIdx and phase through all the loop nests that share smem locations. We handle the following cases:
for # persistent loop
for # can be nested under if
for # can be nested under if
Or
for # can be nested under if
for # can be nested under if
Or
for # persistent loop
for # can be nested under if
The generated code will look like
for(accumLoopCount)
t1 = IfOp
forOp # loop A
tmpIdx = accumLoopCount + numStepsA
yield tmpIdx
else yield accumLoopCount
t2 = IfOp
forOp # loop B
tmpIdx = t1 + numStepsB
yield tmpIdx
else yield t1
yield t2 for accumLoopCount