@@ -10,9 +10,6 @@ namespace CamMoeCombineNormalImpl {
10
10
constexpr uint32_t RANK_ID_OFFSET_IN_SRC_INFO = 0U ;
11
11
constexpr uint32_t TOKEN_IDX_OFFSET_IN_SRC_INFO = 1U ;
12
12
constexpr uint32_t TOPK_IDX_OFFSET_IN_SRC_INFO = 2U ;
13
- constexpr uint32_t RANK_ID_OFFSET_IN_STATUS = 6U ;
14
- constexpr uint32_t TOKEN_IDX_OFFSET_IN_STATUS = 7U ;
15
- constexpr uint32_t STATUS_NUM = 6U ;
16
13
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL ;
17
14
constexpr uint64_t MAGIC_WIN_OFFSET = 975UL * 1024UL ;
18
15
constexpr uint32_t TOKEN_SRC_INFO_LEN = 3U ;
@@ -49,10 +46,10 @@ class CamMoeCombineNormal {
49
46
__aicore__ inline void InitTilingData (const CamMoeCombineNormalTilingData *tilingData);
50
47
__aicore__ inline void InitBuffLen ();
51
48
__aicore__ inline void CopyBufferToShareAndSetStatus ();
52
- __aicore__ inline void CopyBufferToShare (uint32_t tkIndex);
49
+ __aicore__ inline void CopyBufferToShare (uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, uint32_t tkIndex);
53
50
__aicore__ inline void ReadBufferFromRemote ();
54
51
__aicore__ inline void WaitBuffCopy (uint32_t tokenIndex);
55
- __aicore__ inline void SetStatusBySrcInfo (uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, uint32_t tkIndex );
52
+ __aicore__ inline void SetStatusBySrcInfo (uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId);
56
53
__aicore__ inline void ReadBufferAndWeightedSum (uint32_t tokenIndex, uint32_t startTokenIndex);
57
54
58
55
__aicore__ GM_ADDR GetStateAddrByRankId (const int32_t rankId)
@@ -226,24 +223,25 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToSha
226
223
227
224
SyncFunc<AscendC::HardEvent::MTE2_S>();
228
225
for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; tokenIndex++) {
229
- CopyBufferToShare (tokenIndex);
230
- PipeBarrier<PIPE_ALL>();
231
226
uint32_t index = (tokenIndex - startTokenId) * TOKEN_SRC_INFO_LEN;
232
227
uint32_t srcRankId = static_cast <uint32_t >(srcInfoLocal (index + RANK_ID_OFFSET_IN_SRC_INFO));
233
228
uint32_t srcTokenId = static_cast <uint32_t >(srcInfoLocal (index + TOKEN_IDX_OFFSET_IN_SRC_INFO));
234
229
uint32_t srcTopkId = static_cast <uint32_t >(srcInfoLocal (index + TOPK_IDX_OFFSET_IN_SRC_INFO));
235
- SetStatusBySrcInfo (srcRankId, srcTokenId, srcTopkId, tokenIndex);
230
+ CopyBufferToShare (srcRankId, srcTokenId, srcTopkId, tokenIndex);
231
+ PipeBarrier<PIPE_ALL>();
232
+ SetStatusBySrcInfo (srcRankId, srcTokenId, srcTopkId);
236
233
}
237
234
SyncFunc<AscendC::HardEvent::MTE3_S>();
238
235
}
239
236
240
237
template <TemplateMC2TypeClass>
241
- __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShare(uint32_t tkIndex)
238
+ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId,
239
+ uint32_t srcTopkId, uint32_t tkIndex)
242
240
{
243
241
uint32_t tokenOffset = tkIndex * axisH_;
244
- GM_ADDR rankGM = localRankGM_ + tkIndex * h512AlignRecvXLen_;
245
- GlobalTensor<XType> localRankWindow ;
246
- localRankWindow .SetGlobalBuffer ((__gm__ XType*)rankGM );
242
+ GM_ADDR dstGM = GetBufferAddrByRankId (srcRankId) + (srcTokenId * axisK_ + srcTopkId) * h512AlignRecvXLen_;
243
+ GlobalTensor<XType> dstWindow ;
244
+ dstWindow .SetGlobalBuffer ((__gm__ XType*)dstGM );
247
245
DataCopyExtParams xOutCopyParams{1U , static_cast <uint32_t >(hRecvXTypeLen_), 0U , 0U , 0U };
248
246
DataCopyPadExtParams<RecvXType> copyPadExtParams{false , 0U , 0U , 0U };
249
247
@@ -252,18 +250,16 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToSha
252
250
DataCopyPad (localCopyTensor, recvXGM_[tokenOffset], xOutCopyParams, copyPadExtParams);
253
251
localCopyQueue_.EnQue (localCopyTensor);
254
252
localCopyTensor = localCopyQueue_.DeQue <RecvXType>();
255
- DataCopyPad (localRankWindow , localCopyTensor, xOutCopyParams);
253
+ DataCopyPad (dstWindow , localCopyTensor, xOutCopyParams);
256
254
localCopyQueue_.FreeTensor <RecvXType>(localCopyTensor);
257
255
}
258
256
259
257
template <TemplateMC2TypeClass>
260
258
__aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId,
261
- uint32_t srcTopkId, uint32_t tkIndex )
259
+ uint32_t srcTopkId)
262
260
{
263
261
LocalTensor<uint32_t > statusTensor = stateBuf_.AllocTensor <uint32_t >();
264
- statusTensor.SetValue (RANK_ID_OFFSET_IN_STATUS, epRankId_);
265
- statusTensor.SetValue (TOKEN_IDX_OFFSET_IN_STATUS, tkIndex);
266
- GM_ADDR stateGM = GetStateAddrByRankId (srcRankId) + srcTokenId * axisK_ * UB_32_ALIGN + srcTopkId * UB_32_ALIGN;
262
+ GM_ADDR stateGM = GetStateAddrByRankId (srcRankId) + (srcTokenId * axisK_ + srcTopkId) * UB_32_ALIGN;
267
263
GlobalTensor<uint32_t > stateGMTensor;
268
264
stateGMTensor.SetGlobalBuffer ((__gm__ uint32_t *)stateGM);
269
265
DataCopy<uint32_t >(stateGMTensor, statusTensor, FLOAT_NUM_PER_ALIGN);
@@ -277,18 +273,15 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::WaitBuffCopy(ui
277
273
GlobalTensor<float > stateGMTensor;
278
274
stateGMTensor.SetGlobalBuffer ((__gm__ float *)stateGM);
279
275
float current = (float )0.0 ;
280
- float target = (float )1.0 * axisK_ * STATUS_NUM;
281
- SumParams sumPerKParams{axisK_, FLOAT_NUM_PER_ALIGN, STATUS_NUM};
282
- SumParams sumTokenParams{1 , axisK_, axisK_};
276
+ float target = (float )1.0 * axisK_ * FLOAT_NUM_PER_ALIGN;
277
+ SumParams sumPerKParams{1 , calCount, calCount};
283
278
LocalTensor<float > stateTensorLocal = stateBuf_.Get <float >();
284
279
LocalTensor<float > tempStateTensorLocal = tempStateBuf_.Get <float >();
285
280
while (current != target) {
286
281
SyncFunc<AscendC::HardEvent::S_MTE2>();
287
282
DataCopy<float >(stateTensorLocal, stateGMTensor, calCount);
288
283
SyncFunc<AscendC::HardEvent::MTE2_V>();
289
284
Sum (tempStateTensorLocal, stateTensorLocal, sumPerKParams);
290
- SyncFunc<AscendC::HardEvent::V_V>();
291
- Sum (tempStateTensorLocal, tempStateTensorLocal, sumTokenParams);
292
285
SyncFunc<AscendC::HardEvent::V_S>();
293
286
current = tempStateTensorLocal (0 );
294
287
}
@@ -311,16 +304,14 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::ReadBufferAndWe
311
304
const DataCopyExtParams xOutCopyParams{1U , static_cast <uint32_t >(hRecvXTypeLen_), 0U , 0U , 0U };
312
305
313
306
for (uint32_t topkId = 0U ; topkId < axisK_; topkId++) {
314
- uint32_t remoteRankId = stateTensorLocal.GetValue (topkId * FLOAT_NUM_PER_ALIGN + RANK_ID_OFFSET_IN_STATUS);
315
- uint32_t remoteTokenIndex = stateTensorLocal.GetValue (topkId * FLOAT_NUM_PER_ALIGN + TOKEN_IDX_OFFSET_IN_STATUS);
316
307
float scale = topkWeightsLocal.GetValue ((tokenIndex - startTokenIndex) * axisK_ + topkId);
317
- GM_ADDR remoteTokenAddr = (__gm__ uint8_t *)( GetBufferAddrByRankId (remoteRankId)) + remoteTokenIndex * h512AlignRecvXLen_;
318
- GlobalTensor<XType> remoteTokenATensor ;
319
- remoteTokenATensor .SetGlobalBuffer ((__gm__ XType*)remoteTokenAddr );
308
+ GM_ADDR localTokenAddr = localRankGM_ + (tokenIndex * axisK_ + topkId) * h512AlignRecvXLen_;
309
+ GlobalTensor<XType> localTokenTensor ;
310
+ localTokenTensor .SetGlobalBuffer ((__gm__ XType*)localTokenAddr );
320
311
321
312
LocalTensor<XType> tmpToken = weightedSumQueue_.AllocTensor <XType>();
322
313
const DataCopyPadExtParams<RecvXType> copyPadExtParams{false , 0U , 0U , 0U };
323
- DataCopyPad (tmpToken, remoteTokenATensor , xOutCopyParams, copyPadExtParams);
314
+ DataCopyPad (tmpToken, localTokenTensor , xOutCopyParams, copyPadExtParams);
324
315
weightedSumQueue_.EnQue (tmpToken);
325
316
tmpToken = weightedSumQueue_.DeQue <XType>();
326
317
Cast (tokenFloatLocal, tmpToken, AscendC::RoundMode::CAST_NONE, axisH_);
@@ -383,4 +374,4 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::Process()
383
374
}
384
375
385
376
} // CamMoeCombineNormalImpl
386
- #endif // MOE_COMBINE_IMPL_H
377
+ #endif // MOE_COMBINE_IMPL_H
0 commit comments