Skip to content

Commit

Permalink
simplify stride generation
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Nov 14, 2023
1 parent 02c4797 commit d05c7ef
Showing 1 changed file with 31 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(

namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
Expand Down Expand Up @@ -120,30 +120,27 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;

int64_t q_stride[4];
int64_t k_stride[4];
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride,
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride,
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride,
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
std::vector<int64_t> q_strides(q_stride, q_stride + 4);
std::vector<int64_t> k_strides(k_stride, k_stride + 4);
std::vector<int64_t> v_strides(v_stride, v_stride + 4);
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_strides));
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_strides));
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_strides));
.set_stride(v_stride));

attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
Expand Down Expand Up @@ -183,8 +180,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
scaled_dot_product_flash_attention_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
}

if (is_dropout) {
Expand All @@ -205,11 +202,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto [O, Stats] = mha_graph->scaled_dot_product_flash_attention(
Q, K, V, scaled_dot_product_flash_attention_options);

int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride,
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
std::vector<int64_t> o_strides(o_stride, o_stride + 4);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_strides);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);

if (is_training) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
Expand Down Expand Up @@ -397,42 +393,38 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;

int64_t q_stride[4];
int64_t k_stride[4];
int64_t v_stride[4];
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride,
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride,
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride,
generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride,
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
std::vector<int64_t> q_strides(q_stride, q_stride + 4);
std::vector<int64_t> k_strides(k_stride, k_stride + 4);
std::vector<int64_t> v_strides(v_stride, v_stride + 4);
std::vector<int64_t> o_strides(o_stride, o_stride + 4);
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d})
.set_stride(q_strides));
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d})
.set_stride(k_strides));
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d})
.set_stride(v_strides));
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
.set_stride(o_strides));
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
.set_stride(o_strides));
.set_stride(o_stride));
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
Expand Down Expand Up @@ -500,13 +492,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(

dQ->set_output(true)
.set_dim({b, h, s_q, d})
.set_stride(q_strides);
.set_stride(q_stride);
dK->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(k_strides);
.set_stride(k_stride);
dV->set_output(true)
.set_dim({b, hg, s_kv, d})
.set_stride(v_strides);
.set_stride(v_stride);

std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k
Expand Down

0 comments on commit d05c7ef

Please sign in to comment.