diff --git a/docs/api/paddle/nn/functional/scaled_dot_product_attention_cn.rst b/docs/api/paddle/nn/functional/scaled_dot_product_attention_cn.rst index 29106226370..a08d2009aa3 100644 --- a/docs/api/paddle/nn/functional/scaled_dot_product_attention_cn.rst +++ b/docs/api/paddle/nn/functional/scaled_dot_product_attention_cn.rst @@ -17,9 +17,9 @@ scaled_dot_product_attention 参数 :::::::::: - - **query** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。 - - **key** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。 - - **value** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。 + - **query** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。 + - **key** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。 + - **value** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。 - **attn_mask** (Tensor, 可选) - 与添加到注意力分数的 ``query``、 ``key``、 ``value`` 类型相同的浮点掩码, 默认值为空。 - **dropout_p** (float) - ``dropout`` 的比例, 默认值为 0.00 即不进行正则化。 - **is_causal** (bool) - 是否启用因果关系, 默认值为 False 即不启用。 @@ -30,7 +30,7 @@ scaled_dot_product_attention 返回 :::::::::: - - ``out`` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim]`` 的 4 维张量。数据类型可以是 float16 或 bfloat16。 + - ``out`` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim]`` 的 4 维张量或者形状为 ``[seq_len, num_heads, head_dim]`` 的 3 维张量。数据类型可以是 float16 或 bfloat16。 代码示例