Skip to content

Commit a88bfb1

Browse files
Reference implementation for triangle updates (#5732)
Other changes: 1. Add TensorView.dtype for convenience. 2. Clean up broadcast_in_dim_fn. 3. Add layernorm to triangle attention. cc @DejunL --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 2684593 commit a88bfb1

4 files changed

Lines changed: 194 additions & 50 deletions

File tree

python/python_direct/ir.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,24 @@ Returns
231231
-------
232232
list of Val
233233
The shape of this tensor.
234+
)")
235+
.def(
236+
"dtype",
237+
[](TensorView* self) -> PrimDataType {
238+
DataType dt = self->dtype();
239+
NVF_CHECK(
240+
std::holds_alternative<PrimDataType>(dt.type),
241+
"Expected PrimDataType but got type: ",
242+
dt);
243+
return std::get<PrimDataType>(dt.type);
244+
},
245+
R"(
246+
Get the data type of this tensor.
247+
248+
Returns
249+
-------
250+
DataType
251+
The data type of this tensor.
234252
)")
235253
.def("has_root", &TensorView::hasRoot, R"(
236254
Check if this tensor has a root domain.

python/python_direct/ops.cpp

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <bindings.h>
1111
#include <ops/all_ops.h>
1212
#include <ops/arith.h>
13+
#include <utils.h>
1314

1415
namespace nvfuser::python {
1516

@@ -2418,46 +2419,31 @@ TensorView* expand_fn(TensorView* arg, ShapeType generic_new_shape) {
24182419

24192420
template <class ShapeType>
24202421
TensorView* broadcast_in_dim_fn(
2421-
TensorView* arg,
2422+
TensorView* input,
24222423
ShapeType generic_output_shape,
2423-
std::vector<int64_t>& broadcast_dims) {
2424+
const std::vector<int64_t>& nonbroadcast_dims) {
24242425
std::vector<Val*> output_shape = SequenceAsVector(generic_output_shape);
2425-
NVF_CHECK(
2426-
output_shape.size() >= broadcast_dims.size(),
2427-
"broadcast_dims vector size is too big for output shape!");
2426+
NVF_CHECK_GE(output_shape.size(), nonbroadcast_dims.size());
24282427

2429-
const auto arg_ndims = static_cast<size_t>(std::ranges::distance(
2430-
arg->getLoopDomain() | TensorDomain::kNoReductions));
2431-
NVF_CHECK(
2432-
output_shape.size() >= broadcast_dims.size(),
2433-
"The new shape is expected to be greater-then-or-equal to the input: ",
2434-
output_shape.size(),
2435-
" vs ",
2436-
arg_ndims);
2437-
NVF_CHECK(
2438-
arg_ndims == broadcast_dims.size(),
2439-
"The broadcast dimensions should match the input dimensions: ",
2440-
arg_ndims,
2441-
" vs ",
2442-
broadcast_dims.size(),
2443-
". arg = ",
2444-
arg->toString());
2428+
const auto input_ndim = std::ranges::distance(
2429+
input->getLogicalDomain() | TensorDomain::kNoReductions);
2430+
NVF_CHECK_GE(std::ssize(output_shape), input_ndim);
2431+
NVF_CHECK_EQ(input_ndim, std::ssize(nonbroadcast_dims));
24452432

24462433
std::vector<bool> is_broadcast_dim(output_shape.size(), true);
2447-
for (const auto idx : arange(broadcast_dims.size())) {
2448-
if (idx > 0) {
2449-
NVF_CHECK(
2450-
broadcast_dims[idx - 1] < broadcast_dims[idx],
2451-
"Broadcast dimension is not greater than the previous value.");
2452-
}
2434+
for (int64_t nonbroadcast_dim : nonbroadcast_dims) {
2435+
nonbroadcast_dim = wrapDim(nonbroadcast_dim, std::ssize(output_shape));
24532436
NVF_CHECK(
2454-
broadcast_dims[idx] < static_cast<int>(output_shape.size()),
2455-
"Invalid broadcast_dims value.");
2456-
is_broadcast_dim.at(broadcast_dims[idx]) = false;
2437+
is_broadcast_dim.at(nonbroadcast_dim),
2438+
"nonbroadcast_dim (",
2439+
nonbroadcast_dim,
2440+
") is specified more than once.");
2441+
is_broadcast_dim.at(nonbroadcast_dim) = false;
24572442
}
24582443

2459-
auto bcast_output = broadcast(arg, is_broadcast_dim);
2460-
return expand(bcast_output, output_shape);
2444+
TensorView* output = broadcast(input, is_broadcast_dim);
2445+
output = expand(output, output_shape);
2446+
return output;
24612447
}
24622448

24632449
template <class ShapeType>

tests/python/direct/test_alphafold3.py

Lines changed: 155 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dataclasses import dataclass
1111
from enum import Enum, auto
1212

13-
from nvfuser_direct import FusionDefinition, DataType
13+
from nvfuser_direct import FusionDefinition, DataType, TensorView
1414

1515

1616
@dataclass
@@ -28,14 +28,157 @@ class Direction(Enum):
2828
OUTGOING = auto() # aka starting node
2929

3030

31+
def layer_norm(
32+
fd: FusionDefinition, x: TensorView, w: TensorView, b: TensorView
33+
) -> TensorView:
34+
io_dtype = x.dtype()
35+
x = fd.ops.cast(x, dtype=DataType.Float)
36+
var, mean = fd.ops.var_mean(x, dims=[-1], correction=0, keepdim=True)
37+
y = fd.ops.sub(x, mean)
38+
var = fd.ops.add(var, fd.define_scalar(1e-5))
39+
y = fd.ops.mul(y, fd.ops.rsqrt(var))
40+
shape = fd.ops.shape(x)
41+
w = fd.ops.broadcast_in_dim(w, shape=shape, broadcast_dims=[-1])
42+
y = fd.ops.mul(y, w)
43+
b = fd.ops.broadcast_in_dim(b, shape=shape, broadcast_dims=[-1])
44+
y = fd.ops.add(y, b)
45+
y = fd.ops.cast(y, dtype=io_dtype)
46+
return y
47+
48+
49+
def gating(
50+
fd: FusionDefinition,
51+
z: TensorView,
52+
w_p: TensorView,
53+
z_in: TensorView,
54+
w_g: TensorView,
55+
) -> TensorView:
56+
io_dtype = z.dtype()
57+
p = fd.ops.linear(z, w_p)
58+
g = fd.ops.linear(z_in, w_g)
59+
g = fd.ops.sigmoid(g)
60+
z = fd.ops.mul(p, g)
61+
return fd.ops.cast(z, dtype=io_dtype)
62+
63+
64+
# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates
65+
#
66+
# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
67+
# prediction with AlphaFold. Nature 596, 583–589 (2021).
68+
# https://doi.org/10.1038/s41586-021-03819-2
69+
# (see Supplementary Methods 1.6.5 for details)
3170
@pytest.mark.parametrize(
3271
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
3372
)
3473
def test_triangle_updates(direction):
35-
pass
74+
c_z = _DEFAULT_CONFIG.c_z
75+
76+
with FusionDefinition() as fd:
77+
z_in = fd.define_tensor(
78+
shape=[-1, -1, -1, c_z],
79+
dtype=DataType.BFloat16,
80+
contiguity=True,
81+
) # [b, i, j, c_z]
82+
w_norm_in = fd.define_tensor(
83+
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
84+
)
85+
b_norm_in = fd.define_tensor(
86+
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
87+
)
88+
w_p_in = fd.define_tensor(
89+
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
90+
)
91+
w_g_in = fd.define_tensor(
92+
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
93+
)
94+
w_norm_out = fd.define_tensor(
95+
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
96+
)
97+
b_norm_out = fd.define_tensor(
98+
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
99+
)
100+
w_p_out = fd.define_tensor(
101+
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
102+
)
103+
w_g_out = fd.define_tensor(
104+
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
105+
)
106+
# Masking is used in an internal implementation: http://nv/e-4
107+
mask = fd.define_tensor(
108+
shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
109+
) # [b, i, j]
110+
111+
batch_size = fd.ops.size(z_in, 0)
112+
n_tokens = fd.ops.size(z_in, 1)
113+
114+
z_in = layer_norm(fd, z_in, w_norm_in, b_norm_in)
115+
z = gating(fd, z_in, w_p_in, z_in, w_g_in)
116+
mask = fd.ops.broadcast_in_dim(
117+
mask, shape=[batch_size, n_tokens, n_tokens, c_z], broadcast_dims=[0, 1, 2]
118+
)
119+
z = fd.ops.where(mask, z, 0.0)
120+
a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z])
121+
b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2])
122+
123+
match direction:
124+
case Direction.OUTGOING:
125+
# z_out = einsum("bikc,bjkc->bijc", a, b)
126+
a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k]
127+
b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j]
128+
case Direction.INCOMING:
129+
# z_out = einsum("bkic,bkjc->bijc", a, b)
130+
a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k]
131+
b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j]
132+
z = fd.ops.matmul(a, b) # [b, c, i, j]
133+
z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c]
134+
135+
z = layer_norm(fd, z, w_norm_out, b_norm_out)
136+
z = gating(fd, z, w_p_out, z_in, w_g_out)
137+
fd.add_output(z)
138+
139+
batch_size = 3
140+
n_tokens = 5
141+
z_in = torch.testing.make_tensor(
142+
batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda"
143+
)
144+
w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
145+
b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
146+
w_p_in = torch.testing.make_tensor(
147+
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
148+
)
149+
w_g_in = torch.testing.make_tensor(
150+
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
151+
)
152+
w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
153+
b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
154+
w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
155+
w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
156+
mask = torch.testing.make_tensor(
157+
batch_size, n_tokens, n_tokens, dtype=torch.bool, device="cuda"
158+
)
159+
(z_out,) = fd.execute(
160+
[
161+
z_in,
162+
w_norm_in,
163+
b_norm_in,
164+
w_p_in,
165+
w_g_in,
166+
w_norm_out,
167+
b_norm_out,
168+
w_p_out,
169+
w_g_out,
170+
mask,
171+
]
172+
)
173+
assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z)
36174

