Skip to content

Commit 240e668

Browse files
committed
support cp
1 parent 75dfe98 commit 240e668

37 files changed

+491
-54
lines changed

lmdeploy/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class TurbomindEngineConfig:
233233
dp: int = 1
234234
device_num: int = None
235235
attn_tp_size: int = None
236+
attn_cp_size: int = None
236237
attn_dp_size: int = None
237238
mlp_tp_size: int = None
238239
mlp_dp_size: int = None

lmdeploy/turbomind/deploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class ModelConfig:
6767
weight_type: str = None
6868
session_len: int = None
6969
attn_tp_size: int = 1
70+
attn_cp_size: int = 1
7071
mlp_tp_size: int = 1
7172
model_format: str = 'hf'
7273
expert_num: List[int] = ()

lmdeploy/turbomind/deploy/converter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,7 @@ def get_tm_model(model_path,
173173
model_cls=Transformer,
174174
out_dir=out_dir)
175175

176+
engine_config.attn_tp_size = output_model.tm_config.model_config.attn_tp_size
177+
engine_config.attn_cp_size = output_model.tm_config.model_config.attn_cp_size
178+
176179
return output_model

lmdeploy/turbomind/deploy/target_model/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model
5252
self.attention_config = cfg.attention_config
5353
self.lora_config = cfg.lora_config
5454
self.attn_tp_size = self.model_config.attn_tp_size
55+
self.attn_cp_size = self.model_config.attn_cp_size
5556
self.mlp_tp_size = self.model_config.mlp_tp_size
5657
self.out_dir = out_dir
5758
self.to_file = True if out_dir else False
@@ -74,8 +75,14 @@ def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model
7475
self.repeat_kv = 0
7576
if (self.attn_tp_size > self.model_config.kv_head_num
7677
and self.attn_tp_size % self.model_config.kv_head_num == 0):
77-
self.repeat_kv = (self.attn_tp_size // self.model_config.kv_head_num)
78-
self.model_config.kv_head_num = self.attn_tp_size
78+
self.attn_cp_size = self.attn_tp_size // self.model_config.kv_head_num
79+
self.attn_tp_size //= self.attn_cp_size
80+
self.model_config.attn_tp_size = self.attn_tp_size
81+
self.model_config.attn_cp_size = self.attn_cp_size
82+
# if (self.attn_tp_size > self.model_config.kv_head_num
83+
# and self.attn_tp_size % self.model_config.kv_head_num == 0):
84+
# self.repeat_kv = (self.attn_tp_size // self.model_config.kv_head_num)
85+
# self.model_config.kv_head_num = self.attn_tp_size
7986

8087
self.model_config.verify()
8188
assert self.model_config.kv_head_num % self.attn_tp_size == 0

src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class CudaIpcCommImpl: public DeviceCommImpl {
7777
DataType type,
7878
int group0,
7979
int group1,
80+
int cp_size,
8081
const int* local_token_nums,
8182
cudaStream_t stream) override;
8283

src/turbomind/comm/cuda_ipc/fused_allreduce_ex.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,11 @@ void CudaIpcCommImpl::AllreduceResidualBiasRMSnormEx(void* hidden,
189189
DataType dtype,
190190
int group0,
191191
int group1,
192+
int cp_size,
192193
const int* local_token_nums,
193194
cudaStream_t stream)
194195
{
196+
FT_CHECK(cp_size == 1);
195197
FT_CHECK(group0 * group1 == 0);
196198

197199
const auto& g0 = groups_.at(group0);

src/turbomind/comm/device_comm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class DeviceCommImpl {
8787
DataType type,
8888
int group0,
8989
int group1,
90+
int cp_size,
9091
const int* local_token_nums,
9192
cudaStream_t stream)
9293
{

src/turbomind/comm/nccl/nccl.cu

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ public:
237237
DataType type,
238238
int group0,
239239
int group1,
240+
int cp_size,
240241
const int* local_token_nums,
241242
cudaStream_t stream) override
242243
{
@@ -252,9 +253,8 @@ public:
252253
NCCLCHECK(ncclCommCount(comm0, &tp0));
253254
NCCLCHECK(ncclCommCount(comm1, &tp1));
254255

255-
const int inner_tp = std::min(tp0, tp1);
256-
257-
FT_CHECK(tp0 % inner_tp == 0 && tp1 % inner_tp == 0);
256+
FT_CHECK(std::max(tp0, tp1) % std::min(tp0, tp1) == 0);
257+
const int inner_tp = std::min(tp0, tp1) * cp_size;
258258

259259
std::vector<std::tuple<int, int, int>> tasks;
260260
tasks.reserve(global_n_ranks_);
@@ -289,7 +289,18 @@ public:
289289
sync_check_cuda_error();
290290
}
291291

292-
if (tp1 > 1) {
292+
if (cp_size > 1 && tp0 > tp1) {
293+
NCCLCHECK(ncclGroupStart());
294+
for (int i = 0; i < global_n_ranks_; ++i) {
295+
if (auto& [offset, first, num] = tasks[i]; num > 0) {
296+
char* buff = (char*)hidden + elem_size * (offset + first) * dim;
297+
NCCLCHECK(ncclBroadcast(buff, buff, (size_t)num * dim, nccl_type, i % tp0, comm0, stream));
298+
}
299+
}
300+
NCCLCHECK(ncclGroupEnd());
301+
sync_check_cuda_error();
302+
}
303+
else if (tp1 > 1) {
293304
NCCLCHECK(ncclGroupStart());
294305
for (int i = 0; i < global_n_ranks_; ++i) {
295306
if (auto& [offset, first, num] = tasks[i]; num > 0) {

src/turbomind/comm/test_comm.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,7 @@ struct TestComm {
796796
dtype,
797797
group0,
798798
group1,
799+
1,
799800
local_token_nums.data(),
800801
stream);
801802
});

src/turbomind/kernels/attention/attention_params.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ struct AttentionParams {
7575
float* partial_L;
7676
int* locks;
7777

78+
// cp
79+
int cp_rank{0};
80+
int cp_size{1};
81+
float* cp_O{nullptr};
82+
float* cp_M{nullptr};
83+
float* cp_L{nullptr};
84+
7885
int arch;
7986
cudaStream_t stream;
8087

0 commit comments

Comments
 (0)