-
Notifications
You must be signed in to change notification settings - Fork 38
/
value_net.py
42 lines (31 loc) · 1.33 KB
/
value_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Callable, Sequence, Tuple
import jax.numpy as jnp
from flax import linen as nn
from common import MLP
class ValueCritic(nn.Module):
hidden_dims: Sequence[int]
@nn.compact
def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
critic = MLP((*self.hidden_dims, 1))(observations)
return jnp.squeeze(critic, -1)
class Critic(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
@nn.compact
def __call__(self, observations: jnp.ndarray,
actions: jnp.ndarray) -> jnp.ndarray:
inputs = jnp.concatenate([observations, actions], -1)
critic = MLP((*self.hidden_dims, 1),
activations=self.activations)(inputs)
return jnp.squeeze(critic, -1)
class DoubleCritic(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
@nn.compact
def __call__(self, observations: jnp.ndarray,
actions: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
critic1 = Critic(self.hidden_dims,
activations=self.activations)(observations, actions)
critic2 = Critic(self.hidden_dims,
activations=self.activations)(observations, actions)
return critic1, critic2