37175

38176
# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-attention
177+
#
178+
# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
179+
# prediction with AlphaFold. Nature 596, 583–589 (2021).
180+
# https://doi.org/10.1038/s41586-021-03819-2
181+
# (see Supplementary Methods 1.6.6 for details)
39182
@pytest.mark.parametrize(
40183
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
41184
)
@@ -52,8 +195,8 @@ def test_triangle_attention(direction):
52195
dtype=DataType.BFloat16,
53196
contiguity=True,
54197
) # [b, i, j, c_z]
55-
if direction == Direction.INCOMING:
56-
z_in = fd.ops.permute(z_in, [0, 2, 1, 3])
198+
w_norm = fd.define_tensor(shape=[c_z], dtype=DataType.BFloat16, contiguity=True)
199+
b_norm = fd.define_tensor(shape=[c_z], dtype=DataType.BFloat16, contiguity=True)
57200
w_q = fd.define_tensor(
58201
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
59202
)
@@ -64,8 +207,6 @@ def test_triangle_attention(direction):
64207
mask = fd.define_tensor(
65208
shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
66209
) # [b, i, j]
67-
if direction == Direction.INCOMING:
68-
mask = fd.ops.permute(mask, [0, 2, 1])
69210
w_v = fd.define_tensor(
70211
shape=[h * c_hidden, c_z], dtype=DataType.BFloat16, contiguity=True
71212
)
@@ -79,6 +220,9 @@ def test_triangle_attention(direction):
79220
batch_size = fd.ops.size(z_in, 0)
80221
n_tokens = fd.ops.size(z_in, 1)
81222

