Skip to content

Commit

Permalink
[CPU] SDPA supports multi-query and different input layout (#21513)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel authored Dec 15, 2023
1 parent eff9ba7 commit 17fb201
Show file tree
Hide file tree
Showing 11 changed files with 898 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
auto q_len = query.size(2);
auto S = query.size(3);
auto kv_len = present_key.size(2);

auto h_group_num = present_key.size(1);
size_t h_each_group_len = 1;
if (h_group_num != H) {
h_each_group_len = H / h_group_num;
}
if (d_scale == 0.0f)
d_scale = 1.0f / sqrt(S);

Expand All @@ -149,20 +153,21 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,

bool is_abcd = present_key.stride(1) >= present_key.stride(2);
size_t dim0 = is_abcd ? B : kv_len;
size_t dim1 = is_abcd ? H : B;
size_t dim2 = is_abcd ? kv_len : H;
size_t dim1 = is_abcd ? h_group_num : B;
size_t dim2 = is_abcd ? kv_len : h_group_num;

parallel_for3d(dim0, dim1, dim2, [&](size_t d0, size_t d1, size_t d2) {
size_t b = is_abcd ? d0 : d1;
size_t h = is_abcd ? d1 : d2;
size_t h_group = is_abcd ? d1 : d2;
size_t pk = is_abcd ? d2 : d0;

// which batch item should be used at postion pk?
auto b_kv = beams ? beams.at<int32_t>({b, pk}) : b;
for (size_t pq = 0; pq < q_len; pq++) {
buf_attn_w.at<float>({b, h, pq, pk}) = dot_product(&query.at<T>({b, h, pq, 0}),
&present_key.at<T2>({b_kv, h, pk, 0}, true),
S);
for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) {
buf_attn_w.at<float>({b, h, pq, pk}) =
dot_product(&query.at<T>({b, h, pq, 0}), &present_key.at<T2>({b_kv, h_group, pk, 0}, true), S);
}
}
});

Expand Down Expand Up @@ -190,29 +195,31 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
// buf_attn_w {B, H, q_len, kv_len}
parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) {
size_t start{0}, end{0};
splitter(B * H * kv_len, nthr, ithr, start, end);
splitter(B * h_group_num * kv_len, nthr, ithr, start, end);

memset(&buf_attn_score.at<float>({ithr, 0, 0, 0, 0}), 0, buf_attn_score.stride(0) * sizeof(float));

size_t b, h, pv;
size_t b, h_group, pv;
if (start < end) {
if (is_abcd)
parallel_it_init(start, b, B, h, H, pv, kv_len);
parallel_it_init(start, b, B, h_group, h_group_num, pv, kv_len);
else
parallel_it_init(start, pv, kv_len, b, B, h, H);
parallel_it_init(start, pv, kv_len, b, B, h_group, h_group_num);
for (size_t iwork = start; iwork < end; ++iwork) {
auto b_kv = beams ? beams.at<int32_t>({b, pv}) : b;
auto* v = &present_value.at<T2>({b_kv, h, pv, 0}, true);
auto* v = &present_value.at<T2>({b_kv, h_group, pv, 0}, true);
for (size_t pq = 0; pq < q_len; pq++) {
attn_acc_value(&buf_attn_score.at<float>({ithr, b, pq, h, 0}),
buf_attn_w.at<float>({b, h, pq, pv}),
v,
S);
for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) {
attn_acc_value(&buf_attn_score.at<float>({ithr, b, pq, h, 0}),
buf_attn_w.at<float>({b, h, pq, pv}),
v,
S);
}
}
if (is_abcd)
parallel_it_step(b, B, h, H, pv, kv_len);
parallel_it_step(b, B, h_group, h_group_num, pv, kv_len);
else
parallel_it_step(pv, kv_len, b, B, h, H);
parallel_it_step(pv, kv_len, b, B, h_group, h_group_num);
}
}
});
Expand Down
7 changes: 4 additions & 3 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {

// Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout
precision = SDPA->getOriginalInputPrecisionAtPort(childPort);
// Just used a place holder here, the actual layout is obtained at initOptimalPrimitiveDescriptor
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});

