Skip to content

Commit

Permalink
feat: to_device >> device_put
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Sep 13, 2024
1 parent 6c97c8d commit df0e52b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions jflux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def prepare(
vec = repeat(vec, "1 ... -> bs ...", bs=bs)

return {
"img": jax.device_put(img, device=device),
"img_ids": jax.device_put(img_ids, device=device),
"txt": jax.device_put(txt, device=device),
"txt_ids": jax.device_put(txt_ids, device=device),
"vec": jax.device_put(vec, device=device),
"img": img.to_device(device, stream=None),
"img_ids": img_ids.to_device(device, stream=None),
"txt": txt.to_device(device, stream=None),
"txt_ids": txt_ids.to_device(device, stream=None),
"vec": vec.to_device(device, stream=None),
}


Expand Down

0 comments on commit df0e52b

Please sign in to comment.