223+
if direction == Direction.INCOMING:
224+
z_in = fd.ops.permute(z_in, [0, 2, 1, 3])
225+
z_in = layer_norm(fd, z_in, w_norm, b_norm)
82226
q = fd.ops.linear(z_in, w_q)
83227
q_h = fd.ops.reshape(
84228
q, [batch_size, n_tokens, n_tokens, h, -1]
@@ -99,6 +243,8 @@ def test_triangle_attention(direction):
99243
broadcast_dims=[0, 2, 3, 4],
100244
) # [b, 1, h, j, k]
101245

246+
if direction == Direction.INCOMING:
247+
mask = fd.ops.permute(mask, [0, 2, 1])
102248
mask = fd.ops.broadcast_in_dim(
103249
mask,
104250
shape=[batch_size, n_tokens, 1, 1, n_tokens],
@@ -142,6 +288,8 @@ def test_triangle_attention(direction):
142288
z_in = torch.testing.make_tensor(
143289
batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda"
144290
)
291+
w_norm = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
292+
b_norm = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
145293
w_q = torch.testing.make_tensor(
146294
h * c_hidden, c_z, dtype=torch.bfloat16, device="cuda"
147295
)
@@ -161,5 +309,5 @@ def test_triangle_attention(direction):
161309
w_o = torch.testing.make_tensor(
162310
c_z, h * c_hidden, dtype=torch.bfloat16, device="cuda"
163311
)
164-
(z_out,) = fd.execute([z_in, w_q, w_k, w_b, mask, w_v, w_g, w_o])
312+
(z_out,) = fd.execute([z_in, w_norm, b_norm, w_q, w_k, w_b, mask, w_v, w_g, w_o])
165313
assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z)