PortConfig outPortConfig;
Expand All @@ -573,7 +574,6 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
"failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");

const auto& childConfig = childPd->getConfig();
auto childPrecision = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc()->getPrecision();

auto selectedPd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selectedPd,
Expand All @@ -582,8 +582,9 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
" failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");

auto config = selectedPd->getConfig();
auto memDesc = config.outConfs.front().getMemDesc();
auto newMemDesc = memDesc->cloneWithNewPrecision(childPrecision);
// The pyscial layout varies from models, e.g. [LBHS]chatglm, [BHLS]Llama
// The SDPA knows details, so should trust the layout config provided by SPDA
auto newMemDesc = childConfig.inConfs.back().getMemDesc();
config.outConfs.front().setMemDesc(newMemDesc);
//bypass any checks, we enforce the child descriptor precision
selectedPd->setConfig(config);
Expand Down
52 changes: 41 additions & 11 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,16 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
v_input.assert_dims({B, 0, L1, S}, true);
auto past_k_idx = inputs.size() - 2;
auto past_k_mem = inputs[past_k_idx + 0];
L0 = past_k_mem->getStaticDims()[2];
const auto& permute_axes = config.config.permute_axes;
L0 = permute_axes.empty() ? past_k_mem->getStaticDims()[2] : past_k_mem->getStaticDims()[permute_axes[2]];
// [B, H, L0, S]
past_k_output.reset(outputs[1]);
past_v_output.reset(outputs[2]);
if (!permute_axes.empty()) {
// [L, B, H, S] -> [B, H, L, S]
past_k_output = past_k_output.permute(permute_axes);
past_v_output = past_v_output.permute(permute_axes);
}
attn_memcpy(k_input, v_input, past_k_output.slice(2, L0, L0 + L1), past_v_output.slice(2, L0, L0 + L1));
if (!config.is_concat_inplaced) {
PlainTensor past_k_input, past_v_input;
Expand Down Expand Up @@ -560,12 +566,18 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
}

// q: [B, H, L1, S]
const auto & permute_axes = config.config.permute_axes;

PlainTensor present_key, present_value;
if (!permute_axes.empty()) {
q_input = q_input.permute(permute_axes);
k_input = k_input.permute(permute_axes);
v_input = v_input.permute(permute_axes);
}
B = q_input.size(0);
H = q_input.size(1);
L1 = q_input.size(2);
S = q_input.size(-1);

PlainTensor present_key, present_value;
concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value);

ov::intel_cpu::PlainTensor output_emb(outputs[0]);
Expand Down Expand Up @@ -634,9 +646,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ngrap
void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
rtPrecision = getOriginalInputPrecisionAtPort(0);
auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0);

size_t H_idx = 1;
if (!m_config.config.permute_axes.empty()) {
H_idx = m_config.config.permute_axes[1];
}
const auto& qDims = getInputShapeAtPort(0).getDims();
const auto& kDims = getInputShapeAtPort(1).getDims();
// if multi-query, enforce fp32 TODO: support BF16
if (qDims[H_idx] != kDims[H_idx]) {
rtPrecision = ov::element::f32;
}

bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && rtPrecision != ov::element::bf16;

auto kvCachePrecision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision;
Expand Down Expand Up @@ -669,17 +692,25 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
}

if (m_config.config.fuse_concat) {
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});

