diff --git a/README.md b/README.md index 4ee52395..0e7932e9 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ ## The reasons why you use `pytorch-optimizer`. -* Wide range of supported optimizers. Currently, **93 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Wide range of supported optimizers. Currently, **94 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! * Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion` * Easy to use, clean, and tested codes * Active maintenance @@ -201,6 +201,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) | | TAM | *Torque-Aware Momentum* | | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) | | FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) | +| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) | ## Supported LR Scheduler diff --git a/docs/changelogs/v3.3.5.md b/docs/changelogs/v3.3.5.md deleted file mode 100644 index e77f8f14..00000000 --- a/docs/changelogs/v3.3.5.md +++ /dev/null @@ -1,20 +0,0 @@ -### Change Log - -### Feature - -* Implement `FOCUS` optimizer. (#330, #331) - * [First Order Concentrated Updating Scheme](https://arxiv.org/abs/2501.12243) - -### Update - -* Support `OrthoGrad` variant to `Ranger25`. (#332) - -### Fix - -* Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327) -* Add the missing `state_dict`, and `load_state_dict` methods to `TRAC` and `OrthoGrad` optimizers. (#332) -* Skip when the gradient is sparse in `OrthoGrad` optimizer. (#332) - -### Contributions - -thanks to @Vectorrent diff --git a/docs/changelogs/v3.4.0.md b/docs/changelogs/v3.4.0.md new file mode 100644 index 00000000..c8238008 --- /dev/null +++ b/docs/changelogs/v3.4.0.md @@ -0,0 +1,25 @@ +### Change Log + +### Feature + +* Implement `FOCUS` optimizer. (#330, #331) + * [First Order Concentrated Updating Scheme](https://arxiv.org/abs/2501.12243) +* Implement `PSGD Kron`. (#337) + * [preconditioned stochastic gradient descent w/ Kron pre-conditioner](https://arxiv.org/abs/1512.04202) + +### Update + +* Support `OrthoGrad` variant to `Ranger25`. (#332) + * `Ranger25` optimizer is my experimental-crafted optimizer, which mixes lots of optimizer variants such as `ADOPT` + `AdEMAMix` + `Cautious` + `StableAdamW` + `Adam-Atan2` + `OrthoGrad`. + +### Fix + +* Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327) +* Add the missing `state_dict`, and `load_state_dict` methods to `TRAC` and `OrthoGrad` optimizers. (#332) +* Skip when the gradient is sparse in `OrthoGrad` optimizer. (#332) +* Support alternative precision training in `SOAP` optimizer. (#333) +* Store SOAP condition matrices as the dtype of their parameters. (#335) + +### Contributions + +thanks to @Vectorrent, @kylevedder diff --git a/docs/index.md b/docs/index.md index 4ee52395..0e7932e9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ ## The reasons why you use `pytorch-optimizer`. -* Wide range of supported optimizers. Currently, **93 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Wide range of supported optimizers. Currently, **94 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! * Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion` * Easy to use, clean, and tested codes * Active maintenance @@ -201,6 +201,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) | | TAM | *Torque-Aware Momentum* | | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) | | FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) | +| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index 1cbc28d0..57768ef8 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -284,6 +284,10 @@ :docstring: :members: +::: pytorch_optimizer.Kron + :docstring: + :members: + ::: pytorch_optimizer.QHAdam :docstring: :members: diff --git a/poetry.lock b/poetry.lock index ba2772b3..6eb6b816 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,23 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. - -[[package]] -name = "bitsandbytes" -version = "0.44.1" -description = "k-bit optimizers and matrix multiplication routines." -optional = true -python-versions = "*" -files = [ - {file = "bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:b2f24c6cbf11fc8c5d69b3dcecee9f7011451ec59d6ac833e873c9f105259668"}, - {file = "bitsandbytes-0.44.1-py3-none-win_amd64.whl", hash = "sha256:8e68e12aa25d2cf9a1730ad72890a5d1a19daa23f459a6a4679331f353d58cb4"}, -] - -[package.dependencies] -numpy = "*" -torch = "*" - -[package.extras] -benchmark = ["matplotlib", "pandas"] -test = ["lion-pytorch", "scipy"] +# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. [[package]] name = "black" @@ -25,6 +6,8 @@ version = "24.8.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version == \"3.8\"" files = [ {file = "black-24.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09cdeb74d494ec023ded657f7092ba518e8cf78fa8386155e4a03fdcc44679e6"}, {file = "black-24.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:81c6742da39f33b08e791da38410f32e27d632260e599df7245cccee2064afeb"}, @@ -65,12 +48,62 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "black" +version = "25.1.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.9\"" +files = [ + {file = "black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32"}, + {file = "black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da"}, + {file = "black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7"}, + {file = "black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9"}, + {file = "black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0"}, + {file = "black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299"}, + {file = "black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096"}, + {file = "black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2"}, + {file = "black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b"}, + {file = "black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc"}, + {file = "black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f"}, + {file = "black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba"}, + {file = "black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f"}, + {file = "black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3"}, + {file = "black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171"}, + {file = "black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18"}, + {file = "black-25.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1ee0a0c330f7b5130ce0caed9936a904793576ef4d2b98c40835d6a65afa6a0"}, + {file = "black-25.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3df5f1bf91d36002b0a75389ca8663510cf0531cca8aa5c1ef695b46d98655f"}, + {file = "black-25.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9e6827d563a2c820772b32ce8a42828dc6790f095f441beef18f96aa6f8294e"}, + {file = "black-25.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:bacabb307dca5ebaf9c118d2d2f6903da0d62c9faa82bd21a33eecc319559355"}, + {file = "black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717"}, + {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.10)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "click" version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -85,6 +118,8 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +markers = "(python_version >= \"3.9\" or python_version == \"3.8\") and (sys_platform == \"win32\" or platform_system == \"Windows\")" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -96,6 +131,8 @@ version = "7.6.1" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, @@ -183,6 +220,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version >= \"3.9\" and python_version < \"3.11\" or python_version == \"3.8\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -197,6 +236,8 @@ version = "3.16.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, @@ -213,6 +254,8 @@ version = "2024.12.0" description = "File-system specification" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "fsspec-2024.12.0-py3-none-any.whl", hash = "sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2"}, {file = "fsspec-2024.12.0.tar.gz", hash = "sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f"}, @@ -252,6 +295,8 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -263,6 +308,8 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["dev"] +markers = "python_version == \"3.8\"" files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -271,12 +318,31 @@ files = [ [package.extras] colors = ["colorama (>=0.4.6)"] +[[package]] +name = "isort" +version = "6.0.0" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.9.0" +groups = ["dev"] +markers = "python_version >= \"3.9\"" +files = [ + {file = "isort-6.0.0-py3-none-any.whl", hash = "sha256:567954102bb47bb12e0fae62606570faacddd441e45683968c8d1734fb1af892"}, + {file = "isort-6.0.0.tar.gz", hash = "sha256:75d9d8a1438a9432a7d7b54f2d3b45cad9a4a0fdba43617d9873379704a8bdf1"}, +] + +[package.extras] +colors = ["colorama"] +plugins = ["setuptools"] + [[package]] name = "jinja2" version = "3.1.5" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" +groups = ["main"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"}, {file = "jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb"}, @@ -294,6 +360,8 @@ version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.7" +groups = ["main"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, @@ -363,6 +431,8 @@ version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" optional = false python-versions = "*" +groups = ["main"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, @@ -380,6 +450,8 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -391,6 +463,8 @@ version = "3.1" description = "Python package for creating and manipulating graphs and networks" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, @@ -409,6 +483,8 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version == \"3.8\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -446,6 +522,8 @@ version = "2.0.2" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.9\"" files = [ {file = "numpy-2.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece"}, {file = "numpy-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04"}, @@ -500,6 +578,8 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -511,6 +591,8 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -522,6 +604,8 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -538,6 +622,8 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -553,6 +639,8 @@ version = "8.3.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, @@ -575,6 +663,8 @@ version = "5.0.0" description = "Pytest plugin for measuring coverage." optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, @@ -589,29 +679,31 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.9.3" +version = "0.9.4" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ - {file = "ruff-0.9.3-py3-none-linux_armv6l.whl", hash = "sha256:7f39b879064c7d9670197d91124a75d118d00b0990586549949aae80cdc16624"}, - {file = "ruff-0.9.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a187171e7c09efa4b4cc30ee5d0d55a8d6c5311b3e1b74ac5cb96cc89bafc43c"}, - {file = "ruff-0.9.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c59ab92f8e92d6725b7ded9d4a31be3ef42688a115c6d3da9457a5bda140e2b4"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc153c25e715be41bb228bc651c1e9b1a88d5c6e5ed0194fa0dfea02b026439"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:646909a1e25e0dc28fbc529eab8eb7bb583079628e8cbe738192853dbbe43af5"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a5a46e09355695fbdbb30ed9889d6cf1c61b77b700a9fafc21b41f097bfbba4"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c4bb09d2bbb394e3730d0918c00276e79b2de70ec2a5231cd4ebb51a57df9ba1"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96a87ec31dc1044d8c2da2ebbed1c456d9b561e7d087734336518181b26b3aa5"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb7554aca6f842645022fe2d301c264e6925baa708b392867b7a62645304df4"}, - {file = "ruff-0.9.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabc332b7075a914ecea912cd1f3d4370489c8018f2c945a30bcc934e3bc06a6"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:33866c3cc2a575cbd546f2cd02bdd466fed65118e4365ee538a3deffd6fcb730"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:006e5de2621304c8810bcd2ee101587712fa93b4f955ed0985907a36c427e0c2"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ba6eea4459dbd6b1be4e6bfc766079fb9b8dd2e5a35aff6baee4d9b1514ea519"}, - {file = "ruff-0.9.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90230a6b8055ad47d3325e9ee8f8a9ae7e273078a66401ac66df68943ced029b"}, - {file = "ruff-0.9.3-py3-none-win32.whl", hash = "sha256:eabe5eb2c19a42f4808c03b82bd313fc84d4e395133fb3fc1b1516170a31213c"}, - {file = "ruff-0.9.3-py3-none-win_amd64.whl", hash = "sha256:040ceb7f20791dfa0e78b4230ee9dce23da3b64dd5848e40e3bf3ab76468dcf4"}, - {file = "ruff-0.9.3-py3-none-win_arm64.whl", hash = "sha256:800d773f6d4d33b0a3c60e2c6ae8f4c202ea2de056365acfa519aa48acf28e0b"}, - {file = "ruff-0.9.3.tar.gz", hash = "sha256:8293f89985a090ebc3ed1064df31f3b4b56320cdfcec8b60d3295bddb955c22a"}, + {file = "ruff-0.9.4-py3-none-linux_armv6l.whl", hash = "sha256:64e73d25b954f71ff100bb70f39f1ee09e880728efb4250c632ceed4e4cdf706"}, + {file = "ruff-0.9.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6ce6743ed64d9afab4fafeaea70d3631b4d4b28b592db21a5c2d1f0ef52934bf"}, + {file = "ruff-0.9.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:54499fb08408e32b57360f6f9de7157a5fec24ad79cb3f42ef2c3f3f728dfe2b"}, + {file = "ruff-0.9.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37c892540108314a6f01f105040b5106aeb829fa5fb0561d2dcaf71485021137"}, + {file = "ruff-0.9.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:de9edf2ce4b9ddf43fd93e20ef635a900e25f622f87ed6e3047a664d0e8f810e"}, + {file = "ruff-0.9.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87c90c32357c74f11deb7fbb065126d91771b207bf9bfaaee01277ca59b574ec"}, + {file = "ruff-0.9.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:56acd6c694da3695a7461cc55775f3a409c3815ac467279dfa126061d84b314b"}, + {file = "ruff-0.9.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0c93e7d47ed951b9394cf352d6695b31498e68fd5782d6cbc282425655f687a"}, + {file = "ruff-0.9.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d4c8772670aecf037d1bf7a07c39106574d143b26cfe5ed1787d2f31e800214"}, + {file = "ruff-0.9.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc5f1d7afeda8d5d37660eeca6d389b142d7f2b5a1ab659d9214ebd0e025231"}, + {file = "ruff-0.9.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faa935fc00ae854d8b638c16a5f1ce881bc3f67446957dd6f2af440a5fc8526b"}, + {file = "ruff-0.9.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a6c634fc6f5a0ceae1ab3e13c58183978185d131a29c425e4eaa9f40afe1e6d6"}, + {file = "ruff-0.9.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:433dedf6ddfdec7f1ac7575ec1eb9844fa60c4c8c2f8887a070672b8d353d34c"}, + {file = "ruff-0.9.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d612dbd0f3a919a8cc1d12037168bfa536862066808960e0cc901404b77968f0"}, + {file = "ruff-0.9.4-py3-none-win32.whl", hash = "sha256:db1192ddda2200671f9ef61d9597fcef89d934f5d1705e571a93a67fb13a4402"}, + {file = "ruff-0.9.4-py3-none-win_amd64.whl", hash = "sha256:05bebf4cdbe3ef75430d26c375773978950bbf4ee3c95ccb5448940dc092408e"}, + {file = "ruff-0.9.4-py3-none-win_arm64.whl", hash = "sha256:585792f1e81509e38ac5123492f8875fbc36f3ede8185af0a26df348e5154f41"}, + {file = "ruff-0.9.4.tar.gz", hash = "sha256:6907ee3529244bb0ed066683e075f09285b38dd5b4039370df6ff06041ca19e7"}, ] [[package]] @@ -620,6 +712,8 @@ version = "75.8.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.12\"" files = [ {file = "setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3"}, {file = "setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6"}, @@ -640,6 +734,8 @@ version = "1.12.1" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version == \"3.8\"" files = [ {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, @@ -654,6 +750,8 @@ version = "1.13.1" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version >= \"3.9\"" files = [ {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, @@ -671,6 +769,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version >= \"3.9\" and python_full_version <= \"3.11.0a6\" or python_version == \"3.8\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -712,6 +812,8 @@ version = "2.5.1+cpu" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" +groups = ["main"] +markers = "python_version >= \"3.9\" or python_version == \"3.8\"" files = [ {file = "torch-2.5.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:7f91a2200e352745d70e22396bd501448e28350fbdbd8d8b1c83037e25451150"}, {file = "torch-2.5.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:df93157482b672892d29134d3fae9d38ba3219702faedd79f407eb36774c56ce"}, @@ -751,15 +853,14 @@ version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] - -[extras] -bitsandbytes = ["bitsandbytes"] +markers = {main = "python_version >= \"3.9\" or python_version == \"3.8\"", dev = "python_version >= \"3.9\" and python_version < \"3.11\" or python_version == \"3.8\""} [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = ">=3.8" -content-hash = "0fb4f6358622b11b8e5f09beaaa21a9622d536114716c7977535a20e549ab34d" +content-hash = "2f6674bb2c0ae9a42111da81cc06e9ed17aab1a4f2243712309ceaf71005920b" diff --git a/pyproject.toml b/pyproject.toml index 2c525e4f..62211684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ keywords = [ "Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DeMo", "DiffGrad", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero", - "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", + "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", @@ -52,29 +52,27 @@ numpy = [ { version = "<=1.24.4", python = "<3.9" }, ] torch = { version = ">=1.10", python = ">=3.8", source = "torch" } -bitsandbytes = { version = "^0.44", optional = true } -[tool.poetry.dev-dependencies] -isort = { version = "^5", python = ">=3.8" } -black = { version = "^24", python = ">=3.8" } +[tool.poetry.group.dev.dependencies] +isort = [ + { version = "<6", python = "<3.9" }, + { version = "^6", python = ">=3.9" }, +] +black = [ + { version = "<25", python = "<3.9" }, + { version = "^25", python = ">=3.9" }, +] ruff = "*" pytest = "*" pytest-cov = "*" -[tool.poetry.extras] -bitsandbytes = ["bitsandbytes"] - [[tool.poetry.source]] name = "torch" url = "https://download.pytorch.org/whl/cpu" priority = "explicit" [tool.ruff] -src = [ - "pytorch_optimizer", - "tests", - "examples", -] +src = ["pytorch_optimizer", "tests", "examples"] target-version = "py312" line-length = 119 exclude = [ @@ -109,10 +107,10 @@ flake8-quotes.docstring-quotes = "double" flake8-quotes.inline-quotes = "single" [tool.ruff.lint.extend-per-file-ignores] -"hubconf.py" = ["D", "INP001"] "examples/visualize_optimizers.py" = ["D103", "D400", "D415"] -"**/__init__.py" = ["F401"] "{tests,examples}/*.py" = ["D", "S101"] +"**/__init__.py" = ["F401"] +"hubconf.py" = ["D", "INP001"] [tool.ruff.lint.isort] combine-as-imports = false @@ -131,9 +129,9 @@ testpaths = "tests" [tool.coverage.run] omit = [ - "./pytorch_optimizer/optimizer/rotograd.py", "./pytorch_optimizer/optimizer/adam_mini.py", "./pytorch_optimizer/optimizer/demo.py", + "./pytorch_optimizer/optimizer/rotograd.py", ] [build-system] diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index c0c92f01..5088dbb0 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -113,6 +113,7 @@ Gravity, GrokFastAdamW, Kate, + Kron, Lamb, LaProp, Lion, diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index f32d888a..a9ee8666 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -338,7 +338,7 @@ def validate_step(step: int, step_type: str) -> None: @staticmethod def validate_options(x: str, name: str, options: List[str]) -> None: if x not in options: - opts: str = ' or '.join([f'\'{option}\'' for option in options]).strip() + opts: str = ' or '.join([f"'{option}'" for option in options]).strip() raise ValueError(f'[-] {name} {x} must be one of ({opts})') @staticmethod diff --git a/pytorch_optimizer/optimizer/__init__.py b/pytorch_optimizer/optimizer/__init__.py index a7603b37..ba6bbb9e 100644 --- a/pytorch_optimizer/optimizer/__init__.py +++ b/pytorch_optimizer/optimizer/__init__.py @@ -70,6 +70,7 @@ from pytorch_optimizer.optimizer.pid import PID from pytorch_optimizer.optimizer.pnm import PNM from pytorch_optimizer.optimizer.prodigy import Prodigy +from pytorch_optimizer.optimizer.psgd import Kron from pytorch_optimizer.optimizer.qhadam import QHAdam from pytorch_optimizer.optimizer.qhm import QHM from pytorch_optimizer.optimizer.radam import RAdam @@ -293,6 +294,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER: FOCUS, Grams, SPAM, + Kron, Ranger25, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/pytorch_optimizer/optimizer/ftrl.py b/pytorch_optimizer/optimizer/ftrl.py index 3b1d917e..48e86b78 100644 --- a/pytorch_optimizer/optimizer/ftrl.py +++ b/pytorch_optimizer/optimizer/ftrl.py @@ -25,7 +25,7 @@ def __init__( beta: float = 0.0, lambda_1: float = 0.0, lambda_2: float = 0.0, - **kwargs + **kwargs, ): self.validate_learning_rate(lr) self.validate_non_negative(beta, 'beta') diff --git a/pytorch_optimizer/optimizer/muon.py b/pytorch_optimizer/optimizer/muon.py index 400cb095..3008df69 100644 --- a/pytorch_optimizer/optimizer/muon.py +++ b/pytorch_optimizer/optimizer/muon.py @@ -95,10 +95,10 @@ def get_parameters(params: PARAMETERS) -> List[torch.Tensor]: return new_params - def set_muon_state(self, params: PARAMETERS, adamw_params: PARAMETERS, threshold: int = 8192) -> None: + def set_muon_state(self, params: PARAMETERS, adamw_params: PARAMETERS) -> None: r"""Set use_muon flag.""" for p in params: - self.state[p]['use_muon'] = p.ndim >= 2 and p.size(0) < threshold + self.state[p]['use_muon'] = p.ndim >= 2 for p in adamw_params: self.state[p]['use_muon'] = False diff --git a/pytorch_optimizer/optimizer/psgd.py b/pytorch_optimizer/optimizer/psgd.py new file mode 100644 index 00000000..26934985 --- /dev/null +++ b/pytorch_optimizer/optimizer/psgd.py @@ -0,0 +1,375 @@ +import math +from string import ascii_lowercase, ascii_uppercase +from typing import Callable, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import CLOSURE, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.psgd_utils import norm_lower_bound + +MEMORY_SAVE_MODE_TYPE = Literal['one_diag', 'smart_one_diag', 'all_diag'] + + +def precondition_update_prob_schedule( + max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.001, flat_start: int = 500 +) -> Callable[[int], torch.Tensor]: + """Anneal pre-conditioner update probability during beginning of training. + + PSGD benefits from more pre-conditioner updates at the beginning of training, but once the pre-conditioner is + learned the update probability can drop low. + + This schedule is an exponential anneal with a flat start. Default settings keep update probability at 1.0 for 200 + steps then exponentially anneal down to `min_prob` by 4000 steps. Default settings work very well for most models + and training regimes. + """ + + def _schedule(n: int) -> torch.Tensor: + """Exponential anneal with flat start.""" + n = torch.tensor(n, dtype=torch.float32) + prob = max_prob * torch.exp(-decay * (n - flat_start)) + prob.clamp_(min=min_prob, max=max_prob) + return prob + + return _schedule + + +class Kron(BaseOptimizer): + """PSGD with the Kronecker product pre-conditioner. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param momentum: float. momentum factor. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param pre_conditioner_update_probability: Optional[Tuple[Callable, float]]. Probability of updating the + pre-conditioner. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps. + :param max_size_triangular: int. max size for dim's pre-conditioner to be triangular. + :param min_ndim_triangular: int. minimum number of dimensions a layer needs to have triangular pre-conditioners. + :param memory_save_mode: Optional[str]. None, 'one_diag', or 'all_diag', None is default to set all + pre-conditioners to be triangular, 'one_diag' sets the largest or last dim to be diagonal per layer, and + 'all_diag' sets all pre-conditioners to be diagonal. + :param momentum_into_precondition_update: bool. whether to send momentum into pre-conditioner update instead of + raw gradients. + :param mu_dtype: Optional[torch.dtype]. dtype of the momentum accumulator. + :param precondition_dtype: torch.dtype. dtype of the pre-conditioner. + :param balance_prob: float. probability of performing balancing. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-3, + momentum: float = 0.9, + weight_decay: float = 0.0, + weight_decouple: bool = True, + pre_conditioner_update_probability: Optional[Tuple[Callable, float]] = None, + max_size_triangular: int = 8192, + min_ndim_triangular: int = 2, + memory_save_mode: Optional[MEMORY_SAVE_MODE_TYPE] = None, + momentum_into_precondition_update: bool = True, + mu_dtype: Optional[torch.dtype] = None, + precondition_dtype: Optional[torch.dtype] = torch.float32, + balance_prob: float = 0.01, + **kwargs, + ): + self.validate_learning_rate(lr) + self.validate_range(momentum, 'momentum', 0.0, 1.0) + self.validate_non_negative(weight_decay, 'weight_decay') + + if pre_conditioner_update_probability is None: + pre_conditioner_update_probability = precondition_update_prob_schedule() + + self.balance_prob: float = balance_prob + self.eps: float = torch.finfo(torch.bfloat16).tiny + self.prob_step: int = 0 + self.update_counter: int = 0 + + defaults = { + 'lr': lr, + 'momentum': momentum, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'pre_conditioner_update_probability': pre_conditioner_update_probability, + 'max_size_triangular': max_size_triangular, + 'min_ndim_triangular': min_ndim_triangular, + 'memory_save_mode': memory_save_mode, + 'momentum_into_precondition_update': momentum_into_precondition_update, + 'precondition_lr': 1e-1, + 'precondition_init_scale': 1.0, + 'mu_dtype': mu_dtype, + 'precondition_dtype': precondition_dtype, + } + + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'Kron' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['momentum_buffer'] = p.grad.clone() + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + update_prob: Union[float, Callable] = self.param_groups[0]['pre_conditioner_update_probability'] + if callable(update_prob): + update_prob = update_prob(self.prob_step) + + self.update_counter += 1 + do_update: bool = self.update_counter >= 1 / update_prob + if do_update: + self.update_counter = 0 + self.prob_step += 1 + + balance: bool = np.random.random() < self.balance_prob and do_update + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1: float = self.debias(group['momentum'], group['step']) + + mu_dtype, precondition_dtype = group['mu_dtype'], group['precondition_dtype'] + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + state['momentum_buffer'] = torch.zeros_like(p, dtype=mu_dtype or p.dtype) + state['Q'], state['expressions'] = initialize_q_expressions( + p, + group['precondition_init_scale'], + group['max_size_triangular'], + group['min_ndim_triangular'], + group['memory_save_mode'], + dtype=precondition_dtype, + ) + + momentum_buffer = state['momentum_buffer'] + momentum_buffer.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum']) + + if mu_dtype is not None: + momentum_buffer = momentum_buffer.to(dtype=mu_dtype, non_blocking=True) + + de_biased_momentum = (momentum_buffer / bias_correction1).to( + dtype=precondition_dtype, non_blocking=True + ) + + if grad.dim() > 1 and balance: + balance_q(state['Q']) + + if do_update: + update_precondition( + state['Q'], + state['expressions'], + torch.randn_like(de_biased_momentum, dtype=precondition_dtype), + de_biased_momentum if group['momentum_into_precondition_update'] else grad, + group['precondition_lr'], + self.eps, + ) + + precondition_grad = get_precondition_grad(state['Q'], state['expressions'], de_biased_momentum).to( + dtype=p.dtype, non_blocking=True + ) + + precondition_grad.mul_(torch.clamp(1.1 / (precondition_grad.square().mean().sqrt() + 1e-6), max=1.0)) + + if group['weight_decay'] != 0 and p.dim() >= 2: + precondition_grad.add_(p, alpha=group['weight_decay']) + + p.add_(precondition_grad, alpha=-group['lr']) + + return loss + + +def initialize_q_expressions( + t: torch.Tensor, + scale: float, + max_size: int, + min_ndim_triangular: int, + memory_save_mode: Optional[MEMORY_SAVE_MODE_TYPE], + dtype: Optional[torch.dtype] = None, +) -> Tuple[List[torch.Tensor], Tuple[str, List[str], str]]: + r"""Initialize Q expressions. + + For a scalar or tensor t, we initialize its pre-conditioner Q and reusable einsum expressions for updating Q and + pre-conditioning gradient. + """ + letters: str = ascii_lowercase + ascii_uppercase + + dtype: torch.dtype = dtype if dtype is not None else t.dtype + shape = t.shape + if len(shape) == 0: + qs: list[torch.Tensor] = [scale * torch.ones_like(t, dtype=dtype)] + expressions_a: str = ',->' + expression_gr: List[str] = [',->'] + expression_r: str = ',,->' + + return qs, (expressions_a, expression_gr, expression_r) + + if len(shape) > 13: + raise ValueError(f'got tensor with dim {len(t.shape)}. Einstein runs out of letters!') + + scale = math.pow(scale, 1.0 / len(shape)) + + if memory_save_mode is None: + dim_diag = [False for _ in shape] + elif memory_save_mode == 'one_diag': + dim_diag = [False for _ in shape] + dim_diag[np.argsort(shape)[::-1][0]] = True + elif memory_save_mode == 'smart_one_diag': + dim_diag = [False for _ in shape] + sorted_shape = sorted(shape) + if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]: + dim_diag[np.argsort(shape)[::-1][0]] = True + elif memory_save_mode == 'all_diag': + dim_diag = [True for _ in shape] + else: + raise NotImplementedError( + f'invalid memory_save_mode {memory_save_mode}. ' + 'it must be one of [None, \'one_diag\', \'smart_one_diag\', \'all_diag\']' + ) + + qs: List[torch.Tensor] = [] + expr_gr = [] + piece_1a, piece_2a, piece_3a = [], '', '' + piece_1p, piece_2p, piece_3p, piece_4p = [], [], '', '' + for i, (size, dim_d) in enumerate(zip(shape, dim_diag)): + if size == 1 or size > max_size or len(shape) < min_ndim_triangular or dim_d: + qs.append(scale * torch.ones(size, dtype=dtype, device=t.device)) + + piece_1a.append(letters[i]) + piece_2a += letters[i] + piece_3a += letters[i] + + piece1: str = ''.join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))]) + expr_gr.append(f'{piece1},{piece1}->{letters[i + 13]}') + + piece_1p.append(letters[i + 13]) + piece_2p.append(letters[i + 13]) + piece_3p += letters[i + 13] + piece_4p += letters[i + 13] + else: + qs.append(scale * torch.eye(size, dtype=dtype, device=t.device)) + + piece_1a.append(letters[i] + letters[i + 13]) + piece_2a += letters[i + 13] + piece_3a += letters[i] + + piece1: str = ''.join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))]) + piece2: str = ''.join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))]) + expr_gr.append(f'{piece1},{piece2}->{letters[i + 13]}{letters[i + 26]}') + + a, b, c = letters[i], letters[i + 13], letters[i + 26] + piece_1p.append(a + b) + piece_2p.append(a + c) + piece_3p += c + piece_4p += b + + expr_a: str = ','.join(piece_1a) + f',{piece_2a}->{piece_3a}' + expr_r: str = ','.join(piece_1p) + ',' + ','.join(piece_2p) + f',{piece_3p}->{piece_4p}' + + return qs, (expr_a, expr_gr, expr_r) + + +def balance_q(q_in: List[torch.Tensor]) -> None: + r"""Balance Q.""" + norms = torch.stack([q.norm(float('inf')) for q in q_in]) + geometric_mean = norms.prod() ** (1 / len(q_in)) + norms = geometric_mean / norms + for i, q in enumerate(q_in): + q.mul_(norms[i]) + + +def solve_triangular_right(x: torch.Tensor, a: torch.Tensor) -> torch.Tensor: + r"""Calculate X @ inv(A).""" + orig_dtype: torch.dtype = x.dtype + x = x.to(dtype=torch.float32, non_blocking=True) + a = a.to(dtype=torch.float32, non_blocking=True) + out = torch.linalg.solve_triangular(a, x.reshape(-1, x.size(-1)), upper=True, left=False).reshape_as(x) + return out.to(dtype=orig_dtype, non_blocking=True) + + +def get_a_and_conj_b( + expr_a: List[str], g: torch.Tensor, qs: List[torch.Tensor], v: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Get A and b.conj.""" + a = torch.einsum(expr_a, *qs, g) + + order: int = g.dim() + p = list(range(order)) + + conj_b = torch.permute(v.conj(), p[1:] + p[:1]) + for i, q in enumerate(qs): + conj_b = conj_b / q if q.dim() < 2 else solve_triangular_right(conj_b, q) + if i < order - 1: + conj_b = torch.transpose(conj_b, i, order - 1) + + return a, conj_b + + +def get_q_terms(expr_gs: List[str], a: torch.Tensor, conj_b: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]: + r"""Get Q terms.""" + terms: List = [] + for expr_g in expr_gs: + term1 = torch.einsum(expr_g, a, a.conj()) + term2 = torch.einsum(expr_g, conj_b.conj(), conj_b) + terms.append((term1, term2)) + return terms + + +def update_precondition( + qs: List[torch.Tensor], + expressions: List[Tuple[str, List[str], str]], + v: torch.Tensor, + g: torch.Tensor, + step: int, + eps: float, +) -> None: + r"""Update Kronecker product pre-conditioner Q with pair (V, G).""" + expr_a, expr_gs, _ = expressions + + a, conj_b = get_a_and_conj_b(expr_a, g, qs, v) + + q_terms: List[Tuple[torch.Tensor, torch.Tensor]] = get_q_terms(expr_gs, a, conj_b) + + for q, (term1, term2) in zip(qs, q_terms): + tmp = term1 - term2 + tmp *= step + + if q.dim() < 2: + tmp *= q + tmp.div_((term1 + term2).norm(float('inf')).add_(eps)) + else: + tmp = torch.triu(tmp) + tmp.div_(norm_lower_bound(term1 + term2).add_(eps)) + tmp @= q + + q.sub_(tmp) + + +def get_precondition_grad(qs: list[torch.Tensor], expressions: list[str], g: torch.Tensor) -> torch.Tensor: + r"""Precondition gradient G with pre-conditioner Q.""" + return torch.einsum(expressions[-1], *[x.conj() for x in qs], *qs, g) diff --git a/pytorch_optimizer/optimizer/psgd_utils.py b/pytorch_optimizer/optimizer/psgd_utils.py new file mode 100644 index 00000000..b80bdb35 --- /dev/null +++ b/pytorch_optimizer/optimizer/psgd_utils.py @@ -0,0 +1,94 @@ +from typing import List, Tuple + +import torch +from torch.linalg import vector_norm + + +def damped_pair_vg(g: torch.Tensor, damp: float = 2 ** -13) -> Tuple[torch.Tensor, torch.Tensor]: # fmt: skip + r"""Get damped pair v and g. + + Instead of return (v, g), it returns pair (v, g + sqrt(eps)*mean(abs(g))*v) + such that the covariance matrix of the modified g is lower bound by eps * (mean(abs(g)))**2 * I + This should damp the pre-conditioner to encourage numerical stability. + The default amount of damping is 2**(-13), slightly smaller than sqrt(eps('single')). + + If v is integrated out, let's just use the modified g; + If hvp is used, recommend to use L2 regularization to lower bound the Hessian, although this method also works. + + Please check example + https://github.com/lixilinx/psgd_torch/blob/master/misc/psgd_with_finite_precision_arithmetic.py + for the rationale to set default damping level to 2**(-13). + """ + v = torch.randn_like(g) + return v, g + damp * torch.mean(torch.abs(g)) * v + + +def norm_lower_bound(a: torch.Tensor) -> torch.Tensor: + r"""Get a cheap lower bound for the spectral norm of A. + + Numerical results on random matrices with a wide range of distributions and sizes suggest, + norm(A) <= sqrt(2) * norm_lower_bound(A) + Looks to be a very tight lower bound. + """ + max_abs = torch.max(torch.abs(a)) + if max_abs <= 0: + return max_abs + + a.div_(max_abs) + + aa = torch.real(a * a.conj()) + value0, i = torch.max(torch.sum(aa, dim=0), 0) + value1, j = torch.max(torch.sum(aa, dim=1), 0) + + if value0 > value1: + x = a[:, i].conj() @ a + return max_abs * vector_norm((x / vector_norm(x)) @ a.H) + + x = a @ a[j].conj() + return max_abs * vector_norm(a.H @ (x / vector_norm(x))) + + +def woodbury_identity(inv_a: torch.Tensor, u: torch.Tensor, v: torch.Tensor) -> None: + r"""Get the Woodbury identity. + + inv(A + U * V) = inv(A) - inv(A) * U * inv(I + V * inv(A) * U) * V * inv(A) + + with inplace update of inv_a. + + Note that using the Woodbury identity multiple times could accumulate numerical errors. + """ + inv_au = inv_a @ u + v_inv_au = v @ inv_au + + ident = torch.eye(v_inv_au.shape[0], dtype=v_inv_au.dtype, device=v_inv_au.device) + inv_a.sub_(inv_au @ torch.linalg.solve(ident + v_inv_au, v @ inv_a)) + + +def triu_with_diagonal_and_above(a: torch.Tensor) -> torch.Tensor: + r"""Get triu with diagonal and above. + + It is useful as for a small A, the R of QR decomposition qr(I + A) is about I + triu(A, 0) + triu(A, 1) + """ + return torch.triu(a, diagonal=0) + torch.triu(a, diagonal=1) + + +def update_precondition_dense( + q: torch.Tensor, dxs: List[torch.Tensor], dgs: List[torch.Tensor], step: float = 0.01, eps: float = 1.2e-38 +) -> torch.Tensor: + r"""Update dense pre-conditioner P = Q^T * Q. + + :param q: torch.Tensor. Cholesky factor of pre-conditioner with positive diagonal entries. + :param dxs: List[torch.Tensor]. list of perturbations of parameters. + :param dgs: List[torch.Tensor]. list of perturbations of gradients. + :param step: float. update step size normalized to range [0, 1]. + :param eps: float. an offset to avoid division by zero. + """ + dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs]) + dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs]) + + a = q.mm(dg) + b = torch.linalg.solve_triangular(q.t(), dx, upper=False) + + grad = torch.triu(a.mm(a.t()) - b.mm(b.t())) + + return q - (step / norm_lower_bound(grad).add_(eps)) * grad.mm(q) diff --git a/pytorch_optimizer/optimizer/rotograd.py b/pytorch_optimizer/optimizer/rotograd.py index 4bed7224..3873383d 100644 --- a/pytorch_optimizer/optimizer/rotograd.py +++ b/pytorch_optimizer/optimizer/rotograd.py @@ -378,7 +378,7 @@ def _rep_grad(self): self.initial_grads = grad_norms conv_ratios = [torch.ones((1,)) for _ in range(len(self.initial_grads))] else: - conv_ratios = [x / y for x, y, in zip(grad_norms, self.initial_grads)] + conv_ratios = [x / y for x, y in zip(grad_norms, self.initial_grads)] self.counter += 1 diff --git a/pytorch_optimizer/optimizer/shampoo_utils.py b/pytorch_optimizer/optimizer/shampoo_utils.py index d6183d0f..cb6746d7 100644 --- a/pytorch_optimizer/optimizer/shampoo_utils.py +++ b/pytorch_optimizer/optimizer/shampoo_utils.py @@ -186,7 +186,8 @@ def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor: n: int = len(indices) + 1 partitions: List[torch.Tensor] = [ - torch.cat(partitions[idx:idx + n], dim=i) for idx in range(0, len(partitions), n) # fmt: skip + torch.cat(partitions[idx:idx + n], dim=i) + for idx in range(0, len(partitions), n) # fmt: skip ] return partitions[0] @@ -367,7 +368,7 @@ def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor: self.precondition_block( partitioned_grad, self.should_precondition_dims, - self.pre_conditioners[i * self.rank:(i + 1) * self.rank] # fmt: skip + self.pre_conditioners[i * self.rank:(i + 1) * self.rank], # fmt: skip ) for i, partitioned_grad in enumerate(partitioned_grads) ] diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index befce42c..4c95e347 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -48,8 +48,7 @@ def is_deepspeed_zero3_enabled() -> bool: return is_deepspeed_zero3_enabled() # pragma: no cover warnings.warn( - 'you need to install `transformers` to use `is_deepspeed_zero3_enabled` function. ' - 'it will return False.', + 'you need to install `transformers` to use `is_deepspeed_zero3_enabled` function. it will return False.', category=ImportWarning, stacklevel=2, ) diff --git a/tests/constants.py b/tests/constants.py index daecbc52..f37b8e7c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -63,6 +63,7 @@ Gravity, GrokFastAdamW, Kate, + Kron, Lamb, LaProp, Lion, @@ -557,6 +558,7 @@ (TAM, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (AdaTAM, {'lr': 1e-1, 'weight_decay': 1e-3}, 5), (FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5), + (Kron, {'lr': 1e0, 'weight_decay': 1e-3}, 3), (Ranger25, {'lr': 5e0}, 2), (Ranger25, {'lr': 5e0, 't_alpha_beta3': 5}, 2), (Ranger25, {'lr': 2e-1, 'stable_adamw': False, 'orthograd': False, 'eps': None}, 3), diff --git a/tests/test_general_optimizer_parameters.py b/tests/test_general_optimizer_parameters.py index af841137..932843f3 100644 --- a/tests/test_general_optimizer_parameters.py +++ b/tests/test_general_optimizer_parameters.py @@ -55,6 +55,7 @@ def test_epsilon(optimizer_name): 'demo', 'muon', 'focus', + 'kron', ): pytest.skip(f'skip {optimizer_name} optimizer') diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index 30746147..0aafa1fc 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 91 + assert len(get_supported_optimizers()) == 92 assert len(get_supported_optimizers('adam*')) == 7 assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 10 diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 50dd73bf..32b6eb84 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -853,3 +853,21 @@ def test_spam_optimizer(): optimizer = load_optimizer('spam')([simple_parameter(True)], grad_accu_steps=0, update_proj_gap=1) optimizer.step() + + +def test_kron_optimizer(): + model = Example() + + optimizer = load_optimizer('kron')( + model.parameters(), + weight_decay=1e-3, + pre_conditioner_update_probability=1.0, + balance_prob=1.0, + mu_dtype=torch.bfloat16, + ) + optimizer.zero_grad() + + model.fc1.weight.grad = torch.randn((1, 1)) + model.norm1.weight.grad = torch.randn((1,)) + + optimizer.step() diff --git a/tests/test_utils.py b/tests/test_utils.py index 19b00069..db72d7fd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,14 @@ from pytorch_optimizer.optimizer import get_optimizer_parameters from pytorch_optimizer.optimizer.nero import neuron_mean, neuron_norm +from pytorch_optimizer.optimizer.psgd import initialize_q_expressions +from pytorch_optimizer.optimizer.psgd_utils import ( + damped_pair_vg, + norm_lower_bound, + triu_with_diagonal_and_above, + update_precondition_dense, + woodbury_identity, +) from pytorch_optimizer.optimizer.shampoo_utils import ( BlockPartitioner, PreConditioner, @@ -271,3 +279,59 @@ def test_orthograd_name(): _ = optimizer.state assert str(optimizer).lower() == 'orthograd' + + +def test_damped_pair_vg(): + x = torch.zeros(2) + y = damped_pair_vg(x)[1] + + torch.testing.assert_close(x, y) + + +def test_norm_lower_bound(): + x = torch.zeros(1) + y = norm_lower_bound(x) + torch.testing.assert_close(y, x.squeeze()) + + x = torch.FloatTensor([[1, 1]]) + y = norm_lower_bound(x) + torch.testing.assert_close(y, torch.tensor(1.4142135)) + + x = torch.FloatTensor([[2, 1], [2, 1]]) + y = norm_lower_bound(x) + torch.testing.assert_close(y, torch.tensor(3.16227769)) + + +def test_woodbury_identity(): + x = torch.FloatTensor([[1]]) + woodbury_identity(x, x, x) + + +def test_triu_with_diagonal_and_above(): + x = torch.FloatTensor([[1, 2], [3, 4]]) + y = triu_with_diagonal_and_above(x) + torch.testing.assert_close(y, torch.FloatTensor([[1, 4], [0, 4]])) + + +def test_update_precondition_dense(): + q = torch.FloatTensor([[1]]) + dxs = [q] * 1 + dgs = [q] * 1 + + y = update_precondition_dense(q, dxs, dgs) + + torch.testing.assert_close(y, q) + + +def test_initialize_q_expressions(): + x = torch.zeros(1) + _ = initialize_q_expressions(x.squeeze(), 0.0, 0, 0, None) + + with pytest.raises(ValueError): + initialize_q_expressions(x.expand(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), 0.0, 0, 0, None) + + with pytest.raises(NotImplementedError): + initialize_q_expressions(x, 0.0, 0, 0, 'invalid') + + for memory_save_mode in ('one_diag', 'all_diag', 'smart_one_diag'): + initialize_q_expressions(torch.FloatTensor([[1], [2]]), 0.0, 0, 0, memory_save_mode)