tests/python/opinfo/opinfo_input_generators.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,21 +217,14 @@ def broadcast_in_dim_error_generator(
217217
"The new shape is expected to be greater-then-or-equal to the input",
218218
)
219219

220-
# 3. broadcast_dimensions is an ascending sequence of integers.
221-
descending_broadcast_dimensions = (
222-
([2, 2], [2, 2], [1, 0]),
223-
RuntimeError,
224-
"Broadcast dimension is not greater than the previous value.",
225-
)
226-
227-
# 4. Each broadcast dimension is within the new shape.
220+
# 3. Each broadcast dimension is within the new shape.
228221
out_of_bounds_broadcast_dimensions = (
229222
([2, 2], [2, 2], [0, 2]),
230223
RuntimeError,
231224
"Invalid broadcast_dims value.",
232225
)
233226

234-
# 5. The original tensor is not broadcastable to desired shape.
227+
# 4. The original tensor is not broadcastable to desired shape.
235228
# tensor.shape[idx] == 1 or tensor.shape[idx] == output_shape[new_idx]
236229
#
237230
# Jax Exception:
@@ -244,7 +237,7 @@ def broadcast_in_dim_error_generator(
244237
"Invalid broadcast_dims value.",
245238
)
246239

247-
# 6. TypeError: broadcast_in_dim shape must have every element be nonnegative, got (-1, 2, 3).
240+
# 5. TypeError: broadcast_in_dim shape must have every element be nonnegative, got (-1, 2, 3).
248241
negative_shape = (
249242
([2, 3], [2, 3, -1], [0, 1]),
250243
RuntimeError,
@@ -255,7 +248,6 @@ def broadcast_in_dim_error_generator(
255248
error_cases = [
256249
missing_axis_in_bcast_dims,
257250
fewer_dims_in_output_shape,
258-
descending_broadcast_dimensions,
259251
out_of_bounds_broadcast_dimensions,
260252
# not_broadcastable,
261253
# negative_shape,

0 commit comments

Comments
 (0)