1010from dataclasses import dataclass
1111from enum import Enum , auto
1212
13- from nvfuser_direct import FusionDefinition , DataType
13+ from nvfuser_direct import FusionDefinition , DataType , TensorView
1414
1515
1616@dataclass
@@ -28,14 +28,157 @@ class Direction(Enum):
2828 OUTGOING = auto () # aka starting node
2929
3030
31+ def layer_norm (
32+ fd : FusionDefinition , x : TensorView , w : TensorView , b : TensorView
33+ ) -> TensorView :
34+ io_dtype = x .dtype ()
35+ x = fd .ops .cast (x , dtype = DataType .Float )
36+ var , mean = fd .ops .var_mean (x , dims = [- 1 ], correction = 0 , keepdim = True )
37+ y = fd .ops .sub (x , mean )
38+ var = fd .ops .add (var , fd .define_scalar (1e-5 ))
39+ y = fd .ops .mul (y , fd .ops .rsqrt (var ))
40+ shape = fd .ops .shape (x )
41+ w = fd .ops .broadcast_in_dim (w , shape = shape , broadcast_dims = [- 1 ])
42+ y = fd .ops .mul (y , w )
43+ b = fd .ops .broadcast_in_dim (b , shape = shape , broadcast_dims = [- 1 ])
44+ y = fd .ops .add (y , b )
45+ y = fd .ops .cast (y , dtype = io_dtype )
46+ return y
47+
48+
49+ def gating (
50+ fd : FusionDefinition ,
51+ z : TensorView ,
52+ w_p : TensorView ,
53+ z_in : TensorView ,
54+ w_g : TensorView ,
55+ ) -> TensorView :
56+ io_dtype = z .dtype ()
57+ p = fd .ops .linear (z , w_p )
58+ g = fd .ops .linear (z_in , w_g )
59+ g = fd .ops .sigmoid (g )
60+ z = fd .ops .mul (p , g )
61+ return fd .ops .cast (z , dtype = io_dtype )
62+
63+
64+ # https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates
65+ #
66+ # Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
67+ # prediction with AlphaFold. Nature 596, 583–589 (2021).
68+ # https://doi.org/10.1038/s41586-021-03819-2
69+ # (see Supplementary Methods 1.6.5 for details)
3170@pytest .mark .parametrize (
3271 "direction" , [Direction .OUTGOING , Direction .INCOMING ], ids = lambda d : d .name .lower ()
3372)
3473def test_triangle_updates (direction ):
35- pass
74+ c_z = _DEFAULT_CONFIG .c_z
75+
76+ with FusionDefinition () as fd :
77+ z_in = fd .define_tensor (
78+ shape = [- 1 , - 1 , - 1 , c_z ],
79+ dtype = DataType .BFloat16 ,
80+ contiguity = True ,
81+ ) # [b, i, j, c_z]
82+ w_norm_in = fd .define_tensor (
83+ shape = [c_z ], dtype = DataType .BFloat16 , contiguity = True
84+ )
85+ b_norm_in = fd .define_tensor (
86+ shape = [c_z ], dtype = DataType .BFloat16 , contiguity = True
87+ )
88+ w_p_in = fd .define_tensor (
89+ shape = [c_z * 2 , c_z ], dtype = DataType .BFloat16 , contiguity = True
90+ )
91+ w_g_in = fd .define_tensor (
92+ shape = [c_z * 2 , c_z ], dtype = DataType .BFloat16 , contiguity = True
93+ )
94+ w_norm_out = fd .define_tensor (
95+ shape = [c_z ], dtype = DataType .BFloat16 , contiguity = True
96+ )
97+ b_norm_out = fd .define_tensor (
98+ shape = [c_z ], dtype = DataType .BFloat16 , contiguity = True
99+ )
100+ w_p_out = fd .define_tensor (
101+ shape = [c_z , c_z ], dtype = DataType .BFloat16 , contiguity = True
102+ )
103+ w_g_out = fd .define_tensor (
104+ shape = [c_z , c_z ], dtype = DataType .BFloat16 , contiguity = True
105+ )
106+ # Masking is used in an internal implementation: http://nv/e-4
107+ mask = fd .define_tensor (
108+ shape = [- 1 , - 1 , - 1 ], dtype = DataType .Bool , contiguity = True
109+ ) # [b, i, j]
110+
111+ batch_size = fd .ops .size (z_in , 0 )
112+ n_tokens = fd .ops .size (z_in , 1 )
113+
114+ z_in = layer_norm (fd , z_in , w_norm_in , b_norm_in )
115+ z = gating (fd , z_in , w_p_in , z_in , w_g_in )
116+ mask = fd .ops .broadcast_in_dim (
117+ mask , shape = [batch_size , n_tokens , n_tokens , c_z ], broadcast_dims = [0 , 1 , 2 ]
118+ )
119+ z = fd .ops .where (mask , z , 0.0 )
120+ a = fd .ops .slice (z , [0 , 0 , 0 , 0 ], [batch_size , n_tokens , n_tokens , c_z ])
121+ b = fd .ops .slice (z , [0 , 0 , 0 , c_z ], [batch_size , n_tokens , n_tokens , c_z * 2 ])
122+
123+ match direction :
124+ case Direction .OUTGOING :
125+ # z_out = einsum("bikc,bjkc->bijc", a, b)
126+ a = fd .ops .permute (a , [0 , 3 , 1 , 2 ]) # [b, c, i, k]
127+ b = fd .ops .permute (b , [0 , 3 , 2 , 1 ]) # [b, c, k, j]
128+ case Direction .INCOMING :
129+ # z_out = einsum("bkic,bkjc->bijc", a, b)
130+ a = fd .ops .permute (a , [0 , 3 , 2 , 1 ]) # [b, c, i, k]
131+ b = fd .ops .permute (b , [0 , 3 , 1 , 2 ]) # [b, c, k, j]
132+ z = fd .ops .matmul (a , b ) # [b, c, i, j]
133+ z = fd .ops .permute (z , [0 , 2 , 3 , 1 ]) # [b, i, j, c]
134+
135+ z = layer_norm (fd , z , w_norm_out , b_norm_out )
136+ z = gating (fd , z , w_p_out , z_in , w_g_out )
137+ fd .add_output (z )
138+
139+ batch_size = 3
140+ n_tokens = 5
141+ z_in = torch .testing .make_tensor (
142+ batch_size , n_tokens , n_tokens , c_z , dtype = torch .bfloat16 , device = "cuda"
143+ )
144+ w_norm_in = torch .testing .make_tensor (c_z , dtype = torch .bfloat16 , device = "cuda" )
145+ b_norm_in = torch .testing .make_tensor (c_z , dtype = torch .bfloat16 , device = "cuda" )
146+ w_p_in = torch .testing .make_tensor (
147+ c_z * 2 , c_z , dtype = torch .bfloat16 , device = "cuda"
148+ )
149+ w_g_in = torch .testing .make_tensor (
150+ c_z * 2 , c_z , dtype = torch .bfloat16 , device = "cuda"
151+ )
152+ w_norm_out = torch .testing .make_tensor (c_z , dtype = torch .bfloat16 , device = "cuda" )
153+ b_norm_out = torch .testing .make_tensor (c_z , dtype = torch .bfloat16 , device = "cuda" )
154+ w_p_out = torch .testing .make_tensor (c_z , c_z , dtype = torch .bfloat16 , device = "cuda" )
155+ w_g_out = torch .testing .make_tensor (c_z , c_z , dtype = torch .bfloat16 , device = "cuda" )
156+ mask = torch .testing .make_tensor (
157+ batch_size , n_tokens , n_tokens , dtype = torch .bool , device = "cuda"
158+ )
159+ (z_out ,) = fd .execute (
160+ [
161+ z_in ,
162+ w_norm_in ,
163+ b_norm_in ,
164+ w_p_in ,
165+ w_g_in ,
166+ w_norm_out ,
167+ b_norm_out ,
168+ w_p_out ,
169+ w_g_out ,
170+ mask ,
171+ ]
172+ )
173+ assert z_out .shape == (batch_size , n_tokens , n_tokens , c_z )
36174
37175
38176# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-attention
177+ #
178+ # Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
179+ # prediction with AlphaFold. Nature 596, 583–589 (2021).
180+ # https://doi.org/10.1038/s41586-021-03819-2
181+ # (see Supplementary Methods 1.6.6 for details)
39182@pytest .mark .parametrize (
40183 "direction" , [Direction .OUTGOING , Direction .INCOMING ], ids = lambda d : d .name .lower ()
41184)
@@ -52,8 +195,8 @@ def test_triangle_attention(direction):
52195 dtype = DataType .BFloat16 ,
53196 contiguity = True ,
54197 ) # [b, i, j, c_z]
55- if direction == Direction . INCOMING :
56- z_in = fd .ops . permute ( z_in , [ 0 , 2 , 1 , 3 ] )
198+ w_norm = fd . define_tensor ( shape = [ c_z ], dtype = DataType . BFloat16 , contiguity = True )
199+ b_norm = fd .define_tensor ( shape = [ c_z ], dtype = DataType . BFloat16 , contiguity = True )
57200 w_q = fd .define_tensor (
58201 shape = [h * c_hidden , c_z ], dtype = DataType .BFloat16 , contiguity = True
59202 )
@@ -64,8 +207,6 @@ def test_triangle_attention(direction):
64207 mask = fd .define_tensor (
65208 shape = [- 1 , - 1 , - 1 ], dtype = DataType .Bool , contiguity = True
66209 ) # [b, i, j]
67- if direction == Direction .INCOMING :
68- mask = fd .ops .permute (mask , [0 , 2 , 1 ])
69210 w_v = fd .define_tensor (
70211 shape = [h * c_hidden , c_z ], dtype = DataType .BFloat16 , contiguity = True
71212 )
@@ -79,6 +220,9 @@ def test_triangle_attention(direction):
79220 batch_size = fd .ops .size (z_in , 0 )
80221 n_tokens = fd .ops .size (z_in , 1 )
81222
223+ if direction == Direction .INCOMING :
224+ z_in = fd .ops .permute (z_in , [0 , 2 , 1 , 3 ])
225+ z_in = layer_norm (fd , z_in , w_norm , b_norm )
82226 q = fd .ops .linear (z_in , w_q )
83227 q_h = fd .ops .reshape (
84228 q , [batch_size , n_tokens , n_tokens , h , - 1 ]
@@ -99,6 +243,8 @@ def test_triangle_attention(direction):
99243 broadcast_dims = [0 , 2 , 3 , 4 ],
100244 ) # [b, 1, h, j, k]
101245
246+ if direction == Direction .INCOMING :
247+ mask = fd .ops .permute (mask , [0 , 2 , 1 ])
102248 mask = fd .ops .broadcast_in_dim (
103249 mask ,
104250 shape = [batch_size , n_tokens , 1 , 1 , n_tokens ],
@@ -142,6 +288,8 @@ def test_triangle_attention(direction):
142288 z_in = torch .testing .make_tensor (
143289 batch_size , n_tokens , n_tokens , c_z , dtype = torch .bfloat16 , device = "cuda"
144290 )
291+ w_norm = torch .testing .make_tensor (c_z , dtype = torch .bfloat16 , device = "cuda" )
292+ b_norm = torch .testing .make_tensor (c_z , dtype = torch .bfloat16 , device = "cuda" )
145293 w_q = torch .testing .make_tensor (
146294 h * c_hidden , c_z , dtype = torch .bfloat16 , device = "cuda"
147295 )
@@ -161,5 +309,5 @@ def test_triangle_attention(direction):
161309 w_o = torch .testing .make_tensor (
162310 c_z , h * c_hidden , dtype = torch .bfloat16 , device = "cuda"
163311 )
164- (z_out ,) = fd .execute ([z_in , w_q , w_k , w_b , mask , w_v , w_g , w_o ])
312+ (z_out ,) = fd .execute ([z_in , w_norm , b_norm , w_q , w_k , w_b , mask , w_v , w_g , w_o ])
165313 assert z_out .shape == (batch_size , n_tokens , n_tokens , c_z )
0 commit comments