-
Notifications
You must be signed in to change notification settings - Fork 38
/
common.py
101 lines (79 loc) · 3.15 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import collections
import os
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
Batch = collections.namedtuple(
'Batch',
['observations', 'actions', 'rewards', 'masks', 'next_observations'])
def default_init(scale: Optional[float] = jnp.sqrt(2)):
return nn.initializers.orthogonal(scale)
PRNGKey = Any
Params = flax.core.FrozenDict[str, Any]
PRNGKey = Any
Shape = Sequence[int]
Dtype = Any # this could be a real type?
InfoDict = Dict[str, float]
class MLP(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
activate_final: int = False
dropout_rate: Optional[float] = None
@nn.compact
def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
for i, size in enumerate(self.hidden_dims):
x = nn.Dense(size, kernel_init=default_init())(x)
if i + 1 < len(self.hidden_dims) or self.activate_final:
x = self.activations(x)
if self.dropout_rate is not None:
x = nn.Dropout(rate=self.dropout_rate)(
x, deterministic=not training)
return x
@flax.struct.dataclass
class Model:
step: int
apply_fn: nn.Module = flax.struct.field(pytree_node=False)
params: Params
tx: Optional[optax.GradientTransformation] = flax.struct.field(
pytree_node=False)
opt_state: Optional[optax.OptState] = None
@classmethod
def create(cls,
model_def: nn.Module,
inputs: Sequence[jnp.ndarray],
tx: Optional[optax.GradientTransformation] = None) -> 'Model':
variables = model_def.init(*inputs)
_, params = variables.pop('params')
if tx is not None:
opt_state = tx.init(params)
else:
opt_state = None
return cls(step=1,
apply_fn=model_def,
params=params,
tx=tx,
opt_state=opt_state)
def __call__(self, *args, **kwargs):
return self.apply_fn.apply({'params': self.params}, *args, **kwargs)
def apply(self, *args, **kwargs):
return self.apply_fn.apply(*args, **kwargs)
def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']:
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, info = grad_fn(self.params)
updates, new_opt_state = self.tx.update(grads, self.opt_state,
self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(step=self.step + 1,
params=new_params,
opt_state=new_opt_state), info
def save(self, save_path: str):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
f.write(flax.serialization.to_bytes(self.params))
def load(self, load_path: str) -> 'Model':
with open(load_path, 'rb') as f:
params = flax.serialization.from_bytes(self.params, f.read())
return self.replace(params=params)