@@ -131,13 +131,11 @@ def ab_comparsion_input_generation(self):
131
131
q_sa = q .clone ().detach ().requires_grad_ (True )
132
132
k_sa = k .clone ().detach ().requires_grad_ (True )
133
133
v_sa = v .clone ().detach ().requires_grad_ (True )
134
- q_sa .retain_grad ()
135
- k_sa .retain_grad ()
136
- v_sa .retain_grad ()
137
-
138
134
# Repeat the kv tensors to match the q tensor heads. This is required for flash
139
135
k = self .maybe_expend_kv (k )
136
+ k .retain_grad ()
140
137
v = self .maybe_expend_kv (v )
138
+ v .retain_grad ()
141
139
torch_xla .sync ()
142
140
return q , k , v , q_sa , k_sa , v_sa
143
141
@@ -165,20 +163,17 @@ def test_splash_attention_base(self):
165
163
166
164
o = self ._attention (q , k , v , attn_mask = attention_mask )
167
165
torch_xla .sync ()
168
- for i in [q , k , v ]:
169
- i .retain_grad ()
170
166
loss = torch .sum (o )
171
167
loss .backward ()
172
- torch_xla .sync ()
173
168
q_grad , k_grad , v_grad = q .grad , k .grad , v .grad
169
+ torch_xla .sync ()
174
170
175
171
o_sa = splash_attention (q_sa , k_sa , v_sa , self .config .to_json ())
176
172
torch_xla .sync ()
177
- [i .retain_grad () for i in [q_sa , k_sa , v_sa ]]
178
173
loss_sa = torch .sum (o_sa )
179
174
loss_sa .backward ()
180
- torch_xla .sync ()
181
175
q_grad_sa , k_grad_sa , v_grad_sa = q_sa .grad , k_sa .grad , v_sa .grad
176
+ torch_xla .sync ()
182
177
183
178
with torch .no_grad ():
184
179
k_grad = self .maybe_reduce_kv_grad (k_grad )
@@ -228,13 +223,10 @@ def test_splash_attention_segment_id(self):
228
223
v_sa ,
229
224
self .config .to_json (),
230
225
decoder_segment_ids = segment_ids_sa .to ("xla" ))
231
- torch_xla .sync ()
232
- for i in [q_sa , k_sa , v_sa ]:
233
- i .retain_grad ()
234
226
loss_sa = torch .sum (o_sa )
235
227
loss_sa .backward ()
236
- torch_xla .sync ()
237
228
q_grad_sa , k_grad_sa , v_grad_sa = q_sa .grad , k_sa .grad , v_sa .grad
229
+ torch_xla .sync ()
238
230
torch .testing .assert_close (self .o .cpu (), o_sa .cpu (), rtol = 1e-3 , atol = 1e-5 )
239
231
for org_grad , sa_grad in zip ([self .q_grad , self .k_grad , self .v_grad ],
240
232
[q_grad_sa , k_grad_sa , v_grad_sa ],
0 commit comments