diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index fc24d410b..297ee4c8a 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -29,6 +29,7 @@ Layers GLU GroupNorm GRU + GRUCell HardShrink HardTanh Hardswish diff --git a/python/mlx/nn/layers/recurrent.py b/python/mlx/nn/layers/recurrent.py index 3ffa7654c..87a3ff698 100644 --- a/python/mlx/nn/layers/recurrent.py +++ b/python/mlx/nn/layers/recurrent.py @@ -90,13 +90,12 @@ def __call__(self, x, hidden=None): return mx.stack(all_hidden, axis=-2) -class GRU(Module): - r"""A gated recurrent unit (GRU) RNN layer. +class GRUCell(Module): + r"""A gated recurrent unit (GRU) cell with MPS-specific optimizations. - The input has shape ``NLD`` or ``LD`` where: + The input has shape ``ND`` where: - * ``N`` is the optional batch dimension - * ``L`` is the sequence length + * ``N`` is the batch dimension * ``D`` is the input's feature dimension Concretely, for each element of the sequence, this layer computes: @@ -110,9 +109,9 @@ class GRU(Module): h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t \end{aligned} - The hidden state :math:`h` has shape ``NH`` or ``H`` depending on + The hidden state :math:`h` has shape ``NH`` depending on whether the input is batched or not. Returns the hidden state at each - time step of shape ``NLH`` or ``LH``. + time step of shape ``NH``. Args: input_size (int): Dimension of the input, ``D``. @@ -162,35 +161,87 @@ def __call__(self, x, hidden=None): x_rz = x[..., : -self.hidden_size] x_n = x[..., -self.hidden_size :] - all_hidden = [] + rz = x_rz + if hidden is not None: + h_proj = hidden @ self.Wh.T + h_proj_rz = h_proj[..., : -self.hidden_size] + h_proj_n = h_proj[..., -self.hidden_size :] - for idx in range(x.shape[-2]): - rz = x_rz[..., idx, :] - if hidden is not None: - h_proj = hidden @ self.Wh.T - h_proj_rz = h_proj[..., : -self.hidden_size] - h_proj_n = h_proj[..., -self.hidden_size :] + if self.bhn is not None: + h_proj_n += self.bhn - if self.bhn is not None: - h_proj_n += self.bhn + rz = rz + h_proj_rz - rz = rz + h_proj_rz + rz = mx.sigmoid(rz) - rz = mx.sigmoid(rz) + r, z = mx.split(rz, 2, axis=-1) - r, z = mx.split(rz, 2, axis=-1) + n = x_n - n = x_n[..., idx, :] + if hidden is not None: + n = n + r * h_proj_n + n = mx.tanh(n) - if hidden is not None: - n = n + r * h_proj_n - n = mx.tanh(n) + if hidden is not None: + hidden = (1 - z) * n + z * hidden + else: + hidden = (1 - z) * n - if hidden is not None: - hidden = (1 - z) * n + z * hidden - else: - hidden = (1 - z) * n + return hidden + + +class GRU(Module): + r"""A gated recurrent unit (GRU) RNN layer. + + The input has shape ``NLD`` or ``LD`` where: + + * ``N`` is the optional batch dimension + * ``L`` is the sequence length + * ``D`` is the input's feature dimension + Concretely, for each element of the sequence, this layer computes: + + .. math:: + + \begin{aligned} + r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\ + z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\ + n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\ + h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t + \end{aligned} + + The hidden state :math:`h` has shape ``NH`` or ``H`` depending on + whether the input is batched or not. Returns the hidden state at each + time step of shape ``NLH`` or ``LH``. + + Args: + input_size (int): Dimension of the input, ``D``. + hidden_size (int): Dimension of the hidden state, ``H``. + bias (bool): Whether to use biases or not. Default: ``True``. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ): + super().__init__() + + self.hidden_size = hidden_size + self.cell = GRUCell(input_size, hidden_size, bias) + + def _extra_repr(self): + return ( + f"input_dims={self.cell.Wx.shape[1]}, " + f"hidden_size={self.hidden_size}, bias={self.cell.b is not None}" + ) + + def __call__(self, x, hidden=None): + all_hidden = [] + + for idx in range(x.shape[-2]): + hidden = self.cell(x[..., idx, :], hidden) all_hidden.append(hidden) return mx.stack(all_hidden, axis=-2) @@ -286,4 +337,4 @@ def __call__(self, x, hidden=None, cell=None): all_cell.append(cell) all_hidden.append(hidden) - return mx.stack(all_hidden, axis=-2), mx.stack(all_cell, axis=-2) + return mx.stack(all_hidden, axis[-2]), mx.stack(all_cell, axis[-2]) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 4c891545f..758df8d8c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1651,6 +1651,26 @@ def test_gru(self): h_out = layer(inp, h_out[-1, :]) self.assertEqual(h_out.shape, (44, 12)) + def test_gru_cell(self): + cell = nn.GRUCell(5, 12, bias=True) + inp = mx.random.normal((2, 5)) + hidden = mx.random.normal((2, 12)) + + h_out = cell(inp, hidden) + self.assertEqual(h_out.shape, (2, 12)) + + h_out = cell(inp) + self.assertEqual(h_out.shape, (2, 12)) + + inp = mx.random.normal((5,)) + hidden = mx.random.normal((12,)) + + h_out = cell(inp, hidden) + self.assertEqual(h_out.shape, (12,)) + + h_out = cell(inp) + self.assertEqual(h_out.shape, (12,)) + def test_lstm(self): layer = nn.LSTM(5, 12) inp = mx.random.normal((2, 25, 5))