You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is related to the #1499 , the GRUCell is not implemented and the GRU version is also not optimized on MPS compare to pytorch and tensorflow variant.
The text was updated successfully, but these errors were encountered:
thegodone
changed the title
[BUG] adding MPS GRUCell for efficiency
[FEATURE] adding MPS GRUCell for efficiency
Oct 18, 2024
Fixesml-explore#1500
Implement MPS-specific GRUCell and update GRU class for efficiency.
* **GRUCell Implementation:**
- Add a new `GRUCell` class in `python/mlx/nn/layers/recurrent.py` with MPS-specific optimizations.
- Define the input shape and hidden state shape for the `GRUCell`.
- Implement the forward pass for the `GRUCell` with MPS-specific optimizations.
* **GRU Class Update:**
- Update the `GRU` class in `python/mlx/nn/layers/recurrent.py` to use the new `GRUCell` for improved performance on MPS.
- Define the input shape and hidden state shape for the `GRU` class.
- Implement the forward pass for the `GRU` class using the `GRUCell`.
* **Documentation:**
- Update `docs/src/python/nn/layers.rst` to include the new `GRUCell` class.
* **Tests:**
- Add tests for the new `GRUCell` class in `python/tests/test_nn.py` to ensure correctness and performance improvements.
Love MLX : )
Describe the bug
This is related to the #1499 , the GRUCell is not implemented and the GRU version is also not optimized on MPS compare to pytorch and tensorflow variant.
The text was updated successfully, but these errors were encountered: