-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_to_jax.py
61 lines (45 loc) · 2.06 KB
/
pytorch_to_jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Dict
import flax
def convert_from_pytorch(pt_state: Dict, params_flatten):
"""_summary_
Args:
pt_state (Dict): _description_
params_flatten (_type_): _description_
Returns:
_type_: _description_
"""
jax_state = dict(pt_state)
for key, tensor in pt_state.items():
tensor = tensor.cpu().numpy()
if "embedding.weight" in key:
del jax_state[key]
key = key.replace("embedding.weight", "embedding.embedding")
jax_state[key] = tensor
if "layers." in key:
del jax_state[key]
key = key.replace("layers.", "layers_")
jax_state[key] = tensor
if "proj.weight" in key:
del jax_state[key]
key = key.replace("proj.weight", "proj.kernel")
jax_state[key] = tensor
if "conv1d.weight" in key:
del jax_state[key]
key = key.replace("conv1d.weight", "conv1d.kernel")
jax_state[key] = tensor
if "lm_head" in key:
del jax_state[key]
jax_state_transposed = {}
for key in params_flatten.keys():
if params_flatten[key].shape != jax_state[key].shape:
jax_state_transposed[key] = jax_state[key].T
else:
jax_state_transposed[key] = jax_state[key]
if params_flatten[key].dtype != jax_state[key].dtype:
jax_state_transposed[key] = jax_state_transposed[key].numpy()
else:
jax_state_transposed[key] = jax_state_transposed[key]
assert params_flatten[key].shape == jax_state_transposed[key].shape, f'The shape of {key} is not the same with param shape {params_flatten[key].shape} and jax_state shape {jax_state_transposed[key].shape}'
assert params_flatten[key].dtype == jax_state_transposed[key].dtype, f'The dtype of {key} is not the same with param dtype {params_flatten[key].dtype} and jax_state dtype {jax_state_transposed[key].dtype}'
params = flax.traverse_util.unflatten_dict(jax_state_transposed, sep=".")
return params