Skip to content

Commit 90eabe6

Browse files
committed
retain grad
1 parent 1579911 commit 90eabe6

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

test/test_splash_attention_jax.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,11 @@ def ab_comparsion_input_generation(self):
131131
q_sa = q.clone().detach().requires_grad_(True)
132132
k_sa = k.clone().detach().requires_grad_(True)
133133
v_sa = v.clone().detach().requires_grad_(True)
134-
q_sa.retain_grad()
135-
k_sa.retain_grad()
136-
v_sa.retain_grad()
137-
138134
# Repeat the kv tensors to match the q tensor heads. This is required for flash
139135
k = self.maybe_expend_kv(k)
136+
k.retain_grad()
140137
v = self.maybe_expend_kv(v)
138+
v.retain_grad()
141139
torch_xla.sync()
142140
return q, k, v, q_sa, k_sa, v_sa
143141

@@ -165,20 +163,17 @@ def test_splash_attention_base(self):
165163

166164
o = self._attention(q, k, v, attn_mask=attention_mask)
167165
torch_xla.sync()
168-
for i in [q, k, v]:
169-
i.retain_grad()
170166
loss = torch.sum(o)
171167
loss.backward()
172-
torch_xla.sync()
173168
q_grad, k_grad, v_grad = q.grad, k.grad, v.grad
169+
torch_xla.sync()
174170

175171
o_sa = splash_attention(q_sa, k_sa, v_sa, self.config.to_json())
176172
torch_xla.sync()
177-
[i.retain_grad() for i in [q_sa, k_sa, v_sa]]
178173
loss_sa = torch.sum(o_sa)
179174
loss_sa.backward()
180-
torch_xla.sync()
181175
q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad
176+
torch_xla.sync()
182177

183178
with torch.no_grad():
184179
k_grad = self.maybe_reduce_kv_grad(k_grad)
@@ -228,13 +223,10 @@ def test_splash_attention_segment_id(self):
228223
v_sa,
229224
self.config.to_json(),
230225
decoder_segment_ids=segment_ids_sa.to("xla"))
231-
torch_xla.sync()
232-
for i in [q_sa, k_sa, v_sa]:
233-
i.retain_grad()
234226
loss_sa = torch.sum(o_sa)
235227
loss_sa.backward()
236-
torch_xla.sync()
237228
q_grad_sa, k_grad_sa, v_grad_sa = q_sa.grad, k_sa.grad, v_sa.grad
229+
torch_xla.sync()
238230
torch.testing.assert_close(self.o.cpu(), o_sa.cpu(), rtol=1e-3, atol=1e-5)
239231
for org_grad, sa_grad in zip([self.q_grad, self.k_grad, self.v_grad],
240232
[q_grad_sa, k_grad_sa, v_grad_sa],

0 commit comments

Comments
 (0)