Skip to content

在scaled_dot_product_attention函数中,加入3D的输入和输出 #7353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ scaled_dot_product_attention
参数
::::::::::

- **query** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
- **key** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
- **value** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
- **query** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
- **key** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
- **value** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
- **attn_mask** (Tensor, 可选) - 与添加到注意力分数的 ``query``、 ``key``、 ``value`` 类型相同的浮点掩码, 默认值为空。
- **dropout_p** (float) - ``dropout`` 的比例, 默认值为 0.00 即不进行正则化。
- **is_causal** (bool) - 是否启用因果关系, 默认值为 False 即不启用。
Expand All @@ -30,7 +30,7 @@ scaled_dot_product_attention
返回
::::::::::

- ``out`` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim]`` 的 4 维张量。数据类型可以是 float16 或 bfloat16。
- ``out`` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim]`` 的 4 维张量或者形状为 ``[seq_len, num_heads, head_dim]`` 的 3 维张量。数据类型可以是 float16 或 bfloat16。


代码示例
Expand Down