Skip to content

Commit

Permalink
Implement flatten layer in JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Sep 5, 2024
1 parent 65b0723 commit 0142786
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions skrl/utils/model_instantiators/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,19 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s
raise ValueError(f"Invalid or unsupported 'conv2d' layer definition: {kwargs}")
# flatten
elif layer_type == "flatten":
cls = "nn.Flatten"
cls = "lambda x: jnp.reshape(x, (x.shape[0], -1))"
kwargs = None
activation = "" # don't add activation after flatten layer
kwargs = layer[layer_type]
if type(kwargs) is list:
kwargs = {k: v for k, v in zip(["start_dim", "end_dim"][:len(kwargs)], kwargs)}
elif type(kwargs) is dict:
pass
else:
raise ValueError(f"Invalid or unsupported 'flatten' layer definition: {kwargs}")
else:
raise ValueError(f"Invalid or unsupported layer: {layer_type}")
else:
raise ValueError(f"Invalid or unsupported layer definition: {layer}")
# define layer and activation function
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
modules.append(f"{cls}({kwargs})")
if kwargs is None:
modules.append(f"{cls}")
else:
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
modules.append(f"{cls}({kwargs})")
activation = _get_activation_function(activation)
if activation:
modules.append(activation)
Expand Down

0 comments on commit 0142786

Please sign in to comment.