Skip to content

Commit 95cb562

Browse files
authored
combine use dedicated memory (#63)
* combine use dedicated memory combine use dedicated memory ensure that there is no intersection with the memory used by dispatch * update actual memory size used by moe * modify combine * remove temporary plan
1 parent fee6d02 commit 95cb562

File tree

3 files changed

+25
-35
lines changed

3 files changed

+25
-35
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
namespace deep_ep {
1212
constexpr int PADDING_SIZE = 3;
1313
constexpr size_t HCOMM_NAME_LEN = 128;
14-
constexpr double SCALE_SIZE = 1.5;
1514

1615
Buffer::Buffer(int64_t rank, int64_t num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode,
1716
std::string moe_all_to_all_group_name)
@@ -266,8 +265,8 @@ Buffer::intranode_dispatch(const at::Tensor& x, const std::optional<at::Tensor>&
266265
int64_t tp_size = 1;
267266
int64_t tp_rank = 0;
268267
int64_t quant_mode = 0;
269-
int64_t global_bs = static_cast<int64_t>(std::ceil(
270-
std::max(num_max_dispatch_tokens_per_rank * num_ranks, static_cast<int64_t>(num_worst_tokens)) * SCALE_SIZE));
268+
int64_t global_bs = static_cast<int64_t>(
269+
std::max(num_max_dispatch_tokens_per_rank * num_ranks, static_cast<int64_t>(num_worst_tokens)));
271270

272271
auto send_token_idx = send_token_idx_cpu.to(x.device());
273272
auto recv_offset = recv_offset_cpu.to(x.device());

csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,8 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext*
483483
// dispatch数据区 token首对齐512,有效token长度h_align_32b + scale(32b) + 三元组(3*4b)
484484
uint64_t tokenActualLen = ((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_RECV_IDX_BUFFER;
485485
uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
486-
uint64_t actualSize = ((maxBs * tokenNeedSizeDispatch) + (maxBs * tokenNeedSizeCombine * k) +
487-
COMBINE_STATE_WIN_OFFSET) * DOUBLE_DATA_BUFFER;
486+
uint64_t actualSize = (maxBs * k * (tokenNeedSizeCombine + tokenNeedSizeDispatch) + COMBINE_STATE_WIN_OFFSET) *
487+
DOUBLE_DATA_BUFFER;
488488
OP_TILING_CHECK((actualSize > maxWindowSize),
489489
OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u,"
490490
" tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE("
@@ -543,4 +543,4 @@ ge::graphStatus TilingParseForCamMoeCombineNormal(gert::TilingParseContext *cont
543543
IMPL_OP_OPTILING(CamMoeCombineNormal)
544544
.Tiling(CamMoeCombineNormalTilingFunc)
545545
.TilingParse<CamMoeCombineNormalCompileInfo>(TilingParseForCamMoeCombineNormal);
546-
} // namespace optiling
546+
} // namespace optiling

csrc/deepep/ops/op_kernel/cam_moe_combine_normal.h

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ namespace CamMoeCombineNormalImpl {
1010
constexpr uint32_t RANK_ID_OFFSET_IN_SRC_INFO = 0U;
1111
constexpr uint32_t TOKEN_IDX_OFFSET_IN_SRC_INFO = 1U;
1212
constexpr uint32_t TOPK_IDX_OFFSET_IN_SRC_INFO = 2U;
13-
constexpr uint32_t RANK_ID_OFFSET_IN_STATUS = 6U;
14-
constexpr uint32_t TOKEN_IDX_OFFSET_IN_STATUS = 7U;
15-
constexpr uint32_t STATUS_NUM = 6U;
1613
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL;
1714
constexpr uint64_t MAGIC_WIN_OFFSET = 975UL * 1024UL;
1815
constexpr uint32_t TOKEN_SRC_INFO_LEN = 3U;
@@ -49,10 +46,10 @@ class CamMoeCombineNormal {
4946
__aicore__ inline void InitTilingData(const CamMoeCombineNormalTilingData *tilingData);
5047
__aicore__ inline void InitBuffLen();
5148
__aicore__ inline void CopyBufferToShareAndSetStatus();
52-
__aicore__ inline void CopyBufferToShare(uint32_t tkIndex);
49+
__aicore__ inline void CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, uint32_t tkIndex);
5350
__aicore__ inline void ReadBufferFromRemote();
5451
__aicore__ inline void WaitBuffCopy(uint32_t tokenIndex);
55-
__aicore__ inline void SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, uint32_t tkIndex);
52+
__aicore__ inline void SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId);
5653
__aicore__ inline void ReadBufferAndWeightedSum(uint32_t tokenIndex, uint32_t startTokenIndex);
5754

5855
__aicore__ GM_ADDR GetStateAddrByRankId(const int32_t rankId)
@@ -226,24 +223,25 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToSha
226223

227224
SyncFunc<AscendC::HardEvent::MTE2_S>();
228225
for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; tokenIndex++) {
229-
CopyBufferToShare(tokenIndex);
230-
PipeBarrier<PIPE_ALL>();
231226
uint32_t index = (tokenIndex - startTokenId) * TOKEN_SRC_INFO_LEN;
232227
uint32_t srcRankId = static_cast<uint32_t>(srcInfoLocal(index + RANK_ID_OFFSET_IN_SRC_INFO));
233228
uint32_t srcTokenId = static_cast<uint32_t>(srcInfoLocal(index + TOKEN_IDX_OFFSET_IN_SRC_INFO));
234229
uint32_t srcTopkId = static_cast<uint32_t>(srcInfoLocal(index + TOPK_IDX_OFFSET_IN_SRC_INFO));
235-
SetStatusBySrcInfo(srcRankId, srcTokenId, srcTopkId, tokenIndex);
230+
CopyBufferToShare(srcRankId, srcTokenId, srcTopkId, tokenIndex);
231+
PipeBarrier<PIPE_ALL>();
232+
SetStatusBySrcInfo(srcRankId, srcTokenId, srcTopkId);
236233
}
237234
SyncFunc<AscendC::HardEvent::MTE3_S>();
238235
}
239236

240237
template <TemplateMC2TypeClass>
241-
__aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShare(uint32_t tkIndex)
238+
__aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId,
239+
uint32_t srcTopkId, uint32_t tkIndex)
242240
{
243241
uint32_t tokenOffset = tkIndex * axisH_;
244-
GM_ADDR rankGM = localRankGM_ + tkIndex * h512AlignRecvXLen_;
245-
GlobalTensor<XType> localRankWindow;
246-
localRankWindow.SetGlobalBuffer((__gm__ XType*)rankGM);
242+
GM_ADDR dstGM = GetBufferAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * h512AlignRecvXLen_;
243+
GlobalTensor<XType> dstWindow;
244+
dstWindow.SetGlobalBuffer((__gm__ XType*)dstGM);
247245
DataCopyExtParams xOutCopyParams{1U, static_cast<uint32_t>(hRecvXTypeLen_), 0U, 0U, 0U};
248246
DataCopyPadExtParams<RecvXType> copyPadExtParams{false, 0U, 0U, 0U};
249247

@@ -252,18 +250,16 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::CopyBufferToSha
252250
DataCopyPad(localCopyTensor, recvXGM_[tokenOffset], xOutCopyParams, copyPadExtParams);
253251
localCopyQueue_.EnQue(localCopyTensor);
254252
localCopyTensor = localCopyQueue_.DeQue<RecvXType>();
255-
DataCopyPad(localRankWindow, localCopyTensor, xOutCopyParams);
253+
DataCopyPad(dstWindow, localCopyTensor, xOutCopyParams);
256254
localCopyQueue_.FreeTensor<RecvXType>(localCopyTensor);
257255
}
258256

259257
template <TemplateMC2TypeClass>
260258
__aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId,
261-
uint32_t srcTopkId, uint32_t tkIndex)
259+
uint32_t srcTopkId)
262260
{
263261
LocalTensor<uint32_t> statusTensor = stateBuf_.AllocTensor<uint32_t>();
264-
statusTensor.SetValue(RANK_ID_OFFSET_IN_STATUS, epRankId_);
265-
statusTensor.SetValue(TOKEN_IDX_OFFSET_IN_STATUS, tkIndex);
266-
GM_ADDR stateGM = GetStateAddrByRankId(srcRankId) + srcTokenId * axisK_ * UB_32_ALIGN + srcTopkId * UB_32_ALIGN;
262+
GM_ADDR stateGM = GetStateAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * UB_32_ALIGN;
267263
GlobalTensor<uint32_t> stateGMTensor;
268264
stateGMTensor.SetGlobalBuffer((__gm__ uint32_t*)stateGM);
269265
DataCopy<uint32_t>(stateGMTensor, statusTensor, FLOAT_NUM_PER_ALIGN);
@@ -277,18 +273,15 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::WaitBuffCopy(ui
277273
GlobalTensor<float> stateGMTensor;
278274
stateGMTensor.SetGlobalBuffer((__gm__ float*)stateGM);
279275
float current = (float)0.0;
280-
float target = (float)1.0 * axisK_ * STATUS_NUM;
281-
SumParams sumPerKParams{axisK_, FLOAT_NUM_PER_ALIGN, STATUS_NUM};
282-
SumParams sumTokenParams{1, axisK_, axisK_};
276+
float target = (float)1.0 * axisK_ * FLOAT_NUM_PER_ALIGN;
277+
SumParams sumPerKParams{1, calCount, calCount};
283278
LocalTensor<float> stateTensorLocal = stateBuf_.Get<float>();
284279
LocalTensor<float> tempStateTensorLocal = tempStateBuf_.Get<float>();
285280
while (current != target) {
286281
SyncFunc<AscendC::HardEvent::S_MTE2>();
287282
DataCopy<float>(stateTensorLocal, stateGMTensor, calCount);
288283
SyncFunc<AscendC::HardEvent::MTE2_V>();
289284
Sum(tempStateTensorLocal, stateTensorLocal, sumPerKParams);
290-
SyncFunc<AscendC::HardEvent::V_V>();
291-
Sum(tempStateTensorLocal, tempStateTensorLocal, sumTokenParams);
292285
SyncFunc<AscendC::HardEvent::V_S>();
293286
current = tempStateTensorLocal(0);
294287
}
@@ -311,16 +304,14 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::ReadBufferAndWe
311304
const DataCopyExtParams xOutCopyParams{1U, static_cast<uint32_t>(hRecvXTypeLen_), 0U, 0U, 0U};
312305

313306
for (uint32_t topkId = 0U; topkId < axisK_; topkId++) {
314-
uint32_t remoteRankId = stateTensorLocal.GetValue(topkId * FLOAT_NUM_PER_ALIGN + RANK_ID_OFFSET_IN_STATUS);
315-
uint32_t remoteTokenIndex = stateTensorLocal.GetValue(topkId * FLOAT_NUM_PER_ALIGN + TOKEN_IDX_OFFSET_IN_STATUS);
316307
float scale = topkWeightsLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId);
317-
GM_ADDR remoteTokenAddr = (__gm__ uint8_t*)(GetBufferAddrByRankId(remoteRankId)) + remoteTokenIndex * h512AlignRecvXLen_;
318-
GlobalTensor<XType> remoteTokenATensor;
319-
remoteTokenATensor.SetGlobalBuffer((__gm__ XType*)remoteTokenAddr);
308+
GM_ADDR localTokenAddr = localRankGM_ + (tokenIndex * axisK_ + topkId) * h512AlignRecvXLen_;
309+
GlobalTensor<XType> localTokenTensor;
310+
localTokenTensor.SetGlobalBuffer((__gm__ XType*)localTokenAddr);
320311

321312
LocalTensor<XType> tmpToken = weightedSumQueue_.AllocTensor<XType>();
322313
const DataCopyPadExtParams<RecvXType> copyPadExtParams{false, 0U, 0U, 0U};
323-
DataCopyPad(tmpToken, remoteTokenATensor, xOutCopyParams, copyPadExtParams);
314+
DataCopyPad(tmpToken, localTokenTensor, xOutCopyParams, copyPadExtParams);
324315
weightedSumQueue_.EnQue(tmpToken);
325316
tmpToken = weightedSumQueue_.DeQue<XType>();
326317
Cast(tokenFloatLocal, tmpToken, AscendC::RoundMode::CAST_NONE, axisH_);
@@ -383,4 +374,4 @@ __aicore__ inline void CamMoeCombineNormal<TemplateMC2TypeFunc>::Process()
383374
}
384375

385376
} // CamMoeCombineNormalImpl
386-
#endif // MOE_COMBINE_IMPL_H
377+
#endif // MOE_COMBINE_IMPL_H

0 commit comments

Comments
 (0)