Skip to content

Commit

Permalink
Add missed doc string for block_multihead_attention API (#60072) (#60139
Browse files Browse the repository at this point in the history
)

* add missed doc  test=document_fix

* test=document_fix
  • Loading branch information
RichardWooSJTU authored Dec 19, 2023
1 parent 78c5e68 commit 1115890
Showing 1 changed file with 23 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,27 @@ def block_multihead_attention(
cu_seqlens_k (Tensor): The cum sequence lengths of key. Its shape is [batchsize + 1, 1].
block_tables (Tensor): The block tables, used to index the cache. Its shape is [batchsize, block_num_per_seq].
pre_key_cache (Tensor): The pre caches of key. Its shape is [batchsize, num_head, pre_cache_length, head_size].
pre_key_value (Tensor): The pre caches of value. Its shape is [batchsize, num_head, pre_cache_length, head_size].
pre_value_cache (Tensor): The pre caches of value. Its shape is [batchsize, num_head, pre_cache_length, head_size].
cache_k_quant_scales (Tensor): The quant scales of cache key. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
cache_v_quant_scales (Tensor): The quant scales of cache value. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
cache_k_dequant_scales (Tensor): The dequant scales of cache key. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
cache_v_dequant_scales (Tensor): The dequant scales of cache value. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
qkv_out_scale (Tensor): The dequant scale of qkv, which is the input of BLHA. If the dtype of qkv is `int32`, this input will be applied. Its shape is [3 * num_head * head_size], and its dtype should be `float32`.
qkv_bias (Tensor): The bias of qkv. Its shape is [3 * num_head * head_size].
out_shift (Tensor): Shift bias of fmha_out, which is the 1st return value. Its shape is [num_head * head_size].
out_smooth (Tensor): Smooth weight of fmha_out. Its shape is [num_head * head_size].
rope_emb (Tensor): The RoPE embedding. Its shape is [2, batchsize, max_seq_len, 1, head_size // 2].
mask (Tensor): The mask of qk_matmul in encoder. Its shape is [batchsize, 1, max_seq_len, max_seq_len].
tgt_mask (Tensor): The mask of qk_matmul in decoder. Its shape is [batchsize, 1, 1, max_seq_len].
max_seq_len (Int): The max length of the input. Default is -1.
block_size (Int): The block_size of cache. Default is 64.
use_neox_style (Bool): Whether neox_style RoPE is used or not. Default is False.
use_dynamic_cachekv_quant (Bool): Whether dynamic cache kv quantization is applied or not. Default is False.
quant_round_type (Int): The quant rount type in cache kv quantization and fmha_out quantization. If 0 is set, value will be rounding to nearest ties to even. If 1 is set, value will be rounding to nearest ties away from zero.
quant_max_bound (Float32): The max bound of float type to int type.
quant_min_bound (Float32): The min bound of float type to int type.
out_scale (Float32): The quant scale of fmha_out. Default is -1, which means do not apply quantization for fmha_out.
compute_dtype (Str): A compute dtype, is used to represent the input data type. Default is "default", which means compute dtype is determined by input dtype. However, if the dtype of input is Int32, this value should be set to actual dtype of the model.
Returns:
Tensor|(output, qkv_out, cache_k_out, cache_v_out), which output is the output of
block_multihead_attention layers, qkv_out is inplace with input `qkv`, cache_k_out and cache_v_out are inplace with input `cache_k` and `cache_v`.
Expand Down Expand Up @@ -229,6 +243,14 @@ def block_multihead_attention(
... block_tables,
... None, # pre_key_cache
... None, # pre_value_cache
... None, # cache_k_quant_scales
... None, # cache_v_quant_scales
... None, # cache_k_dequant_scales
... None, # cache_v_dequant_scales
... None, # qkv_out_scale
... None, # qkv_bias
... None, # out_shift
... None, # out_smooth
... None, # rotary_embs
... None, # attn_mask
... None, # tgt_mask
Expand Down

0 comments on commit 1115890

Please sign in to comment.