config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc(
ArbitraryOrderDescCreator layoutDescCreator({2, 0, 1, 3});
const auto& permute_axes = m_config.config.permute_axes;
if (!permute_axes.empty()) {
// [L,B,H,S]->permute[1,2,0,3] ->[B,H,L,S]
// The actual index of B is permute[0], H is permute[1], L is permute[2], S is permute[3]
layoutDescCreator = ArbitraryOrderDescCreator({static_cast<size_t>(permute_axes[2]),
static_cast<size_t>(permute_axes[0]),
static_cast<size_t>(permute_axes[1]),
static_cast<size_t>(permute_axes[3])});
}
config.inConfs[orginSDPInputNumber + 0].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc(
config.inConfs[orginSDPInputNumber + 1].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));

config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc(
config.outConfs[1].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getOutputShapeAtPort(1)));
config.outConfs[1].inPlace(orginSDPInputNumber + 0);
config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc(
config.outConfs[2].setMemDesc(layoutDescCreator.createSharedDesc(
kvCachePrecision, getOutputShapeAtPort(2)));
config.outConfs[2].inPlace(orginSDPInputNumber + 1);
}
Expand Down Expand Up @@ -712,7 +743,6 @@ void ScaledDotProductAttention::createPrimitive() {

m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0;
}
auto rtPrecision = getOriginalInputPrecisionAtPort(0);

if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ScaledDotProductAttention : public Node {
Config m_config;
std::shared_ptr<Executor> m_executor;
template <KernelTypes KType, typename T> struct AttentionExecutor;
ov::element::Type rtPrecision;
};

} // namespace node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,43 @@ void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_typ
// [B, H, L0, S]
auto past_kv_ps = get_input_partial_shape(input_num - 1);

auto output_logits = q_ps;
NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false);
NODE_VALIDATION_CHECK(this, q_ps.size() >= 3);
// permute_axes from original to [B, H, L, S]
const auto& permute_axes = this->m_config.permute_axes;
if (past_kv_ps.rank().is_static()) {
const size_t length_index = permute_axes.empty() ? q_ps.size() - 2 : permute_axes[permute_axes.size() - 2];
const size_t head_num_index = permute_axes.empty() ? q_ps.size() - 3 : permute_axes[permute_axes.size() - 3];
NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size());
for (size_t i = 0; i < q_ps.size(); i++) {
if (i == q_ps.size() - 2)
if (i == head_num_index) {
if (q_ps[i].is_static() && past_kv_ps[i].is_static()) {
NODE_VALIDATION_CHECK(this,
q_ps[i].get_length() % past_kv_ps[i].get_length() == 0,
"shape not compatiable at index ",
i);
}
} else if (i == length_index) {
continue;
NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i]));
} else {
NODE_VALIDATION_CHECK(this,
q_ps[i].compatible(past_kv_ps[i]),
"shape not compatiable at index ",
i);
}
}
past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2];
past_kv_ps[length_index] += q_ps[length_index];
}
set_output_type(0, get_input_element_type(0), q_ps);
if (!permute_axes.empty()) {
if (q_ps.rank().is_static()) {
// q_ps needs permute to BHLS
for (size_t i = 0; i < q_ps.size(); i++) {
output_logits[i] = q_ps[permute_axes[i]];
}
}
}
set_output_type(0, get_input_element_type(0), output_logits);
set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps);
set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps);
}
Expand All @@ -52,6 +77,7 @@ bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::A
visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn);
visitor.on_attribute("is_causal", m_config.is_causal);
visitor.on_attribute("fuse_concat", m_config.fuse_concat);
visitor.on_attribute("permute_axes", m_config.permute_axes);
visitor.finish_structure();
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op {
ScaledDotProductAttentionWithKVCache() = default;

struct Config {
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
bool output_BLHxS = false; // true implies that output is [B,L,H*S]

bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
bool is_causal = false; // apply causal mask internally
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
bool is_causal = false; // apply causal mask internally
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
std::vector<size_t> permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S]
// e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S]
};

ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg);
Expand All @@ -47,4 +49,4 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op {
};

} // namespace intel_cpu
} // namespace ov
} // namespace ov
Loading

0 comments on commit 17fb201

Please sign in to comment.