diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 8b8916c631..de956c243e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,6 +6,31 @@ on: - "v*" jobs: + publish-burn-router: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-router + needs: + - publish-burn-common + - publish-burn-tensor + # dev dependencies + - publish-burn-autodiff + - publish-burn-ndarray + - publish-burn-wgpu + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + + publish-burn-remote: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-remote + needs: + - publish-burn-common + - publish-burn-tensor + - publish-burn-router + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + publish-burn-derive: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 with: @@ -162,6 +187,7 @@ jobs: - publish-burn-tch - publish-burn-ndarray - publish-burn-candle + - publish-burn-remote with: crate: burn-core secrets: diff --git a/Cargo.lock b/Cargo.lock index 7b555924fc..c9ce522699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,7 +41,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -62,21 +62,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" -[[package]] -name = "alloc-no-stdlib" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" -dependencies = [ - "alloc-no-stdlib", -] - [[package]] name = "allocator-api2" version = "0.2.21" @@ -139,19 +124,20 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "arbitrary" @@ -188,7 +174,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -206,12 +192,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" -[[package]] -name = "arrayref" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" - [[package]] name = "arrayvec" version = "0.7.6" @@ -269,34 +249,25 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", -] - -[[package]] -name = "atoi" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" -dependencies = [ - "num-traits", + "syn 2.0.98", ] [[package]] name = "atoi_simd" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" +checksum = "4790f9e8961209112beb783d85449b508673cf4a6a419c8449b210743ac4dbe9" [[package]] name = "atomic-waker" @@ -310,17 +281,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -352,19 +312,19 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.9" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ - "async-trait", "axum-core", "base64 0.22.1", "bytes", + "form_urlencoded", "futures-util", - "http 1.2.0", - "http-body 1.0.1", + "http", + "http-body", "http-body-util", - "hyper 1.5.2", + "hyper", "hyper-util", "itoa", "matchit", @@ -378,9 +338,9 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sha1", - "sync_wrapper 1.0.2", + "sync_wrapper", "tokio", - "tokio-tungstenite 0.24.0", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -389,20 +349,19 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.5" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ - "async-trait", "bytes", "futures-util", - "http 1.2.0", - "http-body 1.0.1", + "http", + "http-body", "http-body-util", "mime", "pin-project-lite", "rustversion", - "sync_wrapper 1.0.2", + "sync_wrapper", "tower-layer", "tower-service", "tracing", @@ -410,31 +369,31 @@ dependencies = [ [[package]] name = "backend-comparison" -version = "0.16.0" +version = "0.17.0" dependencies = [ "arboard", "burn", "burn-common", - "clap 4.5.23", + "chrono", + "clap", "colored", "cubecl", "derive-new 0.7.0", "dirs", - "github-device-flow", "half", "indicatif", "log", "os_info", "percent-encoding", "rand", - "reqwest 0.12.12", + "reqwest", "rstest", "serde", "serde_json", "serial_test", "strum", "strum_macros", - "sysinfo 0.32.1", + "sysinfo", "tracing-subscriber", "wgpu", "wsl", @@ -461,12 +420,6 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "base64" version = "0.22.1" @@ -528,9 +481,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" dependencies = [ "serde", ] @@ -541,19 +494,6 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" -[[package]] -name = "blake3" -version = "1.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", -] - [[package]] name = "blas-src" version = "0.10.0" @@ -589,32 +529,11 @@ dependencies = [ "objc2", ] -[[package]] -name = "brotli" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "4.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bstr" -version = "1.11.1" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "786a307d683a5bf92e6fd5fd69a7eb613751668d1d8d67d802846dfe367c62c8" +checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" dependencies = [ "memchr", "serde", @@ -628,13 +547,13 @@ checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "burn" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "burn-train", @@ -642,7 +561,7 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -654,7 +573,7 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-tch", @@ -666,14 +585,14 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.16.0" +version = "0.17.0" dependencies = [ - "cubecl-common 0.4.0", + "cubecl-common", "dashmap", - "getrandom", + "getrandom 0.2.15", "indicatif", "rayon", - "reqwest 0.12.12", + "reqwest", "serde", "tokio", "web-time", @@ -681,7 +600,7 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.16.0" +version = "0.17.0" dependencies = [ "ahash", "bincode", @@ -713,13 +632,13 @@ dependencies = [ "serde_json", "spin", "tempfile", - "thiserror 2.0.9", + "thiserror 2.0.11", "uuid", ] [[package]] name = "burn-cuda" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -734,7 +653,7 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "csv", @@ -761,22 +680,22 @@ dependencies = [ "strum", "strum_macros", "tempfile", - "thiserror 2.0.9", + "thiserror 2.0.11", ] [[package]] name = "burn-derive" -version = "0.16.0" +version = "0.17.0" dependencies = [ "derive-new 0.7.0", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "burn-fusion" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -790,7 +709,7 @@ dependencies = [ [[package]] name = "burn-hip" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -805,9 +724,10 @@ dependencies = [ [[package]] name = "burn-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", + "burn-ndarray", "candle-core", "derive-new 0.7.0", "half", @@ -821,8 +741,8 @@ dependencies = [ "rust-format", "serde", "serde_json", - "syn 2.0.95", - "thiserror 2.0.9", + "syn 2.0.98", + "thiserror 2.0.11", "tracing-core", "tracing-subscriber", "zip 2.2.2", @@ -830,7 +750,7 @@ dependencies = [ [[package]] name = "burn-jit" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -856,7 +776,7 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.16.0" +version = "0.17.0" dependencies = [ "atomic_float", "blas-src", @@ -876,7 +796,7 @@ dependencies = [ [[package]] name = "burn-no-std-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-ndarray", @@ -885,12 +805,11 @@ dependencies = [ [[package]] name = "burn-remote" -version = "0.16.0" +version = "0.17.0" dependencies = [ "async-channel", "axum", "burn-common", - "burn-remote", "burn-router", "burn-tensor", "derive-new 0.7.0", @@ -900,14 +819,14 @@ dependencies = [ "serde", "serde_bytes", "tokio", - "tokio-tungstenite 0.26.1", + "tokio-tungstenite", "tracing-core", "tracing-subscriber", ] [[package]] name = "burn-router" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -921,7 +840,7 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-tensor", @@ -934,7 +853,7 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bincode", "burn-common", @@ -955,7 +874,7 @@ dependencies = [ [[package]] name = "burn-tensor-testgen" -version = "0.16.0" +version = "0.17.0" dependencies = [ "proc-macro2", "quote", @@ -963,7 +882,7 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.16.0" +version = "0.17.0" dependencies = [ "async-channel", "burn-core", @@ -974,7 +893,7 @@ dependencies = [ "ratatui", "rstest", "serde", - "sysinfo 0.32.1", + "sysinfo", "systemstat", "tracing-appender", "tracing-core", @@ -983,7 +902,7 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -1004,13 +923,13 @@ dependencies = [ [[package]] name = "bytemuck_derive" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -1027,9 +946,12 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +dependencies = [ + "serde", +] [[package]] name = "bytesize" @@ -1060,9 +982,9 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1e306c8a4276ba57ce9fac76d823cc8c8a7fca14bf222ac20ad8b12c4273152" +checksum = "855dfedff437d2681d68e1f34ae559d88b0dd84aa5a6b63f2c8e75ebdd875bbf" dependencies = [ "accelerate-src", "byteorder", @@ -1072,7 +994,7 @@ dependencies = [ "gemm", "half", "libc", - "memmap2 0.9.5", + "memmap2", "metal 0.27.0", "num-traits", "num_cpus", @@ -1090,18 +1012,18 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbd8ea6588f3c6286ea89a52dad3365f0536fd0b71e729fa998cc2347f1df3b6" +checksum = "53343628fa470b7075c28c589b98735b4220b464e37ddbb8e117040e199f4787" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc6621c7e2202f4f129bcc3185c2c6d4fa2fc6b8f3f2b07eaf7c06042910c83" +checksum = "50fa64274a009a5d95c542b10bf3a4ea809bd394654c6ae99233bcc35b3a33ef" dependencies = [ "metal 0.27.0", "once_cell", @@ -1135,9 +1057,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.4" +version = "1.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "e4730490333d58093109dc02c23174c3f4d490998c3fed3cc8e82d57afedb9cf" dependencies = [ "jobserver", "libc", @@ -1160,12 +1082,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - [[package]] name = "cfg_aliases" version = "0.2.1" @@ -1182,16 +1098,15 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-targets 0.52.6", ] [[package]] name = "chrono-tz" -version = "0.8.6" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +checksum = "9c6ac4f2c0bf0f44e9161aec9675e1050aa4a530663c4a9e37e108fa948bca9f" dependencies = [ "chrono", "chrono-tz-build", @@ -1200,42 +1115,14 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.2.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] -[[package]] -name = "ciborium" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" - -[[package]] -name = "ciborium-ll" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" -dependencies = [ - "ciborium-io", - "half", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1248,75 +1135,36 @@ dependencies = [ [[package]] name = "clap" -version = "3.2.25" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" -dependencies = [ - "atty", - "bitflags 1.3.2", - "clap_derive 3.2.25", - "clap_lex 0.2.4", - "indexmap 1.9.3", - "once_cell", - "strsim 0.10.0", - "termcolor", - "textwrap", -] - -[[package]] -name = "clap" -version = "4.5.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", - "clap_derive 4.5.18", + "clap_derive", ] [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", - "clap_lex 0.7.4", - "strsim 0.11.1", -] - -[[package]] -name = "clap_derive" -version = "3.2.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008" -dependencies = [ - "heck 0.4.1", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", + "clap_lex", + "strsim", ] [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.95", -] - -[[package]] -name = "clap_lex" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" -dependencies = [ - "os_str_bytes", + "syn 2.0.98", ] [[package]] @@ -1336,9 +1184,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.52" +version = "0.1.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" +checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6" dependencies = [ "cc", ] @@ -1389,9 +1237,9 @@ dependencies = [ [[package]] name = "compact_str" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644" +checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" dependencies = [ "castaway", "cfg-if", @@ -1456,16 +1304,6 @@ dependencies = [ "libc", ] -[[package]] -name = "core-foundation" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1479,7 +1317,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081" dependencies = [ "bitflags 1.3.2", - "core-foundation 0.9.4", + "core-foundation", "core-graphics-types", "foreign-types 0.5.0", "libc", @@ -1492,15 +1330,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" dependencies = [ "bitflags 1.3.2", - "core-foundation 0.9.4", + "core-foundation", "libc", ] [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1578,7 +1416,7 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "crossterm_winapi", "mio", "parking_lot 0.12.3", @@ -1599,9 +1437,9 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -1636,46 +1474,33 @@ dependencies = [ [[package]] name = "cubecl" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-core", "cubecl-cuda", "cubecl-hip", "cubecl-linalg", - "cubecl-runtime 0.4.0", + "cubecl-reduce", + "cubecl-runtime", "cubecl-wgpu", "half", ] [[package]] name = "cubecl-common" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d402af454241d28d303a4cf4d2a861fae18404d65964c31934f746a40a6cf4" -dependencies = [ - "derive-new 0.6.0", - "embassy-futures", - "futures-lite", - "getrandom", - "log", - "portable-atomic", - "rand", - "serde", - "spin", - "web-time", -] - -[[package]] -name = "cubecl-common" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ + "bytemuck", "derive-new 0.6.0", + "derive_more 1.0.0", "embassy-futures", "futures-lite", - "getrandom", + "getrandom 0.2.15", + "half", "log", + "num-traits", "portable-atomic", "rand", "serde", @@ -1685,13 +1510,15 @@ dependencies = [ [[package]] name = "cubecl-core" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ + "bitflags 2.8.0", "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", + "cubecl-ir", "cubecl-macros", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "derive-new 0.6.0", "derive_more 1.0.0", "half", @@ -1704,13 +1531,13 @@ dependencies = [ [[package]] name = "cubecl-cpp" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "derive-new 0.6.0", "half", "log", @@ -1718,14 +1545,14 @@ dependencies = [ [[package]] name = "cubecl-cuda" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", "cubecl-cpp", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "cudarc", "derive-new 0.6.0", "half", @@ -1734,15 +1561,15 @@ dependencies = [ [[package]] name = "cubecl-hip" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", "cubecl-cpp", "cubecl-hip-sys", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "derive-new 0.6.0", "half", "log", @@ -1751,86 +1578,104 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "6.3.0" +version = "6.3.1001" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9974218b3ff1f1e7b2f11ce254fd90b3ebcc2af6b4d084f7f6a0c351fb16112c" +checksum = "c7e92df7f9feff6a469932fc4d4b349d28000af9e6f34e583eb4f8df70038d48" dependencies = [ "libc", ] +[[package]] +name = "cubecl-ir" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" +dependencies = [ + "cubecl-common", + "cubecl-macros-internal", + "derive_more 1.0.0", + "float-ord", + "fnv", + "half", + "hashbrown 0.14.5", + "num-traits", + "portable-atomic", + "serde", + "variadics_please", +] + [[package]] name = "cubecl-linalg" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-core", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "half", "serde", ] [[package]] name = "cubecl-macros" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ - "cubecl-common 0.4.0", + "cubecl-common", "darling", "derive-new 0.6.0", "ident_case", "prettyplease", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", +] + +[[package]] +name = "cubecl-macros-internal" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.98", ] [[package]] name = "cubecl-opt" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ - "cubecl-common 0.4.0", - "cubecl-core", + "cubecl-common", + "cubecl-ir", "float-ord", "log", "num", "petgraph", "smallvec", "stable-vec", + "type-map", ] [[package]] -name = "cubecl-runtime" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3468467f412dff4bbf97fb5061a3557445f017299e2fb73ef7b96c6cdb799bc3" +name = "cubecl-reduce" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ - "async-channel", - "async-lock", - "cfg_aliases 0.2.1", - "cubecl-common 0.3.0", - "derive-new 0.6.0", - "dirs", - "hashbrown 0.14.5", - "log", - "md5", - "sanitize-filename 0.5.0", - "serde", - "serde_json", - "spin", - "wasm-bindgen-futures", + "cubecl-core", + "cubecl-runtime", + "num-traits", ] [[package]] name = "cubecl-runtime" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "async-channel", "async-lock", - "cfg_aliases 0.2.1", - "cubecl-common 0.4.0", + "cfg_aliases", + "cubecl-common", "derive-new 0.6.0", "dirs", "hashbrown 0.14.5", @@ -1840,18 +1685,20 @@ dependencies = [ "serde", "serde_json", "spin", + "variadics_please", "wasm-bindgen-futures", ] [[package]] name = "cubecl-spirv" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ - "cubecl-common 0.4.0", + "bitflags 2.8.0", + "cubecl-common", "cubecl-core", "cubecl-opt", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "half", "hashbrown 0.14.5", "rspirv", @@ -1859,17 +1706,17 @@ dependencies = [ [[package]] name = "cubecl-wgpu" -version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "ash", "async-channel", "bytemuck", "cfg-if", - "cfg_aliases 0.2.1", - "cubecl-common 0.4.0", + "cfg_aliases", + "cubecl-common", "cubecl-core", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "cubecl-spirv", "derive-new 0.6.0", "hashbrown 0.14.5", @@ -1880,9 +1727,9 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.12.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25" +checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" dependencies = [ "half", "libloading", @@ -1890,17 +1737,17 @@ dependencies = [ [[package]] name = "custom-csv-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "csv", - "reqwest 0.12.12", + "reqwest", "serde", ] [[package]] name = "custom-cubecl-kernel" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-jit", @@ -1913,7 +1760,7 @@ dependencies = [ [[package]] name = "custom-image-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "flate2", @@ -1922,7 +1769,7 @@ dependencies = [ [[package]] name = "custom-renderer" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -1934,7 +1781,7 @@ dependencies = [ [[package]] name = "custom-training-loop" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -1946,7 +1793,7 @@ dependencies = [ [[package]] name = "custom-wgpu-kernel" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -1976,8 +1823,8 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim 0.11.1", - "syn 2.0.95", + "strsim", + "syn 2.0.98", ] [[package]] @@ -1988,7 +1835,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2007,9 +1854,9 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +checksum = "0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f" [[package]] name = "deflate64" @@ -2034,7 +1881,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2045,7 +1892,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2056,7 +1903,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2077,7 +1924,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2087,7 +1934,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2098,7 +1945,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2118,7 +1965,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", "unicode-xid", ] @@ -2174,15 +2021,9 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] -[[package]] -name = "doc-comment" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" - [[package]] name = "document-features" version = "0.2.10" @@ -2204,9 +2045,9 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" [[package]] name = "dyn-stack" @@ -2223,9 +2064,6 @@ name = "either" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" -dependencies = [ - "serde", -] [[package]] name = "embassy-futures" @@ -2254,10 +2092,10 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2269,14 +2107,14 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "env_filter" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ "log", "regex", @@ -2284,9 +2122,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" dependencies = [ "anstream", "anstyle", @@ -2331,9 +2169,9 @@ checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" [[package]] name = "event-listener" -version = "5.3.1" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -2367,9 +2205,9 @@ dependencies = [ [[package]] name = "fake" -version = "3.0.1" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "661cb0601b5f4050d1e65452c5b0ea555c0b3e88fb5ed7855906adc6c42523ef" +checksum = "aef603df4ba9adbca6a332db7da6f614f21eafefbaf8e087844e452fdec152d0" dependencies = [ "deunicode", "rand", @@ -2388,10 +2226,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" [[package]] -name = "fast-float" -version = "0.2.0" +name = "fast-float2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" +checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" [[package]] name = "faster-hex" @@ -2468,9 +2306,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" [[package]] name = "foreign-types" @@ -2499,7 +2337,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2523,16 +2361,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "fs4" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c6b3bd49c37d2aa3f3f2220233b29a7cd23f79d1fe70e5337d25fb390793de" -dependencies = [ - "rustix", - "windows-sys 0.52.0", -] - [[package]] name = "futures" version = "0.3.31" @@ -2583,9 +2411,9 @@ checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cef40d21ae2c515b51041df9ed313ed21e572df340ea58a922a0aefe7e8891a1" +checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532" dependencies = [ "fastrand", "futures-core", @@ -2602,7 +2430,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -2788,10 +2616,22 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + [[package]] name = "gif" version = "0.13.1" @@ -2808,20 +2648,6 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" -[[package]] -name = "github-device-flow" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98852ab71f5613dac02a0d1b41f3ffaf993b69449904dd13a10575612a56074d" -dependencies = [ - "chrono", - "clap 3.2.25", - "reqwest 0.11.27", - "serde", - "serde_derive", - "serde_json", -] - [[package]] name = "gix-features" version = "0.39.1" @@ -2836,9 +2662,9 @@ dependencies = [ [[package]] name = "gix-fs" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34740384d8d763975858fa2c176b68652a6fcc09f616e24e3ce967b0d370e4d8" +checksum = "3b3d4fac505a621f97e5ce2c69fdc425742af00c0920363ca4074f0eb48b1db9" dependencies = [ "fastrand", "gix-features", @@ -2852,7 +2678,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b5eccc17194ed0e67d49285e4853307e4147e95407f91c1c3e4a13ba9f4e4ce" dependencies = [ "faster-hex", - "thiserror 2.0.9", + "thiserror 2.0.11", ] [[package]] @@ -2873,15 +2699,15 @@ dependencies = [ [[package]] name = "gix-trace" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04bdde120c29f1fc23a24d3e115aeeea3d60d8e65bab92cc5f9d90d9302eb952" +checksum = "7c396a2036920c69695f760a65e7f2677267ccf483f25046977d87e4cb2665f7" [[package]] name = "gix-utils" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba427e3e9599508ed98a6ddf8ed05493db114564e338e41f6a996d2e4790335f" +checksum = "ff08f24e03ac8916c478c8419d7d3c33393da9bb41fa4c24455d5406aeefd35f" dependencies = [ "fastrand", "unicode-normalization", @@ -2900,9 +2726,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "globset" @@ -2923,16 +2749,16 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "ignore", "walkdir", ] [[package]] name = "glow" -version = "0.14.2" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d51fa363f025f5c111e03f13eda21162faeacb6911fe8caa0c0349f9cf0c4483" +checksum = "c5e5ea60d70410161c8bf5da3fdfeaa1c72ed2c15f8bbb9d19fe3a4fad085f08" dependencies = [ "js-sys", "slotmap", @@ -2942,9 +2768,9 @@ dependencies = [ [[package]] name = "glutin_wgl_sys" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e1951bbd9434a81aa496fe59ccc2235af3820d27b85f9314e279609211e2c" +checksum = "2c4ee00b289aba7a9e5306d57c2d05499b2e5dc427f84ac708bd2c090212cf3e" dependencies = [ "gl_generator", ] @@ -2955,7 +2781,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "gpu-alloc-types", ] @@ -2965,7 +2791,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", ] [[package]] @@ -2982,13 +2808,13 @@ dependencies = [ [[package]] name = "gpu-descriptor" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c08c1f623a8d0b722b8b99f821eb0ba672a1618f0d3b16ddbee1cedd2dd8557" +checksum = "dcf29e94d6d243368b7a56caa16bc213e4f9f8ed38c4d9557069527b5d5281ca" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "gpu-descriptor-types", - "hashbrown 0.14.5", + "hashbrown 0.15.2", ] [[package]] @@ -2997,37 +2823,18 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", ] [[package]] name = "guide" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", "serde", ] -[[package]] -name = "h2" -version = "0.3.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http 0.2.12", - "indexmap 2.7.0", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "h2" version = "0.4.7" @@ -3039,8 +2846,8 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.2.0", - "indexmap 2.7.0", + "http", + "indexmap", "slab", "tokio", "tokio-util", @@ -3062,22 +2869,6 @@ dependencies = [ "serde", ] -[[package]] -name = "halfbrown" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" -dependencies = [ - "hashbrown 0.14.5", - "serde", -] - -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.13.2" @@ -3121,27 +2912,12 @@ dependencies = [ "hashbrown 0.14.5", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.3.9" @@ -3201,17 +2977,6 @@ version = "3.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" -[[package]] -name = "http" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "http" version = "1.2.0" @@ -3223,17 +2988,6 @@ dependencies = [ "itoa", ] -[[package]] -name = "http-body" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" -dependencies = [ - "bytes", - "http 0.2.12", - "pin-project-lite", -] - [[package]] name = "http-body" version = "1.0.1" @@ -3241,7 +2995,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http", ] [[package]] @@ -3252,16 +3006,16 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.2.0", - "http-body 1.0.1", + "http", + "http-body", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" [[package]] name = "httpdate" @@ -3277,40 +3031,16 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" -dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "h2 0.3.26", - "http 0.2.12", - "http-body 0.4.6", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", - "want", -] - -[[package]] -name = "hyper" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.7", - "http 1.2.0", - "http-body 1.0.1", + "h2", + "http", + "http-body", "httparse", "httpdate", "itoa", @@ -3322,35 +3052,21 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.3" +version = "0.27.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.2.0", - "hyper 1.5.2", + "http", + "hyper", "hyper-util", "rustls", - "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls", "tower-service", ] -[[package]] -name = "hyper-tls" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" -dependencies = [ - "bytes", - "hyper 0.14.32", - "native-tls", - "tokio", - "tokio-native-tls", -] - [[package]] name = "hyper-tls" version = "0.6.0" @@ -3359,7 +3075,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.5.2", + "hyper", "hyper-util", "native-tls", "tokio", @@ -3376,9 +3092,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "hyper 1.5.2", + "http", + "http-body", + "hyper", "pin-project-lite", "socket2", "tokio", @@ -3524,7 +3240,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -3595,13 +3311,12 @@ dependencies = [ [[package]] name = "image-classification-web" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-candle", "burn-import", "console_error_panic_hook", - "cubecl-runtime 0.3.0", "js-sys", "log", "serde", @@ -3615,9 +3330,9 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f" +checksum = "b77d01e822461baa8409e156015a1d91735549f0f2c17691bd2d996bef238f7f" dependencies = [ "byteorder-lite", "quick-error", @@ -3631,19 +3346,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - -[[package]] -name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -3652,9 +3357,9 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.17.9" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" dependencies = [ "console", "number_prefix", @@ -3680,16 +3385,15 @@ dependencies = [ [[package]] name = "instability" -version = "0.3.3" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b829f37dead9dc39df40c2d3376c179fdfd2ac771f53f55d3c30dc096a3c0c6e" +checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d" dependencies = [ "darling", "indoc", - "pretty_assertions", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -3709,14 +3413,14 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is_terminal_polyfill" @@ -3757,12 +3461,6 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" -[[package]] -name = "itoap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" - [[package]] name = "jni-sys" version = "0.3.0" @@ -3786,9 +3484,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ "once_cell", "wasm-bindgen", @@ -3825,15 +3523,15 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libfuzzer-sys" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa" +checksum = "cf78f52d400cf2d84a3a973a78a592b4adc535739e0a5597a0da6f0c357adc75" dependencies = [ "arbitrary", "cc", @@ -3846,7 +3544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -3861,7 +3559,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "libc", "redox_syscall 0.5.8", ] @@ -3879,9 +3577,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "litemap" @@ -3913,9 +3611,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "loop9" @@ -3937,9 +3635,9 @@ dependencies = [ [[package]] name = "lz4" -version = "1.28.0" +version = "1.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" dependencies = [ "lz4-sys", ] @@ -4000,9 +3698,9 @@ dependencies = [ [[package]] name = "matchit" -version = "0.7.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "matrixmultiply" @@ -4027,16 +3725,6 @@ dependencies = [ "rayon", ] -[[package]] -name = "md-5" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" -dependencies = [ - "cfg-if", - "digest", -] - [[package]] name = "md5" version = "0.7.0" @@ -4049,15 +3737,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "memmap2" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" -dependencies = [ - "libc", -] - [[package]] name = "memmap2" version = "0.9.5" @@ -4069,21 +3748,27 @@ dependencies = [ ] [[package]] -name = "memoffset" -version = "0.9.1" +name = "metal" +version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" dependencies = [ - "autocfg", + "bitflags 2.8.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", ] [[package]] name = "metal" -version = "0.27.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" +checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -4094,11 +3779,11 @@ dependencies = [ [[package]] name = "metal" -version = "0.29.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" +checksum = "f569fb946490b5743ad69813cb19629130ce9374034abe31614a36402d18f99e" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -4121,9 +3806,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", "simd-adler32", @@ -4137,13 +3822,13 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] [[package]] name = "mnist" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -4152,11 +3837,10 @@ dependencies = [ [[package]] name = "mnist-inference-web" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "console_error_panic_hook", - "cubecl-runtime 0.3.0", "js-sys", "serde", "wasm-bindgen", @@ -4165,12 +3849,23 @@ dependencies = [ [[package]] name = "model" -version = "0.5.0" +version = "0.6.0" dependencies = [ "burn", "burn-import", ] +[[package]] +name = "modern-lstm" +version = "0.1.0" +dependencies = [ + "burn", + "polars", + "rand", + "rand_distr", + "serde", +] + [[package]] name = "monostate" version = "0.1.13" @@ -4189,55 +3884,34 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", -] - -[[package]] -name = "multiversion" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" -dependencies = [ - "multiversion-macros", - "target-features", -] - -[[package]] -name = "multiversion-macros" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "target-features", + "syn 2.0.98", ] [[package]] name = "naga" -version = "23.1.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f" +checksum = "e380993072e52eef724eddfcde0ed013b0c023c3f0417336ed041aa9f076994e" dependencies = [ "arrayvec", "bit-set", - "bitflags 2.6.0", - "cfg_aliases 0.1.1", + "bitflags 2.8.0", + "cfg_aliases", "codespan-reporting", "hexf-parse", - "indexmap 2.7.0", + "indexmap", "log", - "rustc-hash 1.1.0", - "spirv", + "rustc-hash", + "spirv 0.3.0+sdk-1.3.268.0", + "strum", "termcolor", - "thiserror 1.0.69", + "thiserror 2.0.11", "unicode-xid", ] [[package]] name = "named-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "serde", @@ -4245,9 +3919,9 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c" dependencies = [ "libc", "log", @@ -4255,7 +3929,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework 2.11.1", + "security-framework", "security-framework-sys", "tempfile", ] @@ -4413,7 +4087,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -4463,7 +4137,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.9", + "hermit-abi", "libc", ] @@ -4485,7 +4159,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -4509,7 +4183,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c9bff0aa1d48904a1385ea2a8b97576fbdcbc9a3cfccd0d31fe978e1c4038c5" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "libloading", "nvml-wrapper-sys", "static_assertions", @@ -4558,7 +4232,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "block2", "libc", "objc2", @@ -4574,7 +4248,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "block2", "objc2", "objc2-foundation", @@ -4594,9 +4268,9 @@ dependencies = [ [[package]] name = "objc2-encode" -version = "4.0.3" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7891e71393cd1f227313c9379a26a584ff3d7e6e7159e988851f0934c993f0f8" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" [[package]] name = "objc2-foundation" @@ -4604,7 +4278,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "block2", "libc", "objc2", @@ -4616,7 +4290,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "block2", "objc2", "objc2-foundation", @@ -4628,7 +4302,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "block2", "objc2", "objc2-foundation", @@ -4646,43 +4320,13 @@ dependencies = [ [[package]] name = "object" -version = "0.36.5" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] -[[package]] -name = "object_store" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" -dependencies = [ - "async-trait", - "base64 0.22.1", - "bytes", - "chrono", - "futures", - "humantime", - "hyper 1.5.2", - "itertools 0.13.0", - "md-5", - "parking_lot 0.12.3", - "percent-encoding", - "quick-xml", - "rand", - "reqwest 0.12.12", - "ring", - "serde", - "serde_json", - "snafu", - "tokio", - "tracing", - "url", - "walkdir", -] - [[package]] name = "once_cell" version = "1.20.2" @@ -4713,7 +4357,7 @@ dependencies = [ [[package]] name = "onnx-inference" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -4722,7 +4366,7 @@ dependencies = [ [[package]] name = "onnx-ir" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bytemuck", "half", @@ -4739,7 +4383,7 @@ dependencies = [ [[package]] name = "onnx-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -4750,16 +4394,16 @@ dependencies = [ [[package]] name = "openblas-build" -version = "0.10.10" +version = "0.10.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ca8f8c64eb5b43f5538059ccbc71391420bba14d987d7e8ab99ed62ed33e26b" +checksum = "b8140c0c1afaf88d2d30c48abad86b3bdd2334d691e08f7325a960d784240647" dependencies = [ "anyhow", "cc", "flate2", "native-tls", "tar", - "thiserror 2.0.9", + "thiserror 2.0.11", "ureq", ] @@ -4777,11 +4421,11 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.68" +version = "0.10.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "cfg-if", "foreign-types 0.3.2", "libc", @@ -4798,20 +4442,20 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.104" +version = "0.9.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" dependencies = [ "cc", "libc", @@ -4825,23 +4469,26 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "os_info" -version = "3.9.0" +version = "3.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ca711d8b83edbb00b44d504503cd247c9c0bd8b0fa2694f2a1a3d8165379ce" +checksum = "6e6520c8cc998c5741ee68ec1dc369fc47e5f0ea5320018ecf2a1ccd6328f48b" dependencies = [ "log", "serde", "windows-sys 0.52.0", ] -[[package]] -name = "os_str_bytes" -version = "6.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" - [[package]] name = "overload" version = "0.1.1" @@ -4963,23 +4610,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.7.0", + "indexmap", ] [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", "phf_shared", @@ -4987,9 +4634,9 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", "rand", @@ -4997,18 +4644,18 @@ dependencies = [ [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ "siphasher", ] [[package]] name = "pin-project-lite" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -5033,9 +4680,9 @@ dependencies = [ [[package]] name = "png" -version = "0.17.15" +version = "0.17.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67582bd5b65bdff614270e2ea89a1cf15bef71245cc1e5f7ea126977144211d" +checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526" dependencies = [ "bitflags 1.3.2", "crc32fast", @@ -5046,11 +4693,11 @@ dependencies = [ [[package]] name = "polars" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c6aa86d991a64c95416a61202f7952da2f8cccefa448f9a23c1b8f2301ecc" +checksum = "72571dde488ecccbe799798bf99ab7308ebdb7cf5d95bcc498dbd5a132f0da4d" dependencies = [ - "getrandom", + "getrandom 0.2.15", "polars-arrow", "polars-core", "polars-error", @@ -5066,12 +4713,11 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dbb24d29ddea5abb73d7954df8b8d3d4bb7f02a3e5c96d1519cdad9e816a3d" +checksum = "6611c758d52e799761cc25900666b71552e6c929d88052811bc9daad4b3321a8" dependencies = [ "ahash", - "atoi", "atoi_simd", "bytemuck", "chrono", @@ -5079,21 +4725,16 @@ dependencies = [ "dyn-clone", "either", "ethnum", - "fast-float", - "getrandom", + "getrandom 0.2.15", "hashbrown 0.15.2", "itoa", - "itoap", "lz4", - "multiversion", "num-traits", "parking_lot 0.12.3", "polars-arrow-format", "polars-error", "polars-schema", "polars-utils", - "ryu", - "serde", "simdutf8", "streaming-iterator", "strength_reduce", @@ -5114,28 +4755,33 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbdb1071147452a4c4b25560f23d2fbaffef255b04757291131b22fc2c0d35b2" +checksum = "332f2547dbb27599a8ffe68e56159f5996ba03d1dad0382ccb62c109ceacdeb6" dependencies = [ + "atoi_simd", "bytemuck", + "chrono", "either", + "fast-float2", + "itoa", "num-traits", "polars-arrow", "polars-error", "polars-utils", + "ryu", "strength_reduce", "version_check", ] [[package]] name = "polars-core" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd5df9b55e614088a3270b06f8649dce76537c268d6b1ca4d9c37008b2be5949" +checksum = "796d06eae7e6e74ed28ea54a8fccc584ebac84e6cf0e1e9ba41ffc807b169a01" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.8.0", "bytemuck", "chrono", "chrono-tz", @@ -5143,7 +4789,8 @@ dependencies = [ "either", "hashbrown 0.14.5", "hashbrown 0.15.2", - "indexmap 2.7.0", + "indexmap", + "itoa", "num-traits", "once_cell", "polars-arrow", @@ -5156,35 +4803,32 @@ dependencies = [ "rand_distr", "rayon", "regex", - "serde", - "serde_json", "strum_macros", - "thiserror 1.0.69", + "thiserror 2.0.11", "version_check", "xxhash-rust", ] [[package]] name = "polars-error" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4643898a644f30c83737db85f942f8c8956b0c11190b39afec745218eae1746b" +checksum = "19d6529cae0d1db5ed690e47de41fac9b35ae0c26d476830c2079f130887b847" dependencies = [ - "object_store", "polars-arrow-format", "regex", "simdutf8", - "thiserror 1.0.69", + "thiserror 2.0.11", ] [[package]] name = "polars-expr" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea1b431ed816cba1120cff200f06b962748001bbb2e615ce53cfbbdf701cc136" +checksum = "c8e639991a8ad4fb12880ab44bcc3cf44a5703df003142334d9caf86d77d77e7" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.8.0", "hashbrown 0.15.2", "num-traits", "once_cell", @@ -5203,80 +4847,50 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2fab2c016635cb416b49461fd6419b0208c6c13a4fd065bd65e4a87dbb66314" +checksum = "719a77e94480f6be090512da196e378cbcbeb3584c6fe1134c600aee906e38ab" dependencies = [ "ahash", "async-trait", "atoi_simd", - "blake3", "bytes", "chrono", - "fast-float", - "fs4", + "fast-float2", "futures", "glob", "hashbrown 0.15.2", "home", "itoa", "memchr", - "memmap2 0.7.1", + "memmap2", "num-traits", - "object_store", "once_cell", "percent-encoding", "polars-arrow", "polars-core", "polars-error", - "polars-json", "polars-parquet", "polars-schema", "polars-time", "polars-utils", - "pyo3", "rayon", "regex", - "reqwest 0.12.12", "ryu", - "serde", - "serde_json", - "simd-json", "simdutf8", "tokio", "tokio-util", - "url", -] - -[[package]] -name = "polars-json" -version = "0.44.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5c8c057ef04feaf34b6ce52096bdea3a766fa4725f50442078c8a4ee86397bf" -dependencies = [ - "ahash", - "chrono", - "fallible-streaming-iterator", - "hashbrown 0.15.2", - "indexmap 2.7.0", - "itoa", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "ryu", - "simd-json", - "streaming-iterator", ] [[package]] name = "polars-lazy" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a8ca74f42e7b47cad241b36b98d991cc7fbb51b8d0695a055eb937588d1f310" +checksum = "a0a731a672dfc8ac38c1f73c9a4b2ae38d2fc8ac363bfb64c5f3a3e072ffc5ad" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.8.0", + "chrono", "memchr", "once_cell", "polars-arrow", @@ -5296,32 +4910,28 @@ dependencies = [ [[package]] name = "polars-mem-engine" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a32614e5b52c9b83856d80c7e2880b79d83055bfd59969bd1d0b148f9cfdc7a" +checksum = "33442189bcbf2e2559aa7914db3835429030a13f4f18e43af5fba9d1b018cf12" dependencies = [ - "futures", - "memmap2 0.7.1", + "memmap2", "polars-arrow", "polars-core", "polars-error", "polars-expr", "polars-io", - "polars-json", "polars-ops", "polars-plan", "polars-time", "polars-utils", - "pyo3", "rayon", - "tokio", ] [[package]] name = "polars-ops" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "035c800fbe5bbd820afeb8313713ed345853bb014e0f821a4025d40cf0d60e1a" +checksum = "cbb83218b0c216104f0076cd1a005128be078f958125f3d59b094ee73d78c18e" dependencies = [ "ahash", "argminmax", @@ -5332,9 +4942,10 @@ dependencies = [ "either", "hashbrown 0.15.2", "hex", - "indexmap 2.7.0", + "indexmap", "memchr", "num-traits", + "once_cell", "polars-arrow", "polars-compute", "polars-core", @@ -5344,39 +4955,33 @@ dependencies = [ "rayon", "regex", "regex-syntax 0.8.5", - "serde", "strum_macros", + "unicode-normalization", "unicode-reverse", "version_check", ] [[package]] name = "polars-parquet" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91dcf1d9f048079376949eaf2e24e240b313ff4a102fb83b57c9a5f807cdca52" +checksum = "5c60ee85535590a38db6c703a21be4cb25342e40f573f070d1e16f9d84a53ac7" dependencies = [ "ahash", "async-stream", "base64 0.22.1", - "brotli", "bytemuck", "ethnum", - "flate2", "futures", "hashbrown 0.15.2", - "lz4", "num-traits", "polars-arrow", "polars-compute", "polars-error", "polars-parquet-format", "polars-utils", - "serde", "simdutf8", - "snap", "streaming-decompression", - "zstd 0.13.2", ] [[package]] @@ -5391,15 +4996,16 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05936f2b3981eecb2fe74d8ef092bb75a93d2a056b3e4f339f4ac20c71c9e331" +checksum = "42d238fb76698f56e51ddfa89b135e4eda56a4767c6e8859eed0ab78386fcd52" dependencies = [ "crossbeam-channel", "crossbeam-queue", "enum_dispatch", "hashbrown 0.15.2", "num-traits", + "once_cell", "polars-arrow", "polars-compute", "polars-core", @@ -5416,75 +5022,69 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23de436f33f4d1134c58f24e7059a221b957ec20730807e0ef0c80c8e4b3d06a" +checksum = "4f03533a93aa66127fcb909a87153a3c7cfee6f0ae59f497e73d7736208da54c" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.8.0", "bytemuck", "bytes", "chrono", "chrono-tz", - "ciborium", "either", - "futures", "hashbrown 0.15.2", - "memmap2 0.7.1", + "memmap2", "num-traits", "once_cell", "percent-encoding", "polars-arrow", + "polars-compute", "polars-core", "polars-io", - "polars-json", "polars-ops", - "polars-parquet", "polars-time", "polars-utils", - "pyo3", "rayon", "recursive", "regex", - "serde", "strum_macros", "version_check", ] [[package]] name = "polars-row" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3823d3de3e614509bba6929798f1f3d5ae05c1cdfc4eb7029d2ec6ad77201da2" +checksum = "6bf47f7409f8e75328d7d034be390842924eb276716d0458607be0bddb8cc839" dependencies = [ + "bitflags 2.8.0", "bytemuck", "polars-arrow", + "polars-compute", "polars-error", "polars-utils", ] [[package]] name = "polars-schema" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d88667f770291cefa2e8cd366a54f29dc6fe362e9a263914c903db411a58ac1d" +checksum = "416621ae82b84466cf4ff36838a9b0aeb4a67e76bd3065edc8c9cb7da19b1bc7" dependencies = [ - "indexmap 2.7.0", + "indexmap", "polars-error", "polars-utils", - "serde", "version_check", ] [[package]] name = "polars-sql" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69451f08363bb497407f6ebebe00bc01972a51716d20d115b75f9b5326f1f3c8" +checksum = "edaab553b90aa4d6743bb538978e1982368acb58a94408d7dd3299cad49c7083" dependencies = [ "hex", - "once_cell", - "polars-arrow", "polars-core", "polars-error", "polars-lazy", @@ -5493,22 +5093,22 @@ dependencies = [ "polars-time", "polars-utils", "rand", + "regex", "serde", - "serde_json", "sqlparser", ] [[package]] name = "polars-stream" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "188622b0a4bc4530cf91a288134254ffa065d18932e261075377914225e757c2" +checksum = "498997b656c779610c1496b3d96a59fe569ef22a5b81ccfe5325cb3df8dff2fd" dependencies = [ "atomic-waker", "crossbeam-deque", "crossbeam-utils", "futures", - "memmap2 0.7.1", + "memmap2", "parking_lot 0.12.3", "pin-project-lite", "polars-core", @@ -5516,6 +5116,7 @@ dependencies = [ "polars-expr", "polars-io", "polars-mem-engine", + "polars-ops", "polars-parquet", "polars-plan", "polars-utils", @@ -5529,49 +5130,50 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90f36e4d6b19f2c406faea585b9a1814f422fc5b310f65ccf8a55216df0754ef" +checksum = "d192efbdab516d28b3fab1709a969e3385bd5cda050b7c9aa9e2502a01fda879" dependencies = [ - "atoi", + "atoi_simd", "bytemuck", "chrono", "chrono-tz", "now", + "num-traits", "once_cell", "polars-arrow", + "polars-compute", "polars-core", "polars-error", "polars-ops", "polars-utils", + "rayon", "regex", - "serde", "strum_macros", ] [[package]] name = "polars-utils" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96186b70bda00c90b5027bf2f69193c5c40571e80d3e8ec505c22cdc8e3e39aa" +checksum = "a8f6c8166a4a7fbc15b87c81645ed9e1f0651ff2e8c96cafc40ac5bf43441a10" dependencies = [ "ahash", "bytemuck", "bytes", "compact_str", "hashbrown 0.15.2", - "indexmap 2.7.0", + "indexmap", "libc", - "memmap2 0.7.1", + "memmap2", "num-traits", "once_cell", "polars-error", - "pyo3", - "raw-cpuid 11.2.0", + "rand", + "raw-cpuid 11.3.0", "rayon", - "serde", "stacker", - "sysinfo 0.31.4", + "sysinfo", "version_check", ] @@ -5580,6 +5182,9 @@ name = "portable-atomic" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +dependencies = [ + "serde", +] [[package]] name = "portable-atomic-util" @@ -5623,12 +5228,12 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.25" +version = "0.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -5640,35 +5245,11 @@ dependencies = [ "toml_edit", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -5689,7 +5270,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -5726,7 +5307,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "322330e133eab455718444b4e033ebfac7c6528972c784fcde28d2cc783c6257" dependencies = [ "anyhow", - "indexmap 2.7.0", + "indexmap", "log", "protobuf", "protobuf-support", @@ -5765,72 +5346,9 @@ dependencies = [ "reborrow", ] -[[package]] -name = "pyo3" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot 0.12.3", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn 2.0.95", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn 2.0.95", -] - [[package]] name = "pytorch-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -5839,7 +5357,7 @@ dependencies = [ [[package]] name = "pytorch-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-autodiff", @@ -5864,68 +5382,6 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" -[[package]] -name = "quick-xml" -version = "0.36.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" -dependencies = [ - "memchr", - "serde", -] - -[[package]] -name = "quinn" -version = "0.11.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" -dependencies = [ - "bytes", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash 2.1.0", - "rustls", - "socket2", - "thiserror 2.0.9", - "tokio", - "tracing", -] - -[[package]] -name = "quinn-proto" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" -dependencies = [ - "bytes", - "getrandom", - "rand", - "ring", - "rustc-hash 2.1.0", - "rustls", - "rustls-pki-types", - "slab", - "thiserror 2.0.9", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527" -dependencies = [ - "cfg_aliases 0.2.1", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.59.0", -] - [[package]] name = "quote" version = "1.0.38" @@ -5984,7 +5440,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -5999,9 +5455,9 @@ dependencies = [ [[package]] name = "range-alloc" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" +checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" [[package]] name = "ratatui" @@ -6009,7 +5465,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "cassowary", "compact_str", "crossterm", @@ -6086,11 +5542,11 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "11.2.0" +version = "11.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" +checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", ] [[package]] @@ -6159,7 +5615,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -6177,7 +5633,7 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", ] [[package]] @@ -6186,31 +5642,11 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror 1.0.69", ] -[[package]] -name = "ref-cast" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931" -dependencies = [ - "ref-cast-impl", -] - -[[package]] -name = "ref-cast-impl" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.95", -] - [[package]] name = "regex" version = "1.11.1" @@ -6267,46 +5703,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" -[[package]] -name = "reqwest" -version = "0.11.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" -dependencies = [ - "base64 0.21.7", - "bytes", - "encoding_rs", - "futures-core", - "futures-util", - "h2 0.3.26", - "http 0.2.12", - "http-body 0.4.6", - "hyper 0.14.32", - "hyper-tls 0.5.0", - "ipnet", - "js-sys", - "log", - "mime", - "native-tls", - "once_cell", - "percent-encoding", - "pin-project-lite", - "rustls-pemfile 1.0.4", - "serde", - "serde_json", - "serde_urlencoded", - "sync_wrapper 0.1.2", - "system-configuration 0.5.1", - "tokio", - "tokio-native-tls", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "winreg", -] - [[package]] name = "reqwest" version = "0.12.12" @@ -6319,13 +5715,13 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.4.7", - "http 1.2.0", - "http-body 1.0.1", + "h2", + "http", + "http-body", "http-body-util", - "hyper 1.5.2", + "hyper", "hyper-rustls", - "hyper-tls 0.6.0", + "hyper-tls", "hyper-util", "ipnet", "js-sys", @@ -6335,26 +5731,19 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "quinn", - "rustls", - "rustls-native-certs 0.8.1", - "rustls-pemfile 2.2.0", - "rustls-pki-types", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 1.0.2", - "system-configuration 0.6.1", + "sync_wrapper", + "system-configuration", "tokio", "tokio-native-tls", - "tokio-rustls", - "tokio-util", "tower", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", "windows-registry", ] @@ -6376,7 +5765,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin", "untrusted", @@ -6407,12 +5796,11 @@ dependencies = [ [[package]] name = "rspirv" -version = "0.12.0+sdk-1.3.268.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" +version = "0.12.0+sdk-1.3.296.0" +source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e" dependencies = [ - "rustc-hash 1.1.0", - "spirv", + "rustc-hash", + "spirv 0.3.0+sdk-1.3.296.0", ] [[package]] @@ -6441,7 +5829,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.95", + "syn 2.0.98", "unicode-ident", ] @@ -6451,7 +5839,7 @@ version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "fallible-iterator", "fallible-streaming-iterator", "hashlink", @@ -6481,12 +5869,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" -[[package]] -name = "rustc-hash" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" - [[package]] name = "rustc_version" version = "0.4.1" @@ -6498,11 +5880,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.42" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "errno", "libc", "linux-raw-sys", @@ -6511,9 +5893,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.20" +version = "0.23.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" +checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7" dependencies = [ "log", "once_cell", @@ -6531,31 +5913,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" dependencies = [ "openssl-probe", - "rustls-pemfile 2.2.0", - "rustls-pki-types", - "schannel", - "security-framework 2.11.1", -] - -[[package]] -name = "rustls-native-certs" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" -dependencies = [ - "openssl-probe", + "rustls-pemfile", "rustls-pki-types", "schannel", - "security-framework 3.1.0", -] - -[[package]] -name = "rustls-pemfile" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" -dependencies = [ - "base64 0.21.7", + "security-framework", ] [[package]] @@ -6569,12 +5930,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" -dependencies = [ - "web-time", -] +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" [[package]] name = "rustls-webpki" @@ -6589,15 +5947,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "safetensors" @@ -6649,9 +6007,9 @@ dependencies = [ [[package]] name = "scc" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b13f8ea6177672c49d12ed964cca44836f59621981b04a3e26b87e675181de" +checksum = "28e1c91382686d21b5ac7959341fcb9780fa7c03773646995a87c950fa7be640" dependencies = [ "sdd", ] @@ -6692,21 +6050,8 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.6.0", - "core-foundation 0.9.4", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework" -version = "3.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81d3f8c9bfcc3cbb6b0179eb57042d75b1582bdc65c3cb95f3fa999509c03cbc" -dependencies = [ - "bitflags 2.6.0", - "core-foundation 0.10.0", + "bitflags 2.8.0", + "core-foundation", "core-foundation-sys", "libc", "security-framework-sys", @@ -6714,9 +6059,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.13.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -6724,9 +6069,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "seq-macro" @@ -6771,14 +6116,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "serde_json" -version = "1.0.134" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -6849,12 +6194,12 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "server" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "cfg-if", @@ -6933,23 +6278,6 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" -[[package]] -name = "simd-json" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2bcf6c6e164e81bc7a5d49fc6988b3d515d9e8c07457d7b74ffb9324b9cd40" -dependencies = [ - "ahash", - "getrandom", - "halfbrown", - "once_cell", - "ref-cast", - "serde", - "serde_json", - "simdutf8", - "value-trait", -] - [[package]] name = "simd_helpers" version = "0.1.0" @@ -6967,7 +6295,7 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "simple-regression" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -6978,9 +6306,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "0.3.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "slab" @@ -7006,34 +6334,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "snafu" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" -dependencies = [ - "doc-comment", - "snafu-derive", -] - -[[package]] -name = "snafu-derive" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "snap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" - [[package]] name = "socket2" version = "0.5.8" @@ -7060,7 +6360,15 @@ version = "0.3.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", +] + +[[package]] +name = "spirv" +version = "0.3.0+sdk-1.3.296.0" +source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e" +dependencies = [ + "bitflags 2.8.0", ] [[package]] @@ -7077,9 +6385,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.49.0" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" dependencies = [ "log", ] @@ -7139,12 +6447,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" -[[package]] -name = "strsim" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" - [[package]] name = "strsim" version = "0.11.1" @@ -7166,11 +6468,11 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -7186,27 +6488,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote", "unicode-ident", ] [[package]] name = "syn" -version = "2.0.95" +version = "2.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] -[[package]] -name = "sync_wrapper" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" - [[package]] name = "sync_wrapper" version = "1.0.2" @@ -7224,7 +6519,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -7233,7 +6528,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "byteorder", "enum-as-inner", "libc", @@ -7243,22 +6538,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.31.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "355dbe4f8799b304b05e1b0f05fc59b2a18d36645cf169607da45bde2f69a1be" -dependencies = [ - "core-foundation-sys", - "libc", - "memchr", - "ntapi", - "windows 0.57.0", -] - -[[package]] -name = "sysinfo" -version = "0.32.1" +version = "0.33.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" +checksum = "4fc858248ea01b66f19d8e8a6d55f41deaf91e9d495246fd01368d99935c6c01" dependencies = [ "core-foundation-sys", "libc", @@ -7269,36 +6551,15 @@ dependencies = [ "windows 0.57.0", ] -[[package]] -name = "system-configuration" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" -dependencies = [ - "bitflags 1.3.2", - "core-foundation 0.9.4", - "system-configuration-sys 0.5.0", -] - [[package]] name = "system-configuration" version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ - "bitflags 2.6.0", - "core-foundation 0.9.4", - "system-configuration-sys 0.6.0", -] - -[[package]] -name = "system-configuration-sys" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" -dependencies = [ - "core-foundation-sys", - "libc", + "bitflags 2.8.0", + "core-foundation", + "system-configuration-sys", ] [[package]] @@ -7318,7 +6579,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck 0.5.0", + "heck", "pkg-config", "toml", "version-compare", @@ -7349,12 +6610,6 @@ dependencies = [ "xattr", ] -[[package]] -name = "target-features" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" - [[package]] name = "target-lexicon" version = "0.12.16" @@ -7380,12 +6635,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.14.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -7402,7 +6658,7 @@ dependencies = [ [[package]] name = "text-classification" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "derive-new 0.7.0", @@ -7412,7 +6668,7 @@ dependencies = [ [[package]] name = "text-generation" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "derive-new 0.7.0", @@ -7442,12 +6698,6 @@ dependencies = [ "rgb", ] -[[package]] -name = "textwrap" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" - [[package]] name = "thiserror" version = "1.0.69" @@ -7459,11 +6709,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.9" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl 2.0.9", + "thiserror-impl 2.0.11", ] [[package]] @@ -7474,18 +6724,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] name = "thiserror-impl" -version = "2.0.9" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -7563,9 +6813,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -7585,7 +6835,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", - "getrandom", + "getrandom 0.2.15", "hf-hub", "itertools 0.12.1", "lazy_static", @@ -7610,9 +6860,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.42.0" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -7626,13 +6876,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -7655,18 +6905,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-tungstenite" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite 0.24.0", -] - [[package]] name = "tokio-tungstenite" version = "0.26.1" @@ -7676,7 +6914,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite 0.26.1", + "tungstenite", ] [[package]] @@ -7715,11 +6953,11 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.22" +version = "0.22.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +checksum = "02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee" dependencies = [ - "indexmap 2.7.0", + "indexmap", "serde", "serde_spanned", "toml_datetime", @@ -7750,7 +6988,7 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper 1.0.2", + "sync_wrapper", "tokio", "tower-layer", "tower-service", @@ -7776,7 +7014,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58fccce80a2ef6bc32a512514a53cf853d438a44abaea286a4acb0c9f8566860" dependencies = [ "anyhow", - "clap 4.5.23", + "clap", "derive_more 0.99.18", "env_logger", "log", @@ -7796,7 +7034,7 @@ checksum = "5a3a646485f7cd8f580749ab94718ad3d344bcc0cc5b0fefe43c15fdd898bb96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -7831,7 +7069,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -7881,38 +7119,29 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "tungstenite" -version = "0.24.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +checksum = "413083a99c579593656008130e29255e54dcaae495be556cc26888f211648c24" dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.2.0", + "http", "httparse", "log", "rand", "sha1", - "thiserror 1.0.69", + "thiserror 2.0.11", "utf-8", ] [[package]] -name = "tungstenite" -version = "0.26.1" +name = "type-map" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413083a99c579593656008130e29255e54dcaae495be556cc26888f211648c24" +checksum = "deb68604048ff8fa93347f02441e4487594adc20bb8a084f9e564d2b827a0a9f" dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http 1.2.0", - "httparse", - "log", - "rand", - "sha1", - "thiserror 2.0.9", - "utf-8", + "rustc-hash", ] [[package]] @@ -7965,9 +7194,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-normalization" @@ -8037,12 +7266,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" -[[package]] -name = "unindent" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" - [[package]] name = "untrusted" version = "0.9.0" @@ -8061,7 +7284,7 @@ dependencies = [ "native-tls", "once_cell", "rustls", - "rustls-native-certs 0.7.3", + "rustls-native-certs", "rustls-pki-types", "serde", "serde_json", @@ -8106,11 +7329,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ - "getrandom", + "getrandom 0.2.15", "rand", ] @@ -8127,20 +7350,19 @@ dependencies = [ [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] -name = "value-trait" -version = "0.10.1" +name = "variadics_please" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187" +checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" dependencies = [ - "float-cmp", - "halfbrown", - "itoa", - "ryu", + "proc-macro2", + "quote", + "syn 2.0.98", ] [[package]] @@ -8186,36 +7408,46 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.49" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", @@ -8226,9 +7458,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -8236,22 +7468,25 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-logger" @@ -8264,19 +7499,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "wasm-streams" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "wasm-timer" version = "0.2.5" @@ -8294,9 +7516,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -8314,9 +7536,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.7" +version = "0.26.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" dependencies = [ "rustls-pki-types", ] @@ -8327,14 +7549,23 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "wgan" +version = "0.1.0" +dependencies = [ + "burn", + "image", +] + [[package]] name = "wgpu" -version = "23.0.1" +version = "24.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80f70000db37c469ea9d67defdc13024ddf9a5f1b89cb2941b812ad7cde1735a" +checksum = "47f55718f85c2fa756edffa0e7f0e0a60aba463d1362b57e23123c58f035e4b6" dependencies = [ "arrayvec", - "cfg_aliases 0.1.1", + "bitflags 2.8.0", + "cfg_aliases", "document-features", "js-sys", "log", @@ -8354,43 +7585,43 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" +checksum = "82a39b8842dc9ffcbe34346e3ab6d496b32a47f6497e119d762c97fcaae3cb37" dependencies = [ "arrayvec", "bit-vec", - "bitflags 2.6.0", - "cfg_aliases 0.1.1", + "bitflags 2.8.0", + "cfg_aliases", "document-features", - "indexmap 2.7.0", + "indexmap", "log", "naga", "once_cell", "parking_lot 0.12.3", "profiling", "raw-window-handle", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", - "thiserror 1.0.69", + "thiserror 2.0.11", "wgpu-hal", "wgpu-types", ] [[package]] name = "wgpu-hal" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89364b8a0b211adc7b16aeaf1bd5ad4a919c1154b44c9ce27838213ba05fd821" +checksum = "5a782e5056b060b0b4010881d1decddd059e44f2ecd01e2db2971b48ad3627e5" dependencies = [ "android_system_properties", "arrayvec", "ash", "bit-set", - "bitflags 2.6.0", + "bitflags 2.8.0", "block", "bytemuck", - "cfg_aliases 0.1.1", + "cfg_aliases", "core-graphics-types", "glow", "glutin_wgl_sys", @@ -8402,19 +7633,20 @@ dependencies = [ "libc", "libloading", "log", - "metal 0.29.0", + "metal 0.31.0", "naga", "ndk-sys", "objc", "once_cell", + "ordered-float", "parking_lot 0.12.3", "profiling", "range-alloc", "raw-window-handle", "renderdoc-sys", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", - "thiserror 1.0.69", + "thiserror 2.0.11", "wasm-bindgen", "web-sys", "wgpu-types", @@ -8424,12 +7656,13 @@ dependencies = [ [[package]] name = "wgpu-types" -version = "23.0.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" +checksum = "50ac044c0e76c03a0378e7786ac505d010a873665e2d51383dcff8dd227dc69c" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "js-sys", + "log", "web-sys", ] @@ -8538,7 +7771,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -8549,7 +7782,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -8560,7 +7793,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -8571,7 +7804,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -8763,21 +7996,20 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "7e49d2d35d3fad69b39b94139037ecfb4f359f08958b9c11e7315ce770462419" dependencies = [ "memchr", ] [[package]] -name = "winreg" -version = "0.50.0" +name = "wit-bindgen-rt" +version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" dependencies = [ - "cfg-if", - "windows-sys 0.48.0", + "bitflags 2.8.0", ] [[package]] @@ -8789,7 +8021,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -8829,9 +8061,9 @@ checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" [[package]] name = "xattr" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" dependencies = [ "libc", "linux-raw-sys", @@ -8840,13 +8072,13 @@ dependencies = [ [[package]] name = "xml-rs" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea8b391c9a790b496184c29f7f93b9ed5b16abb306c05415b68bcc16e4d06432" +checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4" [[package]] name = "xtask" -version = "1.1.0" +version = "1.2.0" dependencies = [ "log", "rstest", @@ -8856,9 +8088,9 @@ dependencies = [ [[package]] name = "xxhash-rust" -version = "0.8.12" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" [[package]] name = "yansi" @@ -8886,7 +8118,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", "synstructure", ] @@ -8908,7 +8140,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -8928,7 +8160,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", "synstructure", ] @@ -8949,7 +8181,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -8971,7 +8203,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.98", ] [[package]] @@ -9004,7 +8236,7 @@ dependencies = [ "crc32fast", "crossbeam-utils", "displaydoc", - "indexmap 2.7.0", + "indexmap", "num_enum", "thiserror 1.0.69", ] @@ -9025,13 +8257,13 @@ dependencies = [ "displaydoc", "flate2", "hmac", - "indexmap 2.7.0", + "indexmap", "lzma-rs", "memchr", "pbkdf2 0.12.2", "rand", "sha1", - "thiserror 2.0.9", + "thiserror 2.0.11", "time", "zeroize", "zopfli", diff --git a/Cargo.toml b/Cargo.toml index e72651e70c..169d668aa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,18 +23,18 @@ exclude = [ edition = "2021" license = "MIT OR Apache-2.0" readme = "README.md" -version = "0.16.0" +version = "0.17.0" [workspace.dependencies] atomic_float = "1" bytemuck = "1.21.0" candle-core = { version = "0.8" } -clap = { version = "4.5.23", features = ["derive"] } +clap = { version = "4.5.27", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.1" dashmap = "6.1.0" -data-encoding = { version = "2.6.0", default-features = false, features = [ +data-encoding = { version = "2.7.0", default-features = false, features = [ "alloc", ] } dirs = "5.0.1" @@ -47,16 +47,16 @@ globwalk = "0.9.1" hashbrown = "0.15.2" hound = "3.5.1" image = "0.25.5" -indicatif = "0.17.9" +indicatif = "0.17.11" js-sys = "0.3.72" libm = "0.2.11" -log = { default-features = false, version = "0.4.22" } +log = { default-features = false, version = "0.4.25" } md5 = "0.7.0" paste = "1" percent-encoding = "2.3.1" -polars = { version = "0.44.2", features = ["lazy"] } +polars = { version = "0.46.0", features = ["lazy"] } pretty_assertions = "1.4.1" -proc-macro2 = "1.0.92" +proc-macro2 = "1.0.93" protobuf = "3.7.1" protobuf-codegen = "3.7.1" quote = "1.0.38" @@ -84,7 +84,7 @@ strum = "0.26.3" strum_macros = "0.26.4" syn = { version = "2.0.95", features = ["full", "extra-traits"] } tempfile = "3.14.0" -thiserror = "2.0.9" +thiserror = "2.0.11" tokio = { version = "1.42.0", features = ["rt", "macros"] } tracing-appender = "0.2.3" tracing-core = "0.1.33" @@ -101,11 +101,11 @@ ratatui = "0.29.0" # WGPU stuff text_placeholder = "0.5.1" -wgpu = "23.0.0" +wgpu = "24.0.1" # Benchmarks and Burnbench arboard = "3.4.1" -github-device-flow = "0.2.0" +chrono = "0.4.39" os_info = "3.9.0" wsl = "0.1.0" @@ -140,12 +140,12 @@ serde = { version = "1.0.217", default-features = false, features = [ "derive", "alloc", ] } # alloc is for no_std, derive is needed -serde_json = { version = "1.0.134", default-features = false } -uuid = { version = "1.11.0", default-features = false } +serde_json = { version = "1.0.137", default-features = false } +uuid = { version = "1.12.1", default-features = false } -libc = "0.2.168" +libc = "0.2.169" nvml-wrapper = "0.10.0" -sysinfo = "0.32.1" +sysinfo = "0.33.1" systemstat = "0.2.3" tch = "0.15.0" @@ -153,14 +153,14 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "34af9342a2b4f8dcf1b0047afbea0f26405b92cf" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "34af9342a2b4f8dcf1b0047afbea0f26405b92cf" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### -# cubecl = { version = "0.3.0", default-features = false } -# cubecl-common = { version = "0.3.0", default-features = false } +# cubecl = { version = "0.4.0", default-features = false } +# cubecl-common = { version = "0.4.0", default-features = false } ### For xtask crate ### tracel-xtask = { version = "=1.1.8" } diff --git a/NOTICES.md b/NOTICES.md index a11156a772..0f559e27a4 100644 --- a/NOTICES.md +++ b/NOTICES.md @@ -9,7 +9,7 @@ repository copied or derived from. License: BSD 3-Clause License -Copyright (c) 2017, +Copyright (c) 2017, All rights reserved. Redistribution and use in source and binary forms, with or without @@ -572,4 +572,75 @@ SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. \ No newline at end of file +DEALINGS IN THE SOFTWARE. + +## github-device-flow + +**Source**: +- Part of: https://github.com/jakewilkins/gh-device-flow/blob/main/src/lib.rs +- https://github.com/jakewilkins/gh-device-flow/blob/main/src/util.rs + +MIT License + +Copyright (c) 2022 Jake Wilkins + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +## ICU + +UNICODE LICENSE V3 + +COPYRIGHT AND PERMISSION NOTICE + +Copyright © 2016-2024 Unicode, Inc. + +NOTICE TO USER: Carefully read the following legal agreement. BY +DOWNLOADING, INSTALLING, COPYING OR OTHERWISE USING DATA FILES, AND/OR +SOFTWARE, YOU UNEQUIVOCALLY ACCEPT, AND AGREE TO BE BOUND BY, ALL OF THE +TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE, DO NOT +DOWNLOAD, INSTALL, COPY, DISTRIBUTE OR USE THE DATA FILES OR SOFTWARE. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of data files and any associated documentation (the "Data Files") or +software and any associated documentation (the "Software") to deal in the +Data Files or Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, and/or sell +copies of the Data Files or Software, and to permit persons to whom the +Data Files or Software are furnished to do so, provided that either (a) +this copyright and permission notice appear with all copies of the Data +Files or Software, or (b) this copyright and permission notice appear in +associated Documentation. + +THE DATA FILES AND SOFTWARE ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY +KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF +THIRD PARTY RIGHTS. + +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE +BE LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, +OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, +ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE DATA +FILES OR SOFTWARE. + +Except as contained in this notice, the name of a copyright holder shall +not be used in advertising or otherwise to promote the sale, use or other +dealings in these Data Files or Software without prior written +authorization of the copyright holder. diff --git a/README.md b/README.md index a0780dcc16..951b2a9f24 100644 --- a/README.md +++ b/README.md @@ -567,6 +567,8 @@ Additional examples: sample. - [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the DbPedia dataset. +- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits + based on MNIST. For more practical insights, you can clone the repository and run any of them directly on your computer! @@ -619,19 +621,20 @@ leads to more reliable, bug-free solutions built faster (after some practice
> **Deprecation Note**
Since `0.14.0`, the internal structure for tensor data has changed. The -> previous `Data` struct is being deprecated in favor of the new `TensorData` struct, which allows -> for more flexibility by storing the underlying data as bytes and keeping the data type as a field. -> If you are using `Data` in your code, make sure to switch to `TensorData`. +> previous `Data` struct was deprecated and officially removed since `0.17.0` in favor of the new +> `TensorData` struct, which allows for more flexibility by storing the underlying data as bytes and +> keeping the data type as a field. If you are using `Data` in your code, make sure to switch to +> `TensorData`.
@@ -640,8 +643,9 @@ Loading Model Records From Previous Versions ⚠️
-In the event that you are trying to load a model record saved in a previous version, make sure to -enable the `record-backward-compat` feature flag. +In the event that you are trying to load a model record saved in a version older than `0.14.0`, make +sure to use a compatible version (`0.14`, `0.15` or `0.16`) with the `record-backward-compat` +feature flag. ``` features = [..., "record-backward-compat"] @@ -650,13 +654,14 @@ features = [..., "record-backward-compat"] Otherwise, the record won't be deserialized correctly and you will get an error message. This error will also point you to the backward compatible feature flag. -The backward compatibility is maintained for deserialization when loading records. Therefore, as -soon as you have saved the record again it will be saved according to the new structure and you -won't need the backward compatible feature flag anymore. +The backward compatibility was maintained for deserialization when loading records. Therefore, as +soon as you have saved the record again it will be saved according to the new structure and you can +upgrade back to the current version Please note that binary formats are not backward compatible. Thus, you will need to load your record in a previous version and save it in any of the other self-describing record format (e.g., using the -`NamedMpkFileRecorder`) before using the new version with the `record-backward-compat` feature flag. +`NamedMpkFileRecorder`) before using a compatible version (as described) with the +`record-backward-compat` feature flag.
diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index ee5f0bd8a2..821d189fe0 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"] candle-cpu = ["burn/candle"] candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] -cuda-jit = ["burn/cuda-jit"] -cuda-jit-fusion = ["cuda-jit", "burn/fusion"] +cuda = ["burn/cuda"] +cuda-fusion = ["cuda", "burn/fusion"] default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] -hip-jit = ["burn/hip-jit"] +hip = ["burn/hip"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] @@ -27,20 +27,20 @@ tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu", "burn/autotune"] wgpu-fusion = ["wgpu", "burn/fusion"] -wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"] +wgpu-spirv = ["burn/vulkan", "burn/autotune"] wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"] [dependencies] arboard = { workspace = true } burn = { path = "../crates/burn", default-features = false } -burn-common = { path = "../crates/burn-common", version = "0.16.0" } +burn-common = { path = "../crates/burn-common", version = "0.17.0" } clap = { workspace = true } colored = { workspace = true } +chrono = { workspace = true } cubecl = { workspace = true, features = ["wgpu"], default-features = true } derive-new = { workspace = true } dirs = { workspace = true } -github-device-flow = { workspace = true } half = { workspace = true } indicatif = { workspace = true } os_info = { workspace = true } @@ -124,6 +124,10 @@ path = "benches/resnet.rs" harness = false name = "autodiff" +[[bench]] +harness = false +name = "reduce" + [[bin]] name = "burnbench" path = "src/bin/burnbench.rs" diff --git a/backend-comparison/README.md b/backend-comparison/README.md index ba7042bbc1..6f4b547a22 100644 --- a/backend-comparison/README.md +++ b/backend-comparison/README.md @@ -57,6 +57,7 @@ Available Benchmarks: - conv-transpose3d - conv2d - conv3d +- reduce ``` #### Run benchmarks diff --git a/backend-comparison/benches/matmul_fused.rs b/backend-comparison/benches/matmul_fused.rs index 375be97b4e..fbec64c648 100644 --- a/backend-comparison/benches/matmul_fused.rs +++ b/backend-comparison/benches/matmul_fused.rs @@ -1,5 +1,9 @@ use backend_comparison::persistence::save; -use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor}; +use burn::tensor::{ + activation::{gelu, relu}, + backend::Backend, + Distribution, Shape, Tensor, +}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -14,7 +18,7 @@ impl Benchmark for MatmulBenchmark { type Args = (Tensor, Tensor, Tensor); fn name(&self) -> String { - "matmul_bias_relu".into() + "matmul_relu_bias_gelu".into() } fn shapes(&self) -> Vec> { @@ -23,7 +27,7 @@ impl Benchmark for MatmulBenchmark { fn execute(&self, (lhs, rhs, bias): Self::Args) { let bias = bias.unsqueeze(); - relu(lhs.matmul(rhs) + bias); + gelu(relu(lhs.matmul(rhs)) + bias); } fn prepare(&self) -> Self::Args { diff --git a/backend-comparison/benches/reduce.rs b/backend-comparison/benches/reduce.rs new file mode 100644 index 0000000000..df365f2306 --- /dev/null +++ b/backend-comparison/benches/reduce.rs @@ -0,0 +1,102 @@ +use backend_comparison::persistence::save; +use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; +use burn_common::benchmark::{run_benchmark, Benchmark}; + +enum Instruction { + ArgMin(usize), + SumDim(usize), + Sum, +} + +struct ReduceBenchmark { + instruction: Instruction, + shape: Shape, + device: B::Device, + tensor: Tensor, +} + +impl ReduceBenchmark { + pub fn new(instruction: Instruction, device: B::Device) -> Self { + let shape = Shape::new([4096, 512, 64]); + let tensor = Tensor::random(shape.clone(), Distribution::Default, &device); + Self { + instruction, + shape, + device, + tensor, + } + } +} + +impl Benchmark for ReduceBenchmark { + type Args = (); + + fn prepare(&self) -> Self::Args {} + + fn execute(&self, _: Self::Args) { + match self.instruction { + Instruction::ArgMin(axis) => { + self.tensor.clone().argmin(axis); + } + Instruction::SumDim(axis) => { + self.tensor.clone().sum_dim(axis); + } + Instruction::Sum => { + self.tensor.clone().sum(); + } + } + } + + fn name(&self) -> String { + match self.instruction { + Instruction::ArgMin(axis) => format!("reduce-argmin-{axis}"), + Instruction::SumDim(axis) => format!("reduce-sum-{axis}"), + Instruction::Sum => String::from("reduce-sum-full"), + } + } + + fn sync(&self) { + B::sync(&self.device) + } + + fn shapes(&self) -> Vec> { + vec![self.shape.dims.clone()] + } +} + +#[allow(dead_code)] +fn bench( + device: &B::Device, + feature_name: &str, + url: Option<&str>, + token: Option<&str>, +) { + let mut benchmarks = Vec::new(); + + for axis in 0..3 { + benchmarks.push(ReduceBenchmark::::new( + Instruction::ArgMin(axis), + device.clone(), + )); + + benchmarks.push(ReduceBenchmark::::new( + Instruction::SumDim(axis), + device.clone(), + )); + } + + benchmarks.push(ReduceBenchmark::::new(Instruction::Sum, device.clone())); + + save::( + benchmarks.into_iter().map(run_benchmark).collect(), + device, + feature_name, + url, + token, + ) + .unwrap(); +} + +fn main() { + backend_comparison::bench_on_backend!(); +} diff --git a/backend-comparison/src/burnbenchapp/auth.rs b/backend-comparison/src/burnbenchapp/auth/base.rs similarity index 96% rename from backend-comparison/src/burnbenchapp/auth.rs rename to backend-comparison/src/burnbenchapp/auth/base.rs index 3e9470b2bd..f956b0a204 100644 --- a/backend-comparison/src/burnbenchapp/auth.rs +++ b/backend-comparison/src/burnbenchapp/auth/base.rs @@ -1,6 +1,5 @@ use arboard::Clipboard; use burn::serde::{Deserialize, Serialize}; -use github_device_flow::{self, DeviceFlow}; use reqwest; #[cfg(unix)] use std::os::unix::fs::PermissionsExt; @@ -64,7 +63,7 @@ pub(crate) fn get_tokens() -> Option { pub(crate) fn get_username(access_token: &str) -> Option { let client = reqwest::blocking::Client::new(); let response = client - .get(format!("{}users/me", super::USER_BENCHMARK_SERVER_URL)) + .get(format!("{}users/me", USER_BENCHMARK_SERVER_URL)) .header(reqwest::header::USER_AGENT, "burnbench") .header(reqwest::header::CONTENT_TYPE, "application/json") .header( @@ -77,7 +76,7 @@ pub(crate) fn get_username(access_token: &str) -> Option { } fn auth() -> Option { - let mut flow = match DeviceFlow::start(CLIENT_ID, None) { + let mut flow = match DeviceFlow::start(CLIENT_ID, None, None) { Ok(flow) => flow, Err(e) => { eprintln!("Error authenticating: {}", e); @@ -134,7 +133,7 @@ fn verify_tokens(tokens: &Tokens) -> bool { ) .header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION) .send(); - response.map_or(false, |resp| resp.status().is_success()) + response.is_ok_and(|resp| resp.status().is_success()) } fn refresh_tokens(tokens: &Tokens) -> Option { @@ -142,10 +141,7 @@ fn refresh_tokens(tokens: &Tokens) -> Option { println!("Refreshing token..."); let client = reqwest::blocking::Client::new(); let response = client - .post(format!( - "{}auth/refresh-token", - super::USER_BENCHMARK_SERVER_URL - )) + .post(format!("{}auth/refresh-token", USER_BENCHMARK_SERVER_URL)) .header(reqwest::header::USER_AGENT, "burnbench") .header(reqwest::header::CONTENT_TYPE, "application/json") .header( @@ -189,6 +185,8 @@ fn save_tokens(tokens: &Tokens) { #[cfg(test)] use serial_test::serial; +use crate::burnbenchapp::{auth::github_device_flow::DeviceFlow, USER_BENCHMARK_SERVER_URL}; + #[cfg(test)] mod tests { use super::*; diff --git a/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs b/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs new file mode 100644 index 0000000000..55aa00f73e --- /dev/null +++ b/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs @@ -0,0 +1,232 @@ +// Initially from: https://github.com/jakewilkins/gh-device-flow +use std::collections::HashMap; +use std::{fmt, result::Result, thread, time}; + +use chrono::offset::Utc; +use chrono::{DateTime, Duration}; +use serde::{Deserialize, Serialize}; + +pub fn credential_error(msg: String) -> DeviceFlowError { + DeviceFlowError::GitHubError(msg) +} + +pub fn send_request( + device_flow: &mut DeviceFlow, + url: String, + body: String, +) -> Option> { + let client = reqwest::blocking::Client::new(); + let response_struct = client + .post(&url) + .header("Accept", "application/json") + .body(body) + .send(); + + match response_struct { + Ok(resp) => match resp.json::>() { + Ok(hm) => Some(hm), + Err(err) => { + device_flow.state = DeviceFlowState::Failure(err.into()); + None + } + }, + Err(err) => { + device_flow.state = DeviceFlowState::Failure(err.into()); + None + } + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct Credential { + pub token: String, + pub expiry: String, + pub refresh_token: String, +} + +impl Credential { + fn empty() -> Credential { + Credential { + token: String::new(), + expiry: String::new(), + refresh_token: String::new(), + } + } +} + +#[derive(Debug, Clone)] +pub enum DeviceFlowError { + HttpError(String), + GitHubError(String), +} + +impl fmt::Display for DeviceFlowError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string), + DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string), + } + } +} + +impl std::error::Error for DeviceFlowError {} + +impl From for DeviceFlowError { + fn from(e: reqwest::Error) -> Self { + DeviceFlowError::HttpError(format!("{:?}", e)) + } +} + +#[derive(Debug, Clone)] +pub enum DeviceFlowState { + Pending, + Processing(time::Duration), + Success(Credential), + Failure(DeviceFlowError), +} + +#[derive(Clone)] +pub struct DeviceFlow { + pub host: String, + pub client_id: String, + pub scope: String, + pub user_code: Option, + pub device_code: Option, + pub verification_uri: Option, + pub state: DeviceFlowState, +} + +const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0); + +impl DeviceFlow { + pub fn new(client_id: &str, maybe_host: Option<&str>, scope: Option<&str>) -> Self { + Self { + client_id: String::from(client_id), + scope: match scope { + Some(string) => String::from(string), + None => String::new(), + }, + host: match maybe_host { + Some(string) => String::from(string), + None => String::from("github.com"), + }, + user_code: None, + device_code: None, + verification_uri: None, + state: DeviceFlowState::Pending, + } + } + + pub fn start( + client_id: &str, + maybe_host: Option<&str>, + scope: Option<&str>, + ) -> Result { + let mut flow = DeviceFlow::new(client_id, maybe_host, scope); + + flow.setup(); + + match flow.state { + DeviceFlowState::Processing(_) => Ok(flow.to_owned()), + DeviceFlowState::Failure(err) => Err(err), + _ => Err(credential_error( + "Something truly unexpected happened".into(), + )), + } + } + + pub fn setup(&mut self) { + let body = format!("client_id={}&scope={}", &self.client_id, &self.scope); + let entry_url = format!("https://{}/login/device/code", &self.host); + + if let Some(res) = send_request(self, entry_url, body) { + if res.contains_key("error") && res.contains_key("error_description") { + self.state = DeviceFlowState::Failure(credential_error( + res["error_description"].as_str().unwrap().into(), + )) + } else if res.contains_key("error") { + self.state = DeviceFlowState::Failure(credential_error(format!( + "Error response: {:?}", + res["error"].as_str().unwrap() + ))) + } else { + self.user_code = Some(String::from(res["user_code"].as_str().unwrap())); + self.device_code = Some(String::from(res["device_code"].as_str().unwrap())); + self.verification_uri = + Some(String::from(res["verification_uri"].as_str().unwrap())); + self.state = DeviceFlowState::Processing(FIVE_SECONDS); + } + }; + } + + pub fn poll(&mut self, iterations: u32) -> Result { + for count in 0..iterations { + self.update(); + + if let DeviceFlowState::Processing(interval) = self.state { + if count == iterations { + return Err(credential_error("Max poll iterations reached".into())); + } + + thread::sleep(interval); + } else { + break; + } + } + + match &self.state { + DeviceFlowState::Success(cred) => Ok(cred.to_owned()), + DeviceFlowState::Failure(err) => Err(err.to_owned()), + _ => Err(credential_error( + "Unable to fetch credential, sorry :/".into(), + )), + } + } + + pub fn update(&mut self) { + let poll_url = format!("https://{}/login/oauth/access_token", self.host); + let poll_payload = format!( + "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", + self.client_id, + &self.device_code.clone().unwrap() + ); + + if let Some(res) = send_request(self, poll_url, poll_payload) { + if res.contains_key("error") { + match res["error"].as_str().unwrap() { + "authorization_pending" => {} + "slow_down" => { + if let DeviceFlowState::Processing(current_interval) = self.state { + self.state = + DeviceFlowState::Processing(current_interval + FIVE_SECONDS); + }; + } + other_reason => { + self.state = DeviceFlowState::Failure(credential_error(format!( + "Error checking for token: {}", + other_reason + ))); + } + } + } else { + let mut this_credential = Credential::empty(); + this_credential.token = res["access_token"].as_str().unwrap().to_string(); + + if let Some(expires_in) = res.get("expires_in") { + this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap()); + this_credential.refresh_token = + res["refresh_token"].as_str().unwrap().to_string(); + } + + self.state = DeviceFlowState::Success(this_credential); + } + } + } +} + +fn calculate_expiry(expires_in: i64) -> String { + let expires_in = Duration::seconds(expires_in); + let mut expiry: DateTime = Utc::now(); + expiry += expires_in; + expiry.to_rfc3339() +} diff --git a/backend-comparison/src/burnbenchapp/auth/mod.rs b/backend-comparison/src/burnbenchapp/auth/mod.rs new file mode 100644 index 0000000000..7e1e4539d7 --- /dev/null +++ b/backend-comparison/src/burnbenchapp/auth/mod.rs @@ -0,0 +1,4 @@ +mod base; +pub(crate) mod github_device_flow; + +pub(crate) use base::*; diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs index 83c5060a6b..4fb31edab8 100644 --- a/backend-comparison/src/burnbenchapp/base.rs +++ b/backend-comparison/src/burnbenchapp/base.rs @@ -62,6 +62,13 @@ enum BackendValues { CandleCuda, #[strum(to_string = "candle-metal")] CandleMetal, + #[strum(to_string = "cuda")] + Cuda, + #[strum(to_string = "cuda-fusion")] + CudaFusion, + #[cfg(target_os = "linux")] + #[strum(to_string = "hip")] + Hip, #[strum(to_string = "ndarray")] Ndarray, #[strum(to_string = "ndarray-blas-accelerate")] @@ -82,13 +89,6 @@ enum BackendValues { WgpuSpirv, #[strum(to_string = "wgpu-spirv-fusion")] WgpuSpirvFusion, - #[strum(to_string = "cuda-jit")] - CudaJit, - #[strum(to_string = "cuda-jit-fusion")] - CudaJitFusion, - #[cfg(target_os = "linux")] - #[strum(to_string = "hip-jit")] - HipJit, } #[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)] @@ -123,6 +123,8 @@ enum BenchmarkValues { Conv2d, #[strum(to_string = "conv3d")] Conv3d, + #[strum(to_string = "reduce")] + Reduce, } pub fn execute() { diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 03e2d70444..b3351e9dd5 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -54,6 +54,9 @@ fn update_panic_hook() { #[macro_export] macro_rules! bench_on_backend { () => { + $crate::bench_on_backend!(bench) + }; + ($fn_name:ident) => { use std::env; backend_comparison::init_log().unwrap(); @@ -88,25 +91,25 @@ macro_rules! bench_on_backend { let feature_name = "wgpu-spirv"; #[cfg(feature = "wgpu-spirv-fusion")] let feature_name = "wgpu-spirv-fusion"; - #[cfg(feature = "cuda-jit")] - let feature_name = "cuda-jit"; - #[cfg(feature = "cuda-jit-fusion")] - let feature_name = "cuda-jit-fusion"; - #[cfg(feature = "hip-jit")] - let feature_name = "hip-jit"; + #[cfg(feature = "cuda")] + let feature_name = "cuda"; + #[cfg(feature = "cuda-fusion")] + let feature_name = "cuda-fusion"; + #[cfg(feature = "hip")] + let feature_name = "hip"; #[cfg(any(feature = "wgpu"))] { use burn::backend::wgpu::{Wgpu, WgpuDevice}; - bench::>(&WgpuDevice::default(), feature_name, url, token); + $fn_name::>(&WgpuDevice::default(), feature_name, url, token); } #[cfg(any(feature = "wgpu-spirv"))] { use burn::backend::wgpu::{Wgpu, WgpuDevice}; - bench::>(&WgpuDevice::default(), feature_name, url, token); + $fn_name::>(&WgpuDevice::default(), feature_name, url, token); } #[cfg(feature = "tch-gpu")] @@ -117,7 +120,7 @@ macro_rules! bench_on_backend { let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - bench::>(&device, feature_name, url, token); + $fn_name::>(&device, feature_name, url, token); } #[cfg(feature = "tch-cpu")] @@ -125,7 +128,7 @@ macro_rules! bench_on_backend { use burn::backend::{libtorch::LibTorchDevice, LibTorch}; let device = LibTorchDevice::Cpu; - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(any( @@ -139,7 +142,7 @@ macro_rules! bench_on_backend { use burn::backend::NdArray; let device = NdArrayDevice::Cpu; - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(feature = "candle-cpu")] @@ -148,7 +151,7 @@ macro_rules! bench_on_backend { use burn::backend::Candle; let device = CandleDevice::Cpu; - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(feature = "candle-cuda")] @@ -157,7 +160,7 @@ macro_rules! bench_on_backend { use burn::backend::Candle; let device = CandleDevice::cuda(0); - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(feature = "candle-metal")] @@ -166,21 +169,21 @@ macro_rules! bench_on_backend { use burn::backend::Candle; let device = CandleDevice::metal(0); - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } - #[cfg(feature = "cuda-jit")] + #[cfg(feature = "cuda")] { - use burn::backend::cuda_jit::{Cuda, CudaDevice}; + use burn::backend::cuda::{Cuda, CudaDevice}; - bench::>(&CudaDevice::default(), feature_name, url, token); + $fn_name::>(&CudaDevice::default(), feature_name, url, token); } - #[cfg(feature = "hip-jit")] + #[cfg(feature = "hip")] { - use burn::backend::hip_jit::{Hip, HipDevice}; + use burn::backend::hip::{Hip, HipDevice}; - bench::>(&HipDevice::default(), feature_name, url, token); + $fn_name::>(&HipDevice::default(), feature_name, url, token); } }; } diff --git a/backend-comparison/src/persistence/system_info.rs b/backend-comparison/src/persistence/system_info.rs index 287b629c21..3fe24bc955 100644 --- a/backend-comparison/src/persistence/system_info.rs +++ b/backend-comparison/src/persistence/system_info.rs @@ -38,7 +38,7 @@ impl BenchmarkSystemInfo { fn enumerate_cpus() -> Vec { let system = sysinfo::System::new_with_specifics( - sysinfo::RefreshKind::new().with_cpu(sysinfo::CpuRefreshKind::everything()), + sysinfo::RefreshKind::nothing().with_cpu(sysinfo::CpuRefreshKind::everything()), ); let cpu_names: HashSet = system .cpus() diff --git a/burn-book/src/advanced/no-std.md b/burn-book/src/advanced/no-std.md index 7689d25354..e55afc904d 100644 --- a/burn-book/src/advanced/no-std.md +++ b/burn-book/src/advanced/no-std.md @@ -23,7 +23,7 @@ Some other dependencies have to be added ```toml [dependencies] embedded-alloc = "0.5.1" # Only if there is no default allocator for your chip -burn = { version = "0.16", default-features = false, features = ["ndarray"] } # Backend must be ndarray +burn = { version = "0.17", default-features = false, features = ["ndarray"] } # Backend must be ndarray [build-dependencies] burn-import = { version = "0.14" } # Used to auto generate the rust code to import the model @@ -68,7 +68,7 @@ We are using ndarray, so we just need to define the NdArray backend as usual use burn::{backend::NdArray, tensor::Tensor}; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; ``` Then inside the `main` function add @@ -76,7 +76,7 @@ Then inside the `main` function add use your_model::Model; // Get a default device for the backend -let device = BackendDeice::default(); +let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); diff --git a/burn-book/src/basic-workflow/README.md b/burn-book/src/basic-workflow/README.md index 5b32591a58..8515d73d2c 100644 --- a/burn-book/src/basic-workflow/README.md +++ b/burn-book/src/basic-workflow/README.md @@ -14,7 +14,7 @@ automatically add the missing imports as you add the code snippets to your code. Be sure to checkout the git branch corresponding to the version of Burn you are using to follow this guide. -The current version of Burn is `0.16` and the corresponding branch to checkout is `main`. +The current version of Burn is `0.17` and the corresponding branch to checkout is `main`. The code for this demo can be executed from Burn's base directory using the command: diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md index adce46b297..ac4b16dbce 100644 --- a/burn-book/src/basic-workflow/model.md +++ b/burn-book/src/basic-workflow/model.md @@ -20,7 +20,7 @@ version = "0.1.0" edition = "2021" [dependencies] -burn = { version = "~0.16", features = ["train", "wgpu", "vision"] } +burn = { version = "~0.17", features = ["train", "wgpu", "vision"] } ``` Our goal will be to create a basic convolutional neural network used for image classification. We diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md index e029aca708..e5dd4eaae9 100644 --- a/burn-book/src/building-blocks/metric.md +++ b/burn-book/src/building-blocks/metric.md @@ -4,11 +4,12 @@ When working with the learner, you have the option to record metrics that will b throughout the training process. We currently offer a restricted range of metrics. | Metric | Description | -|------------------|---------------------------------------------------------| +| ---------------- | ------------------------------------------------------- | | Accuracy | Calculate the accuracy in percentage | | TopKAccuracy | Calculate the top-k accuracy in percentage | | Precision | Calculate precision in percentage | | Recall | Calculate recall in percentage | +| FBetaScore | Calculate Fβ score in percentage | | AUROC | Calculate the area under curve of ROC in percentage | | Loss | Output the loss used for the backward pass | | CPU Temperature | Fetch the temperature of CPUs | diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 0f5aca7f24..9598d6e39e 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -294,3 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules. | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | | `HuberLoss` | `nn.HuberLoss` | +| `PoissonNllLoss` | `nn.PoissonNLLLoss` | diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index fc4833c3c4..c12bb82c00 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -131,47 +131,47 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. -| Burn | PyTorch Equivalent | -| ------------------------------------- | ------------------------------------------------------------------------- | -| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | -| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | -| `Tensor::from_primitive(primitive)` | N/A | -| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | -| `tensor.all()` | `tensor.all()` | -| `tensor.all_dim(dim)` | `tensor.all(dim)` | -| `tensor.any()` | `tensor.any()` | -| `tensor.any_dim(dim)` | `tensor.any(dim)` | -| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | -| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | -| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | -| `tensor.device()` | `tensor.device` | -| `tensor.dtype()` | `tensor.dtype` | -| `tensor.dims()` | `tensor.size()` | -| `tensor.equal(other)` | `x == y` | -| `tensor.expand(shape)` | `tensor.expand(shape)` | -| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | -| `tensor.flip(axes)` | `tensor.flip(axes)` | -| `tensor.into_data()` | N/A | -| `tensor.into_primitive()` | N/A | -| `tensor.into_scalar()` | `tensor.item()` | -| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | -| `tensor.not_equal(other)` | `x != y` | -| `tensor.permute(axes)` | `tensor.permute(axes)` | -| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | -| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | -| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | -| `tensor.reshape(shape)` | `tensor.view(shape)` | -| `tensor.shape()` | `tensor.shape` | -| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | -| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | -| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | -| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | -| `tensor.to_data()` | N/A | -| `tensor.to_device(device)` | `tensor.to(device)` | -| `tensor.transpose()` | `tensor.T` | -| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | -| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | -| `tensor.unsqueeze_dims(dims)` | N/A | +| Burn | PyTorch Equivalent | +| ------------------------------------------- | ------------------------------------------------------------------------- | +| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | +| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | +| `Tensor::from_primitive(primitive)` | N/A | +| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | +| `tensor.all()` | `tensor.all()` | +| `tensor.all_dim(dim)` | `tensor.all(dim)` | +| `tensor.any()` | `tensor.any()` | +| `tensor.any_dim(dim)` | `tensor.any(dim)` | +| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | +| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | +| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | +| `tensor.device()` | `tensor.device` | +| `tensor.dtype()` | `tensor.dtype` | +| `tensor.dims()` | `tensor.size()` | +| `tensor.equal(other)` | `x == y` | +| `tensor.expand(shape)` | `tensor.expand(shape)` | +| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | +| `tensor.flip(axes)` | `tensor.flip(axes)` | +| `tensor.into_data()` | N/A | +| `tensor.into_primitive()` | N/A | +| `tensor.into_scalar()` | `tensor.item()` | +| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | +| `tensor.not_equal(other)` | `x != y` | +| `tensor.permute(axes)` | `tensor.permute(axes)` | +| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | +| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | +| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | +| `tensor.reshape(shape)` | `tensor.view(shape)` | +| `tensor.shape()` | `tensor.shape` | +| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | +| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | +| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | +| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | +| `tensor.to_data()` | N/A | +| `tensor.to_device(device)` | `tensor.to(device)` | +| `tensor.transpose()` | `tensor.T` | +| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | +| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | +| `tensor.unsqueeze_dims(dims)` | N/A | ### Numeric Operations @@ -229,6 +229,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.neg()` or `-tensor` | `-tensor` | | `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` | | `tensor.ones_like()` | `torch.ones_like(tensor)` | +| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | +| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A | | `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` | | `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` | | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | @@ -257,33 +259,32 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. Those operations are only available for `Float` tensors. -| Burn API | PyTorch Equivalent | -| --------------------------------------------- | ---------------------------------- | -| `Tensor::one_hot(index, num_classes, device)` | N/A | -| `tensor.cast(dtype)` | `tensor.to(dtype)` | -| `tensor.ceil()` | `tensor.ceil()` | -| `tensor.cos()` | `tensor.cos()` | -| `tensor.erf()` | `tensor.erf()` | -| `tensor.exp()` | `tensor.exp()` | -| `tensor.floor()` | `tensor.floor()` | -| `tensor.from_floats(floats, device)` | N/A | -| `tensor.from_full_precision(tensor)` | N/A | -| `tensor.int()` | Similar to `tensor.to(torch.long)` | -| `tensor.log()` | `tensor.log()` | -| `tensor.log1p()` | `tensor.log1p()` | -| `tensor.matmul(other)` | `tensor.matmul(other)` | -| `tensor.random(shape, distribution, device)` | N/A | -| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | -| `tensor.recip()` | `tensor.reciprocal()` | -| `tensor.round()` | `tensor.round()` | -| `tensor.sin()` | `tensor.sin()` | -| `tensor.sqrt()` | `tensor.sqrt()` | -| `tensor.tanh()` | `tensor.tanh()` | -| `tensor.to_full_precision()` | `tensor.to(torch.float)` | -| `tensor.var(dim)` | `tensor.var(dim)` | -| `tensor.var_bias(dim)` | N/A | -| `tensor.var_mean(dim)` | N/A | -| `tensor.var_mean_bias(dim)` | N/A | +| Burn API | PyTorch Equivalent | +| -------------------------------------------- | ---------------------------------- | +| `tensor.cast(dtype)` | `tensor.to(dtype)` | +| `tensor.ceil()` | `tensor.ceil()` | +| `tensor.cos()` | `tensor.cos()` | +| `tensor.erf()` | `tensor.erf()` | +| `tensor.exp()` | `tensor.exp()` | +| `tensor.floor()` | `tensor.floor()` | +| `tensor.from_floats(floats, device)` | N/A | +| `tensor.from_full_precision(tensor)` | N/A | +| `tensor.int()` | Similar to `tensor.to(torch.long)` | +| `tensor.log()` | `tensor.log()` | +| `tensor.log1p()` | `tensor.log1p()` | +| `tensor.matmul(other)` | `tensor.matmul(other)` | +| `tensor.random(shape, distribution, device)` | N/A | +| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | +| `tensor.recip()` | `tensor.reciprocal()` | +| `tensor.round()` | `tensor.round()` | +| `tensor.sin()` | `tensor.sin()` | +| `tensor.sqrt()` | `tensor.sqrt()` | +| `tensor.tanh()` | `tensor.tanh()` | +| `tensor.to_full_precision()` | `tensor.to(torch.float)` | +| `tensor.var(dim)` | `tensor.var(dim)` | +| `tensor.var_bias(dim)` | N/A | +| `tensor.var_mean(dim)` | N/A | +| `tensor.var_mean_bias(dim)` | N/A | ### Int Operations @@ -293,11 +294,21 @@ Those operations are only available for `Int` tensors. | ------------------------------------------------ | ------------------------------------------------------- | | `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` | | `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | +| `tensor.bitwise_and(other)` | `torch.bitwise_and(tensor, other)` | +| `tensor.bitwise_and_scalar(scalar)` | `torch.bitwise_and(tensor, scalar)` | +| `tensor.bitwise_not()` | `torch.bitwise_not(tensor)` | +| `tensor.bitwise_left_shift(other)` | `torch.bitwise_left_shift(tensor, other)` | +| `tensor.bitwise_left_shift_scalar(scalar)` | `torch.bitwise_left_shift(tensor, scalar)` | +| `tensor.bitwise_right_shift(other)` | `torch.bitwise_right_shift(tensor, other)` | +| `tensor.bitwise_right_shift_scalar(scalar)` | `torch.bitwise_right_shift(tensor, scalar)` | +| `tensor.bitwise_or(other)` | `torch.bitwise_or(tensor, other)` | +| `tensor.bitwise_or_scalar(scalar)` | `torch.bitwise_or(tensor, scalar)` | +| `tensor.bitwise_xor(other)` | `torch.bitwise_xor(tensor, other)` | +| `tensor.bitwise_xor_scalar(scalar)` | `torch.bitwise_xor(tensor, scalar)` | | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | | `tensor.cartesian_grid(shape, device)` | N/A | -| `tensor.one_hot(num_classes)` | N/A | ### Bool Operations @@ -329,7 +340,7 @@ strategies. | Burn API | PyTorch Equivalent | | ------------------------------------------------ | -------------------------------------------------- | | `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` | -| `activation::hard_sigmoid(tensor, alpha, beta) | `nn.functional.hardsigmoid(tensor)` | +| `activation::hard_sigmoid(tensor, alpha, beta)` | `nn.functional.hardsigmoid(tensor)` | | `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` | | `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` | | `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` | diff --git a/burn-book/src/examples.md b/burn-book/src/examples.md index c9703a4389..2b083b6fbe 100644 --- a/burn-book/src/examples.md +++ b/burn-book/src/examples.md @@ -85,6 +85,7 @@ The following additional examples are currently available if you want to check t | [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/pytorch-import) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. | | [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. | | [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. | +| [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan) | Trains a WGAN model to generate new handwritten digits based on MNIST. | For more information on each example, see their respective `README.md` file. Be sure to check out the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date diff --git a/burn-book/src/import/onnx-model.md b/burn-book/src/import/onnx-model.md index 05b9d5de81..9b3b7917fd 100644 --- a/burn-book/src/import/onnx-model.md +++ b/burn-book/src/import/onnx-model.md @@ -74,7 +74,7 @@ First, add the `burn-import` crate to your `Cargo.toml`: ```toml [build-dependencies] -burn-import = "~0.16" +burn-import = "~0.17" ``` Then, in your `build.rs` file: diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index 5c17eee3e9..1f584cdc9f 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -162,17 +162,13 @@ struct NetConfig { n_head: usize, n_layer: usize, d_model: usize, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_float: f64, + some_float: f64, some_int: i32, some_bool: bool, some_str: String, some_list_int: Vec, some_list_str: Vec, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_list_float: Vec, + some_list_float: Vec, some_dict: HashMap, } diff --git a/burn-book/src/saving-and-loading.md b/burn-book/src/saving-and-loading.md index 13a96cc94d..24b52dd22a 100644 --- a/burn-book/src/saving-and-loading.md +++ b/burn-book/src/saving-and-loading.md @@ -4,7 +4,7 @@ Saving your trained machine learning model is quite easy, no matter the output f mentioned in the [Record](./building-blocks/record.md) section, different formats are supported to serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` which uses the [MessagePack](https://msgpack.org/) binary serialization format with the help of -[smp_serde](https://docs.rs/rmp-serde/). +[rmp_serde](https://docs.rs/rmp-serde/). ```rust, ignore // Save model in MessagePack format with full precision @@ -22,7 +22,7 @@ Now that you have a trained model saved to your disk, you can easily load it in ```rust, ignore // Load model in full precision from MessagePack file let recorder = NamedMpkFileRecorder::::new(); -model +model = model .load_file(model_path, &recorder, device) .expect("Should be able to load the model weights from the provided file"); ``` diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml index 2144d46885..df7040f835 100644 --- a/crates/burn-autodiff/Cargo.toml +++ b/crates/burn-autodiff/Cargo.toml @@ -18,16 +18,16 @@ std = [] async = [] # Require std [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0" } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } derive-new = { workspace = true } spin = { workspace = true } log = { workspace = true } [dev-dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index ffcc522051..8203d16212 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -352,4 +352,48 @@ impl IntTensorOps for Autodiff { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { B::int_argsort(tensor, dim, descending) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_and(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_or(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_xor(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + B::bitwise_not(tensor) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_right_shift_scalar(lhs, rhs) + } } diff --git a/crates/burn-candle/Cargo.toml b/crates/burn-candle/Cargo.toml index 62af31d5fb..65fbf416ca 100644 --- a/crates/burn-candle/Cargo.toml +++ b/crates/burn-candle/Cargo.toml @@ -21,17 +21,17 @@ accelerate = ["candle-core/accelerate"] [dependencies] derive-new = { workspace = true } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } half = { workspace = true } candle-core = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tch = { path = "../burn-tch", version = "0.16.0", default-features = false, features = [ +burn-tch = { path = "../burn-tch", version = "0.17.0", default-features = false, features = [ ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 6e3586506b..67328c36f3 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -373,6 +373,50 @@ impl IntTensorOps for Candle, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_and is not implemented for Candle IntTensor"); + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_or is not implemented for Candle IntTensor"); + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_xor is not implemented for Candle IntTensor"); + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unimplemented!("bitwise_not is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor"); + } + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { super::base::cumsum(tensor, dim) } diff --git a/crates/burn-common/src/lib.rs b/crates/burn-common/src/lib.rs index 77faa0195d..efe3b6d7d2 100644 --- a/crates/burn-common/src/lib.rs +++ b/crates/burn-common/src/lib.rs @@ -11,6 +11,9 @@ pub mod id; pub use cubecl_common::*; +#[cfg(feature = "rayon")] +pub use rayon; + extern crate alloc; /// Network utilities. diff --git a/crates/burn-common/src/parallel.rs b/crates/burn-common/src/parallel.rs index 93c0da7d2d..969683f9e7 100644 --- a/crates/burn-common/src/parallel.rs +++ b/crates/burn-common/src/parallel.rs @@ -1,51 +1,90 @@ /// Macro for running a function in parallel. +#[cfg(feature = "rayon")] #[macro_export(local_inner_macros)] macro_rules! run_par { ( $func:expr ) => {{ - #[cfg(feature = "rayon")] - use rayon::prelude::*; + use $crate::rayon::prelude::*; - #[cfg(feature = "rayon")] #[allow(clippy::redundant_closure_call)] - let output = rayon::scope(|_| $func()); + $crate::rayon::scope(|_| $func()) + }}; +} - #[cfg(not(feature = "rayon"))] - let output = $func(); +/// Macro for running a function in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! run_par { + ( + $func:expr + ) => {{ + $func() + }}; +} - output +/// Macro for iterating in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_par { + ( + $iter:expr + ) => {{ + $iter }}; } /// Macro for iterating in parallel. +#[cfg(feature = "rayon")] #[macro_export(local_inner_macros)] macro_rules! iter_par { ( $iter:expr ) => {{ - #[cfg(feature = "rayon")] - let output = $iter.into_par_iter(); + $iter.into_par_iter() + }}; +} - #[cfg(not(feature = "rayon"))] - let output = $iter; +/// Macro for iterating in parallel. +#[cfg(feature = "rayon")] +#[macro_export(local_inner_macros)] +macro_rules! iter_slice_par { + ( + $slice:expr + ) => {{ + $slice.into_par_iter() + }}; +} - output +/// Macro for iterating in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_slice_par { + ( + $slice:expr + ) => {{ + $slice.iter() }}; } /// Macro for iterating over a range in parallel. +#[cfg(feature = "rayon")] #[macro_export(local_inner_macros)] macro_rules! iter_range_par { ( $start:expr, $end:expr ) => {{ - #[cfg(feature = "rayon")] - let output = ($start..$end).into_par_iter(); - - #[cfg(not(feature = "rayon"))] - let output = ($start..$end); + ($start..$end).into_par_iter() + }}; +} - output +/// Macro for iterating over a range in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_range_par { + ( + $start:expr, $end:expr + ) => {{ + ($start..$end) }}; } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e63af0fba5..423dc784d8 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -36,8 +36,8 @@ doc = [ "ndarray", "tch", "wgpu", - "cuda-jit", - "hip-jit", + "cuda", + "hip", "audio", "vision", "autodiff", @@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"] ## Backend features accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] -autotune = ["burn-wgpu?/autotune"] +autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"] blas-netlib = ["burn-ndarray?/blas-netlib"] metal = ["burn-candle?/metal"] openblas = ["burn-ndarray?/blas-openblas"] @@ -100,12 +100,13 @@ template = ["burn-wgpu?/template"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] -cuda-jit = ["burn-cuda"] -hip-jit = ["burn-hip"] +cuda = ["burn-cuda"] +hip = ["burn-hip"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] wgpu = ["burn-wgpu"] -wgpu-spirv = ["wgpu", "burn-wgpu/spirv"] +vulkan = ["wgpu", "burn-wgpu/vulkan"] +webgpu = ["wgpu", "burn-wgpu/webgpu"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror", "regex"] @@ -113,37 +114,34 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -# Backwards compatibility with previous serialized data format. -record-backward-compat = [] - -test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. -test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. +test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray. +test-hip = ["hip"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. test-wgpu-spirv = [ "test-wgpu", - "wgpu-spirv", + "vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-dataset = { path = "../burn-dataset", version = "0.16.0", optional = true, default-features = false } -burn-derive = { path = "../burn-derive", version = "0.16.0" } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-dataset = { path = "../burn-dataset", version = "0.17.0", optional = true, default-features = false } +burn-derive = { path = "../burn-derive", version = "0.17.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } # Backends -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", optional = true } -burn-candle = { path = "../burn-candle", version = "0.16.0", optional = true } -burn-cuda = { path = "../burn-cuda", version = "0.16.0", optional = true, default-features = false } -burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-features = false } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false } -burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true } -burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true } -burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } +burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false } +burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true, default-features = false } +burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true } +burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true } +burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false } data-encoding = { workspace = true } uuid = { workspace = true } @@ -173,13 +171,13 @@ thiserror = { workspace = true, optional = true } portable-atomic-util = { workspace = true } [dev-dependencies] -burn-dataset = { path = "../burn-dataset", version = "0.16.0", features = [ +burn-dataset = { path = "../burn-dataset", version = "0.17.0", features = [ "fake", ] } tempfile = { workspace = true } -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0" } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index bd4c959302..31ac3a8c41 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -21,11 +21,17 @@ pub use burn_wgpu as wgpu; #[cfg(feature = "wgpu")] pub use burn_wgpu::Wgpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda as cuda_jit; +#[cfg(feature = "webgpu")] +pub use burn_wgpu::WebGpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda::Cuda as CudaJit; +#[cfg(feature = "vulkan")] +pub use burn_wgpu::Vulkan; + +#[cfg(feature = "cuda")] +pub use burn_cuda as cuda; + +#[cfg(feature = "cuda")] +pub use burn_cuda::Cuda; #[cfg(feature = "candle")] pub use burn_candle as candle; @@ -33,11 +39,11 @@ pub use burn_candle as candle; #[cfg(feature = "candle")] pub use burn_candle::Candle; -#[cfg(feature = "hip-jit")] -pub use burn_hip as hip_jit; +#[cfg(feature = "hip")] +pub use burn_hip as hip; -#[cfg(feature = "hip-jit")] -pub use burn_hip::Hip as HipJit; +#[cfg(feature = "hip")] +pub use burn_hip::Hip; #[cfg(feature = "tch")] pub use burn_tch as libtorch; diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index f554518430..ade8d64db7 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![recursion_limit = "135"] //! The core crate of Burn. diff --git a/crates/burn-core/src/nn/conv/checks.rs b/crates/burn-core/src/nn/conv/checks.rs index cd346163ad..36932621f1 100644 --- a/crates/burn-core/src/nn/conv/checks.rs +++ b/crates/burn-core/src/nn/conv/checks.rs @@ -9,3 +9,14 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize ); } } + +// https://github.com/tracel-ai/burn/issues/2676 +/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel +/// size is not supported as it will not produce the same output size. +pub(crate) fn check_same_padding_support(kernel_size: &[usize]) { + for k in kernel_size.iter() { + if k % 2 == 0 { + unimplemented!("Same padding with an even kernel size is not supported"); + } + } +} diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index 0b64eab324..c3f61a6b07 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -28,6 +28,10 @@ pub struct Conv1dConfig { #[config(default = "1")] pub groups: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// If bias should be added to the output. @@ -87,6 +91,9 @@ impl Conv1dConfig { /// Initialize a new [conv1d](Conv1d) module. pub fn init(&self, device: &B::Device) -> Conv1d { checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); + if self.padding == PaddingConfig1d::Same { + checks::check_same_padding_support(&[self.kernel_size]); + } let shape = [ self.channels_out, @@ -175,6 +182,14 @@ mod tests { .assert_approx_eq(&TensorData::zeros::(conv.weight.shape()), 3); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = Conv1dConfig::new(5, 5, 4).with_padding(PaddingConfig1d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = Conv1dConfig::new(5, 5, 5); diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index 73be36d357..72c00187be 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -30,6 +30,10 @@ pub struct Conv2dConfig { #[config(default = "1")] pub groups: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If bias should be added to the output. @@ -68,6 +72,9 @@ impl Conv2dConfig { /// Initialize a new [conv2d](Conv2d) module. pub fn init(&self, device: &B::Device) -> Conv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + if self.padding == PaddingConfig2d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], @@ -228,6 +235,14 @@ mod tests { let _ = config.init::(&device); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = Conv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = Conv2dConfig::new([5, 1], [5, 5]); diff --git a/crates/burn-core/src/nn/conv/conv3d.rs b/crates/burn-core/src/nn/conv/conv3d.rs index de7fb1ce2b..0b5d530c5a 100644 --- a/crates/burn-core/src/nn/conv/conv3d.rs +++ b/crates/burn-core/src/nn/conv/conv3d.rs @@ -68,6 +68,9 @@ impl Conv3dConfig { /// Initialize a new [conv3d](Conv3d) module. pub fn init(&self, device: &B::Device) -> Conv3d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + if self.padding == PaddingConfig3d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], @@ -228,6 +231,14 @@ mod tests { assert_eq!(config.initializer, init); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = Conv3dConfig::new([4, 4], [2, 2, 2]).with_padding(PaddingConfig3d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = Conv3dConfig::new([5, 1], [5, 5, 5]); diff --git a/crates/burn-core/src/nn/conv/deform_conv2d.rs b/crates/burn-core/src/nn/conv/deform_conv2d.rs index 03becd9d4e..2baff11d07 100644 --- a/crates/burn-core/src/nn/conv/deform_conv2d.rs +++ b/crates/burn-core/src/nn/conv/deform_conv2d.rs @@ -33,6 +33,10 @@ pub struct DeformConv2dConfig { #[config(default = "1")] pub offset_groups: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If bias should be added to the output. @@ -73,6 +77,9 @@ impl DeformConv2dConfig { /// Initialize a new [DeformConv2d](DeformConv2d) module. pub fn init(&self, device: &B::Device) -> DeformConv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups); + if self.padding == PaddingConfig2d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], @@ -250,6 +257,14 @@ mod tests { let _ = config.init::(&device); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = DeformConv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = DeformConv2dConfig::new([5, 1], [5, 5]); diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index d03e95c1f3..79fc12ecbf 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -30,6 +30,12 @@ pub struct Dropout { impl DropoutConfig { /// Initialize a new [dropout](Dropout) module. pub fn init(&self) -> Dropout { + if self.prob < 0.0 || self.prob > 1.0 { + panic!( + "Dropout probability should be between 0 and 1, but got {}", + self.prob + ); + } Dropout { prob: self.prob } } } @@ -108,4 +114,11 @@ mod tests { assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}"); } + + #[test] + #[should_panic = "Dropout probability should be between 0 and 1,"] + fn dropout_prob_invalid() { + let config = DropoutConfig::new(-10.); + let _layer = config.init(); + } } diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index f645c84fd9..54b80f4f60 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -118,9 +118,9 @@ impl BinaryCrossEntropyLoss { (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits) } else { // - (target * log(input) + (1 - target) * log(1 - input)) - (targets_float.clone() * logits.clone().log() - + (targets_float.neg() + 1.) * (logits.neg() + 1.).log()) - .neg() + // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values + (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0) + - targets_float * logits.log().clamp_min(-100.0) }; if let Some(weights) = &self.weights { @@ -171,6 +171,38 @@ mod tests { use crate::tensor::{activation::sigmoid, TensorData}; use crate::TestBackend; + #[test] + fn test_binary_cross_entropy_preds_all_correct() { + let device = Default::default(); + let preds = Tensor::::from_floats([1.0, 0.0, 1.0, 0.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([0.000]); + loss_actual.assert_approx_eq(&loss_expected, 3); + } + + #[test] + fn test_binary_cross_entropy_preds_all_incorrect() { + let device = Default::default(); + let preds = Tensor::::from_floats([0.0, 1.0, 0.0, 1.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([100.000]); // clamped value + loss_actual.assert_approx_eq(&loss_expected, 3); + } + #[test] fn test_binary_cross_entropy() { // import torch diff --git a/crates/burn-core/src/nn/loss/mod.rs b/crates/burn-core/src/nn/loss/mod.rs index cca7b4541b..475364e63b 100644 --- a/crates/burn-core/src/nn/loss/mod.rs +++ b/crates/burn-core/src/nn/loss/mod.rs @@ -2,10 +2,12 @@ mod binary_cross_entropy; mod cross_entropy; mod huber; mod mse; +mod poisson; mod reduction; pub use binary_cross_entropy::*; pub use cross_entropy::*; pub use huber::*; pub use mse::*; +pub use poisson::*; pub use reduction::*; diff --git a/crates/burn-core/src/nn/loss/poisson.rs b/crates/burn-core/src/nn/loss/poisson.rs new file mode 100644 index 0000000000..3cc989ad8e --- /dev/null +++ b/crates/burn-core/src/nn/loss/poisson.rs @@ -0,0 +1,390 @@ +use core::f32::consts::PI; + +use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; +use crate::{config::Config, module::Module}; + +use super::Reduction; + +/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance. +/// +/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss +/// behavior, such as whether the input is in log-space, whether to include the Stirling +/// approximation term, and a small epsilon value to avoid numerical instability. +#[derive(Config, Debug)] +pub struct PoissonNllLossConfig { + /// If `true`, the predictions are expected to be in log-space. + /// + /// When `log_input` is `true`, the loss is computed as: + /// ```text + /// L(predictions, target) = exp(predictions) - target * predictions + /// ``` + /// When `log_input` is `false`, the loss is computed as: + /// ```text + /// L(predictions, target) = predictions - target * log(predictions + eps) + /// ``` + #[config(default = true)] + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + /// + /// When `full` is `true`, the Stirling approximation term is added to the loss: + /// ```text + /// target * log(target) - target + 0.5 * log(2 * PI * target) + /// ``` + #[config(default = false)] + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + /// + /// This epsilon value is added to the predictions to ensure numerical stability + /// when computing the logarithm. + #[config(default = 1e-8)] + pub eps: f64, +} + +impl PoissonNllLossConfig { + /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + pub fn init(&self) -> PoissonNllLoss { + self.assertions(); + PoissonNllLoss { + log_input: self.log_input, + full: self.full, + eps: self.eps, + } + } + + /// Validates the configuration parameters. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + fn assertions(&self) { + assert!( + self.eps > 0., + "eps for PoissonNllLoss must be a positive number." + ); + } +} + +/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target. +/// +/// This loss function is used when the target values are assumed to follow a Poisson distribution. +/// The loss is defined as: +/// ```text +/// target ~ Poisson(input) +/// L(predictions, target) = predictions - target * log(predictions) + log(target!) +/// ``` +/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula. +/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss. +/// +/// For more details, see: +/// +#[derive(Module, Debug, Clone)] +#[module(custom_display)] +pub struct PoissonNllLoss { + /// If `true`, the predictions are expected to be in log-space. + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + pub eps: f64, +} + +impl ModuleDisplay for PoissonNllLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("log_input", &self.log_input) + .add("full", &self.full) + .add("eps", &self.eps) + .optional() + } +} + +impl PoissonNllLoss { + /// Computes the loss element-wise for the given predictions and targets, then reduces + /// the result to a single loss value. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[1]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward( + &self, + predictions: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let loss = self.forward_no_reduction(predictions, targets); + match reduction { + Reduction::Mean | Reduction::Auto => loss.mean(), + Reduction::Sum => loss.sum(), + } + } + + /// Computes the loss element-wise for the given predictions and targets without reduction. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[...dims]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward_no_reduction( + &self, + predictions: Tensor, + targets: Tensor, + ) -> Tensor { + self.assertions(&predictions, &targets); + let mut loss; + if self.log_input { + loss = predictions.clone().exp() - targets.clone() * predictions; + } else { + loss = predictions.clone() - targets.clone() * (predictions + self.eps).log(); + } + if self.full { + let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone() + + (targets.clone() * 2. * PI).log() * 0.5; + loss = loss + + log_stirling_term + .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like()); + } + loss + } + + /// Validates the input tensors for the loss computation. + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + fn assertions( + &self, + predictions: &Tensor, + targets: &Tensor, + ) { + let predictions_dims = predictions.dims(); + let targets_dims = targets.dims(); + assert!( + predictions_dims == targets_dims, + "Shape of targets ({:?}) should correspond to outer shape of predictions ({:?}).", + targets_dims, + predictions_dims + ); + assert!( + targets.clone().greater_equal_elem(0.).all().into_scalar(), + "All the values of `targets` must be non-negative." + ); + if !self.log_input { + assert!( + predictions.clone().greater_equal_elem(0.).all().into_scalar(), + "When `log_input` is `false`, all the values of `predictions` must be non-negative." + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::TensorData; + use crate::TestBackend; + type TestTensor = Tensor; + + #[test] + fn test_poisson_nll_loss() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.0321]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([126.1929]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_no_log_input() { + let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]); + let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + + let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_full() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.9920]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([131.9518]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[cfg(feature = "std")] + #[test] + fn test_poisson_nll_loss_gradients() { + type TestAutodiffTensor = Tensor; + + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad(); + let predictions2 = predictions1.clone(); + let targets = TestAutodiffTensor::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(false).init(); + let poisson_full = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum); + let loss_full_sum = + poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum); + + let grads = loss_sum.backward(); + let grads_full = loss_full_sum.backward(); + + let grads_predictions1 = predictions1.grad(&grads).unwrap(); + let grads_predictions2 = predictions2.grad(&grads_full).unwrap(); + + let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]); + + grads_predictions1 + .into_data() + .assert_approx_eq(&expected, 5); + grads_predictions2 + .into_data() + .assert_approx_eq(&expected, 5); + } + + #[test] + #[should_panic = "eps for PoissonNllLoss must be a positive number."] + fn test_negative_eps() { + let _poisson = PoissonNllLossConfig::new().with_eps(0.).init(); + } + + #[test] + #[should_panic = "All the values of `targets` must be non-negative."] + fn test_targets_with_negative_values() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + } + + #[test] + #[should_panic = "Shape of targets"] + fn test_shape_tensors() { + let predictions = TensorData::from([0., 1., 2.]); + let targets = TensorData::from([0., 1.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."] + fn test_exp_predictions_non_negative() { + let predictions = TensorData::from([0.3, -0.1, 0.4]); + let targets = TensorData::from([0., 1., 0.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + fn display() { + let config = PoissonNllLossConfig::new(); + let loss = config.init(); + + assert_eq!( + alloc::format!("{}", loss), + "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs index 949160fd5b..24ec8ff972 100644 --- a/crates/burn-core/src/nn/pool/avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct AvgPool1dConfig { #[config(default = "1")] pub stride: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// If the padding is counted in the denominator when computing the average. @@ -36,10 +41,6 @@ pub struct AvgPool1dConfig { /// legitimate values, and they contribute to the denominator /// when calculating the average. This is equivalent to /// `torch.nn.AvgPool2d` with `count_include_pad=True`. -/// -/// TODO: Add support for `count_include_pad=False`, see -/// [Issue 636](https://github.com/tracel-ai/burn/issues/636) - #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AvgPool1d { @@ -73,6 +74,9 @@ impl ModuleDisplay for AvgPool1d { impl AvgPool1dConfig { /// Initialize a new [avg pool 1d](AvgPool1d) module. pub fn init(&self) -> AvgPool1d { + if self.padding == PaddingConfig1d::Same { + check_same_padding_support(&[self.kernel_size]); + } AvgPool1d { stride: self.stride, kernel_size: self.kernel_size, @@ -111,6 +115,13 @@ impl AvgPool1d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = AvgPool1dConfig::new(2).with_padding(PaddingConfig1d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = AvgPool1dConfig::new(3); diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs index 6c6ffc87ed..343d59922b 100644 --- a/crates/burn-core/src/nn/pool/avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct AvgPool2dConfig { #[config(default = "[1, 1]")] pub strides: [usize; 2], /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If the padding is counted in the denominator when computing the average. @@ -36,9 +41,6 @@ pub struct AvgPool2dConfig { /// legitimate values, and they contribute to the denominator /// when calculating the average. This is equivalent to /// `torch.nn.AvgPool2d` with `count_include_pad=True`. -/// -/// TODO: Add support for `count_include_pad=False`, see -/// [Issue 636](https://github.com/tracel-ai/burn/issues/636) #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AvgPool2d { @@ -72,6 +74,9 @@ impl ModuleDisplay for AvgPool2d { impl AvgPool2dConfig { /// Initialize a new [avg pool 2d](AvgPool2d) module. pub fn init(&self) -> AvgPool2d { + if self.padding == PaddingConfig2d::Same { + check_same_padding_support(&self.kernel_size); + } AvgPool2d { stride: self.strides, kernel_size: self.kernel_size, @@ -110,6 +115,13 @@ impl AvgPool2d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = AvgPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = AvgPool2dConfig::new([3, 3]); diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs index 5be363e908..71041e6155 100644 --- a/crates/burn-core/src/nn/pool/max_pool1d.rs +++ b/crates/burn-core/src/nn/pool/max_pool1d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct MaxPool1dConfig { #[config(default = "1")] pub stride: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// The dilation. @@ -61,6 +66,9 @@ impl ModuleDisplay for MaxPool1d { impl MaxPool1dConfig { /// Initialize a new [max pool 1d](MaxPool1d) module. pub fn init(&self) -> MaxPool1d { + if self.padding == PaddingConfig1d::Same { + check_same_padding_support(&[self.kernel_size]); + } MaxPool1d { stride: self.stride, kernel_size: self.kernel_size, @@ -93,6 +101,13 @@ impl MaxPool1d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = MaxPool1dConfig::new(2).with_padding(PaddingConfig1d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = MaxPool1dConfig::new(3); diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs index ab9c60d276..3eb94f5db5 100644 --- a/crates/burn-core/src/nn/pool/max_pool2d.rs +++ b/crates/burn-core/src/nn/pool/max_pool2d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct MaxPool2dConfig { #[config(default = "[1, 1]")] pub strides: [usize; 2], /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// The dilation. @@ -61,6 +66,9 @@ impl ModuleDisplay for MaxPool2d { impl MaxPool2dConfig { /// Initialize a new [max pool 2d](MaxPool2d) module. pub fn init(&self) -> MaxPool2d { + if self.padding == PaddingConfig2d::Same { + check_same_padding_support(&self.kernel_size); + } MaxPool2d { stride: self.strides, kernel_size: self.kernel_size, @@ -93,6 +101,13 @@ impl MaxPool2d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = MaxPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = MaxPool2dConfig::new([3, 3]); diff --git a/crates/burn-core/src/nn/rnn/gru.rs b/crates/burn-core/src/nn/rnn/gru.rs index c66ad631b6..e2f8b2425e 100644 --- a/crates/burn-core/src/nn/rnn/gru.rs +++ b/crates/burn-core/src/nn/rnn/gru.rs @@ -20,6 +20,21 @@ pub struct GruConfig { pub d_hidden: usize, /// If a bias should be applied during the Gru transformation. pub bias: bool, + /// If reset gate should be applied after weight multiplication. + /// + /// This configuration option controls how the reset gate is applied to the hidden state. + /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for + /// Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by + /// the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU). + /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine + /// Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication. + /// + /// The differing implementations can give slightly different numerical results and have different efficiencies. For more + /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs). + /// + /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`). + #[config(default = "true")] + pub reset_after: bool, /// Gru initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, @@ -41,6 +56,8 @@ pub struct Gru { pub new_gate: GateController, /// The size of the hidden state. pub d_hidden: usize, + /// If reset gate should be applied after weight multiplication. + pub reset_after: bool, } impl ModuleDisplay for Gru { @@ -58,6 +75,7 @@ impl ModuleDisplay for Gru { .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) + .add("reset_after", &self.reset_after) .optional() } } @@ -94,86 +112,92 @@ impl GruConfig { reset_gate, new_gate, d_hidden: self.d_hidden, + reset_after: self.reset_after, } } } impl Gru { /// Applies the forward pass on the input tensor. This GRU implementation - /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. + /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`. /// - /// # Shapes + /// # Parameters /// - batched_input: `[batch_size, sequence_length, input_size]`. - /// - state: An optional tensor representing an initial cell state with the same dimensions - /// as batched_input. If none is provided, one will be generated. - /// - output: `[batch_size, sequence_length, hidden_size]`. + /// - state: An optional tensor representing an initial cell state with dimensions + /// `[batch_size, hidden_size]`. If none is provided, an empty state will be used. + /// + /// # Returns + /// - output: `[batch_size, sequence_length, hidden_size]` pub fn forward( &self, batched_input: Tensor, - state: Option>, + state: Option>, ) -> Tensor { + let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); - let mut hidden_state = match state { + let mut batched_hidden_state = + Tensor::empty([batch_size, seq_length, self.d_hidden], &device); + + let mut hidden_t = match state { Some(state) => state, - None => Tensor::zeros( - [batch_size, seq_length, self.d_hidden], - &batched_input.device(), - ), + None => Tensor::zeros([batch_size, self.d_hidden], &device), }; - for (t, (input_t, hidden_t)) in batched_input - .iter_dim(1) - .zip(hidden_state.clone().iter_dim(1)) - .enumerate() - { + for (t, input_t) in batched_input.iter_dim(1).enumerate() { let input_t = input_t.squeeze(1); - let hidden_t = hidden_t.squeeze(1); // u(pdate)g(ate) tensors - let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); + let biased_ug_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.update_gate); let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) // r(eset)g(ate) tensors - let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); + let biased_rg_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.reset_gate); let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) - let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate // n(ew)g(ate) tensor - let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); + let biased_ng_input_sum = if self.reset_after { + self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate) + } else { + let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate + self.gate_product(&input_t, &reset_t, None, &self.new_gate) + }; let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) // calculate linear interpolation between previous hidden state and candidate state: // g(t) * (1 - z(t)) + z(t) * hidden_t - let state_vector = candidate_state + hidden_t = candidate_state .clone() .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1) + update_values.clone().mul(hidden_t); - let current_shape = state_vector.shape().dims; - let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; - let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); - hidden_state = hidden_state.slice_assign( + let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1); + + batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], - reshaped_state_vector, + unsqueezed_hidden_state, ); } - hidden_state + batched_hidden_state } /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. + /// bias, if any, and optionally applies reset to hidden state. /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where: /// Wx = weight matrix for the connection to input vector X /// Wh = weight matrix for the connection to hidden state H /// X = input vector /// H = hidden state /// b = bias terms + /// r = reset state fn gate_product( &self, input: &Tensor, hidden: &Tensor, + reset: Option<&Tensor>, gate: &GateController, ) -> Tensor { let input_product = input.clone().matmul(gate.input_transform.weight.val()); @@ -190,13 +214,29 @@ impl Gru { .as_ref() .map(|bias_param| bias_param.val()); - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { + match (input_bias, hidden_bias, reset) { + (Some(input_bias), Some(hidden_bias), Some(r)) => { + input_product + + input_bias.unsqueeze() + + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (Some(input_bias), Some(hidden_bias), None) => { input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, + (Some(input_bias), None, Some(r)) => { + input_product + input_bias.unsqueeze() + r.clone().mul(hidden_product) + } + (Some(input_bias), None, None) => { + input_product + input_bias.unsqueeze() + hidden_product + } + (None, Some(hidden_bias), Some(r)) => { + input_product + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (None, Some(hidden_bias), None) => { + input_product + hidden_product + hidden_bias.unsqueeze() + } + (None, None, Some(r)) => input_product + r.clone().mul(hidden_product), + (None, None, None) => input_product + hidden_product, } } } @@ -207,29 +247,16 @@ mod tests { use crate::tensor::{Distribution, TensorData}; use crate::{module::Param, nn::LinearRecord, TestBackend}; - /// Test forward pass with simple input vector. - /// - /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 - /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 - /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 - /// - /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 - #[test] - fn tests_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = GruConfig::new(1, 1, false); - let device = Default::default(); - let mut gru = config.init::(&device); - - fn create_gate_controller( + fn init_gru(reset_after: bool, device: &B::Device) -> Gru { + fn create_gate_controller( weights: f32, biases: f32, d_input: usize, d_output: usize, bias: bool, initializer: Initializer, - device: &::Device, - ) -> GateController { + device: &B::Device, + ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), @@ -248,6 +275,9 @@ mod tests { ) } + let config = GruConfig::new(1, 1, false).with_reset_after(reset_after); + let mut gru = config.init::(device); + gru.update_gate = create_gate_controller( 0.5, 0.0, @@ -255,7 +285,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.reset_gate = create_gate_controller( 0.6, @@ -264,7 +294,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.new_gate = create_gate_controller( 0.7, @@ -273,18 +303,72 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); + gru + } + + /// Test forward pass with simple input vector. + /// + /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 + /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 + /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 + /// + /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 + #[test] + fn tests_forward_single_input_single_feature() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(false, &device); let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device); + let expected = TensorData::from([[0.034]]); + // Reset gate applied to hidden state before the matrix multiplication + let state = gru.forward(input.clone(), None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state after the matrix multiplication + gru.reset_after = true; // override forward behavior + let state = gru.forward(input, None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn tests_forward_seq_len_3() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(true, &device); + + let input = + Tensor::::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device); + let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]); + + let result = gru.forward(input.clone(), None); + let output = result + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state before the matrix multiplication + gru.reset_after = false; // override forward behavior let state = gru.forward(input, None); let output = state .select(0, Tensor::arange(0..1, &device)) .squeeze::<2>(0); - let expected = TensorData::from([[0.034]]); output.to_data().assert_approx_eq(&expected, 3); } @@ -308,7 +392,7 @@ mod tests { assert_eq!( alloc::format!("{}", layer), - "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}" + "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}" ); } } diff --git a/crates/burn-core/src/record/primitive.rs b/crates/burn-core/src/record/primitive.rs index 9dd921e824..2f9fa3e83c 100644 --- a/crates/burn-core/src/record/primitive.rs +++ b/crates/burn-core/src/record/primitive.rs @@ -5,9 +5,7 @@ use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde}; use super::{PrecisionSettings, Record}; use crate::module::{Param, ParamId}; -#[allow(deprecated)] -use burn_tensor::DataSerialize; -use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor}; +use burn_tensor::{backend::Backend, Bool, Int, Tensor}; use hashbrown::HashMap; use serde::{ @@ -143,23 +141,6 @@ where } } -#[allow(deprecated)] -impl Record for DataSerialize -where - E: Element, - B: Backend, -{ - type Item = DataSerialize; - - fn into_item(self) -> Self::Item { - self.convert() - } - - fn from_item(item: Self::Item, _device: &B::Device) -> Self { - item.convert() - } -} - /// (De)serialize parameters into a clean format. #[derive(new, Debug, Clone, Serialize, Deserialize)] pub struct ParamSerde { diff --git a/crates/burn-core/src/record/tensor.rs b/crates/burn-core/src/record/tensor.rs index ab6f448b7e..a07453bcba 100644 --- a/crates/burn-core/src/record/tensor.rs +++ b/crates/burn-core/src/record/tensor.rs @@ -4,20 +4,7 @@ use super::{PrecisionSettings, Record}; use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData}; use serde::{Deserialize, Serialize}; -#[cfg(not(feature = "record-backward-compat"))] use alloc::format; -#[cfg(feature = "record-backward-compat")] -use burn_tensor::DataSerialize; - -/// Versioned serde data deserialization to maintain backward compatibility between formats. -#[cfg(feature = "record-backward-compat")] -#[allow(deprecated)] -#[derive(Serialize, Deserialize)] -#[serde(untagged)] -enum TensorDataSerde { - V1(DataSerialize), - V2(TensorData), -} /// Deserialize the value into [`TensorData`]. fn deserialize_data<'de, E, De>(deserializer: De) -> Result @@ -25,31 +12,18 @@ where E: Element + Deserialize<'de>, De: serde::Deserializer<'de>, { - #[cfg(feature = "record-backward-compat")] - { - let data = match TensorDataSerde::::deserialize(deserializer)? { - TensorDataSerde::V1(data) => data.into_tensor_data(), - // NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16 - TensorDataSerde::V2(data) => data.convert::(), - }; - Ok(data) - } - - #[cfg(not(feature = "record-backward-compat"))] - { - let data = TensorData::deserialize(deserializer).map_err(|e| { - serde::de::Error::custom(format!( - "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag. Once you have saved the record in the new format, you can disable the feature flag.\n", - e - )) - })?; - let data = if let DType::QFloat(_) = data.dtype { - data // do not convert quantized tensors - } else { - data.convert::() - }; - Ok(data) - } + let data = TensorData::deserialize(deserializer).map_err(|e| { + serde::de::Error::custom(format!( + "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n", + e + )) + })?; + let data = if let DType::QFloat(_) = data.dtype { + data // do not convert quantized tensors + } else { + data.convert::() + }; + Ok(data) } /// This struct implements serde to lazily serialize and deserialize a float tensor diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index c366386b0e..1a92e695b2 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -19,9 +19,9 @@ fusion = ["burn-fusion", "burn-jit/fusion"] std = ["burn-jit/std", "cubecl/std"] [dependencies] -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", features = [ "cubecl-cuda", ] } cubecl = { workspace = true, features = ["cuda"] } @@ -34,7 +34,7 @@ log = { workspace = true } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } paste = { workspace = true } diff --git a/crates/burn-dataset/Cargo.toml b/crates/burn-dataset/Cargo.toml index 0237765973..c7ddbebc41 100644 --- a/crates/burn-dataset/Cargo.toml +++ b/crates/burn-dataset/Cargo.toml @@ -30,7 +30,7 @@ __sqlite-shared = [ dataframe = ["dep:polars"] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0", optional = true, features = [ +burn-common = { path = "../burn-common", version = "0.17.0", optional = true, features = [ "network", ] } csv = { workspace = true } diff --git a/crates/burn-dataset/src/dataset/dataframe.rs b/crates/burn-dataset/src/dataset/dataframe.rs index 023b357454..c851e8a3e3 100644 --- a/crates/burn-dataset/src/dataset/dataframe.rs +++ b/crates/burn-dataset/src/dataset/dataframe.rs @@ -269,20 +269,20 @@ mod tests { } fn create_test_dataframe() -> DataFrame { - let s0 = Column::Series(Series::new("int32".into(), &[1i32, 2i32, 3i32])); - let s1 = Column::Series(Series::new("bool".into(), &[true, false, true])); - let s2 = Column::Series(Series::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64])); - let s3 = Column::Series(Series::new("string".into(), &["Boo", "Boo2", "Boo3"])); - let s6 = Column::Series(Series::new("int16".into(), &[1i16, 2i16, 3i16])); - let s8 = Column::Series(Series::new("uint32".into(), &[1u32, 2u32, 3u32])); - let s9 = Column::Series(Series::new("uint64".into(), &[1u64, 2u64, 3u64])); - let s10 = Column::Series(Series::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32])); - let s11 = Column::Series(Series::new("int64".into(), &[1i64, 2i64, 3i64])); - let s12 = Column::Series(Series::new("int8".into(), &[1i8, 2i8, 3i8])); + let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]); + let s1 = Column::new("bool".into(), &[true, false, true]); + let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]); + let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]); + let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]); + let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]); + let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]); + let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]); + let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]); + let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]); let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]; - let s13 = Column::Series(Series::new("binary".into(), binary_data)); + let s13 = Column::new("binary".into(), binary_data); DataFrame::new(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap() } diff --git a/crates/burn-fusion/Cargo.toml b/crates/burn-fusion/Cargo.toml index eb4296097b..1f2f785940 100644 --- a/crates/burn-fusion/Cargo.toml +++ b/crates/burn-fusion/Cargo.toml @@ -17,8 +17,8 @@ std = ["serde/std"] doc = ["default"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0" } -burn-common = { path = "../burn-common", version = "0.16.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } +burn-common = { path = "../burn-common", version = "0.17.0" } hashbrown = { workspace = true } derive-new = {workspace = true } spin = { workspace = true } diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index baa5169db3..658907bf3e 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,5 +1,6 @@ use burn_tensor::{ ops::{binary_ops_shape, FloatTensor, IntTensor}, + repr::{FromDataOperationDescription, TensorDescription}, DType, Element, TensorData, }; use std::marker::PhantomData; @@ -24,15 +25,32 @@ use burn_tensor::{ impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_bool_tensor::(&self.desc.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_empty(shape.clone(), device); + let out = client.tensor_uninitialized(shape.dims.clone(), DType::Bool); - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn bool_into_data(tensor: BoolTensor) -> TensorData { @@ -40,16 +58,35 @@ impl BoolTensorOps for Fusion { } fn bool_from_data(data: burn_tensor::TensorData, device: &Device) -> BoolTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_from_data(self.desc.data, &self.device); + handles.register_bool_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 41d691272f..afbe6ae77c 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -16,16 +16,35 @@ use std::{marker::PhantomData, ops::Range}; impl FloatTensorOps for Fusion { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_from_data(self.desc.data, &self.device); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::float_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - StreamId::current(), - B::FloatElem::dtype(), - ) + let out = client.tensor_uninitialized(data.shape.clone(), B::FloatElem::dtype()); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn float_random( @@ -233,16 +252,32 @@ impl FloatTensorOps for Fusion { } fn float_empty(shape: Shape, device: &Device) -> FloatTensor { - let client = get_client::(&device.clone()); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_float_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); - let tensor = B::float_empty(shape.clone(), device); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::FloatElem::dtype()); - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - stream, - B::FloatElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { @@ -278,9 +313,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -323,7 +356,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = ClampOperationDescription { tensor: tensor.into_description(), @@ -375,9 +408,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), rhs: rhs.elem(), @@ -428,9 +459,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -481,9 +510,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -534,9 +561,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -567,9 +592,7 @@ impl FloatTensorOps for Fusion { shape[ndims - 2] = lhs.shape[ndims - 2]; shape[ndims - 1] = rhs.shape[ndims - 1]; - let out = lhs - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(shape, dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), rhs: rhs.into_description(), @@ -601,13 +624,12 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let mut out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let mut out = tensor.client.tensor_uninitialized(shape, dtype); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -641,9 +663,8 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(shape.dims, B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(shape.dims, dtype); let desc = ReshapeDescription { input: tensor.into_description(), @@ -1300,9 +1321,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let dtype = tensor.dtype; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1327,9 +1346,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1352,9 +1369,8 @@ impl FloatTensorOps for Fusion { unary_float_ops!(ProdOps, B::float_prod, reduce); let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1363,7 +1379,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Prod(desc.clone()), ), ProdOps::::new(desc), @@ -1376,11 +1392,10 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(ProdDimOps, B::float_prod_dim, usize, noconvert); let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1404,9 +1419,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let dtype = tensor.dtype; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1431,9 +1444,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1716,6 +1727,7 @@ impl FloatTensorOps for Fusion { } let tensor_first = tensors.first().unwrap(); + let dtype = tensor_first.dtype; let client = tensor_first.client.clone(); // Calculate the output shape @@ -1726,7 +1738,7 @@ impl FloatTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = client.tensor_uninitialized(shape, dtype); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index d3a2492a03..c6ea03e759 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -15,16 +15,32 @@ use std::marker::PhantomData; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_empty(shape.clone(), device); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_int_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn int_into_data(tensor: IntTensor) -> TensorData { @@ -32,17 +48,35 @@ impl IntTensorOps for Fusion { } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_from_data(self.desc.data, &self.device); + handles.register_int_tensor::(&self.desc.out.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn int_device(tensor: &IntTensor) -> Device { @@ -1820,6 +1854,269 @@ impl IntTensorOps for Fusion { out } + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseAndOps, B::bitwise_and); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAnd(desc.clone())), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAndScalar( + desc.clone(), + )), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseOrOps, B::bitwise_or); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOr(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOrScalar(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseXorOps, B::bitwise_xor); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXor(desc.clone())), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXorScalar( + desc.clone(), + )), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_int_ops!(BitwiseNotOps, B::bitwise_not); + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseNot(desc.clone())), + BitwiseNotOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShift( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShiftScalar( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShift( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShiftScalar( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { scalar_int_ops!(CumsumOps, B::int_cumsum, usize, noconvert); diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 41bc7ccde6..1449a485af 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -4,8 +4,9 @@ use burn_tensor::{ ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, repr::{ - DequantizeOperationDescription, FloatOperationDescription, HandleContainer, - OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, + BaseOperationDescription, DequantizeOperationDescription, FloatOperationDescription, + FromDataOperationDescription, HandleContainer, OperationDescription, + QuantizationParametersDescription, QuantizeOperationDescription, }, DType, Device, Element, Shape, TensorData, }; @@ -19,19 +20,41 @@ use crate::{ impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::q_from_data(self.desc.data, &self.device); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + match data.dtype { DType::QFloat(_scheme) => { let dtype = data.dtype; - let client = get_client::(device); - let tensor = B::q_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::quantized_tensor_handle(tensor), - shape.dims, - StreamId::current(), - dtype, - ) + + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), dtype); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData( + desc.clone(), + )), + FromDataOps::::new(desc, device.clone()), + ); + + out } _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 477cf737f2..28562854b4 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -59,6 +59,84 @@ pub(crate) struct OperationConverter { scalar_u8: Vec, } +/// Fork of a [context](Context) which owns its data. +pub struct ContextOwned { + tensors: HashMap, + handles: HandleContainer, + scalar_f32: Vec, + scalar_f16: Vec, + scalar_bf16: Vec, + scalar_i64: Vec, + scalar_i32: Vec, + scalar_i16: Vec, + scalar_i8: Vec, + scalar_u64: Vec, + scalar_u32: Vec, + scalar_u16: Vec, + scalar_u8: Vec, +} + +impl ContextOwned { + /// Convert into [context](Context). + pub fn as_context(&mut self) -> Context<'_, H> { + Context { + tensors: &mut self.tensors, + handles: &mut self.handles, + scalar_f32: &self.scalar_f32, + scalar_f16: &self.scalar_f16, + scalar_bf16: &self.scalar_bf16, + scalar_i64: &self.scalar_i64, + scalar_i32: &self.scalar_i32, + scalar_i16: &self.scalar_i16, + scalar_i8: &self.scalar_i8, + scalar_u64: &self.scalar_u64, + scalar_u32: &self.scalar_u32, + scalar_u16: &self.scalar_u16, + scalar_u8: &self.scalar_u8, + } + } + + /// Fork the context again. + pub fn fork(&self) -> ContextOwned { + ContextOwned { + tensors: self.tensors.clone(), + handles: self.handles.fork(), + scalar_f32: self.scalar_f32.clone(), + scalar_f16: self.scalar_f16.clone(), + scalar_bf16: self.scalar_bf16.clone(), + scalar_i64: self.scalar_i64.clone(), + scalar_i32: self.scalar_i32.clone(), + scalar_i16: self.scalar_i16.clone(), + scalar_i8: self.scalar_i8.clone(), + scalar_u64: self.scalar_u64.clone(), + scalar_u32: self.scalar_u32.clone(), + scalar_u16: self.scalar_u16.clone(), + scalar_u8: self.scalar_u8.clone(), + } + } +} + +impl Context<'_, H> { + /// Fork the context into an [owned context](ContextOwned). + pub fn fork(&self) -> ContextOwned { + ContextOwned { + tensors: self.tensors.clone(), + handles: self.handles.fork(), + scalar_f32: self.scalar_f32.clone(), + scalar_f16: self.scalar_f16.clone(), + scalar_bf16: self.scalar_bf16.clone(), + scalar_i64: self.scalar_i64.clone(), + scalar_i32: self.scalar_i32.clone(), + scalar_i16: self.scalar_i16.clone(), + scalar_i8: self.scalar_i8.clone(), + scalar_u64: self.scalar_u64.clone(), + scalar_u32: self.scalar_u32.clone(), + scalar_u16: self.scalar_u16.clone(), + scalar_u8: self.scalar_u8.clone(), + } + } +} + pub(crate) trait RelativeOps { /// Convert (usually an [`OperationDescription`]) to a relative form. /// @@ -616,6 +694,82 @@ impl RelativeOps for IntOperationDescription { out: desc.out.to_relative(converter), }) } + IntOperationDescription::BitwiseAnd(desc) => { + IntOperationDescription::BitwiseAnd(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + IntOperationDescription::BitwiseAndScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOr(desc) => { + IntOperationDescription::BitwiseOr(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + IntOperationDescription::BitwiseOrScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXor(desc) => { + IntOperationDescription::BitwiseXor(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + IntOperationDescription::BitwiseXorScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseNot(desc) => { + IntOperationDescription::BitwiseNot(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + IntOperationDescription::BitwiseLeftShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + IntOperationDescription::BitwiseLeftShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShift(desc) => { + IntOperationDescription::BitwiseRightShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + IntOperationDescription::BitwiseRightShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } } } } @@ -1063,6 +1217,12 @@ impl RelativeOps for BaseOperationDescription { BaseOperationDescription::Empty(desc) => { BaseOperationDescription::Empty(desc.to_relative(converter)) } + BaseOperationDescription::FromData(desc) => { + BaseOperationDescription::FromData(FromDataOperationDescription { + data: desc.data.clone(), + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-hip/Cargo.toml b/crates/burn-hip/Cargo.toml index d5f0bb70f5..206f56e8fe 100644 --- a/crates/burn-hip/Cargo.toml +++ b/crates/burn-hip/Cargo.toml @@ -20,9 +20,9 @@ std = ["burn-jit/std", "cubecl/std"] [dependencies] cubecl = { workspace = true, features = ["hip"] } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = ["cubecl-hip"] } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", features = ["cubecl-hip"] } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } half = { workspace = true } bytemuck = { workspace = true } @@ -31,7 +31,7 @@ log = { workspace = true } derive-new = { workspace = true } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } paste = { workspace = true } diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index fc8f704e74..13f5239637 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -26,7 +26,8 @@ pub type Hip = burn_fusion::Fusion, some_list_str: Vec, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_list_float: Vec, + some_list_float: Vec, some_dict: HashMap, } @@ -35,17 +31,13 @@ mod tests { n_head: 2, n_layer: 3, d_model: 512, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_float: 0.1, + some_float: 0.1, some_int: 1, some_bool: true, some_str: "hello".to_string(), some_list_int: vec![1, 2, 3], some_list_str: vec!["hello".to_string(), "world".to_string()], - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_list_float: vec![0.1, 0.2, 0.3], + some_list_float: vec![0.1, 0.2, 0.3], some_dict: { let mut map = HashMap::new(); map.insert("some_key".to_string(), "some_value".to_string()); diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index 798636c323..7f511dafd4 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -5,8 +5,9 @@ use burn::nn::PaddingConfig1d; use burn::nn::PaddingConfig2d; use burn::nn::PaddingConfig3d; -fn convert_primitive(primitive: T) -> TokenStream { - let value = primitive.to_string(); +fn convert_primitive(primitive: T) -> TokenStream { + let value = format!("{:?}", primitive); + value.parse().unwrap() } diff --git a/crates/burn-import/src/burn/graph.rs b/crates/burn-import/src/burn/graph.rs index f6dee5479d..2411a60e27 100644 --- a/crates/burn-import/src/burn/graph.rs +++ b/crates/burn-import/src/burn/graph.rs @@ -50,7 +50,7 @@ pub struct BurnGraph { } // The backend used for recording. -type Backend = burn::backend::ndarray::NdArray; +type Backend = burn_ndarray::NdArray; impl BurnGraph { /// Register a new operation node into the graph. diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index f945cb0dce..480d3a1f1c 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -17,13 +17,12 @@ use super::{ unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; -use burn::backend::NdArray; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use serde::Serialize; /// Backend used for serialization. -pub type SerializationBackend = NdArray; +pub type SerializationBackend = burn_ndarray::NdArray; /// Codegen trait that should be implemented by all [node](Node) entries. pub trait NodeCodegen: std::fmt::Debug { diff --git a/crates/burn-import/src/burn/node/resize.rs b/crates/burn-import/src/burn/node/resize.rs index 59afcfb607..606f3ef38d 100644 --- a/crates/burn-import/src/burn/node/resize.rs +++ b/crates/burn-import/src/burn/node/resize.rs @@ -228,7 +228,7 @@ mod tests { TensorType::new_float("tensor1", 3), TensorType::new_float("tensor2", 3), "cubic".to_string(), - vec![], + vec![2.0], vec![20], )); @@ -253,7 +253,7 @@ mod tests { pub fn new(device: &B::Device) -> Self { let resize = Interpolate1dConfig::new() .with_output_size(Some(20)) - .with_scale_factor(None) + .with_scale_factor(Some(2.0)) .with_mode(InterpolateMode::Cubic) .init(); Self { diff --git a/crates/burn-import/src/pytorch/recorder.rs b/crates/burn-import/src/pytorch/recorder.rs index 170f64a9d3..32dea273c9 100644 --- a/crates/burn-import/src/pytorch/recorder.rs +++ b/crates/burn-import/src/pytorch/recorder.rs @@ -11,7 +11,7 @@ use serde::{de::DeserializeOwned, Serialize}; use super::reader::from_file; -/// A recorder that that loads PyTorch files (`.pt`) into Burn modules. +/// A recorder that loads PyTorch files (`.pt`) into Burn modules. /// /// LoadArgs can be used to remap keys or file path. /// See [LoadArgs](struct.LoadArgs.html) for more information. diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 0811374fd1..214b21eef3 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -25,17 +25,19 @@ export_tests = [ "paste", ] fusion = ["burn-fusion"] -std = ["cubecl/std"] +fusion-experimental = ["fusion"] +std = ["cubecl/std", "burn-tensor/std"] + template = [] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0" } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ +burn-common = { path = "../burn-common", version = "0.17.0" } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "cubecl", "repr", ] } -cubecl = { workspace = true, features = ["linalg"] } +cubecl = { workspace = true, features = ["linalg", "reduce"] } bytemuck = { workspace = true } derive-new = { workspace = true } @@ -52,12 +54,12 @@ futures-lite = { workspace = true, features = ["std"] } serde = { workspace = true } text_placeholder = { workspace = true, features = ["struct_context"] } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } hashbrown = { workspace = true } # When exporting tests -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, optional = true } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, optional = true } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true } paste = { workspace = true, optional = true } serial_test = { workspace = true, optional = true } diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index f0e15352cf..a1bbab7f5f 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -57,6 +57,7 @@ impl IntElement for i64 {} impl IntElement for i32 {} impl IntElement for i16 {} impl IntElement for i8 {} +impl IntElement for u32 {} impl BoolElement for u8 {} impl BoolElement for u32 {} diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs index 1dbbf3baea..bba18e88f9 100644 --- a/crates/burn-jit/src/fusion/matmul/args.rs +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -247,7 +247,7 @@ impl CubeType for FusedMatmulState { } impl Init for FusedMatmulStateExpand { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index 986332914f..f197237819 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -47,7 +47,13 @@ impl OptimizationBuilder> for MatmulBuilder let rhs = self.builder.input_unhandled(&op.rhs); let out = self.builder.output_unhandled(&op.out); - self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone())); + self.matmul = Some(FusedMatmul::new( + lhs, + rhs, + out, + op.clone(), + Default::default(), + )); } else { self.builder.close(); } diff --git a/crates/burn-jit/src/fusion/matmul/mod.rs b/crates/burn-jit/src/fusion/matmul/mod.rs index 1afeef9c88..cddec5983a 100644 --- a/crates/burn-jit/src/fusion/matmul/mod.rs +++ b/crates/burn-jit/src/fusion/matmul/mod.rs @@ -2,3 +2,4 @@ pub(crate) mod args; pub(crate) mod builder; pub(crate) mod optimization; pub(crate) mod spec; +pub(crate) mod tune; diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index b1d8431c67..804628613d 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -12,7 +12,9 @@ use burn_tensor::Shape; use cubecl::linalg::matmul::components; use cubecl::linalg::matmul::components::tile::accelerated::Accelerated; use cubecl::linalg::matmul::components::MatmulProblem; -use cubecl::linalg::matmul::kernels::matmul::{MatmulSelector, StandardSelector}; +use cubecl::linalg::matmul::kernels::matmul::{ + MatmulSelector, PipelinedSelector, SpecializedSelector, StandardSelector, +}; use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}; use cubecl::linalg::tensor::{matrix_layout, MatrixLayout}; use cubecl::{client::ComputeClient, prelude::*}; @@ -26,16 +28,18 @@ use crate::fusion::on_write::{ use super::args::FusedMatmulInputLaunch; use super::spec::FusedMatmulSpec; +use super::tune::fused_matmul_autotune; -#[derive(new)] /// Fuse matmul operation followed by elemwise operations into a single kernel. pub struct MatmulOptimization { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, - client: ComputeClient, - device: R::Device, - len: usize, - matmul: FusedMatmul, + pub(crate) client: ComputeClient, + pub(crate) device: R::Device, + pub(crate) len: usize, + pub(crate) matmul_standard: FusedMatmul, + pub(crate) matmul_pipelined: FusedMatmul, + pub(crate) matmul_specialized: FusedMatmul, } #[derive(Serialize, Deserialize, Debug)] @@ -43,14 +47,47 @@ pub struct MatmulOptimization { pub struct MatmulOptimizationState { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, - matmul: FusedMatmul, + matmul_standard: FusedMatmul, + matmul_pipelined: FusedMatmul, + matmul_specialized: FusedMatmul, len: usize, } impl MatmulOptimization { + pub fn new( + trace: FuseOnWriteTrace, + trace_fallback: FuseOnWriteTrace, + client: ComputeClient, + device: R::Device, + len: usize, + matmul: FusedMatmul, + ) -> Self { + let mut matmul_standard = matmul.clone(); + let mut matmul_specialized = matmul.clone(); + let mut matmul_pipelined = matmul; + + matmul_standard.selector = FusedMatmulSelector::Standard; + matmul_specialized.selector = FusedMatmulSelector::Specialized; + matmul_pipelined.selector = FusedMatmulSelector::Pipelined; + + Self { + trace, + trace_fallback, + client, + device, + len, + matmul_standard, + matmul_pipelined, + matmul_specialized, + } + } /// Execute the optimization. pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { - if self.execute_fused::(context).is_err() { + #[cfg(feature = "autotune")] + fused_matmul_autotune::(self, context); + + #[cfg(not(feature = "autotune"))] + if self.execute_standard_fused::(context).is_err() { self.execute_fallback::(context); } } @@ -68,7 +105,9 @@ impl MatmulOptimization { len: state.len, client: R::client(device), device: device.clone(), - matmul: state.matmul.clone(), + matmul_standard: state.matmul_standard.clone(), + matmul_specialized: state.matmul_specialized.clone(), + matmul_pipelined: state.matmul_pipelined.clone(), } } @@ -77,21 +116,51 @@ impl MatmulOptimization { MatmulOptimizationState { trace: self.trace.clone(), trace_fallback: self.trace_fallback.clone(), - matmul: self.matmul.clone(), + matmul_standard: self.matmul_standard.clone(), + matmul_specialized: self.matmul_specialized.clone(), + matmul_pipelined: self.matmul_pipelined.clone(), len: self.len, } } - fn execute_fused( - &mut self, + pub fn execute_standard_fused( + &self, + context: &mut Context<'_, JitFusionHandle>, + ) -> Result<(), FusedMatmulError> { + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_standard, + ) + } + + pub fn execute_specialized_fused( + &self, context: &mut Context<'_, JitFusionHandle>, ) -> Result<(), FusedMatmulError> { - self.trace - .run::(&self.client, &self.device, context, &self.matmul) + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_specialized, + ) } - fn execute_fallback(&mut self, context: &mut Context<'_, JitFusionHandle>) { - match self.matmul.lhs.precision() { + pub fn execute_pipelined_fused( + &self, + context: &mut Context<'_, JitFusionHandle>, + ) -> Result<(), FusedMatmulError> { + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_pipelined, + ) + } + + pub fn execute_fallback(&self, context: &mut Context<'_, JitFusionHandle>) { + match self.matmul_standard.lhs.precision() { ElemwisePrecision::F32 => self.run_fallback::(context), ElemwisePrecision::F16 => self.run_fallback::(context), ElemwisePrecision::BF16 => self.run_fallback::(context), @@ -100,13 +169,25 @@ impl MatmulOptimization { } fn run_fallback( - &mut self, + &self, context: &mut Context<'_, JitFusionHandle>, ) { let (out_tensor, out_desc) = { - let lhs = context.tensors.get(&self.matmul.op.lhs.id).unwrap().clone(); - let rhs = context.tensors.get(&self.matmul.op.rhs.id).unwrap().clone(); - let out = context.tensors.get(&self.matmul.op.out.id).unwrap().clone(); + let lhs = context + .tensors + .get(&self.matmul_standard.op.lhs.id) + .unwrap() + .clone(); + let rhs = context + .tensors + .get(&self.matmul_standard.op.rhs.id) + .unwrap() + .clone(); + let out = context + .tensors + .get(&self.matmul_standard.op.out.id) + .unwrap() + .clone(); let lhs_handle = context.handles.get_handle(&lhs.id, &TensorStatus::ReadOnly); let rhs_handle = context.handles.get_handle(&rhs.id, &TensorStatus::ReadOnly); @@ -122,7 +203,8 @@ impl MatmulOptimization { rhs_tensor, None, matmul::MatmulStrategy::default(), - ); + ) + .unwrap(); (out_tensor, out) }; context @@ -135,12 +217,21 @@ impl MatmulOptimization { } } +#[derive(Default, Clone, Serialize, Deserialize, Debug)] +pub enum FusedMatmulSelector { + #[default] + Standard, + Pipelined, + Specialized, +} + #[derive(new, Clone, Serialize, Deserialize, Debug)] pub struct FusedMatmul { lhs: Arg, rhs: Arg, out: Arg, - op: BinaryOperationDescription, + pub(crate) op: BinaryOperationDescription, + pub(crate) selector: FusedMatmulSelector, } #[derive(Debug)] @@ -260,15 +351,43 @@ impl FusedMatmul { } }; - match matmul_launch_kernel::>( - client, - FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), - outputs, - problem, - plane_size, - ) { - Ok(_) => Ok(()), - Err(err) => Err(FusedMatmulError::LaunchError(err)), + match self.selector { + FusedMatmulSelector::Standard => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } + FusedMatmulSelector::Pipelined => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } + FusedMatmulSelector::Specialized => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } } } } diff --git a/crates/burn-jit/src/fusion/matmul/tune.rs b/crates/burn-jit/src/fusion/matmul/tune.rs new file mode 100644 index 0000000000..0f6e42c486 --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/tune.rs @@ -0,0 +1,133 @@ +use crate::{ + fusion::{ + tune::{TuneContext, TuneInput}, + JitFusionHandle, + }, + kernel::matmul::MatmulAutotuneKey, + BoolElement, JitRuntime, JitTuneId, +}; +use burn_fusion::stream::Context; +use cubecl::{ + tune::{local_tuner, LocalTuner, TunableSet}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use super::optimization::MatmulOptimization; + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +pub struct FusedMatmulAutotuneKey { + matmul_key: MatmulAutotuneKey, + #[autotune(anchor)] + num_ops_fused: usize, +} + +/// Executes autotune on matmul operations +pub fn fused_matmul_autotune( + optimization: &MatmulOptimization, + context: &mut Context>, +) { + static TUNER: LocalTuner = local_tuner!(); + + let tunables = TunableSet::new(create_key::, input_gen::) + .with_tunable(tune_standard_fused::) + .with_tunable(tune_specialized_fused::) + .with_tunable(tune_pipelined_fused::) + .with_tunable(tune_fallback::); + + TUNER.execute( + &JitTuneId::new::(&optimization.device), + &optimization.client, + &tunables, + TuneInput::new(context, optimization), + ); +} + +pub(crate) fn create_key( + input: &TuneInput>, +) -> FusedMatmulAutotuneKey { + let opt = input.optimization(); + let context = match input.context() { + TuneContext::Original(context) => context, + TuneContext::Fork(_) => panic!("Not supported when generating key"), + }; + + let lhs = context.tensors.get(&opt.matmul_standard.op.lhs.id).unwrap(); + let rhs = context.tensors.get(&opt.matmul_standard.op.rhs.id).unwrap(); + let out = context.tensors.get(&opt.matmul_standard.op.out.id).unwrap(); + + let key = MatmulAutotuneKey::from_shape( + &lhs.shape.clone().into(), + &rhs.shape.clone().into(), + out.dtype, + ); + FusedMatmulAutotuneKey::new(key, opt.len) +} + +fn input_gen( + _key: &FusedMatmulAutotuneKey, + input: &TuneInput>, +) -> TuneInput> { + input.clone() +} + +fn tune_standard_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_standard_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_standard_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_specialized_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_specialized_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_specialized_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_pipelined_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_pipelined_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_pipelined_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_fallback( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_fallback::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_fallback::(&mut context_owned.as_context()) + } + }; + + Ok(()) +} diff --git a/crates/burn-jit/src/fusion/mod.rs b/crates/burn-jit/src/fusion/mod.rs index 4c44770b4e..96e1704964 100644 --- a/crates/burn-jit/src/fusion/mod.rs +++ b/crates/burn-jit/src/fusion/mod.rs @@ -3,5 +3,6 @@ mod base; pub(crate) mod elemwise; pub(crate) mod matmul; pub(crate) mod on_write; +pub(crate) mod tune; pub use base::*; diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 0b6272b4f1..36c8e402a0 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -45,13 +45,13 @@ impl CubeType for Arg { } impl Init for Arg { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } impl IntoRuntime for Arg { - fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { + fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType { self } } @@ -154,7 +154,7 @@ impl GlobalArgsLaunch<'_, R> { } } - /// Resolve the [argument](Arg) to a [tensor arguemnt](TensorArg). + /// Resolve the [argument](Arg) to a [tensor argument](TensorArg). /// /// # Panics /// diff --git a/crates/burn-jit/src/fusion/tune.rs b/crates/burn-jit/src/fusion/tune.rs new file mode 100644 index 0000000000..8c45f93bb0 --- /dev/null +++ b/crates/burn-jit/src/fusion/tune.rs @@ -0,0 +1,108 @@ +use super::JitFusionHandle; +use crate::JitRuntime; +use burn_fusion::stream::{Context, ContextOwned}; + +/// Fusion context used when tuning kernels. +/// +/// Either the original context is returned or a fork of the original. +/// The fork is only given when performing autotuning, and not when actually performing the +/// operation. +pub enum TuneContext<'a, R: JitRuntime> { + Original(&'a mut Context<'a, JitFusionHandle>), + Fork(Box>>), +} + +/// Fusion input wrapper containing the context and the optimization. +/// +/// # Safety +/// +/// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions +/// are made based on its behavior. +pub struct TuneInput { + context: UnsafeTuneContext, + optimization: *const O, +} + +/// Unsafe wrapper around the context. +/// +/// # Safety +/// +/// The wrapper removes the context lifetime. +/// +/// For it to be correct, the context must not be used after the invocation of the +/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are +/// tuned using a cloned version of the input; therefore, a fork of the context will be used to find +/// the best kernel to use, which can be async. +enum UnsafeTuneContext { + Original(*mut Context<'static, JitFusionHandle>), + Fork(Box>>), +} + +unsafe impl Send for UnsafeTuneContext {} +unsafe impl Send for TuneInput {} + +impl TuneInput { + /// Create a new autotune input from the [context](Context) and an optimization. + pub fn new(context: &mut Context>, optimization: &O) -> Self { + let context = UnsafeTuneContext::new(context); + // We can erase the lifetime for the same reason we do with the context. + let optimization = core::ptr::from_ref(optimization); + + Self { + context, + optimization, + } + } + + /// Retrieve the [autotune context](TuneContext) for the current input. + pub fn context(&self) -> TuneContext<'static, R> { + self.context.get() + } + + /// Retrieve the optimization for the current input. + pub fn optimization(&self) -> &O { + unsafe { self.optimization.as_ref().unwrap() } + } +} + +impl UnsafeTuneContext { + fn new(context: &mut Context<'_, JitFusionHandle>) -> Self { + let ptr = core::ptr::from_mut(context); + + // It is necessary for the lifetime. + #[allow(clippy::unnecessary_cast)] + Self::Original(ptr as *mut Context<'static, _>) + } + + fn get(&self) -> TuneContext<'static, R> { + match self { + UnsafeTuneContext::Original(ptr) => { + TuneContext::Original(unsafe { ptr.as_mut().unwrap() }) + } + UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())), + } + } +} + +impl Clone for TuneInput { + fn clone(&self) -> Self { + Self { + context: self.context.clone(), + optimization: self.optimization, + } + } +} + +impl Clone for UnsafeTuneContext { + fn clone(&self) -> Self { + let context = match self { + UnsafeTuneContext::Original(ptr) => { + let context: &mut Context<'static, JitFusionHandle> = + unsafe { ptr.as_mut().unwrap() }; + context.fork() + } + UnsafeTuneContext::Fork(context) => context.fork(), + }; + UnsafeTuneContext::Fork(Box::new(context)) + } +} diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d799d1caea..f0da764a7a 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -4,34 +4,60 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, use burn_tensor::Shape; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, - tensor_vectorization_factor, + tensor_line_size_parallel, }; use super::into_contiguous; +pub(crate) trait BinaryOpFamily: Send + Sync + 'static { + type BinaryOp: BinaryOp; +} + #[cube] pub(crate) trait BinaryOp: 'static + Send + Sync { /// Execute a binary operation. fn execute(lhs: Line, rhs: Line) -> Line; } -pub(crate) trait BinaryOpSpec: Send + Sync + 'static { - type C: Numeric; -} -pub(crate) struct Spec { - _c: PhantomData, -} - -impl BinaryOpSpec for Spec { - type C = C; -} - pub(crate) struct AddOp; pub(crate) struct SubOp; pub(crate) struct MulOp; pub(crate) struct DivOp; pub(crate) struct RemainderOp; -pub(crate) struct PowOp; + +/// Since Powf only works on float, but we still want to implement the numeric binary op family, we +/// set another precision in the family type to cast, when necessary, the input value to a valid +/// float. +/// +/// Because of this we won't benefit from the cubecl rust compilation speed improvement from using +/// the family pattern for [PowOp], but at least we don't duplicate code. +pub(crate) struct PowOp { + _f: PhantomData, +} + +impl BinaryOpFamily for AddOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for SubOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for MulOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for DivOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for RemainderOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for PowOp { + type BinaryOp = Self; +} #[cube] impl BinaryOp for AddOp { @@ -69,30 +95,34 @@ impl BinaryOp for RemainderOp { } #[cube] -impl BinaryOp for PowOp { +impl BinaryOp for PowOp { fn execute(lhs: Line, rhs: Line) -> Line { - Line::powf(lhs, rhs) + let lhs = Line::::cast_from(lhs); + let rhs = Line::::cast_from(rhs); + let out = Line::powf(lhs, rhs); + + Line::cast_from(out) } } -#[cube(launch)] -pub(crate) fn kernel_scalar_binop>( - input: &Tensor>, - scalar: BS::C, - output: &mut Tensor>, +#[cube(launch_unchecked)] +pub(crate) fn kernel_scalar_binop( + input: &Tensor>, + scalar: C, + output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } - output[ABSOLUTE_POS] = O::execute(input[ABSOLUTE_POS], Line::new(scalar)); + output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); } -#[cube(launch)] -pub(crate) fn kernel_binop>( - lhs: &Tensor>, - rhs: &Tensor>, - out: &mut Tensor>, +#[cube(launch_unchecked)] +pub(crate) fn kernel_binop( + lhs: &Tensor>, + rhs: &Tensor>, + out: &mut Tensor>, #[comptime] rank: Option, #[comptime] to_contiguous_lhs: bool, #[comptime] to_contiguous_rhs: bool, @@ -102,11 +132,11 @@ pub(crate) fn kernel_binop>( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { - offset_lhs = index_offset_with_layout::( + offset_lhs = index_offset_with_layout::( lhs, out, offset_out, @@ -117,7 +147,7 @@ pub(crate) fn kernel_binop>( } if to_contiguous_rhs { - offset_rhs = index_offset_with_layout::( + offset_rhs = index_offset_with_layout::( rhs, out, offset_out, @@ -127,20 +157,27 @@ pub(crate) fn kernel_binop>( ); } - out[offset_out] = O::execute(lhs[offset_lhs], rhs[offset_rhs]); + out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); } -pub(crate) fn launch_binop>( +pub(crate) fn launch_binop( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { let ndims = lhs.shape.num_dims(); - let vectorization_factor_lhs = - tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1); - let vectorization_factor_rhs = - tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1); - - let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs); + let line_size_lhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &lhs.shape.dims, + &lhs.strides, + ndims - 1, + ); + let line_size_rhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &rhs.shape.dims, + &rhs.strides, + ndims - 1, + ); + let line_size = Ord::min(line_size_lhs, line_size_rhs); let mut shape_out = vec![0; ndims]; lhs.shape @@ -157,59 +194,60 @@ pub(crate) fn launch_binop>( let num_elems = shape_out.num_elements(); let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - - if lhs.can_mut_broadcast(&rhs) { - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - TensorArg::alias(0), - None, - false, - rhs.strides != lhs.strides || rhs.shape != lhs.shape, - ); - - lhs - } else if rhs.can_mut_broadcast(&lhs) { - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - TensorArg::alias(1), - None, - rhs.strides != lhs.strides || rhs.shape != lhs.shape, - false, - ); - - rhs - } else { - let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); - let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; - let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; - - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), - None, - to_contiguous_lhs, - to_contiguous_rhs, - ); - - output + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if lhs.can_mut_broadcast(&rhs) { + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(0), + None, + false, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + ); + + lhs + } else if rhs.can_mut_broadcast(&lhs) { + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(1), + None, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + false, + ); + + rhs + } else { + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; + let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; + + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + None, + to_contiguous_lhs, + to_contiguous_rhs, + ); + + output + } } } -pub(crate) fn launch_scalar_binop>( +pub(crate) fn launch_scalar_binop( mut tensor: JitTensor, scalar: E, ) -> JitTensor { @@ -219,42 +257,47 @@ pub(crate) fn launch_scalar_binop>( // Vectorization is only enabled when the last dimension is contiguous. let ndims = tensor.shape.num_dims(); - let vectorization_factor = - tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); let client = tensor.client.clone(); let num_elems = tensor.shape.num_elements(); let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - - if tensor.can_mut() { - kernel_scalar_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - ScalarArg::new(scalar), - TensorArg::alias(0), - ); - - tensor - } else { - let output = empty_device::( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - ); - - kernel_scalar_binop::launch::, O, R>( - &client, - cube_count, - CubeDim::default(), - tensor.as_tensor_arg::(vectorization_factor), - ScalarArg::new(scalar), - output.as_tensor_arg::(vectorization_factor), - ); - - output + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if tensor.can_mut() { + kernel_scalar_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + TensorArg::alias(0), + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + kernel_scalar_binop::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + output.as_tensor_arg::(line_size), + ); + + output + } } } diff --git a/crates/burn-jit/src/kernel/binary_int.rs b/crates/burn-jit/src/kernel/binary_int.rs new file mode 100644 index 0000000000..390bfc479e --- /dev/null +++ b/crates/burn-jit/src/kernel/binary_int.rs @@ -0,0 +1,276 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use burn_tensor::Shape; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +use super::into_contiguous; + +pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { + type BinaryOp: BinaryOpInt; +} + +#[cube] +pub(crate) trait BinaryOpInt: 'static + Send + Sync { + /// Execute a binary operation. + fn execute(lhs: Line, rhs: Line) -> Line; +} + +pub(crate) struct BitwiseAndOp; +pub(crate) struct BitwiseOrOp; +pub(crate) struct BitwiseXorOp; +pub(crate) struct BitwiseShrOp; +pub(crate) struct BitwiseShlOp; + +impl BinaryOpIntFamily for BitwiseAndOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseOrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseXorOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShlOp { + type BinaryOp = Self; +} + +#[cube] +impl BinaryOpInt for BitwiseAndOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs & rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseOrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs | rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseXorOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs ^ rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs >> rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShlOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs << rhs + } +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_scalar_binop_int( + input: &Tensor>, + scalar: C, + output: &mut Tensor>, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_binop_int( + lhs: &Tensor>, + rhs: &Tensor>, + out: &mut Tensor>, + #[comptime] rank: Option, + #[comptime] to_contiguous_lhs: bool, + #[comptime] to_contiguous_rhs: bool, +) { + let offset_out = ABSOLUTE_POS; + let mut offset_lhs = ABSOLUTE_POS; + let mut offset_rhs = ABSOLUTE_POS; + + if offset_out >= out.len() { + terminate!(); + } + + if to_contiguous_lhs { + offset_lhs = index_offset_with_layout::( + lhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + if to_contiguous_rhs { + offset_rhs = index_offset_with_layout::( + rhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); +} + +pub(crate) fn launch_binop_int( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + let ndims = lhs.shape.num_dims(); + let line_size_lhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &lhs.shape.dims, + &lhs.strides, + ndims - 1, + ); + let line_size_rhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &rhs.shape.dims, + &rhs.strides, + ndims - 1, + ); + let line_size = Ord::min(line_size_lhs, line_size_rhs); + + let mut shape_out = vec![0; ndims]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::from(shape_out); + let client = lhs.client.clone(); + let num_elems = shape_out.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if lhs.can_mut_broadcast(&rhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(0), + None, + false, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + ); + + lhs + } else if rhs.can_mut_broadcast(&lhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(1), + None, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + false, + ); + + rhs + } else { + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; + let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; + + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + None, + to_contiguous_lhs, + to_contiguous_rhs, + ); + + output + } + } +} + +pub(crate) fn launch_scalar_binop_int( + mut tensor: JitTensor, + scalar: E, +) -> JitTensor { + if !tensor.is_contiguous_buffer() { + tensor = into_contiguous(tensor); + } + + // Vectorization is only enabled when the last dimension is contiguous. + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if tensor.can_mut() { + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + TensorArg::alias(0), + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + output.as_tensor_arg::(line_size), + ); + + output + } + } +} diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 798b79a0f0..43b24f071a 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -12,7 +12,7 @@ pub(crate) fn cast_element( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } let offset_input = index_offset_with_layout::( diff --git a/crates/burn-jit/src/kernel/clamp.rs b/crates/burn-jit/src/kernel/clamp.rs index 683e8aff8f..ec2bc93d1f 100644 --- a/crates/burn-jit/src/kernel/clamp.rs +++ b/crates/burn-jit/src/kernel/clamp.rs @@ -1,7 +1,11 @@ use cubecl::prelude::*; -use crate::kernel::{launch_unary, UnaryOp}; -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, + kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}, + tensor::JitTensor, + JitRuntime, +}; #[derive(CubeLaunch)] struct Options { @@ -16,28 +20,25 @@ pub(crate) fn clamp( ) -> JitTensor { struct ClampOp; - impl UnaryOp for ClampOp { - type Options = Options; + #[cube] + impl NumericUnaryOp for ClampOp { + type Options = Options; - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - options: OptionsExpand, - ) -> as CubeType>::ExpandType { - #[cube] - fn execute(input: Line, options: &Options) -> Line { - Line::clamp( - input, - Line::new(options.min_value), - Line::new(options.max_value), - ) - } - - execute::expand(context, input, options) + fn execute(input: Line, options: &Self::Options) -> Line { + Line::clamp( + input, + Line::new(options.min_value), + Line::new(options.max_value), + ) } } - launch_unary::(input, |_| { + impl NumericUnaryOpFamily for ClampOp { + type Options = Options; + type Unary = Self; + } + + launch_unary_numeric::(input, |_| { OptionsLaunch::new(ScalarArg::new(min_value), ScalarArg::new(max_value)) }) } diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index e33687fb5a..a6de9025bb 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -82,7 +82,7 @@ pub(crate) fn kernel_scalar_cmp>( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = Line::cast_from(O::execute(input[ABSOLUTE_POS], Line::new(scalar))); @@ -102,7 +102,7 @@ pub(crate) fn kernel_cmp>( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 0b3a35dc45..f015677a2b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -1,6 +1,8 @@ use burn_tensor::ops::{ConvOptions, ConvTransposeOptions}; -use crate::{tensor::JitTensor, FloatElement, IntElement, JitRuntime}; +use crate::{ + kernel::conv::ConvLaunchError, tensor::JitTensor, FloatElement, IntElement, JitRuntime, +}; #[cfg(feature = "autotune")] use super::{conv2d_autotune, conv_transpose2d_autotune}; @@ -75,11 +77,11 @@ pub fn conv2d( bias: Option>, options: ConvOptions<2>, strategy: Conv2dStrategy, -) -> JitTensor { +) -> Result, ConvLaunchError> { match strategy { Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] - Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), + Conv2dStrategy::Autotune => Ok(conv2d_autotune::(input, weight, bias, options)), Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::(input, weight, bias, options), Conv2dStrategy::ImplicitGemmComplex => { @@ -102,15 +104,15 @@ pub fn conv_transpose2d( bias: Option>, options: ConvTransposeOptions<2>, strategy: ConvTranspose2dStrategy, -) -> JitTensor { +) -> Result, ConvLaunchError> { match strategy { ConvTranspose2dStrategy::Direct => { conv_transpose2d_direct::(input, weight, bias, options) } #[cfg(feature = "autotune")] - ConvTranspose2dStrategy::Autotune => { - conv_transpose2d_autotune::(input, weight, bias, options) - } + ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::( + input, weight, bias, options, + )), ConvTranspose2dStrategy::Gemm => { conv_transpose2d_col2im::(input, weight, bias, options) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 0d9c48dc30..4f6931f86d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -6,6 +6,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ + conv::ConvLaunchError, into_contiguous, matmul::{matmul, MatmulStrategy}, slice, @@ -29,7 +30,7 @@ pub fn conv_transpose2d_col2im( weight: JitTensor, bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims(); let [batch_size, _, input_h, input_w] = input.shape.dims(); let groups = options.groups; @@ -94,9 +95,12 @@ pub fn conv_transpose2d_col2im( options.clone(), kernel_h, kernel_w, - ); + )?; } - reshape(image, Shape::new([batch_size, im_channels, im_h, im_w])) + Ok(reshape( + image, + Shape::new([batch_size, im_channels, im_h, im_w]), + )) } else { let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); let image = empty_device::(input.client.clone(), input.device.clone(), im_shape); @@ -108,8 +112,8 @@ pub fn conv_transpose2d_col2im( options, kernel_h, kernel_w, - ); - image + )?; + Ok(image) } } @@ -135,7 +139,7 @@ fn execute( options: ConvTransposeOptions<2>, kernel_h: usize, kernel_w: usize, -) { +) -> Result<(), ConvLaunchError> { let [batch_size, _, input_h, input_w] = input.shape.dims(); let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims(); @@ -145,12 +149,14 @@ fn execute( let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); - let columns = matmul::(weight, input, None, MatmulStrategy::default()); + let columns = matmul::(weight, input, None, MatmulStrategy::default())?; let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im::( columns, bias, image, kernel_h, kernel_w, input_h, input_w, options, ); + + Ok(()) } #[allow(clippy::too_many_arguments)] @@ -235,7 +241,7 @@ fn col2im_kernel( #[comptime] has_bias: bool, ) { if ABSOLUTE_POS >= image.len() { - return; + terminate!(); } let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index d5154ecc4b..1cd24f7c0c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -5,7 +5,7 @@ use burn_tensor::{ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ - kernel::into_contiguous, + kernel::{conv::ConvLaunchError, into_contiguous}, ops::{ numeric::{empty_device, zeros_device}, reshape, @@ -35,7 +35,7 @@ fn direct_conv2d_kernel( #[comptime] kernel_size_1_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); @@ -125,7 +125,7 @@ pub fn conv2d_direct( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, _, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let channels_per_group = out_channels / options.groups; @@ -193,5 +193,5 @@ pub fn conv2d_direct( kernel_w_unroll, ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs index 374a03be29..ffff3675bd 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs @@ -5,7 +5,7 @@ use cubecl::{ tile::{accelerated::Accelerated, TileMatmulFamily}, InvalidConfigError, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + kernels::matmul::AdvancedConfig, }, prelude::*, }; @@ -13,7 +13,6 @@ use cubecl::{ use super::{ base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem}, homogeneous::base::ImplicitGemmConvolutionFamily, - precision::ConvPrecision, selection::ConvSelection, }; @@ -47,34 +46,6 @@ pub trait Algorithm { Self::GlobalConvolution::check_config(&config)?; Ok(config) } - - /// Check availability of the matmul algorithm - fn check_availability( - client: &ComputeClient, - config: &::Config, - ) -> Result<(), MatmulAvailabilityError> { - Self::GlobalConvolution::check_availability::(client, config) - } - - /// Determine whether the given convolution problem is valid to launch (within hardware limits) - fn can_launch( - client: &ComputeClient, - problem: &ConvolutionProblem, - config: &::Config, - selection: &Self::Selection, - ) -> bool { - if problem.options.groups > 1 || Self::check_availability::(client, config).is_err() - { - return false; - } - - let cube_count = Self::cube_count(selection, problem); - let (max_x, max_y, max_z) = R::max_cube_count(); - match cube_count { - CubeCount::Static(x, y, z) => x <= max_x && y <= max_y && z <= max_z, - _ => true, - } - } } /// Cmma convolution diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs index e69b33b40f..a78082950a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs @@ -6,7 +6,7 @@ use cubecl::linalg::{ stage::{StageMatmul, StageMatmulFamily}, InvalidConfigError, MatmulProblem, MatrixLayout, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + kernels::matmul::AdvancedConfig, }, tensor::{ReadWrite, VirtualTensor}, }; @@ -91,12 +91,6 @@ pub trait ConvolutionConfigFactory: Send + Sync + 'static { /// Asserts that the configuration for this matmul will lead to a valid computation fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>; - /// Checks if the client can handle the features used in this computation - fn check_availability( - client: &ComputeClient, - config: &Self::Config, - ) -> Result<(), MatmulAvailabilityError>; - fn make_config( input: Self::Input, problem: &ConvolutionProblem, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index 2f32c8471e..988cd0ead6 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -16,7 +16,7 @@ use cubecl::{ }, Ident, InvalidConfigError, MatrixLayout, StageDim, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + kernels::matmul::AdvancedConfig, }, tensor::{ReadWrite, VirtualTensor}, }, @@ -194,13 +194,6 @@ where SMM::check_config(&config.to_smm_config()) } - fn check_availability( - client: &ComputeClient, - config: &Self::Config, - ) -> Result<(), MatmulAvailabilityError> { - SMM::check_availability::(client, &config.to_smm_config()) - } - fn make_config( input: Self::Input, problem: &ConvolutionProblem, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index c99861c82d..ad70a9b825 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -7,7 +7,7 @@ use burn_tensor::{ use cubecl::{ flex32, ir::{Elem, FloatKind}, - linalg::matmul::{self, components::MatrixLayout}, + linalg::matmul::{self, kernels::MatmulLaunchError}, tensor_line_size, tf32, Feature, }; use half::{bf16, f16}; @@ -23,7 +23,7 @@ use crate::{ algorithm::{Algorithm, ImplicitCmmaConv}, base::{ConvolutionLaunch, ConvolutionProblem}, }, - nchw_to_nhwc, Conv2dAutotuneKey, + nchw_to_nhwc, ConvLaunchError, }, into_contiguous, }, @@ -44,7 +44,7 @@ pub fn conv2d_gemm_cmma_large_m( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } @@ -60,7 +60,7 @@ pub fn conv2d_gemm_cmma_balanced( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } @@ -74,14 +74,16 @@ fn conv2d_gemm_cmma_strategy< weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { if TypeId::of::() == TypeId::of::() { conv2d_gemm_with_algo::(input, weight, bias, options) } else if TypeId::of::() == TypeId::of::() || TypeId::of::() == TypeId::of::() { conv2d_gemm_with_algo::(input, weight, bias, options) - } else { + } else if has_tf32(&input) { conv2d_gemm_with_algo::(input, weight, bias, options) + } else { + conv2d_gemm_with_algo::(input, weight, bias, options) } } @@ -102,10 +104,14 @@ pub fn conv2d_gemm_with_algo< weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor +) -> Result, ConvLaunchError> where SP::EG: JitElement, { + if options.groups != 1 { + return Err(ConvLaunchError::Groups(options.groups)); + } + let [batch_size, in_channels, height, width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); @@ -189,18 +195,14 @@ where let cube_count = Alg::cube_count(&selection, &problem); let advanced_config = Default::default(); - let config = match Alg::make_config( + let config = Alg::make_config( config_input, &problem, &cube_dim, &cube_count, &advanced_config, - ) { - Ok(val) => val, - Err(err) => { - panic!("Can't launch conv kernel because of an invalid config: {err}") - } - }; + ) + .map_err(MatmulLaunchError::InvalidConfig)?; let bias = bias.unwrap_or_else(|| { empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])) @@ -221,59 +223,7 @@ where // Reset to NCHW let out = reshape(out, Shape::new([batch_size, out_h, out_w, out_channels])); - permute(out, &[0, 3, 1, 2]) -} - -pub fn problem_from_key( - key: &Conv2dAutotuneKey, - out_h: usize, - out_w: usize, -) -> ConvolutionProblem { - let in_stride_2 = key.in_channels; - let in_stride_1 = key.width * in_stride_2; - let in_stride_0 = key.height * in_stride_1; - - let m = key.batch_size * out_h * out_w; - let n = key.out_channels; - let k = key.kernel_size[0] * key.kernel_size[1] * key.in_channels; - - let options = ConvOptions { - stride: key.stride, - padding: key.padding, - dilation: key.dilation, - groups: key.groups, - }; - - // Target 128 bit accesses - let available_vectorizations = R::supported_line_sizes() - .iter() - .copied() - .filter(|it| *it as usize * size_of::() <= 16) - .collect::>(); - let lhs_line_size = tensor_line_size( - &available_vectorizations, - &[key.batch_size, key.height, key.width, key.in_channels], - &[in_stride_0, in_stride_1, in_stride_2, 1], - 3, - ); - let rhs_line_size = tensor_line_size(&available_vectorizations, &[k, n], &[n, 1], 1); - let out_line_size = tensor_line_size(&available_vectorizations, &[m, n], &[n, 1], 1); - - ConvolutionProblem { - m, - n, - k, - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size, - rhs_line_size, - out_line_size, - kernel_size: (key.kernel_size[0] as u32, key.kernel_size[1] as u32), - options, - out_shape_y: out_h, - out_shape_x: out_w, - has_bias: key.has_bias, - } + Ok(permute(out, &[0, 3, 1, 2])) } pub(crate) fn has_tf32(c: &JitTensor) -> bool { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index a65c29466c..09ce56898b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -6,7 +6,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ - conv::index, + conv::{index, ConvLaunchError}, into_contiguous, launch_binop, matmul::{matmul, MatmulStrategy}, AddOp, @@ -53,7 +53,7 @@ fn im2col_kernel( let out_w = args.out_w; if ABSOLUTE_POS > args.num_elements { - return; + terminate!(); } let out_x = ABSOLUTE_POS % out_w; @@ -98,25 +98,38 @@ fn im2col_kernel( } #[cfg(not(test))] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + use cubecl::linalg::matmul::kernels::MatmulAvailabilityError; + let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX); let max_cube_count = u16::MAX as usize; let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); if max_simultaneous == 0 { - return None; + return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( + cube_count_per_batch as u32, + 1, + 1, + )) + .into()); } - Some( - (0..=max_simultaneous) - .rev() - .find(|per_run| batch_size % per_run == 0) - .expect("Logically not possible"), - ) + Ok((0..=max_simultaneous) + .rev() + .find(|per_run| batch_size % per_run == 0) + .expect("Logically not possible")) } #[cfg(test)] #[allow(unused)] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - Some(1) +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + Ok(1) } fn im2col( @@ -188,7 +201,7 @@ pub fn conv2d_im2col( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, in_channels, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -214,8 +227,7 @@ pub fn conv2d_im2col( return execute_1x1_kernel::(input, weight, bias, options); } - let batches_per_run = batches_per_run(batch_size, out_h, out_w) - .expect("Image too large to run even one batch at once"); + let batches_per_run = batches_per_run(batch_size, out_h, out_w)?; let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]); let mut out = if batches_per_run != batch_size { @@ -237,13 +249,13 @@ pub fn conv2d_im2col( options.clone(), out_h, out_w, - ); + )?; } let out = swap_dims(out, 1, 2); reshape(out, Shape::new([batch_size, out_channels, out_h, out_w])) } else { let out = empty_device::(input.client.clone(), input.device.clone(), matmul_shape); - execute::(input, weight, out.clone(), options, out_h, out_w); + execute::(input, weight, out.clone(), options, out_h, out_w)?; let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); swap_dims(out, 0, 1) }; @@ -252,7 +264,8 @@ pub fn conv2d_im2col( let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); out = launch_binop::(out, bias) } - out + + Ok(out) } fn execute_1x1_kernel( @@ -260,7 +273,7 @@ fn execute_1x1_kernel( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, _, height, width] = input.shape.dims(); let [out_channels, in_c_per_grp, _, _] = weight.shape.dims(); let groups = options.groups; @@ -271,7 +284,7 @@ fn execute_1x1_kernel( let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); let input = reshape(input, in_shape); - let out = matmul::(weight, input, None, MatmulStrategy::default()); + let out = matmul::(weight, input, None, MatmulStrategy::default())?; let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); if let Some(bias) = bias { @@ -279,7 +292,7 @@ fn execute_1x1_kernel( out = launch_binop::(out, bias) } - swap_dims(out, 0, 1) + Ok(swap_dims(out, 0, 1)) } fn execute( @@ -289,7 +302,7 @@ fn execute( options: ConvOptions<2>, out_h: usize, out_w: usize, -) { +) -> Result<(), ConvLaunchError> { let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -301,5 +314,7 @@ fn execute( let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1])); let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); - matmul::(weight, columns, Some(out), Default::default()); + matmul::(weight, columns, Some(out), Default::default())?; + + Ok(()) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 2e4f469068..9c8edf0103 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -6,13 +6,14 @@ use cmma::{Matrix, MatrixIdent, MatrixLayout}; use cubecl::{ cube, ir::{Elem, FloatKind}, + linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}, prelude::*, Compiler, CubeCount, CubeDim, Feature, }; use half::f16; use crate::{ - kernel::{into_contiguous, slice, slice_assign}, + kernel::{conv::ConvLaunchError, into_contiguous, slice, slice_assign}, ops::{ numeric::{empty_device, zeros_device}, permute, @@ -35,7 +36,7 @@ pub fn conv2d_implicit_gemm( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let is_tf32 = F::as_elem_native_unchecked() == Elem::Float(FloatKind::F32) && input .client @@ -66,7 +67,7 @@ pub fn conv2d_implicit_gemm( let padded_batch_size = padded_batch_size(batch_size, out_h, out_w); - if !can_do_implicit_gemm::( + check_availability::( batch_size, in_channels, out_channels, @@ -75,15 +76,7 @@ pub fn conv2d_implicit_gemm( out_h, out_w, &input.client, - ) { - panic!( - "Requirements for implicit GEMM not met: -- CMMA must be available -- `groups` must be 1 -- subcube size must be non-variable (might not hold on Intel) - " - ); - } + )?; // If input is contiguous NCHW, use custom transpose kernel let input = match input.is_contiguous() { @@ -210,7 +203,7 @@ pub fn conv2d_implicit_gemm( let out = slice::(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); // Reset to NCHW - permute(out, &[0, 3, 1, 2]) + Ok(permute(out, &[0, 3, 1, 2])) } fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 { @@ -643,7 +636,7 @@ fn load_weight_tile( } #[allow(clippy::too_many_arguments)] -pub(crate) fn can_do_implicit_gemm( +pub(crate) fn check_availability( batch_size: usize, in_channels: usize, out_channels: usize, @@ -652,7 +645,7 @@ pub(crate) fn can_do_implicit_gemm( out_h: usize, out_w: usize, client: &ComputeClient, -) -> bool { +) -> Result<(), ConvLaunchError> { let cmma_k = match ( E::as_elem_native_unchecked(), client @@ -672,19 +665,43 @@ pub(crate) fn can_do_implicit_gemm( let gemm_n = out_channels; let gemm_k = in_channels * kernel_h * kernel_w; - let size = find_cmma_size::(client, gemm_m as u32, gemm_k as u32, gemm_n as u32); - - if let Some((cmma_m, cmma_k, cmma_n)) = size { - let warps_per_cube = 8; + let (cmma_m, cmma_n, cmma_k) = + find_cmma_size::(client, gemm_m as u32, gemm_k as u32, gemm_n as u32).ok_or_else( + || { + ConvLaunchError::Matmul(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::CmmaInstructionUnavailable { + input: E::as_elem_native_unchecked(), + output: E::as_elem_native_unchecked(), + m: 16, + n: 16, + k: cmma_k as u32, + }, + )) + }, + )?; + + let warps_per_cube = 8; + + let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::(); + if ::max_shared_memory_size() < smem_size { + return Err(ConvLaunchError::Matmul(MatmulLaunchError::InvalidConfig( + Box::new("Not enough shared memory"), + ))); + } - let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::(); - let topology = client.properties().hardware_properties(); - let not_intel = topology.plane_size_min >= 32; + let topology = client.properties().hardware_properties(); + if topology.plane_size_min < 32 { + return Err(ConvLaunchError::Matmul(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::PlaneDimUnsupported { + plane_dim: topology.plane_size_min, + }, + ))); + } - ::max_shared_memory_size() >= smem_size && groups == 1 && not_intel - } else { - false + if groups != 1 { + return Err(ConvLaunchError::Groups(groups)); } + Ok(()) } fn padded_k( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs index 62f0e56d8f..7cbe09dbc0 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -107,7 +107,7 @@ fn nchw_to_nhwc_kernel( let batch = CUBE_POS_Z; if batch >= input.shape(0) { - return; + terminate!(); } let batch_offset = batch * input.stride(0); @@ -163,7 +163,7 @@ fn nchw_to_nhwc_kernel( let hw = base_hw + mat_hw; if hw >= shape_hw { - return; + terminate!(); } let mat_c_start = mat_hw_start; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 6a97ab8759..a8cd1ceb7f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -2,7 +2,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ element::JitElement, - kernel::into_contiguous, + kernel::{conv::ConvLaunchError, into_contiguous}, ops::{ numeric::{empty_device, zeros_device}, reshape, @@ -32,7 +32,7 @@ fn conv_transpose2d_direct_kernel( args: ConvArgs, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_c_per_group = weight.shape(0) / args.groups; @@ -126,7 +126,7 @@ pub fn conv_transpose2d_direct( weight: JitTensor, bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_height, in_width] = input.shape.dims(); @@ -184,5 +184,5 @@ pub fn conv_transpose2d_direct( ), ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 157d4d443d..36d12e2255 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -1,26 +1,12 @@ -use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions}, - ElementConversion, Shape, -}; -use cubecl::{ - ir::{Elem, FloatKind}, - tf32, tune, - tune::{local_tuner, tune_with, LocalTuner}, -}; -use half::f16; +use burn_tensor::{ops::ConvOptions, ElementConversion, Shape}; +use cubecl::tune::{local_tuner, LocalTuner, TunableSet}; use super::Conv2dAutotuneKey; use crate::{ kernel::{ conv::{ - algorithm::{Algorithm, ImplicitCmmaConv}, - batches_per_run, can_do_implicit_gemm, - conv2d::gemm::base::ConvolutionProblem, conv2d_direct, conv2d_gemm_cmma_balanced, conv2d_gemm_cmma_large_m, conv2d_im2col, - conv2d_implicit_gemm, has_tf32, - precision::ConvPrecision, - problem_from_key, - selection::{Balanced, ConvSelector, Large}, + conv2d_implicit_gemm, }, prng::random_uniform, }, @@ -39,31 +25,33 @@ pub fn conv2d_autotune( static TUNER: LocalTuner = local_tuner!(); + let tunables = TunableSet::new(create_key::, create_conv2d_input::) + .with_tunable(conv2d_direct::) + .with_tunable(conv2d_im2col::) + .with_tunable(conv2d_implicit_gemm::) + .with_tunable(conv2d_gemm_cmma_large_m::) + .with_tunable(conv2d_gemm_cmma_balanced::); + TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(Conv2dOperations::::new(input, weights, bias, options)), + &tunables, + (input, weights, bias, options), ) } -#[tune( - operations( - conv2d_direct, - conv2d_im2col, - conv2d_implicit_gemm, - conv2d_gemm_cmma_large_m, - conv2d_gemm_cmma_balanced - ), - create_key = create_key::, - should_run = should_run -)] -pub fn conv2d_operations( - key: JitAutotuneKey, - input: JitTensor, - weights: JitTensor, - bias: Option>, - options: ConvOptions<2>, -) -> JitTensor { +pub fn create_conv2d_input( + key: &JitAutotuneKey, + input: &JitTensor, + _weights: &JitTensor, + _bias: &Option>, + options: &ConvOptions<2>, +) -> ( + JitTensor, + JitTensor, + Option>, + ConvOptions<2>, +) { let device = &input.device; let key = match key { JitAutotuneKey::Conv2d(key) => key, @@ -82,125 +70,7 @@ pub fn conv2d_operations( .has_bias .then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1)); - tune_with!(input, weights, bias, options) -} - -macro_rules! check_algo { - ($algo:tt, $float:ty, $input:expr, $problem:expr) => { - match (<$float>::as_elem_native_unchecked(), has_tf32(&$input)) { - (Elem::Float(FloatKind::F32), true) => { - can_launch::<$algo, R, ($float, tf32, f32)>($input, $problem) - } - (Elem::Float(FloatKind::Flex32), _) => { - can_launch::<$algo, R, ($float, f16, f32)>($input, $problem) - } - _ => can_launch::<$algo, R, ($float, $float, f32)>($input, $problem), - } - }; - - ($algo:tt, $input:expr, $problem:expr) => { - let plane_dim = 32; - let conv_problem = $problem; - - let (selection, config_input) = $algo::select_kernel::(plane_dim); - let cube_dim = ImplicitCmmaConv::cube_dim(&selection); - let cube_count = ImplicitCmmaConv::cube_count(&selection, &conv_problem); - - let advanced_config = Default::default(); - let config = ImplicitCmmaConv::make_config( - config_input, - &conv_problem, - &cube_dim, - &cube_count, - &advanced_config, - ); - - match config { - Ok(config) => ImplicitCmmaConv::can_launch::( - &op.input.client, - &conv_problem, - &config, - &selection, - ), - Err(_) => false, - } - }; -} - -fn should_run( - op: &Conv2dOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let key = match key { - JitAutotuneKey::Conv2d(key) => key, - _ => unreachable!(), - }; - - let out_h = calculate_conv_output_size( - key.kernel_size[0], - key.stride[0], - key.padding[0], - key.dilation[0], - key.height, - ); - let out_w = calculate_conv_output_size( - key.kernel_size[1], - key.stride[1], - key.padding[1], - key.dilation[1], - key.width, - ); - - let conv_problem = problem_from_key::(key, out_h, out_w); - - match index { - // im2col - 1 => batches_per_run(key.batch_size, out_h, out_w).is_some(), - // Implicit gemm. - 2 => can_do_implicit_gemm::( - key.batch_size, - key.in_channels, - key.out_channels, - key.kernel_size, - op.options.groups, - out_h, - out_w, - &op.input.client, - ), - // GEMM large m - 3 => check_algo!(Large, F, &op.input, &conv_problem), - // GEMM balanced - 4 => check_algo!(Balanced, F, &op.input, &conv_problem), - _ => true, - } -} - -fn can_launch, R: JitRuntime, CS: ConvPrecision>( - input: &JitTensor, - conv_problem: &ConvolutionProblem, -) -> bool { - let plane_dim = 32; - - let (selection, config_input) = S::select_kernel::(plane_dim); - let cube_dim = ImplicitCmmaConv::cube_dim(&selection); - let cube_count = ImplicitCmmaConv::cube_count(&selection, conv_problem); - - let advanced_config = Default::default(); - let config = ImplicitCmmaConv::make_config( - config_input, - conv_problem, - &cube_dim, - &cube_count, - &advanced_config, - ); - - match config { - Ok(config) => { - ImplicitCmmaConv::can_launch::(&input.client, conv_problem, &config, &selection) - } - Err(_) => false, - } + (input, weights, bias, options.clone()) } fn create_key( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs index c2d546151a..df0159b75d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -1,12 +1,9 @@ use burn_tensor::{ops::ConvTransposeOptions, ElementConversion, Shape}; -use cubecl::{ - tune, - tune::{local_tuner, tune_with, LocalTuner}, -}; +use cubecl::tune::{local_tuner, LocalTuner, TunableSet}; use crate::{ kernel::{ - conv::{batches_per_run, conv_transpose2d_col2im, conv_transpose2d_direct}, + conv::{conv_transpose2d_col2im, conv_transpose2d_direct}, prng::random_uniform, }, tensor::JitTensor, @@ -26,23 +23,30 @@ pub fn conv_transpose2d_autotune( static TUNER: LocalTuner = local_tuner!(); + let tune_set = TunableSet::new(create_key::, create_transpose2d_input::) + .with_tunable(conv_transpose2d_direct::) + .with_tunable(conv_transpose2d_col2im::); + TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(ConvTranspose2dOperations::::new( - input, weights, bias, options, - )), + &tune_set, + (input, weights, bias, options), ) } -#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key::, should_run = should_run)] -pub fn conv_transpose2d_operations( - key: JitAutotuneKey, - input: JitTensor, - weights: JitTensor, - bias: Option>, - options: ConvTransposeOptions<2>, -) -> JitTensor { +pub fn create_transpose2d_input( + key: &JitAutotuneKey, + input: &JitTensor, + _weights: &JitTensor, + _bias: &Option>, + options: &ConvTransposeOptions<2>, +) -> ( + JitTensor, + JitTensor, + Option>, + ConvTransposeOptions<2>, +) { let key = match key { JitAutotuneKey::ConvTranspose2d(key) => key, _ => unreachable!(), @@ -60,7 +64,7 @@ pub fn conv_transpose2d_operations( let bias = key .has_bias .then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1)); - tune_with!(input, weights, bias, options) + (input, weights, bias, options.clone()) } fn create_key( @@ -94,20 +98,3 @@ fn create_key( E::dtype(), )) } - -fn should_run( - _op: &ConvTranspose2dOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let key = match key { - JitAutotuneKey::ConvTranspose2d(key) => key, - _ => unreachable!(), - }; - - match index { - // im2col - 1 => batches_per_run(key.batch_size, key.height, key.width).is_some(), - _ => true, - } -} diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index 157610794b..a616c432b9 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -41,7 +41,7 @@ fn conv3d_kernel( #[comptime] kernel_size_2_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index b22821aef1..300d714335 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -19,6 +19,8 @@ use crate::{ FloatElement, JitRuntime, }; +use super::ConvLaunchError; + #[derive(CubeLaunch)] struct DeformConv2dArgs { conv_stride_h: u32, @@ -262,7 +264,7 @@ pub(crate) fn deform_conv2d( mask: Option>, bias: Option>, options: DeformConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let offset = into_contiguous(offset); let weight = into_contiguous(weight); @@ -298,15 +300,15 @@ pub(crate) fn deform_conv2d( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); - let out = matmul::(weight, columns, None, MatmulStrategy::default()); + let out = matmul::(weight, columns, None, MatmulStrategy::default())?; let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); - launch_binop::(out, bias) + Ok(launch_binop::(out, bias)) } else { - out + Ok(out) } } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index b75ac43182..5840f4dc9f 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -1,8 +1,13 @@ +use std::marker::PhantomData; + use burn_tensor::{ ops::{DeformConv2dBackward, DeformConvOptions, FloatTensorOps as _}, Shape, }; -use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch}; +use cubecl::{ + calculate_cube_count_elemwise, cube, ir::Elem, prelude::*, AtomicFeature, CubeDim, CubeLaunch, + Feature, +}; use crate::{ element::BoolElement, @@ -19,7 +24,7 @@ use crate::{ FloatElement, IntElement, JitBackend, JitRuntime, }; -use super::{bilinear_interpolate, deform_im2col, index}; +use super::{bilinear_interpolate, deform_im2col, index, ConvLaunchError}; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow(clippy::single_range_in_vec_init)] @@ -36,7 +41,7 @@ pub(crate) fn deform_conv2d_backward< bias: Option>, out_grad: JitTensor, options: DeformConvOptions<2>, -) -> DeformConv2dBackward> { +) -> Result>, ConvLaunchError> { let [_, _, out_h, out_w] = out_grad.shape.dims(); let [_, _, kernel_h, kernel_w] = weight.shape.dims(); @@ -60,7 +65,7 @@ pub(crate) fn deform_conv2d_backward< out_grad.clone(), &options, (kernel_h, kernel_w), - ); + )?; let weight_grad = compute_weight_grad::( input, @@ -70,15 +75,15 @@ pub(crate) fn deform_conv2d_backward< options, (kernel_h, kernel_w), (out_h, out_w), - ); + )?; - DeformConv2dBackward::new( + Ok(DeformConv2dBackward::new( input_gradient, offset_gradient, weight_grad, mask_gradient, gradient_bias, - ) + )) } fn compute_weight_grad( @@ -89,7 +94,7 @@ fn compute_weight_grad( options: DeformConvOptions<2>, kernel_dims: (usize, usize), out_dims: (usize, usize), -) -> JitTensor { +) -> Result, ConvLaunchError> { let [_, in_channels, _, _] = input.shape.dims(); let [_, out_channels, _, _] = out_grad.shape.dims(); let (kernel_h, kernel_w) = kernel_dims; @@ -108,12 +113,12 @@ fn compute_weight_grad( let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); - let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default()); + let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default())?; - reshape( + Ok(reshape( grad_weight, Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]), - ) + )) } type InputGradients = (JitTensor, JitTensor, Option>); @@ -126,7 +131,7 @@ fn backward_gradient_inputs( out_grad: JitTensor, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> InputGradients { +) -> Result, ConvLaunchError> { let client = out_grad.client.clone(); let device = out_grad.device.clone(); @@ -150,7 +155,7 @@ fn backward_gradient_inputs( for group in 0..groups { let weight = swap_dims(index::(weight.clone(), group), 0, 1); let out_grad = index::(out_grad.clone(), group); - let values = matmul::(weight, out_grad, None, MatmulStrategy::default()); + let values = matmul::(weight, out_grad, None, MatmulStrategy::default())?; let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); columns = slice_assign::( columns, @@ -169,12 +174,12 @@ fn backward_gradient_inputs( mask.clone(), options, kernel_dims, - ); + )?; let input_gradient = compute_input_grad::(columns, offset, mask, options, kernel_dims, input_shape); - (input_gradient, offset_gradient, mask_gradient) + Ok((input_gradient, offset_gradient, mask_gradient)) } fn compute_offset_and_mask_gradient( @@ -184,7 +189,7 @@ fn compute_offset_and_mask_gradient( mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> (JitTensor, Option>) { +) -> Result<(JitTensor, Option>), ConvLaunchError> { let client = offset.client.clone(); let device = offset.device.clone(); let (kernel_height, kernel_width) = kernel_dims; @@ -238,7 +243,7 @@ fn compute_offset_and_mask_gradient( }; let mask_gradient = if use_mask { Some(grad_mask) } else { None }; - (grad_offset, mask_gradient) + Ok((grad_offset, mask_gradient)) } #[derive(CubeLaunch)] @@ -270,7 +275,7 @@ fn deform_col2img_coord_kernel( // Alternatively : [batch, offset_channels, out_h, out_w] if ABSOLUTE_POS >= grad_offset.len() { - return; + terminate!(); } let offset_channels = offset.shape(1); @@ -439,15 +444,29 @@ fn compute_input_grad( let client = offset.client.clone(); let device = offset.device.clone(); + let kind = match E::as_elem_native_unchecked() { + Elem::Float(kind) => kind, + _ => unreachable!("Should be float"), + }; + let props = client.properties(); + + let supports_fadd = props.feature_enabled(Feature::AtomicFloat(AtomicFeature::Add)); + let supports_same_type = props.feature_enabled(Feature::Type(Elem::AtomicFloat(kind))); + let [batch_size, in_channels, height, width] = input_shape.dims(); let (kernel_height, kernel_width) = kernel_dims; - // Force `f32` to enable bitcasting as `u32` - let grad_in = zeros_device::( - client.clone(), - device.clone(), - Shape::new([batch_size, in_channels, height, width]), - ); + let shape = Shape::new([batch_size, in_channels, height, width]); + let grad_in = match supports_fadd && supports_same_type { + // Use type as is to save a cast + true => zeros_device::(client.clone(), device.clone(), shape), + // Force `f32` to enable bitcasting as `u32`, or use intrinsic when supported + false => zeros_device::(client.clone(), device.clone(), shape), + }; + let grad_arg = match supports_fadd && supports_same_type { + true => grad_in.as_tensor_arg::(1), + false => grad_in.as_tensor_arg::(1), + }; let use_mask = mask.is_some(); let mask = mask.unwrap_or_else(|| { @@ -458,43 +477,60 @@ fn compute_input_grad( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elements, cube_dim); - deform_col2img_kernel::launch::( - &offset.client, - cube_count, - cube_dim, - offset.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - columns.as_tensor_arg::(1), - grad_in.as_tensor_arg::(1), - DeformConv2dCol2ImgArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.padding[0] as f32), - ScalarArg::new(options.padding[1] as f32), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(batch_size as u32), - ScalarArg::new(in_channels as u32), - ScalarArg::new(height as u32), - ScalarArg::new(width as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), - ), - use_mask, - ); - - cast::(grad_in) + let launch = match (supports_fadd, supports_same_type) { + // use same type intrinsic if supported + (true, true) => deform_col2img_kernel::launch_unchecked::, R>, + // use f32 intrinsic if float add is supported at all + (true, false) => { + deform_col2img_kernel::launch_unchecked::, R> + } + // fall back to compare and swap + _ => deform_col2img_kernel::launch_unchecked::, + }; + + unsafe { + launch( + &offset.client, + cube_count, + cube_dim, + offset.as_tensor_arg::(1), + mask.as_tensor_arg::(1), + columns.as_tensor_arg::(1), + grad_arg, + DeformConv2dCol2ImgArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(E::new(options.padding[0] as f32)), + ScalarArg::new(E::new(options.padding[1] as f32)), + ScalarArg::new(options.offset_groups as u32), + ScalarArg::new(batch_size as u32), + ScalarArg::new(in_channels as u32), + ScalarArg::new(height as u32), + ScalarArg::new(width as u32), + ScalarArg::new(kernel_height as u32), + ScalarArg::new(kernel_width as u32), + ), + use_mask, + ) + }; + + if !supports_same_type || !supports_fadd { + cast::(grad_in) + } else { + grad_in + } } #[derive(CubeLaunch)] -struct DeformConv2dCol2ImgArgs { +struct DeformConv2dCol2ImgArgs { stride_h: u32, stride_w: u32, dilation_h: u32, dilation_w: u32, - pad_h: f32, - pad_w: f32, + pad_h: F, + pad_w: F, offset_groups: u32, batch_size: u32, in_channels: u32, @@ -504,17 +540,19 @@ struct DeformConv2dCol2ImgArgs { kernel_width: u32, } -#[cube(launch)] -fn deform_col2img_kernel( +#[cube(launch_unchecked)] +fn deform_col2img_kernel( offset: &Tensor, mask: &Tensor, columns: &Tensor, - grad_input: &mut Tensor, - args: &DeformConv2dCol2ImgArgs, + grad_input: &mut Tensor>, + args: &DeformConv2dCol2ImgArgs, #[comptime] use_mask: bool, ) { // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] - let _ = mask[0]; // Keep mask in bind group + if ABSOLUTE_POS >= columns.len() { + terminate!(); + } let n_in_channels = args.in_channels; let height = args.height; @@ -545,8 +583,8 @@ fn deform_col2img_kernel( let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x; let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; - let offset_y = f32::cast_from(offset[offset_base_idx + offset_y_idx]); - let offset_x = f32::cast_from(offset[offset_base_idx + offset_x_idx]); + let offset_y = offset[offset_base_idx + offset_y_idx]; + let offset_x = offset[offset_base_idx + offset_x_idx]; let mask_value = if use_mask { let mask_base_idx = @@ -558,47 +596,78 @@ fn deform_col2img_kernel( }; let y = - f32::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h + offset_y; + F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h + offset_y; let x = - f32::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w + offset_x; + F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w + offset_x; for dy in -1..=1 { #[unroll] for dx in -1..=1 { - let yp = f32::floor(y) + dy as f32; - let xp = f32::floor(x) + dx as f32; - - if yp >= 0.0 - && yp < height as f32 - && xp >= 0.0 - && xp < width as f32 - && f32::abs(y - yp) < 1.0 - && f32::abs(x - xp) < 1.0 + let yp = F::floor(y) + F::cast_from(dy); + let xp = F::floor(x) + F::cast_from(dx); + + if yp >= F::new(0.0) + && yp < F::cast_from(height) + && xp >= F::new(0.0) + && xp < F::cast_from(width) + && F::abs(y - yp) < F::new(1.0) + && F::abs(x - xp) < F::new(1.0) { let gradient_pos = - ((batch * n_in_channels + in_channel) * height + yp as u32) * width + xp as u32; + ((batch * n_in_channels + in_channel) * height + u32::cast_from(yp)) * width + + u32::cast_from(xp); - let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp)); + let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp)); let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS]; - float_atomic_add(&mut grad_input[gradient_pos], f32::cast_from(value)); + FAdd::float_atomic_add::(&mut grad_input[gradient_pos], value); } } } } #[cube] -fn float_atomic_add(ptr: &mut AtomicU32, value: f32) { - if value != 0.0 { - let mut v = AtomicU32::load(ptr); - loop { - let prev = v; - let v_float = f32::bitcast_from(v); - let new = u32::bitcast_from(v_float + value); - v = AtomicU32::compare_and_swap(ptr, v, new); - if prev == v { - break; +trait FloatAtomicAdd: Send + Sync + 'static { + type ProxyType: Numeric; + + fn float_atomic_add(ptr: &mut Atomic, value: F); +} + +#[derive(CubeType)] +struct IntrinsicFloatAtomicAdd { + _ty: PhantomData, +} + +#[derive(CubeType)] +struct CASFloatAtomicAdd; + +#[cube] +impl FloatAtomicAdd for IntrinsicFloatAtomicAdd { + type ProxyType = FAdd; + + fn float_atomic_add(ptr: &mut Atomic, value: F) { + let value = FAdd::cast_from(value); + Atomic::add(ptr, value); + } +} + +#[cube] +impl FloatAtomicAdd for CASFloatAtomicAdd { + type ProxyType = u32; + + fn float_atomic_add(ptr: &mut Atomic, value: F) { + let value = f32::cast_from(value); + if value != 0.0 { + let mut v = Atomic::load(ptr); + loop { + let prev = v; + let v_float = f32::bitcast_from(v); + let new = u32::bitcast_from(v_float + value); + v = Atomic::compare_and_swap(ptr, v, new); + if prev == v { + break; + } } } } diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs new file mode 100644 index 0000000000..2654a20e24 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -0,0 +1,47 @@ +use core::fmt::Debug; +use cubecl::{ + linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}, + tune::AutotuneError, +}; + +pub enum ConvLaunchError { + Matmul(MatmulLaunchError), + Groups(usize), + Unknown, +} + +impl Debug for ConvLaunchError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConvLaunchError::Matmul(err) => { + write!(f, "{err:?}") + } + ConvLaunchError::Groups(groups) => { + writeln!( + f, + "Unable to launch matmul because groups must be one, is actually {groups}", + ) + } + ConvLaunchError::Unknown => write!(f, "Unknown"), + } + } +} + +impl From for ConvLaunchError { + fn from(value: MatmulLaunchError) -> Self { + Self::Matmul(value) + } +} + +impl From for ConvLaunchError { + fn from(value: MatmulAvailabilityError) -> Self { + Self::Matmul(MatmulLaunchError::Unavailable(value)) + } +} + +#[allow(clippy::from_over_into)] +impl Into for ConvLaunchError { + fn into(self) -> AutotuneError { + AutotuneError::Unknown(format!("{self:?}")) + } +} diff --git a/crates/burn-jit/src/kernel/conv/mod.rs b/crates/burn-jit/src/kernel/conv/mod.rs index 5d6794495f..04794e9b42 100644 --- a/crates/burn-jit/src/kernel/conv/mod.rs +++ b/crates/burn-jit/src/kernel/conv/mod.rs @@ -3,11 +3,13 @@ mod conv3d; mod conv_transpose3d; mod deform_conv2d; mod deform_conv_transpose2d; +mod error; pub(crate) use conv2d::*; pub(crate) use conv3d::*; pub(crate) use conv_transpose3d::*; pub(crate) use deform_conv2d::*; pub(crate) use deform_conv_transpose2d::*; +pub(crate) use error::*; pub use conv2d::{conv2d, conv_transpose2d, nchw_to_nhwc, Conv2dStrategy, ConvTranspose2dStrategy}; diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 583e0346d3..a682a76eac 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -11,7 +11,7 @@ fn flip_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 9e9b5685bb..c1aa56072e 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -12,7 +12,7 @@ fn gather_kernel( dim: &u32, ) { if ABSOLUTE_POS >= indices.len() { - return; + terminate!(); } let index = indices[ABSOLUTE_POS]; diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index 3887bfbd8b..b19f9e2b21 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -4,7 +4,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] fn repeat_dim_kernel(input: &Tensor, output: &mut Tensor, dim: u32) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 4ddd9c00fb..4cca94f824 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -46,7 +46,7 @@ fn scatter_kernel( let should_stop = ABSOLUTE_POS >= num_elems; if should_stop { - return; + terminate!(); } for i in 0..shape_value { diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index b104bf504f..fe664ab420 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -10,7 +10,7 @@ fn select_kernel( dim: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index a0fed49dbd..cd4c013f63 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -29,7 +29,7 @@ fn select_assign_kernel( } if ABSOLUTE_POS >= num_elems { - return; + terminate!(); } let strides_tensor_dim = tensor.stride(dim); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index 7f20f033b8..b6daba8da5 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -52,7 +52,7 @@ fn slice_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 1d545d79c7..3f77ef1302 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bicubic_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index 3557fcdbb8..f0cb95b536 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 0743a13567..0e6ba32552 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 5ea860a7ae..f0442ec92e 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_backward_kernel(grad: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let out_h = output.shape(2); diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index 386e7a5039..95096c7994 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -16,7 +16,7 @@ fn mask_fill_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -35,7 +35,7 @@ fn mask_fill_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index 5518e9648b..99384fde98 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -16,7 +16,7 @@ fn mask_where_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -36,7 +36,7 @@ fn mask_where_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 7fa141cf67..611f1e32d4 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,3 +1,5 @@ +use cubecl::linalg::matmul::kernels::MatmulLaunchError; + use super::init_matmul_output; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; @@ -30,7 +32,7 @@ pub fn matmul( rhs: JitTensor, out: Option>, strategy: MatmulStrategy, -) -> JitTensor { +) -> Result, MatmulLaunchError> { match strategy { MatmulStrategy::Cube => { let out = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); @@ -43,11 +45,11 @@ pub fn matmul( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ) - .unwrap(); - out + )?; + + Ok(out) } #[cfg(feature = "autotune")] - MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs, out), + MatmulStrategy::Autotune => Ok(matmul_autotune::(lhs, rhs, out)), } } diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 3f3232db10..dacd2693b9 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -1,10 +1,7 @@ use burn_tensor::{Element, ElementConversion}; use cubecl::{ - ir::{Elem, FloatKind}, linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy}, - tune, - tune::{local_tuner, tune_with, LocalTuner}, - Feature, + tune::{local_tuner, LocalTuner, TunableSet}, }; use crate::{ @@ -18,44 +15,19 @@ use crate::{ use super::key::create_key; -#[tune( - operations(matmul_tiling2d, matmul_accelerated, matmul_simple), - create_key = create_key::, - should_run = should_run -)] -fn matmul_ops( - key: JitAutotuneKey, - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, -) { +fn matmul_input_gen( + _key: &JitAutotuneKey, + lhs: &JitTensor, + rhs: &JitTensor, + out: &JitTensor, +) -> (JitTensor, JitTensor, JitTensor) { let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); let lhs = random_like_uniform(lhs, random_bounds.0, random_bounds.1); let rhs = random_like_uniform(rhs, random_bounds.0, random_bounds.1); let out = empty_device::(out.client.clone(), out.device.clone(), out.shape.clone()); - tune_with!(lhs, rhs, out) -} - -fn should_run( - op: &MatmulOps, - _key: &JitAutotuneKey, - index: usize, -) -> bool { - match index { - // Accelerated - // TODO: Add way to query actual requirements from cubecl - 1 => op.lhs.client.properties().feature_enabled(Feature::Cmma { - a: Elem::Float(FloatKind::F16), - b: Elem::Float(FloatKind::F16), - c: Elem::Float(FloatKind::F32), - m: 16, - k: 16, - n: 16, - }), - _ => true, - } + (lhs, rhs, out) } /// Executes autotune on matmul operations @@ -70,10 +42,16 @@ pub fn matmul_autotune( static TUNER: LocalTuner = local_tuner!(); + let tunables = TunableSet::new(create_key::, matmul_input_gen::) + .with_tunable(matmul_tiling2d::) + .with_tunable(matmul_accelerated::) + .with_tunable(matmul_simple::); + TUNER.execute( &JitTuneId::new::(&lhs.device), &client, - Box::new(MatmulOps::::new(lhs, rhs, output.clone())), + &tunables, + (lhs, rhs, output.clone()), ); output @@ -83,7 +61,7 @@ fn matmul_accelerated( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Standard, &lhs.client, @@ -91,14 +69,14 @@ fn matmul_accelerated( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } fn matmul_tiling2d( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Tiling2D(Tiling2dConfig::default()), &lhs.client, @@ -106,14 +84,14 @@ fn matmul_tiling2d( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } fn matmul_simple( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Simple, &lhs.client, @@ -121,5 +99,5 @@ fn matmul_simple( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } diff --git a/crates/burn-jit/src/kernel/matmul/tune/key.rs b/crates/burn-jit/src/kernel/matmul/tune/key.rs index d25cce3023..44cb079399 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/key.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/key.rs @@ -22,7 +22,7 @@ pub struct MatmulAutotuneKey { } impl MatmulAutotuneKey { - fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { + pub(crate) fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { let ndims = lhs_shape.num_dims(); let m = lhs_shape.dims[ndims - 2]; let k = lhs_shape.dims[ndims - 1]; diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index afa8ecd6fa..93d2833976 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -1,19 +1,26 @@ mod binary; +mod binary_int; mod cast; mod clamp; mod comparison; mod contiguous; mod index; mod mask; -mod unary; +mod unary_float; +mod unary_int; +mod unary_numeric; pub(crate) use binary::*; +pub(crate) use binary_int::*; pub use cast::*; pub use contiguous::*; pub use mask::*; -pub(crate) use unary::*; +pub(crate) use unary_float::*; +pub(crate) use unary_int::*; +pub(crate) use unary_numeric::*; -pub use cubecl::{Kernel, PLANE_DIM_APPROX}; +pub use burn_common::PLANE_DIM_APPROX; +pub use cubecl::Kernel; /// Convolution kernels pub mod conv; diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index bba68c7166..d2a5a21d0a 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -24,7 +24,7 @@ fn avg_pool2d_backward_kernel( #[comptime] count_include_pad: bool, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 6da6e2b37c..40259c4573 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -16,7 +16,7 @@ fn max_pool2d_with_indices_backward_kernel( #[comptime] kernel_size_1: i32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 72040d8839..270e32f854 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -48,7 +48,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( ) { // Last two positions contain the qparams if ABSOLUTE_POS >= input.len() - 2 { - return; + terminate!(); } let qparams = QParams::new(scheme); @@ -85,7 +85,7 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( ) { // Last position contains the qparam if ABSOLUTE_POS >= input.len() - 1 { - return; + terminate!(); } let qparams = QParams::new(scheme); diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index e9494aa987..0a7b0ea553 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -34,7 +34,7 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -43,13 +43,13 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } // Cast the offset to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 2 { output[ABSOLUTE_POS] = u32::bitcast_from(offset); - return; + terminate!(); } let line_size = comptime!(input.line_size()); @@ -120,7 +120,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -128,7 +128,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } let line_size = comptime!(input.line_size()); diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 730cc83f37..ccfcc3ef9e 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -1,83 +1,167 @@ -use cubecl::prelude::Numeric; - #[cfg(feature = "autotune")] -use crate::kernel::reduce::reduce_dim_autotune; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; - -use super::{ - naive::{base::ReduceDimNaiveFamily, kernel::reduce_dim_naive}, - shared::{base::ReduceDimShared, kernel::reduce_dim_shared}, - subcube::{base::ReduceDimSubcube, kernel::reduce_dim_subcube}, +use super::{autotune_reduce, autotune_sum}; +use crate::{ + element::JitElement, + ops::{from_data, numeric::empty_device}, + tensor::JitTensor, + JitRuntime, }; +use burn_tensor::{Shape, TensorData}; +pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; +use cubecl::reduce::shared_sum; -#[allow(dead_code)] -pub(crate) trait ReduceDimAlgorithm: - core::fmt::Debug + ReduceDimNaiveFamily + ReduceDimShared + ReduceDimSubcube -{ -} - -/// Creates an empty output tensor with reduce output shape -pub fn init_reduce_output( - input: &JitTensor, - reduce_dim: usize, -) -> JitTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[reduce_dim] = 1; +/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return +/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`. +/// +/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction. +/// +/// Return an error if the `client` doesn't support atomic add for the type `E`. +pub fn sum( + tensor: JitTensor, + cube_count: SumStrategy, +) -> Result, cubecl::reduce::ReduceError> { + let client = tensor.client.clone(); + let device = tensor.device.clone(); - empty_device::(input.client.clone(), input.device.clone(), shape_out) + match cube_count { + SumStrategy::OneShot(cube_count) => { + let output = shared_sum::(&client, tensor.as_handle_ref(), cube_count)?; + Ok(from_data::( + TensorData::new(vec![output], vec![1]), + &device, + )) + } + SumStrategy::Chained(strategy) => reduce::(tensor, strategy), + #[cfg(feature = "autotune")] + SumStrategy::Autotune => Ok(autotune_sum::(&client, tensor)), + } } -#[derive(Copy, Clone, Debug)] -#[allow(missing_docs)] -pub enum ReduceStrategy { - /// Naive - Naive, - /// Use shared memory as an accumulator - SharedMemory, - /// Use subcube functions - Subcube, +/// Select a strategy to perform a sum. +pub enum SumStrategy { + /// Run a single kernel with many cubes working in parallel to sum all elements. + /// The provided value is the number of elements summed per unit (up-to-rounding ) + OneShot(u32), + /// Use multiple kernels + Chained(ReduceStrategy), + /// Use autotune to find the best cube count given the hardware and the input. #[cfg(feature = "autotune")] Autotune, } -impl Default for ReduceStrategy { +impl Default for SumStrategy { fn default() -> Self { - // if autotune is enabled, default to autotune #[cfg(feature = "autotune")] - return ReduceStrategy::Autotune; + return Self::Autotune; #[cfg(not(feature = "autotune"))] - ReduceStrategy::Naive + return Self::OneShot(4); } } -macro_rules! reduce_operation { - ($name:ident, $ops:ident) => { - #[derive(Debug)] - pub(crate) struct $ops; +/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). +/// +/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. +/// +/// If there is no error, the output is a tensor with decreasing strides +/// where the shape of reduced dim is set to 1 but all shape are similar to the input. +pub fn reduce( + mut tensor: JitTensor, + strategy: ReduceStrategy, +) -> Result, cubecl::reduce::ReduceError> { + // In practice, it looks like starting by the axis with the smallest shape + // and going in increasing order lead to the fastest calculation. + let sorted_axis = argsort(&tensor.shape.dims); + for axis in sorted_axis { + tensor = reduce_dim::(tensor, axis, strategy)?; + } + // reshape to scalar tensor + tensor.shape = Shape::new([1]); + tensor.strides = vec![1]; + Ok(tensor) +} - impl ReduceDimAlgorithm for $ops {} +fn argsort(shape: &[usize]) -> Vec { + let mut indices = (0..shape.len()).collect::>(); + indices.sort_by_key(|&i| &shape[i]); + indices +} - /// Executes the reduce operation with the given strategy. - pub fn $name( - tensor: JitTensor, - dim: usize, - strategy: ReduceStrategy, - ) -> JitTensor { - match strategy { - ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim), - #[cfg(feature = "autotune")] - ReduceStrategy::Autotune => reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim), - } +/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). +/// +/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. +/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. +/// +/// If there is no error, the output is a tensor with decreasing strides +/// where the shape of reduced dim is set to 1 but all shape are similar to the input. +pub fn reduce_dim( + input: JitTensor, + dim: usize, + strategy: ReduceStrategy, +) -> Result, cubecl::reduce::ReduceError> { + let client = input.client.clone(); + let output = init_reduce_output::(&input, dim).ok_or( + cubecl::reduce::ReduceError::InvalidAxis { + axis: dim, + rank: input.shape.num_dims(), + }, + )?; + let result = match strategy { + ReduceStrategy::Unspecified => cubecl::reduce::reduce::( + &client, + input.as_handle_ref(), + output.as_handle_ref(), + dim, + None, + ), + ReduceStrategy::Specific(strategy) => cubecl::reduce::reduce::( + &client, + input.as_handle_ref(), + output.as_handle_ref(), + dim, + Some(strategy), + ), + #[cfg(feature = "autotune")] + ReduceStrategy::Autotune => { + autotune_reduce::(&client, input, output.clone(), dim); + Ok(()) } }; + result.map(|_| output) } -// Autotunable reduce operation variants -reduce_operation!(sum_dim, SumDim); -reduce_operation!(mean_dim, MeanDim); -reduce_operation!(prod_dim, ProdDim); -reduce_operation!(argmin, Argmin); -reduce_operation!(argmax, Argmax); +/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input` +/// or return `None` if `axis` is out-of-bound. +pub fn init_reduce_output( + input: &JitTensor, + dim: usize, +) -> Option> { + (dim < input.shape.num_dims()).then(|| { + let mut shape_out = input.shape.clone(); + shape_out.dims[dim] = 1; + empty_device::(input.client.clone(), input.device.clone(), shape_out) + }) +} + +/// Select a strategy to perform a reduction. +#[derive(Copy, Clone, Debug)] +pub enum ReduceStrategy { + /// Use a best-effort strategy based on the hardware capacity. + /// This differs from Autotune as it doesn't try and compare many strategies to select the best. + Unspecified, + /// Fix the exact strategy for the reduction. + Specific(cubecl::reduce::ReduceStrategy), + /// Use autotune to find the best strategy given the hardware and the inputs. + #[cfg(feature = "autotune")] + Autotune, +} + +impl Default for ReduceStrategy { + fn default() -> Self { + #[cfg(feature = "autotune")] + return Self::Autotune; + + #[cfg(not(feature = "autotune"))] + return Self::Unspecified; + } +} diff --git a/crates/burn-jit/src/kernel/reduce/mod.rs b/crates/burn-jit/src/kernel/reduce/mod.rs index 2401f9467e..8ff38a9da7 100644 --- a/crates/burn-jit/src/kernel/reduce/mod.rs +++ b/crates/burn-jit/src/kernel/reduce/mod.rs @@ -1,12 +1,5 @@ mod base; -mod naive; -mod prod; -mod shared; -mod subcube; -mod sum; mod tune; pub use base::*; -pub use prod::*; -pub use sum::*; pub use tune::*; diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs deleted file mode 100644 index d577d3decf..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs +++ /dev/null @@ -1,36 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmax; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for Argmax { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for Argmax { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (EI::min_value(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (max, index) = accumulator; - if current_value > *max { - *max = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs deleted file mode 100644 index 2302a2b205..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs +++ /dev/null @@ -1,36 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmin; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for Argmin { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for Argmin { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (EI::max_value(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (min, index) = accumulator; - if current_value < *min { - *min = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/base.rs b/crates/burn-jit/src/kernel/reduce/naive/base.rs deleted file mode 100644 index 7512103ebb..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/base.rs +++ /dev/null @@ -1,25 +0,0 @@ -use cubecl::prelude::*; - -pub trait ReduceDimNaiveFamily: Send + Sync + 'static { - type Reduce: ReduceDimNaive; -} - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimNaive: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - - /// Initialization for naive algorithm - fn initialize_naive() -> Self::Accumulator; - - /// Inner loop for naive algorithm - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32); - - /// Assignation for naive algorithm - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - shape_reduce_dim: u32, - ); -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs deleted file mode 100644 index a3a1a5441b..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::{ - element::JitElement, kernel::reduce::init_reduce_output, tensor::JitTensor, JitRuntime, -}; -use cubecl::calculate_cube_count_elemwise; -use cubecl::prelude::*; - -use super::base::ReduceDimNaive; -use super::base::ReduceDimNaiveFamily; - -#[cube(launch_unchecked)] -pub(crate) fn naive_reduce_dim_kernel( - input: &Tensor, - output: &mut Tensor, - dim: u32, -) { - naive_reduce::, EI, EO>(input, output, dim) -} - -#[cube] -fn naive_reduce, EI: Numeric, EO: Numeric>( - input: &Tensor, - output: &mut Tensor, - dim: u32, -) { - if ABSOLUTE_POS >= output.len() { - return; - } - - let mut offset_input = 0; - - for i in 0..input.rank() { - let mut offset_local = ABSOLUTE_POS / output.stride(i); - offset_local %= output.shape(i); - if i != dim { - offset_input += offset_local * input.stride(i); - } - } - - let mut accumulator = RD::initialize_naive(); - - for i in 0..input.shape(dim) { - let index = i * input.stride(dim) + offset_input; - RD::inner_loop_naive(&mut accumulator, input[index], i); - } - - RD::assign_naive::(output, accumulator, input.shape(dim)); -} - -/// Executes the naive kernel for reduce dim -pub fn reduce_dim_naive( - input: JitTensor, - dim: usize, -) -> JitTensor { - let output = init_reduce_output::(&input, dim); - - let cube_dim = CubeDim::default(); - let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); - - unsafe { - naive_reduce_dim_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - ScalarArg::new(dim as u32), - ); - } - - output -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs deleted file mode 100644 index 774c9b251c..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs +++ /dev/null @@ -1,27 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::MeanDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for MeanDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for MeanDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(0) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, shape_reduce_dim: u32) { - let mean = accumulator / EI::cast_from(shape_reduce_dim); - output[ABSOLUTE_POS] = EO::cast_from(mean); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/mod.rs b/crates/burn-jit/src/kernel/reduce/naive/mod.rs deleted file mode 100644 index b11ee5e2da..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs deleted file mode 100644 index 1ea52a149c..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs +++ /dev/null @@ -1,26 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::ProdDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for ProdDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for ProdDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(1) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator *= current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs deleted file mode 100644 index 7168e07ff3..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs +++ /dev/null @@ -1,26 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::SumDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for SumDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for SumDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(0) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/prod.rs b/crates/burn-jit/src/kernel/reduce/prod.rs deleted file mode 100644 index 77227bae6f..0000000000 --- a/crates/burn-jit/src/kernel/reduce/prod.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{prod_dim, ReduceStrategy}; - -/// Multiply all elements in the input buffer. -pub fn prod( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - prod_dim::(input, 0, strategy) -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs deleted file mode 100644 index 43c03c09ce..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::kernel::reduce::Argmax; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::min_value(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value > values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs deleted file mode 100644 index 0e47693c5a..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ /dev/null @@ -1,64 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmin; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::max_value(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value < values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/base.rs b/crates/burn-jit/src/kernel/reduce/shared/base.rs deleted file mode 100644 index 256123fe1b..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/base.rs +++ /dev/null @@ -1,33 +0,0 @@ -use cubecl::prelude::*; - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimShared: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - type Value: CubeType; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> Self::Accumulator; - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut Self::Accumulator, - write_position: u32, - value: Self::Value, - ); - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> Self::Value; - - /// How to read from shared memory - fn read_from_shared(shared_memory: &Self::Accumulator, read_position: u32) -> Self::Value; - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &Self::Accumulator, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ); -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs deleted file mode 100644 index 1b2dcb356e..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs +++ /dev/null @@ -1,117 +0,0 @@ -use cubecl::prelude::*; - -use crate::{kernel::reduce::init_reduce_output, tensor::JitTensor, JitElement, JitRuntime}; - -use super::base::ReduceDimShared; - -#[cube(launch)] -pub fn reduce_dim_shared_kernel< - RD: ReduceDimShared, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] smem_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let mut shared_memory = RD::initialize_shared(smem_size, UNIT_POS); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } else { - if nth < shape_reduce_dim_input { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } - } - } - - sync_units(); - - let mut n_threads = CUBE_DIM; - - while n_threads > 1 { - n_threads /= 2; - - if UNIT_POS < n_threads { - let read_pos = n_threads + UNIT_POS; - let read_value = RD::read_from_shared(&shared_memory, read_pos); - RD::write_to_shared(&mut shared_memory, UNIT_POS, read_value); - } - - sync_units(); - } - - if UNIT_POS == 0 { - RD::assign_shared( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_shared< - RD: ReduceDimShared, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> JitTensor { - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim::default(); - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_shared_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - cube_dim.num_elems(), - elems_per_thread, - divisible_shape, - ); - - output -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs deleted file mode 100644 index eef8f5f478..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::kernel::reduce::MeanDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ) { - let mean = shared_memory[0] / EIn::cast_from(shape_reduce_dim); - output[write_position] = EOut::cast_from(mean); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/mod.rs b/crates/burn-jit/src/kernel/reduce/shared/mod.rs deleted file mode 100644 index b11ee5e2da..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs deleted file mode 100644 index 594f2fec11..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::kernel::reduce::ProdDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(1); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] *= value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs deleted file mode 100644 index 476dd554a4..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::kernel::reduce::SumDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs b/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs deleted file mode 100644 index c8e567e816..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::Argmax; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::min_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Max::max(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let max = plane_max(val); - - if max == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs b/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs deleted file mode 100644 index b7950ebfe2..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::Argmin; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::max_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Min::min(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let min = plane_min(val); - - if min == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/base.rs b/crates/burn-jit/src/kernel/reduce/subcube/base.rs deleted file mode 100644 index f20e538914..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/base.rs +++ /dev/null @@ -1,15 +0,0 @@ -use cubecl::prelude::*; - -#[cube] -pub trait ReduceDimSubcube: Send + Sync + 'static { - type Accumulator: CubeType; - type Value: CubeType; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator; - fn init_value() -> Self::Value; - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value; - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value; - fn update_value(current: &mut Self::Value, new: Self::Value); - fn reduce_subcube(acc: &mut Self::Accumulator, pos: u32, value: Self::Value); - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_len: u32); -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs deleted file mode 100644 index 4a32b5d641..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ /dev/null @@ -1,134 +0,0 @@ -use cubecl::{prelude::*, CubeCount, CubeDim, Feature}; - -use crate::{ - kernel::reduce::{init_reduce_output, shared::kernel::reduce_dim_shared, ReduceDimAlgorithm}, - tensor::JitTensor, - JitElement, JitRuntime, -}; - -use super::base::ReduceDimSubcube; - -#[cube(launch)] -pub fn reduce_dim_subcube_kernel< - RD: ReduceDimSubcube, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] subcube_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let should_unroll = elems_per_thread <= 8; - - let warp_id = plane_broadcast(UNIT_POS / PLANE_DIM, 0); - - let mut shared_memory = RD::init_shared(subcube_size); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - let mut value = RD::init_value(); - - #[unroll(should_unroll)] - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - let current_pos = nth * stride_reduce_dim_input + index_offset; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } else { - if nth < shape_reduce_dim_input { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } - } - } - - RD::reduce_subcube(&mut shared_memory, warp_id, value); - - sync_units(); - - if UNIT_POS >= PLANE_DIM { - return; - } - - let value = RD::read_from_shared(&shared_memory, UNIT_POS); - RD::reduce_subcube(&mut shared_memory, 0, value); - - if UNIT_POS == 0 { - RD::store( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_subcube< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> JitTensor { - let topology = input.client.properties().hardware_properties(); - - if !input.client.properties().feature_enabled(Feature::Plane) - || topology.plane_size_min != topology.plane_size_max - { - return reduce_dim_shared::(input, dim); - } - - let subcube_size = topology.plane_size_min; - - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim { - x: subcube_size, - y: subcube_size, - z: 1, - }; - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_subcube_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - subcube_size, - elems_per_thread, - divisible_shape, - ); - - output -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs deleted file mode 100644 index fb8c0b41d6..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs +++ /dev/null @@ -1,45 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::MeanDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::cast_from(0u32) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_length: u32) { - let denom = EIn::cast_from(dim_length); - out[pos] = EOut::cast_from(acc[0] / denom); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/mod.rs b/crates/burn-jit/src/kernel/reduce/subcube/mod.rs deleted file mode 100644 index 183c1e2daf..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod argmax; -pub mod argmin; -pub mod base; -pub mod kernel; -pub mod mean_dim; -pub mod prod_dim; -pub mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs deleted file mode 100644 index cccec95167..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::ProdDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::from_int(1) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current *= new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let prod = plane_prod(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = prod; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs deleted file mode 100644 index 1059432eb2..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::SumDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::cast_from(0u32) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/sum.rs b/crates/burn-jit/src/kernel/reduce/sum.rs deleted file mode 100644 index fea80bccf0..0000000000 --- a/crates/burn-jit/src/kernel/reduce/sum.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{sum_dim, ReduceStrategy}; - -/// Sum all elements in the input buffer. -pub fn sum( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - sum_dim::(input, 0, strategy) -} diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs new file mode 100644 index 0000000000..cd5cd61157 --- /dev/null +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -0,0 +1,294 @@ +#![allow(missing_docs)] + +use burn_tensor::ElementConversion; +use cubecl::{ + client::ComputeClient, + tune::{local_tuner, LocalTuner, TunableSet}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor, + JitAutotuneKey, JitElement, JitRuntime, JitTuneId, +}; + +/// Executes autotune on reduce operations. +pub fn autotune_reduce< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, +>( + client: &ComputeClient, + input: JitTensor, + output: JitTensor, + dim: usize, +) { + use reduce_ops::*; + + static TUNER: LocalTuner = local_tuner!(); + + let tunables = TunableSet::new(create_key::, reduce_input_gen::) + .with_tunable(reduce::) + .with_tunable(reduce_shared::) + .with_tunable(reduce_plane::) + .with_tunable(reduce_shared_plane::); + + TUNER.execute( + &JitTuneId::new::(&input.device), + client, + &tunables, + (input, output, dim), + ); +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of reduce versions +pub struct ReduceAutotuneKey { + dtype: burn_tensor::DType, + #[autotune(anchor)] + reduce_axis_shape: usize, + #[autotune(anchor)] + reduce_axis_stride: usize, + #[autotune(anchor)] + outer_axes_product: usize, // The product of the shapes of all axes with greater strides. +} + +impl ReduceAutotuneKey { + pub(crate) fn generate(input: &JitTensor, axis: usize) -> Self { + let rank = input.shape.num_dims(); + + if axis > rank { + panic!("axis {axis} is out-of-bound for a rank of {rank}"); + } + + let dtype = input.dtype; + let reduce_axis_shape = input.shape.dims[axis]; + let reduce_axis_stride = input.strides[axis]; + + let outer_axes_product = input + .strides + .iter() + .zip(input.shape.dims.iter()) + .filter_map(|(stride, shape)| (*stride > reduce_axis_stride).then_some(shape)) + .product(); + + Self::new( + dtype, + reduce_axis_shape, + reduce_axis_stride, + outer_axes_product, + ) + } +} + +pub(crate) fn create_key( + input: &JitTensor, + _output: &JitTensor, + dim: &usize, +) -> JitAutotuneKey { + JitAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim)) +} + +mod reduce_ops { + #![allow(missing_docs)] + + use super::*; + + pub(crate) fn reduce_input_gen( + _key: &JitAutotuneKey, + input: &JitTensor, + output: &JitTensor, + dim: &usize, + ) -> (JitTensor, JitTensor, usize) { + let random_bounds: (In, In) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); + let input = random_like_uniform(input, random_bounds.0, random_bounds.1); + + let output = empty_device::( + output.client.clone(), + output.device.clone(), + output.shape.clone(), + ); + + (input, output, *dim) + } + + pub(crate) fn reduce< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: false, + use_planes: false, + }), + ) + .map_err(|e| format!("{e}")) + } + + pub(crate) fn reduce_shared< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: true, + use_planes: false, + }), + ) + .map_err(|e| format!("{e}")) + } + + pub(crate) fn reduce_plane< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: false, + use_planes: true, + }), + ) + .map_err(|e| format!("{e}")) + } + + pub(crate) fn reduce_shared_plane< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: true, + use_planes: true, + }), + ) + .map_err(|e| format!("{e}")) + } +} + +/// Executes autotune on reduce operations. +#[cfg(feature = "autotune")] +pub fn autotune_sum( + client: &ComputeClient, + input: JitTensor, +) -> JitTensor { + use sum_ops::*; + + static TUNER: LocalTuner = local_tuner!(); + + let tunables = TunableSet::new(create_key_sum::, sum_input_gen::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_chained::); + + TUNER.execute( + &JitTuneId::new::(&input.device), + client, + &tunables, + input, + ) +} + +pub(crate) fn create_key_sum(input: &JitTensor) -> JitAutotuneKey { + JitAutotuneKey::Sum(SumAutotuneKey::generate(input)) +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of sum versions +pub struct SumAutotuneKey { + dtype: burn_tensor::DType, + #[autotune(anchor)] + length: usize, +} + +impl SumAutotuneKey { + pub(crate) fn generate(input: &JitTensor) -> Self { + let dtype = input.dtype; + let length = input.shape.num_elements(); + Self { dtype, length } + } +} +mod sum_ops { + #![allow(missing_docs)] + + use burn_tensor::TensorData; + use cubecl::reduce::instructions::Sum; + + use crate::ops::from_data; + + use super::*; + + pub(crate) fn sum_input_gen( + _key: &JitAutotuneKey, + input: &JitTensor, + ) -> JitTensor { + let random_bounds: (E, E) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); + random_like_uniform(input, random_bounds.0, random_bounds.1) + } + + pub(crate) fn sum_one_shot( + input: JitTensor, + ) -> Result, String> { + let device = input.device.clone(); + cubecl::reduce::shared_sum::(&input.client, input.as_handle_ref(), C) + .map(|output| from_data::(TensorData::new(vec![output], vec![1]), &device)) + .map_err(|e| e.to_string()) + } + + #[cfg(feature = "autotune")] + pub(crate) fn sum_chained( + input: JitTensor, + ) -> Result, String> { + crate::kernel::reduce::reduce::( + input, + crate::kernel::reduce::ReduceStrategy::Autotune, + ) + .map_err(|e| e.to_string()) + } +} diff --git a/crates/burn-jit/src/kernel/reduce/tune/base.rs b/crates/burn-jit/src/kernel/reduce/tune/base.rs deleted file mode 100644 index f52bfd7ca0..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/base.rs +++ /dev/null @@ -1,94 +0,0 @@ -use burn_tensor::{Element, ElementConversion}; -use cubecl::tune::{local_tuner, tune_with, LocalTuner}; -use cubecl::{tune, Feature}; - -use crate::{ - element::JitElement, - kernel::{ - prng::random_like_uniform, - reduce::{ - naive::kernel::reduce_dim_naive, shared::kernel::reduce_dim_shared, - subcube::kernel::reduce_dim_subcube, ReduceDimAlgorithm, - }, - }, - tensor::JitTensor, - tune_key::JitAutotuneKey, - JitRuntime, JitTuneId, -}; - -use super::create_key; - -/// Set of reduce_dim implementations available for autotune -/// Autotune key is given by concatenating the closest upper power of 2 of -/// dim to reduce, and product of others -#[tune( - operations(reduce_dim_naive, reduce_dim_shared, reduce_dim_subcube), - create_key = create_key::, - should_run = should_run -)] -pub fn reduce_dim_operations< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - key: JitAutotuneKey, - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let random_bounds: (EI, EI) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(input, random_bounds.0, random_bounds.1); - - tune_with!(input, reduce_dim) -} - -/// Executes autotune on reduce_dim operation -pub(crate) fn reduce_dim_autotune< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let client = input.client.clone(); - - let id = JitTuneId::new::(&input.device); - - let operation_set = Box::new(ReduceDimOperations::::new(input, reduce_dim)); - - static TUNER: LocalTuner = local_tuner!(); - - TUNER.execute(&id, &client, operation_set) -} - -fn should_run< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - op: &ReduceDimOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let JitAutotuneKey::ReduceDim(key) = key else { - unreachable!() - }; - - match index { - // Naive - 0 => key.reduce_dim_length <= 8192, - // Shared - 1 => key.reduce_dim_length >= 16, - // Subcube - 2 => { - let props = op.input.client.properties(); - let hardware = props.hardware_properties(); - props.feature_enabled(Feature::Plane) - && hardware.plane_size_min == hardware.plane_size_max - } - _ => true, - } -} diff --git a/crates/burn-jit/src/kernel/reduce/tune/key.rs b/crates/burn-jit/src/kernel/reduce/tune/key.rs deleted file mode 100644 index 3634022bc7..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/key.rs +++ /dev/null @@ -1,39 +0,0 @@ -use cubecl::AutotuneKey; -use serde::{Deserialize, Serialize}; - -use burn_tensor::DType; - -use crate::{tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime}; - -/// Autotune key representative of reduce versions -#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] -pub struct ReduceAutotuneKey { - #[autotune(anchor)] - pub(crate) reduce_dim_length: usize, - #[autotune(anchor)] - pub(crate) reduce_dim_stride: usize, - #[autotune(anchor)] - pub(crate) others_product: usize, - dtype: DType, -} - -pub(crate) fn create_key( - input: &JitTensor, - reduce_dim: &usize, -) -> JitAutotuneKey { - let dims = &input.shape.dims; - let reduce_dim = *reduce_dim; - - let mut others_product = 1; - for (d, len) in dims.iter().enumerate() { - if d != reduce_dim { - others_product *= len - } - } - JitAutotuneKey::ReduceDim(ReduceAutotuneKey::new( - dims[reduce_dim], - input.strides[reduce_dim], - others_product, - EI::dtype(), - )) -} diff --git a/crates/burn-jit/src/kernel/reduce/tune/mod.rs b/crates/burn-jit/src/kernel/reduce/tune/mod.rs deleted file mode 100644 index aee5569b6b..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[cfg(feature = "autotune")] -mod base; -mod key; - -#[cfg(feature = "autotune")] -pub(crate) use base::*; -pub use key::*; diff --git a/crates/burn-jit/src/kernel/unary.rs b/crates/burn-jit/src/kernel/unary.rs deleted file mode 100644 index 09f9c77689..0000000000 --- a/crates/burn-jit/src/kernel/unary.rs +++ /dev/null @@ -1,158 +0,0 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use cubecl::{ - calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, - tensor_vectorization_factor, unexpanded, -}; - -#[cube] -pub(crate) trait UnaryOp: 'static + Send + Sync { - type Options: LaunchArg; - - /// Execute a unary operation. - fn execute(_input: Line, _options: &Self::Options) -> Line { - unexpanded!(); - } -} - -#[cube(launch)] -pub(crate) fn unary_kernel>( - input: &Tensor>, - output: &mut Tensor>, - options: &O::Options, - #[comptime] rank: Option, - #[comptime] to_contiguous: bool, -) { - let offset_output = ABSOLUTE_POS; - - if offset_output >= output.len() { - return; - } - - if to_contiguous { - let offset_input = index_offset_with_layout::( - input, - output, - offset_output, - 0, - rank.unwrap_or_else(|| output.rank()), - rank.is_some(), - ); - - output[offset_output] = O::execute(input[offset_input], options); - } else { - output[offset_output] = O::execute(input[offset_output], options); - } -} - -pub(crate) fn launch_unary, F>( - tensor: JitTensor, - options: F, -) -> JitTensor -where - // Magic fix for lifetime, the closure is supposed to capture everything required to create the - // argument. - for<'a> F: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, -{ - let ndims = tensor.shape.num_dims(); - // Vectorization is only enabled when the last dimension is contiguous. - let vectorization_factor = - tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1); - - let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); - - let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - let is_contiguous = tensor.is_contiguous(); - - if tensor.can_mut() && tensor.is_contiguous_buffer() { - unary_kernel::launch::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - TensorArg::alias(0), - options(&()), - None, - false, - ); - - tensor - } else { - let output = empty_device::( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - ); - - unary_kernel::launch::( - &client, - cube_count, - CubeDim::default(), - tensor.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), - options(&()), - Some(ndims as u32), - !is_contiguous, - ); - output - } -} - -macro_rules! unary_op { - ($name:ident, $elem:ident, $expand:expr) => { - struct $name; - - impl UnaryOp for $name { - type Options = (); - - #[allow(clippy::redundant_closure_call)] - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - _options: ::ExpandType, - ) -> as CubeType>::ExpandType { - $expand(context, input) - } - } - }; - (scalar $name:ident, $elem:ident, $expand:expr) => { - struct $name; - - impl UnaryOp for $name { - type Options = C; - - #[allow(clippy::redundant_closure_call)] - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - scalar: C::ExpandType, - ) -> as CubeType>::ExpandType { - $expand(context, input, scalar) - } - } - }; - (float($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Float, $exp); - launch_unary::($tensor, |_| ()) - }}; - (int($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Numeric, $exp); - launch_unary::($tensor, |_| ()) - }}; - (numeric($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Numeric, $exp); - launch_unary::($tensor, |_| ()) - }}; - (numeric($tensor:expr, $scalar:expr) => $exp:expr) => {{ - unary_op!(scalar Op, Numeric, $exp); - launch_unary::($tensor, |_| ScalarArg::new($scalar)) - }}; - (float($tensor:expr, $scalar:expr) => $exp:expr) => {{ - unary_op!(scalar Op, Float, $exp); - launch_unary::($tensor, |_| ScalarArg::new($scalar)) - }}; -} - -pub(crate) use unary_op; diff --git a/crates/burn-jit/src/kernel/unary_float.rs b/crates/burn-jit/src/kernel/unary_float.rs new file mode 100644 index 0000000000..4664d3c0b3 --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_float.rs @@ -0,0 +1,181 @@ +use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: FloatUnaryOp>; +} + +#[cube] +pub(crate) trait FloatUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_float( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + terminate!(); + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_float(tensor: JitTensor, args: Args) -> JitTensor +where + // Magic fix for lifetime, the closure is supposed to capture everything required to create the + // argument. + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: JitElement + Float, + O: FloatUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_float::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_float::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} + +/// Use comptime enum to implement all unary operations that don't have any input argument in the +/// kernel definition. +pub(crate) mod unary_basic { + use crate::execute_with_dtype; + + use super::*; + + pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + where + R: JitRuntime, + for<'a> Args: FnOnce(&'a ()) -> &'a BasicFloatUnaryKind, + { + execute_with_dtype!( + float(tensor.dtype), + F, + launch_unary_float::(tensor, |input| { + BasicFloatUnaryOptionsLaunch::new(args(input)) + }) + ) + } + + #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub enum BasicFloatUnaryKind { + Exp, + Log, + Log1p, + Sqrt, + Abs, + Cos, + Sin, + Tanh, + Round, + Floor, + Ceil, + Erf, + Recip, + } + + #[derive(CubeLaunch)] + struct BasicFloatUnaryOptions { + #[cube(comptime)] + kind: BasicFloatUnaryKind, + } + struct BasicFloatUnary; + + #[cube] + impl FloatUnaryOp for BasicFloatUnary { + type Options = BasicFloatUnaryOptions; + + fn execute(input: Line, options: &Self::Options) -> Line { + match comptime![options.kind] { + BasicFloatUnaryKind::Exp => Line::exp(input), + BasicFloatUnaryKind::Log => Line::log(input), + BasicFloatUnaryKind::Log1p => Line::log1p(input), + BasicFloatUnaryKind::Sqrt => Line::sqrt(input), + BasicFloatUnaryKind::Abs => Line::abs(input), + BasicFloatUnaryKind::Cos => Line::cos(input), + BasicFloatUnaryKind::Sin => Line::sin(input), + BasicFloatUnaryKind::Tanh => Line::tanh(input), + BasicFloatUnaryKind::Round => Line::round(input), + BasicFloatUnaryKind::Floor => Line::floor(input), + BasicFloatUnaryKind::Ceil => Line::ceil(input), + BasicFloatUnaryKind::Erf => Line::erf(input), + BasicFloatUnaryKind::Recip => Line::recip(input), + } + } + } + + impl FloatUnaryOpFamily for BasicFloatUnary { + type Options = BasicFloatUnaryOptions; + type Unary = Self; + } +} diff --git a/crates/burn-jit/src/kernel/unary_int.rs b/crates/burn-jit/src/kernel/unary_int.rs new file mode 100644 index 0000000000..17bced52d1 --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_int.rs @@ -0,0 +1,148 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: IntUnaryOp>; +} + +#[cube] +pub(crate) trait IntUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_int( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + terminate!(); + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_int(tensor: JitTensor, args: Args) -> JitTensor +where + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: IntElement + Int, + O: IntUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} + +pub(crate) mod unary_basic_int { + + use super::*; + + pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + where + R: JitRuntime, + for<'a> Args: FnOnce(&'a ()) -> &'a BasicIntUnaryKind, + I: IntElement, + { + launch_unary_int::(tensor, |input| { + BasicIntUnaryOptionsLaunch::new(args(input)) + }) + } + + #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub enum BasicIntUnaryKind { + BitwiseNot, + } + + #[derive(CubeLaunch)] + struct BasicIntUnaryOptions { + #[cube(comptime)] + kind: BasicIntUnaryKind, + } + struct BasicIntUnary; + + #[cube] + impl IntUnaryOp for BasicIntUnary { + type Options = BasicIntUnaryOptions; + + fn execute(input: Line, options: &Self::Options) -> Line { + match comptime![options.kind] { + BasicIntUnaryKind::BitwiseNot => Line::bitwise_not(input), + } + } + } + + impl IntUnaryOpFamily for BasicIntUnary { + type Options = BasicIntUnaryOptions; + type Unary = Self; + } +} diff --git a/crates/burn-jit/src/kernel/unary_numeric.rs b/crates/burn-jit/src/kernel/unary_numeric.rs new file mode 100644 index 0000000000..aaeadbb685 --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_numeric.rs @@ -0,0 +1,106 @@ +use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: NumericUnaryOp>; +} + +#[cube] +pub(crate) trait NumericUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_numeric( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + terminate!(); + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_numeric(tensor: JitTensor, args: Args) -> JitTensor +where + // Magic fix for lifetime, the closure is supposed to capture everything required to create the + // argument. + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: JitElement + Numeric, + O: NumericUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_numeric::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_numeric::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index bce600604e..645aaf1535 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -17,7 +17,8 @@ pub(crate) async fn into_data(tensor: JitTensor let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; - TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) + let actual_len = tensor.shape.num_elements() * size_of::(); + TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape) } /// Read data from a `JitTensor` synchronously @@ -26,7 +27,8 @@ pub fn into_data_sync(tensor: JitTensor) -> Ten let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one(tensor.handle.binding()); - TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) + let actual_len = tensor.shape.num_elements() * size_of::(); + TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape) } pub(crate) async fn bool_into_data( @@ -34,8 +36,9 @@ pub(crate) async fn bool_into_data( ) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; + let actual_len = tensor.shape.num_elements() * size_of::(); TensorData::new( - BT::from_bytes(&bytes) + BT::from_bytes(&bytes[..actual_len]) .iter() .map(|i| *i != BT::false_val()) .collect(), diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 8cf10292d7..0bacf75045 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -1,6 +1,9 @@ use super::{expand, numeric, permute}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; -use crate::kernel::{self, launch_unary, reduce, unary_op, UnaryOp}; +use crate::kernel::unary_basic::BasicFloatUnaryKind; +use crate::kernel::{ + self, launch_unary_float, reduce, unary_basic, FloatUnaryOp, FloatUnaryOpFamily, +}; use crate::{ element::BoolElement, kernel::matmul::{matmul, MatmulStrategy}, @@ -162,7 +165,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - matmul::(lhs, rhs, None, MatmulStrategy::default()) + matmul::(lhs, rhs, None, MatmulStrategy::default()).unwrap() ) } @@ -352,7 +355,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum::(tensor, Default::default()) + reduce::sum::(tensor, Default::default()).unwrap() ) } @@ -360,7 +363,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum_dim::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -368,7 +371,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::mean_dim::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -376,7 +379,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() ) } @@ -384,197 +387,87 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod_dim::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } fn float_exp(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::exp(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Exp) } fn float_log(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Log) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log1p(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Log1p) } fn float_powf_scalar(lhs: FloatTensor, rhs: f32) -> FloatTensor { + struct Powf; + + #[cube] + impl FloatUnaryOp for Powf { + type Options = F; + + fn execute(input: Line, options: &Self::Options) -> Line { + Line::powf(input, Line::new(*options)) + } + } + + impl FloatUnaryOpFamily for Powf { + type Options = F; + type Unary = Self; + } + execute_with_dtype!( float(lhs.dtype), F, - unary_op!(float(lhs, rhs.elem::()) => |context, tensor, scalar| { - #[cube] - fn execute(input: Line, scalar: C) -> Line { - Line::powf(input, Line::new(scalar)) - } - execute::expand::(context, tensor, scalar) - }) + launch_unary_float::(lhs, |_| ScalarArg::new(rhs.elem::())) ) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sqrt(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Sqrt) } fn float_abs(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::abs(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Abs) } fn float_cos(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::cos(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Cos) } fn float_sin(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sin(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Sin) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::tanh(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Tanh) } fn float_round(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::round(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Round) } fn float_floor(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::floor(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Floor) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::ceil(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Ceil) } fn float_erf(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::erf(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Erf) } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { execute_with_dtype!( float(tensor.dtype), E, - reduce::argmax::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -582,7 +475,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmin::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -603,17 +496,7 @@ where } fn float_recip(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::recip(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Recip) } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index b0772f11cc..c60b2fe171 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,10 @@ +use self::unary_basic_int::BasicIntUnaryKind; + use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::kernel::{ + launch_binop_int, launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int, + BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, +}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -193,31 +198,31 @@ where } fn int_sum(tensor: IntTensor) -> IntTensor { - kernel::reduce::sum::(tensor, Default::default()) + reduce::sum::(tensor, Default::default()).unwrap() } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::sum_dim::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_prod(tensor: IntTensor) -> IntTensor { - kernel::reduce::prod::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::prod_dim::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::mean_dim::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin::(tensor, dim, Default::default()) + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_clamp( @@ -229,13 +234,23 @@ where } fn int_abs(tensor: IntTensor) -> IntTensor { - unary_op!(int(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { + struct Abs; + + #[cube] + impl NumericUnaryOp for Abs { + type Options = (); + + fn execute(input: Line, _options: &Self::Options) -> Line { Line::abs(input) } - execute::expand::(context, tensor) - }) + } + + impl NumericUnaryOpFamily for Abs { + type Options = (); + type Unary = Self; + } + + launch_unary_numeric::(tensor, |_| ()) } fn int_into_float(tensor: IntTensor) -> FloatTensor { @@ -284,7 +299,52 @@ where kernel::flip::(tensor, axes) } +<<<<<<< HEAD fn int_cumsum(_tensor: IntTensor, _dim: usize) -> IntTensor { todo!() +======= + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_and::(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_and_scalar::(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_or::(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_xor::(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_basic_int::launch::(tensor, |_| &BasicIntUnaryKind::BitwiseNot) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + launch_binop_int::(lhs, rhs) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + launch_scalar_binop_int::(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + launch_binop_int::(lhs, rhs) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + launch_scalar_binop_int::(lhs, rhs) +>>>>>>> main } } diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index b5c96058f9..c7f7b18b32 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -25,7 +25,7 @@ where bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { - kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) + kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()).unwrap() } fn deform_conv2d( @@ -36,7 +36,7 @@ where bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { - kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) + kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options).unwrap() } fn deform_conv2d_backward( @@ -57,6 +57,7 @@ where output_grad, options, ) + .unwrap() } fn conv3d( @@ -81,6 +82,7 @@ where options, ConvTranspose2dStrategy::default(), ) + .unwrap() } fn conv_transpose3d( diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index 5632425198..cf15916aab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -1,8 +1,9 @@ use crate::kernel::{ - launch_binop, launch_scalar_binop, AddOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, + launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int, AddOp, + BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, }; use crate::{element::JitElement, tensor::JitTensor}; -use crate::{FloatElement, JitRuntime}; +use crate::{FloatElement, IntElement, JitRuntime}; use burn_tensor::{ElementConversion, Shape}; use cubecl::client::ComputeClient; use cubecl::tensor_vectorization_factor; @@ -30,7 +31,7 @@ pub fn full_device( #[cube(launch)] pub fn full_kernel(tensor: &mut Tensor, value: C) { if ABSOLUTE_POS >= tensor.len() { - return; + terminate!(); } tensor[ABSOLUTE_POS] = value; @@ -137,5 +138,38 @@ pub fn remainder_scalar(lhs: JitTensor, rhs: E) } pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_binop::(lhs, rhs) + launch_binop::>(lhs, rhs) +} + +pub fn bitwise_and( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_and_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_or( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_or_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) } diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 54e50468fb..cfdf3319fe 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -1,5 +1,6 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use cubecl::{prelude::*, Compiler, ExecutionMode, KernelId}; +use burn_common::ExecutionMode; +use cubecl::{prelude::*, Compiler, KernelId}; use super::SourceTemplate; diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index e114b2f8e6..b586c4a6b7 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -1,5 +1,5 @@ use crate::element::JitElement; -use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; use crate::JitRuntime; use burn_tensor::quantization::QTensorPrimitive; use burn_tensor::{DType, Shape, TensorMetadata}; @@ -314,15 +314,29 @@ where /// Copy the current tensor. pub fn copy(&self) -> Self { - execute_with_dtype!(self.dtype, E, { - unary_op!(numeric(self.clone()) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - input - } - execute::expand::(context, tensor) - }) - }) + struct Copy; + + #[cube] + impl NumericUnaryOp for Copy { + type Options = (); + + fn execute(input: Line, _options: &Self::Options) -> Line { + input + } + } + + impl NumericUnaryOpFamily for Copy { + type Options = (); + type Unary = Self; + } + + let tensor = self.clone(); + + execute_with_dtype!( + tensor.dtype, + E, + launch_unary_numeric::(tensor, |_| ()) + ) } /// Check if the tensor is safe to mutate. diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index f60edc2a1b..a79ac3c437 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -48,7 +48,6 @@ macro_rules! testgen_all { mod kernel { use super::*; - burn_jit::testgen_reduction!(); burn_jit::testgen_conv2d!(); burn_jit::testgen_conv3d!(); burn_jit::testgen_conv_transpose2d!(); @@ -80,6 +79,8 @@ macro_rules! testgen_all { burn_jit::testgen_clamp!(); burn_jit::testgen_unary!(); + burn_jit::testgen_reduce!(); + burn_jit::testgen_quantization!(); } } diff --git a/crates/burn-jit/src/tests/reduce.rs b/crates/burn-jit/src/tests/reduce.rs index 3e8f81fa8c..8e533361e9 100644 --- a/crates/burn-jit/src/tests/reduce.rs +++ b/crates/burn-jit/src/tests/reduce.rs @@ -1,566 +1,128 @@ -#[burn_tensor_testgen::testgen(reduction)] -mod reduction { +#[burn_tensor_testgen::testgen(reduce)] +mod reduce { use super::*; use burn_jit::kernel::reduce::{ - argmax, argmin, mean_dim, prod, prod_dim, sum, sum_dim, ReduceStrategy, + reduce, reduce_dim, ArgMax, ArgMin, Mean, Prod, ReduceStrategy, Sum, }; use burn_tensor::{ backend::Backend, ops::IntTensorOps, Distribution, Int, Shape, Tensor, TensorData, TensorPrimitive, }; - #[test] - fn reduction_sum_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - ))); - let val_ref = tensor_ref.sum_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_prod_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(prod_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - ))); - let val_ref = tensor_ref.prod_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmin_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(argmin::( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - )); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(argmax::( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - )); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn sum_dim_should_work_with_int() { - let summed_shape = Shape::new([1]); - let data = TensorData::from([1, 2, 3, 4]); - let tensor = TestBackend::int_from_data(data, &Default::default()); - - let val = Tensor::::from_primitive(sum_dim::( - tensor, - 0, - ReduceStrategy::Naive, - )); - - let sum_as_data = TensorData::from([10]); - val.into_data().assert_approx_eq(&sum_as_data, 1); - } - - #[test] - fn mean_dim_should_work_with_int() { - let mean_shape = Shape::new([1]); - let data = TensorData::from([1, 2, 3, 4]); - let tensor = TestBackend::int_from_data(data, &Default::default()); - - let val = Tensor::::from_primitive(mean_dim::( - tensor, - 0, - ReduceStrategy::Naive, - )); - - // Mean calculation truncates to an integer - let mean_as_data = TensorData::from([2]); - val.into_data().assert_approx_eq(&mean_as_data, 1); - } - - #[test] - fn reduction_sum_dim_shared_memory_small() { - let tensor = - Tensor::::random([700], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_small() { - let tensor = - Tensor::::random([700], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } + const RANK: usize = 4; + const SHAPE: [usize; RANK] = [2, 4, 8, 16]; #[test] - fn reduction_sum_dim_shared_memory_medium_divisible() { + fn reduction_argmax_should_match_reference_backend() { let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_medium_divisible() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .argmax(dim) + .into_data() + .assert_eq(&tensor_ref.clone().argmax(dim).into_data(), false); + } } #[test] - fn reduction_sum_dim_shared_memory_medium_not_divisible() { + fn reduction_argmin_should_match_reference_backend() { let tensor = - Tensor::::random([6, 1025], Distribution::Default, &Default::default()); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .argmin(dim) + .into_data() + .assert_eq(&tensor_ref.clone().argmin(dim).into_data(), false); + } } #[test] - fn reduction_sum_dim_subcube_medium_not_divisible() { + fn reduction_mean_dim_should_match_reference_backend() { let tensor = - Tensor::::random([6, 1025], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_shared_memory_large() { - let tensor = Tensor::::random( - [4, 1024, 50], - Distribution::Default, - &Default::default(), - ); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_large() { - let tensor = Tensor::::random( - [4, 1024, 50], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .mean_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().mean_dim(dim).into_data(), 1e-6); + } } #[test] - fn reduction_mean_dim_shared_memory_medium() { + fn reduction_mean_should_match_reference_backend() { let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(mean_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.mean_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .mean() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().mean().into_data(), 1e-6); } #[test] - fn reduction_mean_dim_subcube_medium() { + fn reduction_prod_dim_should_match_reference_backend() { let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(mean_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.mean_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .prod_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().prod_dim(dim).into_data(), 1e-6); + } } #[test] - fn reduction_argmin_shared_memory_medium() { + fn reduction_prod_should_match_reference_backend() { let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmin::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .prod() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().prod().into_data(), 1e-6); } #[test] - fn reduction_argmin_subcube_medium() { + fn reduction_sum_dim_should_match_reference_backend() { let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmin::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_shared_memory_medium() { - let tensor = Tensor::::random( - [10, 3000], - Distribution::Default, - &Default::default(), - ); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmax::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_subcube_medium() { - let tensor = Tensor::::random( - [10, 3000], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmax::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .sum_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().sum_dim(dim).into_data(), 1e-6); + } } #[test] fn reduction_sum_should_match_reference_backend() { let tensor = - Tensor::::random([6, 256], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum::< - _, - ::FloatElem, - >( - tensor.into_primitive().tensor(), - ReduceStrategy::default(), - ))); - let val_ref = tensor_ref.sum(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_prod_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 256], Distribution::Default, &Default::default()); + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - - let val = Tensor::::from_primitive(TensorPrimitive::Float(prod::< - _, - ::FloatElem, - >( - tensor.into_primitive().tensor(), - ReduceStrategy::default(), - ))); - let val_ref = tensor_ref.prod(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmax_shared_memory_extreme_values_float() { - let data = TensorData::from([-999999., -999997., -999998.]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmax::( - tensor.into_primitive().tensor(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 1, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmin_shared_memory_extreme_values_float() { - let data = TensorData::from([999999., 999998., 999997.]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmin::( - tensor.into_primitive().tensor(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 2, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmin_shared_memory_extreme_values_i32() { - let data = TensorData::from([999999, 999998, 999997]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmin::( - tensor.into_primitive(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 2, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmax_shared_memory_extreme_values_i32() { - let data = TensorData::from([-999999, -999997, -999998]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmax::( - tensor.into_primitive(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 1, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .sum() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().sum().into_data(), 1e-6); } } diff --git a/crates/burn-jit/src/tune_key.rs b/crates/burn-jit/src/tune_key.rs index 0a7ae855b9..9a86a85483 100644 --- a/crates/burn-jit/src/tune_key.rs +++ b/crates/burn-jit/src/tune_key.rs @@ -1,7 +1,7 @@ use crate::kernel::{ conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey}, matmul::MatmulAutotuneKey, - reduce::ReduceAutotuneKey, + reduce::{ReduceAutotuneKey, SumAutotuneKey}, }; use cubecl::tune::AutotuneKey; use serde::{Deserialize, Serialize}; @@ -13,7 +13,9 @@ pub enum JitAutotuneKey { /// Key for matmul operation Matmul(MatmulAutotuneKey), /// Key for reduce dim operations - ReduceDim(ReduceAutotuneKey), + Reduce(ReduceAutotuneKey), + /// Key for sum operations + Sum(SumAutotuneKey), /// Key for convolution operations Conv2d(Conv2dAutotuneKey), /// Key for transpose convolution operations @@ -24,7 +26,8 @@ impl Display for JitAutotuneKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), - JitAutotuneKey::ReduceDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + JitAutotuneKey::Sum(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), } diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index 89253cd7e8..111649ab25 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -43,9 +43,9 @@ blas-openblas-system = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", optional = true } -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"] } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, optional = true } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"] } atomic_float = { workspace = true } blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible @@ -62,10 +62,10 @@ spin = { workspace = true } # usi portable-atomic-util = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index ac2e25ae86..56b969a67c 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -6,7 +6,7 @@ use burn_tensor::{ use core::ops::AddAssign; use ndarray::{ s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, - Ix4, + Ix4, Zip, }; #[cfg(not(feature = "std"))] use num_traits::Float; @@ -593,29 +593,37 @@ pub mod backward { AtomicF32::new(0.0) }); + let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| { + let group = in_channel / channels_per_offset_group; + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset + .to_shape((offs_groups, kernel_h, kernel_w, 2)) + .unwrap(); + let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); + let offset = [offset[0], offset[1]]; + let mask = mask + .as_ref() + .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); + deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); + }; + + // `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise + #[cfg(feature = "std")] run_par!(|| { - iter_par!(columns.indexed_iter()).for_each( - |((in_channel, kernel_y, kernel_x, batch, out_y, out_x), col)| { - let group = in_channel / channels_per_offset_group; - let offset = offset.slice(s![batch, .., out_y, out_x]); - let offset = offset - .to_shape((offs_groups, kernel_h, kernel_w, 2)) - .unwrap(); - let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); - let offset = [offset[0], offset[1]]; - let mask = mask - .as_ref() - .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); - let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - - F::from_elem(args.padding[0]) - + offset[0]; - let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - - F::from_elem(args.padding[1]) - + offset[1]; - let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); - deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); - }, - ) + iter_par!(Zip::indexed(columns)) + .for_each(|(args0, args1)| compute_for_each(args0, args1)) + }); + + #[cfg(not(feature = "std"))] + run_par!(|| { + iter_par!(Zip::indexed(columns)).for_each(|args0, args1| compute_for_each(args0, args1)) }); let grad_in: Array1 = grad_in diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 87d6e46a3f..0f43ef1dac 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -352,6 +352,73 @@ impl IntTensorOps NdArrayOps::expand(tensor, shape) } + fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() & (b.elem::())).elem() + }) + } + + fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() & rhs.elem::()).elem() + }) + } + + fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() | (b.elem::())).elem() + }) + } + + fn bitwise_or_scalar( + lhs: burn_tensor::ops::IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> burn_tensor::ops::IntTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() | rhs.elem::()).elem() + }) + } + + fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() ^ (b.elem::())).elem() + }) + } + + fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() ^ rhs.elem::()).elem() + }) + } + + fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) + } + + fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() << (b.elem::())).elem() + }) + } + + fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() << rhs.elem::()).elem() + }) + } + + fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() >> (b.elem::())).elem() + }) + } + + fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() >> rhs.elem::()).elem() + }) + } + fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::cumsum(tensor, dim) } diff --git a/crates/burn-no-std-tests/Cargo.toml b/crates/burn-no-std-tests/Cargo.toml index e15ce56d15..77c7524f6f 100644 --- a/crates/burn-no-std-tests/Cargo.toml +++ b/crates/burn-no-std-tests/Cargo.toml @@ -14,7 +14,7 @@ version.workspace = true # ** Please make sure all dependencies support no_std ** -burn = { path = "../burn", version = "0.16.0", default-features = false } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } +burn = { path = "../burn", version = "0.17.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } serde = { workspace = true } diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index 16a236ac94..9ebd0c8568 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -19,9 +19,9 @@ server = ["axum", "tracing-core", "tracing-subscriber"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = true, features = ["repr"]} -burn-common = { path = "../burn-common", version = "0.16.0", default-features = true} -burn-router = { path = "../burn-router", version = "0.16.0", default-features = true} +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = true, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.17.0", default-features = true} +burn-router = { path = "../burn-router", version = "0.17.0", default-features = true} # Basic dependencies derive-new = {workspace = true } @@ -39,14 +39,10 @@ async-channel = { workspace = true, optional = true } tokio-tungstenite = { version = "0.26", optional = true } # Server dependencies -axum = { version = "0.7.9", features = ["ws"], optional = true } +axum = { version = "0.8.1", features = ["ws"], optional = true } tracing-core = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } -[dev-dependencies] -# We activate the features client and server during dev. -burn-remote = { path = ".", version = "0.16.0", features=["client", "server"] } - [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/burn-remote/src/server/base.rs b/crates/burn-remote/src/server/base.rs index 1a364c87e9..87c0d7f7cf 100644 --- a/crates/burn-remote/src/server/base.rs +++ b/crates/burn-remote/src/server/base.rs @@ -102,7 +102,10 @@ impl WsServer { let response = callback.recv().unwrap(); let bytes = rmp_serde::to_vec(&response).unwrap(); - socket.send(ws::Message::Binary(bytes)).await.unwrap(); + socket + .send(ws::Message::Binary(bytes.into())) + .await + .unwrap(); } } Err(err) => panic!("Can't start the response handler {err:?}"), diff --git a/crates/burn-remote/src/server/session.rs b/crates/burn-remote/src/server/session.rs index 7d32d04b74..3da6b2afa1 100644 --- a/crates/burn-remote/src/server/session.rs +++ b/crates/burn-remote/src/server/session.rs @@ -101,12 +101,12 @@ impl SessionManager { impl Session { fn new(runner: Runner) -> Self { - let (sender, reveiver) = std::sync::mpsc::sync_channel(1); + let (sender, receiver) = std::sync::mpsc::sync_channel(1); Self { runner, streams: Default::default(), sender, - receiver: Some(reveiver), + receiver: Some(receiver), } } diff --git a/crates/burn-router/Cargo.toml b/crates/burn-router/Cargo.toml index f6df54e59f..6f21d63640 100644 --- a/crates/burn-router/Cargo.toml +++ b/crates/burn-router/Cargo.toml @@ -17,22 +17,22 @@ std = ["burn-tensor/std", "burn-common/std"] doc = ["default"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]} -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false} +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false} hashbrown = { workspace = true } spin = { workspace = true } log = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", default-features = false } [package.metadata.docs.rs] diff --git a/crates/burn-router/src/lib.rs b/crates/burn-router/src/lib.rs index 644f65ee67..773235f781 100644 --- a/crates/burn-router/src/lib.rs +++ b/crates/burn-router/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![recursion_limit = "138"] //! Burn multi-backend router. diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 25c46ae854..5d01ddd7d3 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -4,9 +4,9 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntEle use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, - OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription, - ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, + FromDataOperationDescription, OperationDescription, PermuteOperationDescription, + RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription, + SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata}; @@ -31,7 +31,18 @@ impl BoolTensorOps for BackendRouter { fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::()) + let out = client.register_empty_tensor(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::FromData(desc), + )); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index e66343ec2a..c988b76b7b 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -8,13 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + FloatOperationDescription, FromDataOperationDescription, GatherOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -25,7 +25,18 @@ use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; impl FloatTensorOps for BackendRouter { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::FloatElem>()) + let out = client.register_empty_tensor(data.shape.clone(), FloatElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::FromData(desc), + )); + + out } fn float_random( diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index c0d3c13218..22072f900f 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -8,13 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + FromDataOperationDescription, GatherOperationDescription, IntOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -45,7 +45,18 @@ impl IntTensorOps for BackendRouter { fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::IntElem>()) + let out = client.register_empty_tensor(data.shape.clone(), IntElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::FromData(desc), + )); + + out } fn int_device(tensor: &IntTensor) -> Device { @@ -1174,6 +1185,203 @@ impl IntTensorOps for BackendRouter { out } + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAnd(desc), + )); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOr(desc), + )); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXor(desc), + )); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseNot(desc), + )); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAndScalar(desc), + )); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOrScalar(desc), + )); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXorScalar(desc), + )); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShift(desc), + )); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShiftScalar(desc), + )); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShift(desc), + )); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShiftScalar(desc), + )); + + out + } + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let dtype = tensor.dtype; diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 0c9238e5e7..596eabec00 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -245,6 +245,10 @@ impl RunnerClient for Runner { let output = B::float_empty(shape, &self.device); handles.register_float_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::float_from_data(desc.data.clone(), &self.device); + handles.register_float_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseInt(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -316,6 +320,10 @@ impl RunnerClient for Runner { let output = B::int_empty(shape, &self.device); handles.register_int_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::int_from_data(desc.data.clone(), &self.device); + handles.register_int_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseBool(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -391,6 +399,10 @@ impl RunnerClient for Runner { let output = B::bool_empty(shape, &self.device); handles.register_bool_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::bool_from_data(desc.data.clone(), &self.device); + handles.register_bool_tensor::(&desc.out.id, output); + } }, OperationDescription::NumericFloat(_dtype, op) => match op { NumericOperationDescription::Add(desc) => { @@ -798,6 +810,39 @@ impl RunnerClient for Runner { let output = B::int_into_float(tensor); handles.register_float_tensor::(&desc.out.id, output); } + IntOperationDescription::BitwiseAnd(desc) => { + binary_int_ops!(handles, desc, B::bitwise_and) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_and_scalar) + } + IntOperationDescription::BitwiseOr(desc) => { + binary_int_ops!(handles, desc, B::bitwise_or) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_or_scalar) + } + IntOperationDescription::BitwiseXor(desc) => { + binary_int_ops!(handles, desc, B::bitwise_xor) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_xor_scalar) + } + IntOperationDescription::BitwiseNot(desc) => { + unary_int_ops!(handles, desc, B::bitwise_not) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_left_shift) + } + IntOperationDescription::BitwiseRightShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_right_shift) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar) + } }, OperationDescription::Float(_dtype, op) => match op { FloatOperationDescription::Exp(desc) => { diff --git a/crates/burn-tch/Cargo.toml b/crates/burn-tch/Cargo.toml index 44702c21ba..69b0240c34 100644 --- a/crates/burn-tch/Cargo.toml +++ b/crates/burn-tch/Cargo.toml @@ -16,7 +16,7 @@ default = [] doc = ["tch/doc-only"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } half = { workspace = true, features = ["std"] } libc = { workspace = true } @@ -25,10 +25,10 @@ tch = { workspace = true, features = ["download-libtorch"] } log = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index e1a7bf7d09..9f57081969 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -484,4 +484,118 @@ impl TchOps { pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor { TchTensor::new(tensor.tensor.argsort(dim as i64, descending)) } + + pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_and_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_or_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_xor_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_not(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_not_().unwrap(), + |tensor| tensor.f_bitwise_not().unwrap(), + ) + } + + pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_left_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } + + pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_right_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } } diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index b2cd14f326..e23e534dbf 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -417,6 +417,65 @@ impl IntTensorOps for LibTorch { TchOps::argsort(tensor, dim, descending) } + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_and(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_or(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_xor(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + TchOps::bitwise_not(tensor) + } + + fn bitwise_and_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_right_shift_scalar(lhs, rhs) + } + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { TchOps::cumsum(tensor, dim) } diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 0eb79c2e88..7428408292 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -16,7 +16,7 @@ cubecl = ["dep:cubecl"] cubecl-cuda = ["cubecl", "cubecl/cuda"] cubecl-hip = ["cubecl", "cubecl/hip"] cubecl-wgpu = ["cubecl", "cubecl/wgpu"] -default = ["std", "repr"] +default = ["std", "repr", "burn-common/rayon"] doc = ["default"] experimental-named-tensor = [] export_tests = ["burn-tensor-testgen", "cubecl"] @@ -26,14 +26,13 @@ std = [ "half/std", "num-traits/std", "burn-common/std", - "burn-common/rayon", "colored", ] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } -cubecl = { workspace = true, optional = true, default-features = true } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } +cubecl = { workspace = true, optional = true, default-features = false } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } colored = { workspace = true, optional = true } diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index d3cb280e90..0376da57a2 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -1,8 +1,6 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] -// Allow deprecated `Data` and `DataSerialize` -#![allow(deprecated)] //! This library provides multiple tensor implementations hidden behind an easy to use API //! that supports reverse mode automatic differentiation. @@ -59,6 +57,8 @@ mod cube_wgpu { use crate::backend::{DeviceId, DeviceOps}; use cubecl::wgpu::WgpuDevice; + // Allow deprecated `WgpuDevice::BestAvailable` + #[allow(deprecated)] impl DeviceOps for WgpuDevice { fn id(&self) -> DeviceId { match self { diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 85e18ec444..dce51f5ee2 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -26,6 +26,23 @@ pub struct HandleContainer { pub handles_orphan: Vec, } +impl HandleContainer { + /// Fork the container, useful for autotune. + pub fn fork(&self) -> Self { + let mut handles = HashMap::with_capacity(self.handles.len()); + + for (id, handle) in self.handles.iter() { + handles.insert(*id, handle.clone()); + } + + Self { + handles, + counter: self.counter, + handles_orphan: self.handles_orphan.clone(), + } + } +} + impl core::fmt::Debug for HandleContainer { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("HandleContainer") @@ -37,6 +54,7 @@ impl core::fmt::Debug for HandleContainer { } /// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state +#[derive(Clone)] pub enum Handle { /// No [tensor handle](ReprBackend::Handle) has been created yet NotInit, diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 75b46f2220..2e6b1ae709 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -6,6 +6,7 @@ use alloc::borrow::ToOwned; use alloc::boxed::Box; use alloc::{string::String, vec, vec::Vec}; +use crate::TensorData; use crate::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, @@ -197,6 +198,12 @@ pub enum ModuleOperationDescription { /// Basic operations that can be done on any tensor type. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BaseOperationDescription { + /// Operation corresponding to: + /// + /// Float => [from_data](crate::ops::FloatTensorOps::float_from_data). + /// Int => [from_data](crate::ops::IntTensorOps::int_from_data). + /// Bool => [from_data](crate::ops::BoolTensorOps::bool_from_data). + FromData(FromDataOperationDescription), /// Operation corresponding to: /// /// Float => [to device](crate::ops::FloatTensorOps::float_to_device). @@ -272,9 +279,9 @@ pub enum BaseOperationDescription { /// Operation corresponding to: /// - /// Float => [equal](crate::ops::FloatTensorOps::float_empty). - /// Int => [equal](crate::ops::IntTensorOps::int_empty). - /// Bool => [equal](crate::ops::BoolTensorOps::bool_empty). + /// Float => [empty](crate::ops::FloatTensorOps::float_empty). + /// Int => [empty](crate::ops::IntTensorOps::int_empty). + /// Bool => [empty](crate::ops::BoolTensorOps::bool_empty). Empty(TensorDescription), } @@ -522,6 +529,50 @@ pub enum NumericOperationDescription { pub enum IntOperationDescription { /// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float). IntoFloat(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and](crate::ops::IntTensorOps::bitwise_and). + BitwiseAnd(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and scalar](crate::ops::IntTensorOps::bitwise_and_scalar). + BitwiseAndScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or](crate::ops::IntTensorOps::bitwise_or). + BitwiseOr(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or scalar](crate::ops::IntTensorOps::bitwise_or_scalar). + BitwiseOrScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor](crate::ops::IntTensorOps::bitwise_xor). + BitwiseXor(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor scalar](crate::ops::IntTensorOps::bitwise_xor_scalar). + BitwiseXorScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise not](crate::ops::IntTensorOps::bitwise_not). + BitwiseNot(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift](crate::ops::IntTensorOps::bitwise_left_shift). + BitwiseLeftShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift scalar](crate::ops::IntTensorOps::bitwise_left_shift_scalar). + BitwiseLeftShiftScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift](crate::ops::IntTensorOps::bitwise_right_shift). + BitwiseRightShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift scalar](crate::ops::IntTensorOps::bitwise_right_shift_scalar). + BitwiseRightShiftScalar(ScalarOperationDescription), } /// Operation description specific to a bool tensor. @@ -588,6 +639,13 @@ pub struct RandomOperationDescription { pub distribution: Distribution, } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct FromDataOperationDescription { + pub out: TensorDescription, + pub data: TensorData, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ReshapeDescription { @@ -1366,6 +1424,7 @@ impl BaseOperationDescription { BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], BaseOperationDescription::Empty(desc) => vec![desc], + BaseOperationDescription::FromData(desc) => vec![&desc.out], } } } @@ -1547,6 +1606,39 @@ impl IntOperationDescription { fn nodes(&self) -> Vec<&TensorDescription> { match self { IntOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out], + IntOperationDescription::BitwiseAnd(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseAndScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseOr(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseOrScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseXor(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseXorScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseNot(desc) => { + vec![&desc.input, &desc.out] + } + IntOperationDescription::BitwiseLeftShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseRightShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } } } } @@ -1680,6 +1772,12 @@ impl ModuleOperationDescription { } } +impl core::hash::Hash for FromDataOperationDescription { + fn hash(&self, state: &mut H) { + self.out.hash(state); + } +} + impl core::hash::Hash for RandomOperationDescription { fn hash(&self, state: &mut H) { self.out.hash(state); diff --git a/crates/burn-tensor/src/tensor/activation/base.rs b/crates/burn-tensor/src/tensor/activation/base.rs index cc5990d375..15fcc7ab50 100644 --- a/crates/burn-tensor/src/tensor/activation/base.rs +++ b/crates/burn-tensor/src/tensor/activation/base.rs @@ -144,6 +144,8 @@ pub fn sigmoid(tensor: Tensor) -> Tensor } /// Applies the hard sigmoid function +/// +/// `hard_sigmoid(x) = max(0, min(1, alpha * x + beta))` pub fn hard_sigmoid( tensor: Tensor, alpha: f64, diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index fabf321d96..4bbc522f49 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -805,6 +805,7 @@ where /// # Arguments /// /// * `ranges` - A type implementing the `RangesArg` trait, which can be: + /// - A single `core::ops::Range` (slice the first dimension) /// - An array of `core::ops::Range` /// - An array of `Option<(i64, i64)>` /// - An array of `(i64, i64)` tuples @@ -2988,6 +2989,13 @@ impl RangesArg for [(i64, i64); D2] { } } +impl RangesArg<1> for core::ops::Range { + fn into_ranges(self, shape: Shape) -> [core::ops::Range; 1] { + let (start, end) = Self::clamp_range(self.start, self.end, shape.dims[0]); + [(start..end)] + } +} + /// Trait used for reshape arguments. pub trait ReshapeArgs { /// Converts to a shape. diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d4ab13faf4..8a6fb2ad78 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, Int, Shape, Tensor}; +use crate::{backend::Backend, BasicOps, Numeric, Shape, Tensor}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -447,22 +447,8 @@ impl TensorCheck { check } - pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self { - let mut check = Self::Ok; - if index >= num_classes { - check = check.register( - "One Hot", - TensorError::new(format!( - "Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})", - )), - ); - } - - check - } - - pub(crate) fn one_hot_tensor( - index_tensor: Tensor, + pub(crate) fn one_hot_tensor>( + index_tensor: Tensor, num_classes: usize, ) -> Self { let mut check = Self::Ok; @@ -487,6 +473,20 @@ impl TensorCheck { check } + pub(crate) fn one_hot_tensor_rank() -> Self { + let mut check = Self::Ok; + if D + 1 != D2 { + check = check.register( + "One Hot", + TensorError::new( + "The one-hot tensor rank must correspond to the rank of the tensor + 1", + ) + .details(format!("Expected D2={}, got {D2}", D + 1)), + ); + } + check + } + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a6f59f6e88..b50d0d0596 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,11 +1,8 @@ -use alloc::vec::Vec; -use core::convert::TryInto; - use crate::check::TensorCheck; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; -use crate::tensor::{Distribution, Shape, TensorData}; +use crate::tensor::{Distribution, TensorData}; use crate::Tensor; use crate::{check, FloatDType}; use crate::{Int, TensorPrimitive}; @@ -174,35 +171,6 @@ where ))) } - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let device = Default::default(); - /// let one_hot = Tensor::::one_hot(2, 10, &device); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - /// ``` - pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { - check!(TensorCheck::one_hot_index(index, num_classes)); - - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape, device); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) - } - /// Applies the matrix multiplication operation. /// /// `C = AB` diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 08bdab0fe7..5d65b68ceb 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,5 +1,3 @@ -use crate::check; -use crate::check::TensorCheck; use crate::{ backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -29,34 +27,6 @@ where pub fn arange_step(range: Range, step: usize, device: &B::Device) -> Self { Tensor::new(B::int_arange_step(range, step, device)) } - - /// Create a one hot tensor from an index tensor. - /// - /// # Arguments - /// - /// * `num_classes` - The number of classes to use in encoding. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); - /// let one_hot = indices.one_hot(4); - /// println!("{}", one_hot.to_data()); - /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] - /// } - /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - let [num_samples] = self.dims(); - let indices = self.unsqueeze_dim(1); - let values = indices.ones_like(); - Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values) - } } impl Tensor @@ -129,4 +99,59 @@ where ) -> Tensor { cartesian_grid::(shape, device) } + + /// Applies the bitwise logical and operation with each bit representing the integer. + pub fn bitwise_and(self, other: Self) -> Self { + Self::new(B::bitwise_and(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical or operation with another tensor. + pub fn bitwise_or(self, other: Self) -> Self { + Self::new(B::bitwise_or(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical xor operation with another tensor. + pub fn bitwise_xor(self, other: Self) -> Self { + Self::new(B::bitwise_xor(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical not operation. + pub fn bitwise_not(self) -> Self { + Self::new(B::bitwise_not(self.primitive)) + } + + /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_and_scalar(self.primitive, other)) + } + + /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_or_scalar(self.primitive, other)) + } + + /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_xor_scalar(self.primitive, other)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift(self, other: Self) -> Self { + Self::new(B::bitwise_left_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift(self, other: Self) -> Self { + Self::new(B::bitwise_right_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_left_shift_scalar(self.primitive, other)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_right_shift_scalar(self.primitive, other)) + } } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 54d801c125..1270315132 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2041,6 +2041,103 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example(){ + /// let device = Default::default(); + /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); + /// let one_hot: Tensor = indices.one_hot(4); + /// println!("{}", one_hot.to_data()); + /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + /// } + /// ``` + pub fn one_hot(self, num_classes: usize) -> Tensor { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill(num_classes, 1.0, 0.0, -1) + } + + /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. + /// + /// # Arguments + /// + /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. + /// * `on_value`: The value to assign for active positions (corresponding to indices). + /// * `off_value`: The value to assign for inactive positions. + /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. + /// + /// # Returns + /// + /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Float}; + /// fn example>>() { + /// let device = B::Device::default(); + /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); + /// // One-hot encoding + /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); + /// println!("{tensor}"); + /// // [[[5.0, 0.0, 0.0], + /// // [0.0, 0.0, 5.0]], + /// // [[0.0, 5.0, 0.0], + /// // [0.0, 0.0, 5.0]]] + /// } + /// ``` + pub fn one_hot_fill( + self, + num_classes: usize, + on_value: f32, + off_value: f32, + axis: i64, + ) -> Tensor { + check!(TensorCheck::one_hot_tensor_rank::()); + // Initialize shape from the current tensor dimensions and prepare for modification + let mut shape = self.shape().dims::().to_vec(); + let device = self.device(); + let rank = self.dims().len(); + + // Adjust negative axis to a positive index + let axis = if axis < 0 { + axis + rank as i64 + 1 + } else { + axis + }; + + // Ensure axis is within valid range + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + // Convert the input tensor to integer indices + let indices: Tensor = + Tensor::from_data(self.to_data().convert::(), &device); + // Insert the new dimension for the one-hot representation + shape.insert(axis as usize, num_classes); + // Adjust indices to valid range and handle invalid indices + let adjusted_indices = indices + .clone() + .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices + .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices + // Unsqueeze the indices tensor along the specified axis + let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); + + // Initialize the output tensor with the off_value + let output = Tensor::full(shape.clone(), off_value, &device); + + // Prepare scatter tensor for on_value and off_value adjustments + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + + // Scatter on_value at the appropriate indices to create the one-hot representation + output.scatter(axis as usize, indices_unsqueezed, scatter_on_values) + } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// diff --git a/crates/burn-tensor/src/tensor/backend/conversion.rs b/crates/burn-tensor/src/tensor/backend/conversion.rs index 46b0423b71..6aebe06463 100644 --- a/crates/burn-tensor/src/tensor/backend/conversion.rs +++ b/crates/burn-tensor/src/tensor/backend/conversion.rs @@ -188,7 +188,7 @@ mod tests { } #[test] - fn should_build_indices_2d_complexe() { + fn should_build_indices_2d_complex() { let shape = Shape::new([2, 3]); let indices = build_indices(&shape, Order::Left); @@ -206,7 +206,7 @@ mod tests { } #[test] - fn should_build_indices_3d_complexe() { + fn should_build_indices_3d_complex() { let shape = Shape::new([2, 5, 3]); let indices = build_indices(&shape, Order::Left); diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 5fa6f765fc..bd144e397f 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -1,7 +1,4 @@ -use core::{ - any::{Any, TypeId}, - f32, -}; +use core::f32; use alloc::boxed::Box; use alloc::format; @@ -14,7 +11,7 @@ use crate::{ quantization::{ Quantization, QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes, }, - tensor::{bytes::Bytes, Shape}, + tensor::bytes::Bytes, DType, Distribution, Element, ElementConversion, }; @@ -777,396 +774,6 @@ impl core::fmt::Display for TensorData { } } -/// Data structure for serializing and deserializing tensor data. -#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)] -#[deprecated( - since = "0.14.0", - note = "the internal data format has changed, please use `TensorData` instead" -)] -pub struct DataSerialize { - /// The values of the tensor. - pub value: Vec, - /// The shape of the tensor. - pub shape: Vec, -} - -/// Data structure for tensors. -#[derive(new, Debug, Clone, PartialEq, Eq)] -#[deprecated( - since = "0.14.0", - note = "the internal data format has changed, please use `TensorData` instead" -)] -pub struct Data { - /// The values of the tensor. - pub value: Vec, - - /// The shape of the tensor. - pub shape: Shape, -} - -#[allow(deprecated)] -impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - - Data { - value, - shape: self.shape, - } - } - - /// Asserts each value is within a given range. - /// - /// # Arguments - /// - /// * `range` - The range. - /// - /// # Panics - /// - /// If any value is not within the half-open range bounded inclusively below - /// and exclusively above (`start..end`). - pub fn assert_within_range(&self, range: core::ops::Range) { - let start = range.start.elem::(); - let end = range.end.elem::(); - - for elem in self.value.iter() { - let elem = elem.elem::(); - if elem < start || elem >= end { - panic!("Element ({elem:?}) is not within range {range:?}"); - } - } - } -} - -#[allow(deprecated)] -impl DataSerialize { - /// Converts the data to a different element type. - pub fn convert(self) -> DataSerialize { - if TypeId::of::() == TypeId::of::() { - let cast: Box = Box::new(self); - let cast: Box> = cast.downcast().unwrap(); - return *cast; - } - - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - - DataSerialize { - value, - shape: self.shape, - } - } - - /// Converts the data to the new [TensorData] format. - pub fn into_tensor_data(self) -> TensorData { - TensorData::new(self.value, self.shape) - } -} - -#[allow(deprecated)] -impl Data { - /// Populates the data with random values. - pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(E::random(distribution, rng)); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with zeros. - pub fn zeros>(shape: S) -> Data { - let shape = shape.into(); - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(0.elem()); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with ones. - pub fn ones(shape: Shape) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(1.elem()); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with the given value - pub fn full(shape: Shape, fill_value: E) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - for _ in 0..num_elements { - data.push(fill_value) - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data { - /// Serializes the data. - /// - /// # Returns - /// - /// The serialized data. - pub fn serialize(&self) -> DataSerialize { - DataSerialize { - value: self.value.clone(), - shape: self.shape.dims.to_vec(), - } - } -} - -#[allow(deprecated)] -impl + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data { - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `precision` - The precision of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq(&self, other: &Self, precision: usize) { - let tolerance = 0.1.pow(precision as f64); - - self.assert_approx_eq_diff(other, tolerance) - } - - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `tolerance` - The tolerance of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) { - let mut message = String::new(); - if self.shape != other.shape { - message += format!( - "\n => Shape is different: {:?} != {:?}", - self.shape.dims, other.shape.dims - ) - .as_str(); - } - - let iter = self.value.clone().into_iter().zip(other.value.clone()); - - let mut num_diff = 0; - let max_num_diff = 5; - - for (i, (a, b)) in iter.enumerate() { - let a: f64 = a.into(); - let b: f64 = b.into(); - - //if they are both nan, then they are equally nan - let both_nan = a.is_nan() && b.is_nan(); - //this works for both infinities - let both_inf = a.is_infinite() && b.is_infinite() && ((a > 0.) == (b > 0.)); - - if both_nan || both_inf { - continue; - } - - let err = (a - b).abs(); - - if E::dtype().is_float() { - if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ - {tolerance}" - ) - .as_str(); - } - num_diff += 1; - } - } else if err > tolerance || err.is_nan() { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ - {tolerance}" - ) - .as_str(); - } - num_diff += 1; - } - } - - if num_diff >= max_num_diff { - message += format!("\n{} more errors...", num_diff - 5).as_str(); - } - - if !message.is_empty() { - panic!("Tensors are not approx eq:{}", message); - } - } -} - -#[allow(deprecated)] -impl Data { - /// Converts the usize data to a different element type. - pub fn from_usize(self) -> Data { - let value: Vec = self - .value - .into_iter() - .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap()) - .collect(); - - Data { - value, - shape: self.shape, - } - } -} - -#[allow(deprecated)] -impl From<&DataSerialize> for Data { - fn from(data: &DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value.clone(), Shape::new(dims)) - } -} - -#[allow(deprecated)] -impl From> for Data { - fn from(data: DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value, Shape::new(dims)) - } -} - -#[allow(deprecated)] -impl From<[E; A]> for Data { - fn from(elems: [E; A]) -> Self { - let mut data = Vec::with_capacity(2 * A); - for elem in elems.into_iter() { - data.push(elem); - } - - Data::new(data, Shape::new([A])) - } -} - -#[allow(deprecated)] -impl From<&[E]> for Data { - fn from(elems: &[E]) -> Self { - let mut data = Vec::with_capacity(elems.len()); - for elem in elems.iter() { - data.push(*elem); - } - - Data::new(data, Shape::new([elems.len()])) - } -} - -#[allow(deprecated)] -impl From<[[E; B]; A]> for Data { - fn from(elems: [[E; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B); - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - data.push(elem); - } - } - - Data::new(data, Shape::new([A, B])) - } -} - -#[allow(deprecated)] -impl - From<[[[E; C]; B]; A]> for Data -{ - fn from(elems: [[[E; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - data.push(elem); - } - } - } - - Data::new(data, Shape::new([A, B, C])) - } -} - -#[allow(deprecated)] -impl< - E: core::fmt::Debug + Copy, - const A: usize, - const B: usize, - const C: usize, - const D: usize, - > From<[[[[E; D]; C]; B]; A]> for Data -{ - fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C * D); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - for elem in elem.into_iter().take(D) { - data.push(elem); - } - } - } - } - - Data::new(data, Shape::new([A, B, C, D])) - } -} - -#[allow(deprecated)] -impl core::fmt::Display for Data { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("{:?}", &self.value).as_str()) - } -} - fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> { let epsilon_deviations = tolerance / f32::EPSILON as f64; let epsilon = match ty { @@ -1192,9 +799,8 @@ fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<( } #[cfg(test)] -#[allow(deprecated)] mod tests { - use crate::quantization::AffineQuantization; + use crate::{quantization::AffineQuantization, Shape}; use super::*; use alloc::vec; diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 89cbf74ebc..d38e78f62b 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1197,4 +1197,37 @@ pub trait IntTensorOps { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { argsort::(tensor, dim, descending) } + + /// Bitwise AND operation for Int Tensors + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise AND operation for Int Tensors with a scalar + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise OR operation for Int Tensors + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise OR operation for Int Tensors with a scalar + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors with a scalar + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise NOT operation for Int Tensors + fn bitwise_not(tensor: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors with a scalar + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors with a scalar + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; } diff --git a/crates/burn-tensor/src/tensor/quantization/bytes.rs b/crates/burn-tensor/src/tensor/quantization/bytes.rs index 9091c37960..6d880cc923 100644 --- a/crates/burn-tensor/src/tensor/quantization/bytes.rs +++ b/crates/burn-tensor/src/tensor/quantization/bytes.rs @@ -100,7 +100,7 @@ impl QuantizedBytes { /// Splits the quantized values of the tensor from the quantization parameters. /// - /// Returns the packed values and a newly allocated vector containining the quantization parameters. + /// Returns the packed values and a newly allocated vector containing the quantization parameters. fn split_values_off(self) -> (Vec, Vec) { // The bytes can be created either from packed u32 or existing bytes with the same representation. let mut values = match self.bytes.align() { diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index fb141ee16d..27fa996ad6 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -37,7 +37,7 @@ impl CubeType for QuantizationScheme { } #[cfg(feature = "cubecl")] impl cubecl::frontend::Init for QuantizationScheme { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _scope: &mut cubecl::ir::Scope) -> Self { self } } diff --git a/crates/burn-tensor/src/tensor/quantization/strategy.rs b/crates/burn-tensor/src/tensor/quantization/strategy.rs index bb8b4c6bfb..73f1b1c0b0 100644 --- a/crates/burn-tensor/src/tensor/quantization/strategy.rs +++ b/crates/burn-tensor/src/tensor/quantization/strategy.rs @@ -4,7 +4,7 @@ use core::{ }; use alloc::vec::Vec; -use burn_common::{iter_par, run_par}; +use burn_common::{iter_slice_par, run_par}; use num_traits::{Float, PrimInt}; use serde::{Deserialize, Serialize}; @@ -35,7 +35,7 @@ impl QuantizationStrategy { /// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision /// data type `Q` and vice-versa. -pub trait Quantization { +pub trait Quantization { /// Create a new quantization scheme for an input range `[alpha, beta]`. fn new(alpha: E, beta: E) -> Self; /// Convert the values to a lower precision data type. @@ -48,7 +48,7 @@ pub trait Quantization { /// /// Note that the accumulation type `A` should have a bigger range than quantized type `Q`. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct AffineQuantization { +pub struct AffineQuantization { /// The scaling factor. pub scale: E, /// The zero-point offset. @@ -66,7 +66,7 @@ fn valid_scale(mut scale: E) -> E { scale } -impl AffineQuantization { +impl AffineQuantization { /// Initialize an affine quantization scheme with the given parameters. pub fn init(scale: E, offset: Q) -> Self { Self { @@ -77,7 +77,9 @@ impl AffineQuantization { } } -impl Quantization for AffineQuantization { +impl Quantization + for AffineQuantization +{ fn new(alpha: E, beta: E) -> Self { // Q range `[a, b]` let a = E::from(Q::min_value()).unwrap(); @@ -107,7 +109,7 @@ impl Quantization for AffineQuantization // x_q = clamp(round(x / scale + offset), a, b) let z = E::from(self.offset).unwrap(); run_par!(|| { - iter_par!(values.iter()) + iter_slice_par!(values) .map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap()) .collect() }) @@ -116,7 +118,7 @@ impl Quantization for AffineQuantization fn dequantize(&self, values: &[Q]) -> Vec { // x = scale * (x_q - offset) run_par!(|| { - iter_par!(values.iter()) + iter_slice_par!(values) .map(|x_q| { self.scale * (E::from( @@ -133,14 +135,14 @@ impl Quantization for AffineQuantization /// Symmetric quantization scheme. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct SymmetricQuantization { +pub struct SymmetricQuantization { /// The scaling factor. pub scale: E, /// The quantized type. _q: PhantomData, } -impl SymmetricQuantization { +impl SymmetricQuantization { /// Initialize a symmetric quantization scheme with the given parameters. pub fn init(scale: E) -> Self { Self { @@ -150,7 +152,9 @@ impl SymmetricQuantization { } } -impl Quantization for SymmetricQuantization { +impl Quantization + for SymmetricQuantization +{ fn new(alpha: E, beta: E) -> Self { assert!( !Q::min_value().is_zero(), @@ -214,7 +218,9 @@ fn canonicalize_signed_zero(x: T) -> T { x + T::zero() } -impl Hash for AffineQuantization { +impl Hash + for AffineQuantization +{ fn hash(&self, state: &mut H) { // Hash raw bits. let bits = raw_double_bits(&canonicalize_signed_zero(self.scale)); @@ -223,15 +229,20 @@ impl Hash for AffineQuantization PartialEq for AffineQuantization { +impl PartialEq + for AffineQuantization +{ fn eq(&self, other: &Self) -> bool { self.scale == other.scale && self.offset == other.offset } } -impl Eq for AffineQuantization {} +impl Eq + for AffineQuantization +{ +} -impl Hash for SymmetricQuantization { +impl Hash for SymmetricQuantization { fn hash(&self, state: &mut H) { // Hash raw bits. let bits = raw_double_bits(&canonicalize_signed_zero(self.scale)); @@ -239,13 +250,13 @@ impl Hash for SymmetricQuantization { } } -impl PartialEq for SymmetricQuantization { +impl PartialEq for SymmetricQuantization { fn eq(&self, other: &Self) -> bool { self.scale == other.scale } } -impl Eq for SymmetricQuantization {} +impl Eq for SymmetricQuantization {} #[cfg(test)] mod tests { diff --git a/crates/burn-tensor/src/tensor/shape.rs b/crates/burn-tensor/src/tensor/shape.rs index 8ad54ba4d9..29eebd549e 100644 --- a/crates/burn-tensor/src/tensor/shape.rs +++ b/crates/burn-tensor/src/tensor/shape.rs @@ -33,6 +33,13 @@ impl Shape { dims[..D].copy_from_slice(&self.dims[..D]); dims } + + /// Change the shape to one dimensional with the same number of elements. + pub fn flatten(&self) -> Self { + Self { + dims: [self.dims.iter().product()].into(), + } + } } impl From<[usize; D]> for Shape { diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 2ce0a2cc1d..08bc10b37e 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -312,6 +312,7 @@ macro_rules! testgen_with_int_param { burn_tensor::testgen_sub!(); burn_tensor::testgen_transpose!(); burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_bitwise!(); // test stats burn_tensor::testgen_eye!(); diff --git a/crates/burn-tensor/src/tests/ops/bitwise.rs b/crates/burn-tensor/src/tests/ops/bitwise.rs new file mode 100644 index 0000000000..c85f5edcc5 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/bitwise.rs @@ -0,0 +1,176 @@ +#[burn_tensor_testgen::testgen(bitwise)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_apply_bitwise_and_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[2, 4, 0], [9, 2, 8]]), false); + } + + #[test] + fn should_apply_bitwise_and_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([9, 3]), false); + } + + #[test] + fn should_apply_bitwise_and_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_and_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 4, 5], [1, 1, 0]]), false); + } + + #[test] + fn should_apply_bitwise_not_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + + let output = tensor_1.bitwise_not(); + + output + .into_data() + .assert_eq(&TensorData::from([[-4, -5, -6], [-10, -4, -9]]), false); + } + + #[test] + fn should_apply_bitwise_or_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_or_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 5, 5], [13, 7, 13]]), false); + } + + #[test] + fn should_apply_bitwise_or_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 7, 13], [9, 11, 15]]), false); + } + + #[test] + fn should_apply_bitwise_or_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([15, 7]), false); + } + + #[test] + fn should_apply_bitwise_xor_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_xor_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 1, 0], [12, 6, 13]]), false); + } + + #[test] + fn should_apply_bitwise_xor_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[5, 3, 13], [0, 9, 7]]), false); + } + + #[test] + fn should_apply_bitwise_xor_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([6, 4]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_2d() { + if (IntType::MAX as u32) < 512 { + return; + } + + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_left_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 16, 40], [144, 96, 512]]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_left_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[12, 16, 20], [36, 12, 32]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_right_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 1, 0], [0, 0, 0]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_right_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[0, 1, 1], [2, 0, 2]]), false); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index e101dc7777..13c0ae5ea4 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -7,6 +7,7 @@ mod arange; mod arange_step; mod arg; mod argwhere_nonzero; +mod bitwise; mod bool; mod cartesian_grid; mod cast; diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 310399119f..24e8f24b38 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -1,74 +1,114 @@ #[burn_tensor_testgen::testgen(one_hot)] mod tests { use super::*; - use burn_tensor::{Int, TensorData}; + use burn_tensor::{ + as_type, + backend::Backend, + tests::{Float as _, Int as _}, + Float, Int, Numeric, Shape, Tensor, TensorData, + }; #[test] fn float_should_support_one_hot() { - let device = Default::default(); - - let tensor = TestTensor::<1>::one_hot(0, 5, &device); - let expected = TensorData::from([1., 0., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(1, 5, &device); - let expected = TensorData::from([0., 1., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(4, 5, &device); - let expected = TensorData::from([0., 0., 0., 0., 1.]); - tensor.into_data().assert_eq(&expected, false); + let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]); + one_hot_tensor.into_data().assert_eq(&expected, false); + } - let tensor = TestTensor::<1>::one_hot(1, 2, &device); - let expected = TensorData::from([0., 1.]); - tensor.into_data().assert_eq(&expected, false); + #[test] + fn float_should_support_one_hot_index() { + let tensor = TestTensor::<1>::from([2.0]); + let one_hot_tensor: Tensor = tensor.one_hot::<2>(10); + let expected = TensorData::from([[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]); + one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(1, 1, &device); + let tensor = TestTensor::<1>::from([5.0]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn float_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(0, 0, &device); + let tensor = TestTensor::<1>::from([0.0]); + let result: Tensor = tensor.one_hot(0); } #[test] fn int_should_support_one_hot() { - let device = Default::default(); - - let index_tensor = TestTensorInt::<1>::arange(0..5, &device); - let one_hot_tensor = index_tensor.one_hot(5); - let expected = TestTensorInt::eye(5, &device).into_data(); + let tensor = TestTensorInt::<1>::from([0, 1, 4]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..6, &device); - let one_hot_tensor = index_tensor.one_hot(5); + let tensor = TestTensorInt::<1>::from([5]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(0); + let tensor = TestTensorInt::<1>::from([2]); + let result: Tensor = tensor.one_hot(0); + } + + #[test] + fn one_hot_fill_with_positive_axis_and_indices() { + let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); + let expected = TensorData::from(as_type!(IntType: [ + [[1, 1], [3, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 3]], + [[1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + + #[test] + fn one_hot_fill_with_negative_axis_and_indices() { + let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); + let expected = TensorData::from(as_type!(FloatType: [ + [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], + [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); + + one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] + fn one_hot_fill_with_negative_indices() { + let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); + let expected = TensorData::from(as_type!(FloatType: [ + [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + #[should_panic] - fn int_one_hot_should_panic_when_number_of_classes_is_1() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(1); + #[test] + fn one_hot_fill_should_panic_when_axis_out_range_of_rank() { + let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(2, 5.0, 0.0, 3); } } diff --git a/crates/burn-tensor/src/tests/ops/slice.rs b/crates/burn-tensor/src/tests/ops/slice.rs index 61725a506a..1be5b76315 100644 --- a/crates/burn-tensor/src/tests/ops/slice.rs +++ b/crates/burn-tensor/src/tests/ops/slice.rs @@ -47,6 +47,17 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_slice_range_first_dim() { + let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.slice(0..1); + let expected = TensorData::from([[0.0, 1.0, 2.0]]); + + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_support_partial_sliceing_3d() { let tensor = TestTensor::<3>::from_floats( diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index 35707f5052..8c024c88f8 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -12,13 +12,13 @@ documentation = "https://docs.rs/burn-train" version.workspace = true [features] -default = ["metrics", "tui"] +default = ["sys-metrics", "tui"] doc = ["default"] -metrics = ["nvml-wrapper", "sysinfo", "systemstat"] +sys-metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui"] [dependencies] -burn-core = { path = "../burn-core", version = "0.16.0", features = [ +burn-core = { path = "../burn-core", version = "0.17.0", features = [ "dataset", "std", ], default-features = false } @@ -28,7 +28,7 @@ tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } tracing-core = { workspace = true } -# Metrics +# System Metrics nvml-wrapper = { workspace = true, optional = true } sysinfo = { workspace = true, optional = true } systemstat = { workspace = true, optional = true } @@ -40,11 +40,11 @@ ratatui = { workspace = true, optional = true, features = ["all-widgets", "cross derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } async-channel = { workspace = true } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } rstest.workspace = true [dev-dependencies] -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-train/src/checkpoint/strategy/metric.rs b/crates/burn-train/src/checkpoint/strategy/metric.rs index 4efcf14028..68f36bcd7f 100644 --- a/crates/burn-train/src/checkpoint/strategy/metric.rs +++ b/crates/burn-train/src/checkpoint/strategy/metric.rs @@ -114,7 +114,7 @@ mod tests { process_train(&mut processor, 0.3, epoch); end_epoch(&mut processor, epoch); - // Should save the current record and delete the pervious one. + // Should save the current record and delete the previous one. assert_eq!( vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], strategy.checkpointing(epoch, &store) diff --git a/crates/burn-train/src/metric/cpu_use.rs b/crates/burn-train/src/metric/cpu_use.rs index 2769793088..d06d8429db 100644 --- a/crates/burn-train/src/metric/cpu_use.rs +++ b/crates/burn-train/src/metric/cpu_use.rs @@ -26,7 +26,9 @@ impl CpuUse { } fn refresh(sys: &mut System) -> f64 { - sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); + sys.refresh_specifics( + RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()), + ); let cpus = sys.cpus(); let num_cpus = cpus.len(); diff --git a/crates/burn-train/src/metric/fbetascore.rs b/crates/burn-train/src/metric/fbetascore.rs new file mode 100644 index 0000000000..5eeba0aa9c --- /dev/null +++ b/crates/burn-train/src/metric/fbetascore.rs @@ -0,0 +1,195 @@ +use super::{ + classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, + confusion_stats::{ConfusionStats, ConfusionStatsInput}, + state::{FormatOptions, NumericMetricState}, + Metric, MetricEntry, MetricMetadata, Numeric, +}; +use burn_core::{ + prelude::{Backend, Tensor}, + tensor::cast::ToElement, +}; +use core::marker::PhantomData; +use std::num::NonZeroUsize; + +/// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric. +#[derive(Default)] +pub struct FBetaScoreMetric { + state: NumericMetricState, + _b: PhantomData, + config: ClassificationMetricConfig, + beta: f64, +} + +impl FBetaScoreMetric { + /// F-beta score metric for binary classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn binary(beta: f64, threshold: f64) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + // binary classification results are the same independently of class_reduction + ..Default::default() + }, + beta, + ..Default::default() + } + } + + /// F-beta score metric for multiclass classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::TopK( + NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + ), + class_reduction, + }, + beta, + ..Default::default() + } + } + + /// F-beta score metric for multi-label classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + class_reduction, + }, + beta, + ..Default::default() + } + } + + fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { + use ClassReduction::{Macro, Micro}; + let avg_tensor = match self.config.class_reduction { + Micro => aggregated_metric, + Macro => { + if aggregated_metric.contains_nan().any().into_scalar() { + let nan_mask = aggregated_metric.is_nan(); + aggregated_metric = aggregated_metric + .clone() + .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + } + aggregated_metric.mean() + } + }; + avg_tensor.into_scalar().to_f64() + } +} + +impl Metric for FBetaScoreMetric { + const NAME: &'static str = "FBetaScore"; + type Input = ConfusionStatsInput; + + fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let [sample_size, _] = input.predictions.dims(); + + let cf_stats = ConfusionStats::new(input, &self.config); + let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2)); + let metric = self.class_average( + scaled_true_positive.clone() + / (scaled_true_positive + + cf_stats.clone().false_negative() * self.beta.powi(2) + + cf_stats.false_positive()), + ); + + self.state.update( + 100.0 * metric, + sample_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for FBetaScoreMetric { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::{ + ClassReduction::{self, *}, + FBetaScoreMetric, Metric, MetricMetadata, Numeric, + }; + use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; + use burn_core::tensor::TensorData; + use rstest::rstest; + + #[rstest] + #[case::binary_b1(1.0, THRESHOLD, 0.5)] + #[case::binary_b2(2.0, THRESHOLD, 0.5)] + fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) { + let input = dummy_classification_input(&ClassificationType::Binary).into(); + let mut metric = FBetaScoreMetric::binary(beta, threshold); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)] + #[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))] + #[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)] + #[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)] + #[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)] + #[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))] + #[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)] + #[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)] + fn test_multiclass_fscore( + #[case] beta: f64, + #[case] class_reduction: ClassReduction, + #[case] top_k: usize, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multiclass).into(); + let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))] + #[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)] + #[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))] + #[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)] + fn test_multilabel_fscore( + #[case] beta: f64, + #[case] class_reduction: ClassReduction, + #[case] threshold: f64, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multilabel).into(); + let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } +} diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index e6358e3023..ac8211e884 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -1,62 +1,56 @@ /// State module. pub mod state; +/// Module responsible to save and exposes data collected during training. +pub mod store; -mod acc; -mod auroc; -mod base; -#[cfg(feature = "metrics")] +// System metrics +#[cfg(feature = "sys-metrics")] mod cpu_temp; -#[cfg(feature = "metrics")] +#[cfg(feature = "sys-metrics")] mod cpu_use; -#[cfg(feature = "metrics")] +#[cfg(feature = "sys-metrics")] mod cuda; -mod hamming; -mod learning_rate; -mod loss; -#[cfg(feature = "metrics")] +#[cfg(feature = "sys-metrics")] mod memory_use; +#[cfg(feature = "sys-metrics")] +pub use cpu_temp::*; +#[cfg(feature = "sys-metrics")] +pub use cpu_use::*; +#[cfg(feature = "sys-metrics")] +pub use cuda::*; +#[cfg(feature = "sys-metrics")] +pub use memory_use::*; -#[cfg(feature = "metrics")] +// Training metrics +mod acc; +mod auroc; +mod base; +mod confusion_stats; +mod fbetascore; +mod hamming; mod iteration; -#[cfg(feature = "metrics")] +mod learning_rate; +mod loss; +mod precision; +mod recall; mod top_k_acc; pub use acc::*; pub use auroc::*; pub use base::*; -#[cfg(feature = "metrics")] -pub use cpu_temp::*; -#[cfg(feature = "metrics")] -pub use cpu_use::*; -#[cfg(feature = "metrics")] -pub use cuda::*; +pub use confusion_stats::ConfusionStatsInput; +pub use fbetascore::*; pub use hamming::*; -#[cfg(feature = "metrics")] pub use iteration::*; pub use learning_rate::*; pub use loss::*; -#[cfg(feature = "metrics")] -pub use memory_use::*; -#[cfg(feature = "metrics")] +pub use precision::*; +pub use recall::*; pub use top_k_acc::*; +pub(crate) mod classification; pub(crate) mod processor; -// Expose `ItemLazy` so it can be implemented for custom types -pub use processor::ItemLazy; -/// Module responsible to save and exposes data collected during training. -pub mod store; - -pub(crate) mod classification; -#[cfg(feature = "metrics")] pub use crate::metric::classification::ClassReduction; -mod confusion_stats; -pub use confusion_stats::ConfusionStatsInput; -#[cfg(feature = "metrics")] -mod precision; -#[cfg(feature = "metrics")] -pub use precision::*; -#[cfg(feature = "metrics")] -mod recall; -#[cfg(feature = "metrics")] -pub use recall::*; +// Expose `ItemLazy` so it can be implemented for custom types +pub use processor::ItemLazy; diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 067261cbdf..375d368795 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -42,6 +42,7 @@ impl PrecisionMetric { /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { @@ -60,6 +61,7 @@ impl PrecisionMetric { /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary value. + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { @@ -129,7 +131,7 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_macro(THRESHOLD, 0.5)] + #[case::binary(THRESHOLD, 0.5)] fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = PrecisionMetric::binary(threshold); diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index 8ce4351396..5003ddcd03 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -11,7 +11,7 @@ use burn_core::{ use core::marker::PhantomData; use std::num::NonZeroUsize; -///The Precision Metric +///The Recall Metric #[derive(Default)] pub struct RecallMetric { state: NumericMetricState, @@ -42,6 +42,7 @@ impl RecallMetric { /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { @@ -60,6 +61,7 @@ impl RecallMetric { /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { @@ -128,7 +130,7 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_macro(THRESHOLD, 0.5)] + #[case::binary(THRESHOLD, 0.5)] fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = RecallMetric::binary(threshold); diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index d3975faad3..e0c247172d 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -17,22 +17,29 @@ default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"] doc = ["burn-jit/doc"] exclusive-memory-only = ["cubecl/exclusive-memory-only"] fusion = ["burn-fusion", "burn-jit/fusion"] -spirv = ["cubecl/wgpu-spirv"] std = ["burn-jit/std", "cubecl/std"] template = ["burn-jit/template", "cubecl/template"] +# Backends +webgpu = ["cubecl-wgsl"] +vulkan = ["cubecl-spirv"] + +# Compilers +cubecl-wgsl = [] +cubecl-spirv = ["cubecl/wgpu-spirv"] + [dependencies] cubecl = { workspace = true, features = ["wgpu"] } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "cubecl-wgpu", ] } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } half = { workspace = true } diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index deb6a8ebd8..c11854fcaf 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -12,21 +12,21 @@ pub use burn_jit::{ pub use burn_jit::{tensor::JitTensor, JitBackend}; pub use burn_jit::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; -pub use cubecl::ir::CubeDim; -pub use cubecl::wgpu::*; +pub use cubecl::CubeDim; -pub type Wgsl = cubecl::wgpu::WgslCompiler; -#[cfg(feature = "spirv")] -pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler; +pub use cubecl::wgpu::{ + init_device, init_setup, init_setup_async, MemoryConfiguration, RuntimeOptions, WgpuDevice, + WgpuResource, WgpuRuntime, WgpuSetup, WgpuStorage, +}; +// Vulkan and WebGpu would have conflicting type names +pub mod graphics { + pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu}; +} -#[cfg(feature = "spirv")] -type Compiler = SpirV; -#[cfg(feature = "spirv")] -type Bool = u8; -#[cfg(not(feature = "spirv"))] -type Compiler = Wgsl; -#[cfg(not(feature = "spirv"))] -type Bool = u32; +#[cfg(feature = "cubecl-spirv")] +pub use cubecl::wgpu::spirv::SpirvCompiler; +#[cfg(feature = "cubecl-wgsl")] +pub use cubecl::wgpu::WgslCompiler; #[cfg(feature = "fusion")] /// Tensor backend that uses the wgpu crate for executing GPU compute shaders. @@ -44,14 +44,14 @@ type Bool = u32; /// ```rust, ignore /// fn custom_init() { /// let device = Default::default(); -/// burn::backend::wgpu::init_sync::( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -60,7 +60,7 @@ type Bool = u32; /// /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. -pub type Wgpu = +pub type Wgpu = burn_fusion::Fusion, F, I, B>>; #[cfg(not(feature = "fusion"))] @@ -79,14 +79,14 @@ pub type Wgpu = /// ```rust, ignore /// fn custom_init() { /// let device = Default::default(); -/// burn::backend::wgpu::init_sync::( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -95,20 +95,33 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = +pub type Wgpu = JitBackend, F, I, B>; +#[cfg(feature = "vulkan")] +/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V. +pub type Vulkan = Wgpu; + +#[cfg(feature = "webgpu")] +/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL. +pub type WebGpu = Wgpu; + #[cfg(test)] mod tests { use burn_jit::JitBackend; - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] pub use half::f16; - pub type TestRuntime = cubecl::wgpu::WgpuRuntime; + + #[cfg(feature = "cubecl-spirv")] + type Compiler = cubecl::wgpu::spirv::VkSpirvCompiler; + #[cfg(not(feature = "cubecl-spirv"))] + type Compiler = cubecl::wgpu::WgslCompiler; + pub type TestRuntime = cubecl::wgpu::WgpuRuntime; // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it // breaks a lot of tests from precision issues - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); - #[cfg(not(feature = "spirv"))] + #[cfg(not(feature = "vulkan"))] burn_jit::testgen_all!([f32], [i32], [u32]); } diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 7f6af14fbb..b0abf7d178 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -24,7 +24,7 @@ train = ["burn-train", "autodiff", "dataset"] tui = ["burn-train?/tui"] ## Includes system info metrics (CPU/GPU usage, etc) -metrics = ["burn-train?/metrics"] +metrics = ["burn-train?/sys-metrics"] # Datasets dataset = ["burn-core/dataset"] @@ -50,15 +50,16 @@ openblas-system = ["burn-core/openblas-system"] template = ["burn-core/template"] candle = ["burn-core/candle"] -cuda-jit = ["burn-core/cuda-jit"] -hip-jit = ["burn-core/hip-jit"] +cuda = ["burn-core/cuda"] +hip = ["burn-core/hip"] ndarray = ["burn-core/ndarray"] remote = ["burn-core/remote"] router = ["burn-core/router"] server = ["burn-core/server"] tch = ["burn-core/tch"] wgpu = ["burn-core/wgpu"] -wgpu-spirv = ["burn-core/wgpu-spirv"] +vulkan = ["burn-core/vulkan"] +webgpu = ["burn-core/webgpu"] # Network utils network = ["burn-core/network"] @@ -67,12 +68,11 @@ network = ["burn-core/network"] experimental-named-tensor = ["burn-core/experimental-named-tensor"] # Records -record-backward-compat = ["burn-core/record-backward-compat"] record-item-custom-serde = ["burn-core/record-item-custom-serde"] [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** -burn-core = { path = "../burn-core", version = "0.16.0", default-features = false } -burn-train = { path = "../burn-train", version = "0.16.0", optional = true, default-features = false } +burn-core = { path = "../burn-core", version = "0.17.0", default-features = false } +burn-train = { path = "../burn-train", version = "0.17.0", optional = true, default-features = false } diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index b0ecf06a71..203d1a802d 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -76,12 +76,14 @@ //! - `vision`: Enables vision datasets (MnistDataset) //! - Backends //! - `wgpu`: Makes available the WGPU backend -//! - `wgpu-spirv`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler +//! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `cuda`: Makes available the CUDA backend +//! - `hip`: Makes available the HIP backend //! - `candle`: Makes available the Candle backend //! - `tch`: Makes available the LibTorch backend //! - `ndarray`: Makes available the NdArray backend //! - Backend specifications -//! - `cuda`: If supported, CUDA will be used //! - `accelerate`: If supported, Accelerate will be used //! - `blas-netlib`: If supported, Blas Netlib will be use //! - `openblas`: If supported, Openblas will be use diff --git a/deny.toml b/deny.toml index a9a4506064..ac64c923fe 100644 --- a/deny.toml +++ b/deny.toml @@ -75,12 +75,15 @@ allow = [ "Apache-2.0 WITH LLVM-exception", "Apache-2.0", "BSD-3-Clause", + "BSD-2-Clause", + "BSL-1.0", # in NOTICES.md "CC0-1.0", "ISC", "MIT", "MPL-2.0", "OpenSSL", "Unicode-DFS-2016", + "Unicode-3.0", "Unlicense", "Zlib", ] diff --git a/examples/custom-cubecl-kernel/src/kernel.rs b/examples/custom-cubecl-kernel/src/kernel.rs index 0809971327..08d4ded4d7 100644 --- a/examples/custom-cubecl-kernel/src/kernel.rs +++ b/examples/custom-cubecl-kernel/src/kernel.rs @@ -17,7 +17,7 @@ pub fn fused_matmul_add_relu_kernel( let dim_k = rhs.shape(rhs.rank() - 1); if row >= n_rows || col >= n_cols { - return; + terminate!(); } let offset_output = batch * n_rows * n_cols; diff --git a/examples/custom-image-dataset/src/dataset.rs b/examples/custom-image-dataset/src/dataset.rs index eee2bdf9fc..b40396d3c6 100644 --- a/examples/custom-image-dataset/src/dataset.rs +++ b/examples/custom-image-dataset/src/dataset.rs @@ -5,7 +5,7 @@ use tar::Archive; use burn::data::{dataset::vision::ImageFolderDataset, network::downloader}; /// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44). -/// Licensed under the [Appache License](https://github.com/fastai/fastai/blob/master/LICENSE). +/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE). const URL: &str = "https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz"; /// The [CIFAR-10](https://www.cs.toronto.edu/%7Ekriz/cifar.html) dataset consists of 60,000 32x32 diff --git a/examples/custom-renderer/examples/custom-renderer.rs b/examples/custom-renderer/examples/custom-renderer.rs index ea580833df..aa344b1d2b 100644 --- a/examples/custom-renderer/examples/custom-renderer.rs +++ b/examples/custom-renderer/examples/custom-renderer.rs @@ -1,5 +1,5 @@ -use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu}; +use burn::backend::{wgpu::WgpuDevice, Autodiff, WebGpu}; fn main() { - custom_renderer::run::>(WgpuDevice::default()); + custom_renderer::run::>(WgpuDevice::default()); } diff --git a/examples/custom-training-loop/Cargo.toml b/examples/custom-training-loop/Cargo.toml index 536307fdba..6e1fca1e92 100644 --- a/examples/custom-training-loop/Cargo.toml +++ b/examples/custom-training-loop/Cargo.toml @@ -7,7 +7,7 @@ publish = false version.workspace = true [dependencies] -burn = {path = "../../crates/burn", features=["autodiff", "wgpu", "vision"]} +burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]} guide = {path = "../guide"} # Serialization diff --git a/examples/custom-training-loop/examples/custom-training-loop.rs b/examples/custom-training-loop/examples/custom-training-loop.rs index a418ede196..ec9d55f42a 100644 --- a/examples/custom-training-loop/examples/custom-training-loop.rs +++ b/examples/custom-training-loop/examples/custom-training-loop.rs @@ -1,5 +1,5 @@ -use burn::backend::{Autodiff, Wgpu}; +use burn::backend::{Autodiff, WebGpu}; fn main() { - custom_training_loop::run::>(Default::default()); + custom_training_loop::run::>(Default::default()); } diff --git a/examples/guide/Cargo.toml b/examples/guide/Cargo.toml index e60b8d45e5..aea61f5e25 100644 --- a/examples/guide/Cargo.toml +++ b/examples/guide/Cargo.toml @@ -10,7 +10,7 @@ version.workspace = true default = ["burn/default"] [dependencies] -burn = {path = "../../crates/burn", features = ["wgpu", "train", "vision"]} +burn = {path = "../../crates/burn", features = ["webgpu", "train", "vision"]} # Serialization log = {workspace = true} diff --git a/examples/guide/src/bin/infer.rs b/examples/guide/src/bin/infer.rs index 3c64879bc5..44c5b1dabc 100644 --- a/examples/guide/src/bin/infer.rs +++ b/examples/guide/src/bin/infer.rs @@ -1,8 +1,9 @@ -use burn::{backend::Wgpu, data::dataset::Dataset}; +#![recursion_limit = "131"] +use burn::{backend::WebGpu, data::dataset::Dataset}; use guide::inference; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = burn::backend::wgpu::WgpuDevice::default(); diff --git a/examples/guide/src/bin/print.rs b/examples/guide/src/bin/print.rs index 9432aa93a4..6f3b710c25 100644 --- a/examples/guide/src/bin/print.rs +++ b/examples/guide/src/bin/print.rs @@ -1,8 +1,8 @@ -use burn::backend::Wgpu; +use burn::backend::WebGpu; use guide::model::ModelConfig; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = Default::default(); let model = ModelConfig::new(10, 512).init::(&device); diff --git a/examples/guide/src/bin/train.rs b/examples/guide/src/bin/train.rs index 04f1f44146..a4acf02b69 100644 --- a/examples/guide/src/bin/train.rs +++ b/examples/guide/src/bin/train.rs @@ -1,5 +1,5 @@ use burn::{ - backend::{Autodiff, Wgpu}, + backend::{Autodiff, WebGpu}, data::dataset::Dataset, optim::AdamConfig, }; @@ -10,7 +10,7 @@ use guide::{ }; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device diff --git a/examples/image-classification-web/Cargo.toml b/examples/image-classification-web/Cargo.toml index 5f036532ae..9429b24d25 100644 --- a/examples/image-classification-web/Cargo.toml +++ b/examples/image-classification-web/Cargo.toml @@ -14,11 +14,10 @@ default = [] half_precision = [] [dependencies] -burn = { path = "../../crates/burn", version = "0.16.0", default-features = false, features = [ +burn = { path = "../../crates/burn", version = "0.17.0", default-features = false, features = [ "ndarray", "wgpu", ] } -cubecl-runtime = { version = "0.3.0", features = ["channel-mpsc"] } # missing feature flags -burn-candle = { path = "../../crates/burn-candle", version = "0.16.0", default-features = false } +burn-candle = { path = "../../crates/burn-candle", version = "0.17.0", default-features = false } log = { workspace = true } serde = { workspace = true } @@ -35,4 +34,4 @@ js-sys = "0.3" [build-dependencies] # Used to generate code from ONNX model -burn-import = { path = "../../crates/burn-import" } +burn-import = { path = "../../crates/burn-import", default-features = false, features = ["onnx"]} diff --git a/examples/image-classification-web/src/lib.rs b/examples/image-classification-web/src/lib.rs index 3881123eaf..3d528f2e9d 100644 --- a/examples/image-classification-web/src/lib.rs +++ b/examples/image-classification-web/src/lib.rs @@ -1,4 +1,5 @@ #![cfg_attr(not(test), no_std)] +#![recursion_limit = "135"] pub mod model; pub mod web; diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 4b20507abc..a9868099f6 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -14,7 +14,7 @@ use burn::{ tensor::activation::softmax, }; -use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn::backend::wgpu::{graphics::AutoGraphicsApi, WebGpu, WgpuDevice}; use burn_candle::Candle; use serde::Serialize; @@ -37,8 +37,8 @@ pub enum ModelType { /// The model is loaded to the NdArray backend WithNdArrayBackend(Model>), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(Model>), + /// The model is loaded to the WebGpu backend + WithWgpuBackend(Model>), } /// The image is 224x224 pixels with 3 channels (RGB) diff --git a/examples/mnist-inference-web/Cargo.toml b/examples/mnist-inference-web/Cargo.toml index c8f7803607..a72b3d3c1b 100644 --- a/examples/mnist-inference-web/Cargo.toml +++ b/examples/mnist-inference-web/Cargo.toml @@ -13,11 +13,10 @@ crate-type = ["cdylib"] default = ["ndarray"] ndarray = ["burn/ndarray"] -wgpu = ["burn/wgpu", "cubecl-runtime"] +wgpu = ["burn/wgpu"] [dependencies] burn = { path = "../../crates/burn", default-features = false } -cubecl-runtime = { version = "0.3.0", optional = true, features = ["channel-mpsc"] } # missing feature flag serde = { workspace = true } console_error_panic_hook = { workspace = true } diff --git a/examples/mnist-inference-web/src/state.rs b/examples/mnist-inference-web/src/state.rs index 5516fe90fd..c54290bb7b 100644 --- a/examples/mnist-inference-web/src/state.rs +++ b/examples/mnist-inference-web/src/state.rs @@ -5,7 +5,7 @@ use burn::{ }; #[cfg(feature = "wgpu")] -use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn::backend::wgpu::{init_setup_async, AutoGraphicsApi, Wgpu, WgpuDevice}; #[cfg(feature = "wgpu")] pub type Backend = Wgpu; @@ -18,7 +18,7 @@ static STATE_ENCODED: &[u8] = include_bytes!("../model.bin"); /// Builds and loads trained parameters into the model. pub async fn build_and_load_model() -> Model { #[cfg(feature = "wgpu")] - init_async::(&WgpuDevice::default(), Default::default()).await; + init_setup_async::(&WgpuDevice::default(), Default::default()).await; let model: Model = Model::new(&Default::default()); let record = BinBytesRecorder::::default() diff --git a/examples/modern-lstm/Cargo.toml b/examples/modern-lstm/Cargo.toml new file mode 100644 index 0000000000..86855e9ad4 --- /dev/null +++ b/examples/modern-lstm/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "modern-lstm" +version = "0.1.0" +edition = "2021" + +[features] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] +cuda = ["burn/cuda"] + +[dependencies] +burn = { path = "../../crates/burn", features=["train"] } + +# Random number generator +rand = { workspace = true } +rand_distr = { workspace = true } + +# Serialization +serde = {workspace = true, features = ["std", "derive"]} + +# Organise the results in dataframe +polars = { workspace = true } diff --git a/examples/modern-lstm/README.md b/examples/modern-lstm/README.md new file mode 100644 index 0000000000..832851a1f0 --- /dev/null +++ b/examples/modern-lstm/README.md @@ -0,0 +1,46 @@ +# Advanced LSTM Implementation with Burn + +A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined +weight matrices for the input and hidden states, based on the +[PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch). + +`LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM +variants differ by `bidirectional` and `num_layers` settings: + +- LSTM: `num_layers = 1` and `bidirectional = false` +- Stacked LSTM: `num_layers > 1` and `bidirectional = false` +- Bidirectional LSTM: `num_layers = 1` and `bidirectional = true` +- Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true` + +This implementation is complementary to Burn's official LSTM, users can choose either one depends on +the project's specific needs. + +## Usage + +## Training + +```sh +# Cuda backend +cargo run --example lstm-train --release --features cuda-jit + +# Wgpu backend +cargo run --example lstm-train --release --features wgpu + +# Tch GPU backend +export TORCH_CUDA_VERSION=cu121 # Set the cuda version +cargo run --example lstm-train --release --features tch-gpu + +# Tch CPU backend +cargo run --example lstm-train --release --features tch-cpu + +# NdArray backend (CPU) +cargo run --example lstm-train --release --features ndarray +cargo run --example lstm-train --release --features ndarray-blas-openblas +cargo run --example lstm-train --release --features ndarray-blas-netlib +``` + +### Inference + +```sh +cargo run --example lstm-infer --release --features cuda-jit +``` diff --git a/examples/modern-lstm/examples/lstm-infer.rs b/examples/modern-lstm/examples/lstm-infer.rs new file mode 100644 index 0000000000..f601d08c79 --- /dev/null +++ b/examples/modern-lstm/examples/lstm-infer.rs @@ -0,0 +1,86 @@ +use burn::tensor::backend::Backend; + +pub fn launch(device: B::Device) { + modern_lstm::inference::infer::("/tmp/modern-lstm", device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + + use crate::launch; + + pub fn run() { + launch::(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + launch::(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::wgpu::Wgpu; + + pub fn run() { + launch::(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::Cuda; + + pub fn run() { + launch::(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/modern-lstm/examples/lstm-train.rs b/examples/modern-lstm/examples/lstm-train.rs new file mode 100644 index 0000000000..454263d331 --- /dev/null +++ b/examples/modern-lstm/examples/lstm-train.rs @@ -0,0 +1,104 @@ +use burn::{ + grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend, +}; +use modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig}; + +pub fn launch(device: B::Device) { + let config = TrainingConfig::new( + LstmNetworkConfig::new(), + // Gradient clipping via optimizer config + AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))), + ); + + modern_lstm::training::train::("/tmp/modern-lstm", config, device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::{cuda::CudaDevice, Autodiff, Cuda}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/modern-lstm/src/dataset.rs b/examples/modern-lstm/src/dataset.rs new file mode 100644 index 0000000000..b2d04d525f --- /dev/null +++ b/examples/modern-lstm/src/dataset.rs @@ -0,0 +1,110 @@ +use burn::{ + data::{ + dataloader::batcher::Batcher, + dataset::{Dataset, InMemDataset}, + }, + prelude::*, +}; +use rand::Rng; +use rand_distr::{Distribution, Normal}; +use serde::{Deserialize, Serialize}; + +// Dataset parameters +pub const NUM_SEQUENCES: usize = 1000; +pub const SEQ_LENGTH: usize = 10; +pub const NOISE_LEVEL: f32 = 0.1; +pub const RANDOM_SEED: u64 = 5; + +// Generate a sequence where each number is the sum of previous two numbers plus noise +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SequenceDatasetItem { + pub sequence: Vec, + pub target: f32, +} + +impl SequenceDatasetItem { + pub fn new(seq_length: usize, noise_level: f32) -> Self { + // Start with two random numbers between 0 and 1 + let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()]; + + // Generate sequence + for _i in 0..seq_length { + // Next number is sum of previous two plus noise + let normal = Normal::new(0.0, noise_level).unwrap(); + let next_val = + seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng()); + seq.push(next_val); + } + + Self { + // Convert to sequence and target + sequence: seq[0..seq.len() - 1].to_vec(), // All but last + target: seq[seq.len() - 1], // Last value + } + } +} + +// Custom Dataset for Sequence Data +pub struct SequenceDataset { + dataset: InMemDataset, +} + +impl SequenceDataset { + pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self { + let mut items = vec![]; + for _i in 0..num_sequences { + items.push(SequenceDatasetItem::new(seq_length, noise_level)); + } + let dataset = InMemDataset::new(items); + + Self { dataset } + } +} + +impl Dataset for SequenceDataset { + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + +#[derive(Clone, Debug)] +pub struct SequenceBatcher { + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct SequenceBatch { + pub sequences: Tensor, // [batch_size, seq_length, input_size] + pub targets: Tensor, // [batch_size, 1] +} + +impl SequenceBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +impl Batcher> for SequenceBatcher { + fn batch(&self, items: Vec) -> SequenceBatch { + let mut sequences: Vec> = Vec::new(); + + for item in items.iter() { + let seq_tensor = Tensor::::from_floats(item.sequence.as_slice(), &self.device); + // Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations + sequences.push(seq_tensor.unsqueeze_dims(&[-1])); + } + let sequences = Tensor::stack(sequences, 0); + + let targets = items + .iter() + .map(|item| Tensor::::from_floats([item.target], &self.device)) + .collect(); + let targets = Tensor::stack(targets, 0); + + SequenceBatch { sequences, targets } + } +} diff --git a/examples/modern-lstm/src/inference.rs b/examples/modern-lstm/src/inference.rs new file mode 100644 index 0000000000..bad0af2996 --- /dev/null +++ b/examples/modern-lstm/src/inference.rs @@ -0,0 +1,45 @@ +use crate::{ + dataset::{ + SequenceBatcher, SequenceDataset, SequenceDatasetItem, NOISE_LEVEL, NUM_SEQUENCES, + SEQ_LENGTH, + }, + model::LstmNetwork, + training::TrainingConfig, +}; +use burn::{ + data::{dataloader::batcher::Batcher, dataset::Dataset}, + prelude::*, + record::{CompactRecorder, Recorder}, +}; +use polars::prelude::*; + +pub fn infer(artifact_dir: &str, device: B::Device) { + // Loading model + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into(), &device) + .expect("Trained model should exist; run train first"); + + let model: LstmNetwork = config.model.init(&device).load_record(record); + + let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL); + let items: Vec = dataset.iter().collect(); + + let batcher = SequenceBatcher::new(device); + // Put all items in one batch + let batch = batcher.batch(items); + let predicted = model.forward(batch.sequences, None); + let targets = batch.targets; + + let predicted = predicted.squeeze::<1>(1).into_data(); + let expected = targets.squeeze::<1>(1).into_data(); + + // Display the predicted vs expected values + let results = df![ + "predicted" => &predicted.to_vec::().unwrap(), + "expected" => &expected.to_vec::().unwrap(), + ] + .unwrap(); + println!("{}", &results.head(Some(10))); +} diff --git a/examples/modern-lstm/src/lib.rs b/examples/modern-lstm/src/lib.rs new file mode 100644 index 0000000000..1a167ffd75 --- /dev/null +++ b/examples/modern-lstm/src/lib.rs @@ -0,0 +1,4 @@ +pub mod dataset; +pub mod inference; +pub mod model; +pub mod training; diff --git a/examples/modern-lstm/src/model.rs b/examples/modern-lstm/src/model.rs new file mode 100644 index 0000000000..268de59a0b --- /dev/null +++ b/examples/modern-lstm/src/model.rs @@ -0,0 +1,362 @@ +use burn::{ + nn::{ + Dropout, DropoutConfig, Initializer, LayerNorm, LayerNormConfig, Linear, LinearConfig, + LstmState, Sigmoid, Tanh, + }, + prelude::*, +}; + +/// LSTM Cell implementation with layer normalization. +/// +/// Mathematical formulation of LSTM: +/// f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # Forget gate +/// i_t = σ(W_i · [h_{t-1}, x_t] + b_i] # Input gate +/// g_t = tanh(W_g · [h_{t-1}, x_t] + b_g] # Candidate cell state +/// o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # Output gate +/// +/// c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # New cell state +/// h_t = o_t ⊙ tanh(c_t) # New hidden state +/// +/// where: +/// - σ is the sigmoid function +/// - ⊙ is the element-wise multiplication +/// - [h_{t-1}, x_t] represents concatenation + +#[derive(Module, Debug)] +pub struct LstmCell { + pub hidden_size: usize, + // Combined weight matrices for efficiency + // weight_ih layer uses combined weights for [i_t, f_t, g_t, o_t] for input x_t + // weight_hh layer uses combined weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1} + pub weight_ih: Linear, + pub weight_hh: Linear, + // Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM. + pub norm_x: LayerNorm, // Normalize gate pre-activations + pub norm_h: LayerNorm, // Normalize hidden state + pub norm_c: LayerNorm, // Normalize cell state + pub dropout: Dropout, +} + +/// Configuration to create a Lstm module using the init function. +#[derive(Config, Debug)] +pub struct LstmCellConfig { + // The size of the input features + pub input_size: usize, + // The size of the hidden state + pub hidden_size: usize, + // The number of hidden layers + pub dropout: f64, +} + +impl LstmCellConfig { + // Initialize parameters using best practices: + // 1. Orthogonal initialization for better gradient flow (here we use Xavier because of the lack of Orthogonal in burn) + // 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training + #[allow(clippy::single_range_in_vec_init)] + pub fn init(&self, device: &B::Device) -> LstmCell { + let initializer = Initializer::XavierNormal { gain: 1.0 }; + let init_bias = Tensor::::ones([self.hidden_size], device); + + let mut weight_ih = LinearConfig::new(self.input_size, 4 * self.hidden_size) + .with_initializer(initializer.clone()) + .init(device); + // Set forget gate bias to 1.0 (helps with learning long sequences) + let bias = weight_ih + .bias + .clone() + .unwrap() + .val() + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias.clone()); + weight_ih.bias = weight_ih.bias.map(|p| p.map(|_t| bias)); + + let mut weight_hh = LinearConfig::new(self.hidden_size, 4 * self.hidden_size) + .with_initializer(initializer) + .init(device); + let bias = weight_hh + .bias + .clone() + .unwrap() + .val() + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias); + weight_hh.bias = weight_hh.bias.map(|p| p.map(|_t| bias)); + + LstmCell { + hidden_size: self.hidden_size, + weight_ih, + weight_hh, + norm_x: LayerNormConfig::new(4 * self.hidden_size).init(device), + norm_h: LayerNormConfig::new(self.hidden_size).init(device), + norm_c: LayerNormConfig::new(self.hidden_size).init(device), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +impl LstmCell { + /// Forward pass of LSTM cell. + /// Args: + /// x: Input tensor of shape (batch_size, input_size) + /// state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size) + /// Returns: + /// Tuple of (h_t, c_t) representing new hidden and cell states + pub fn forward(&self, x: Tensor, state: LstmState) -> LstmState { + let (h_prev, c_prev) = (state.hidden, state.cell); + + // Combined matrix multiplication for all gates + // Shape: (batch_size, 4 * hidden_size) + let gates_x = self.weight_ih.forward(x); // Transform input + let gates_h = self.weight_hh.forward(h_prev); // Transform previous hidden state + + // Apply layer normalization + let gates_x = self.norm_x.forward(gates_x); + // Combined gate pre-activations + let gates = gates_x + gates_h; + + // Split into individual gates + // Each gate shape: (batch_size, hidden_size) + let gates = gates.chunk(4, 1); + let i_gate = gates[0].clone(); + let f_gate = gates[1].clone(); + let g_gate = gates[2].clone(); + let o_gate = gates[3].clone(); + + // Apply gate non-linearities + let i_t = Sigmoid::new().forward(i_gate); + let f_t = Sigmoid::new().forward(f_gate); + let g_t = Tanh::new().forward(g_gate); + let o_t = Sigmoid::new().forward(o_gate); + + // Update cell state: c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t + let c_t = f_t * c_prev + i_t * g_t; + let c_t = self.norm_c.forward(c_t); + + // Update cell state: h_t = o_t ⊙ tanh(c_t) + let h_t = o_t * Tanh::new().forward(c_t.clone()); + let h_t = self.norm_h.forward(h_t); + + let h_t = self.dropout.forward(h_t); + + LstmState::new(h_t, c_t) + } + + // Initialize cell state and hidden state if provided or with zeros + pub fn init_state(&self, batch_size: usize, device: &B::Device) -> LstmState { + let cell = Tensor::zeros([batch_size, self.hidden_size], device); + let hidden = Tensor::zeros([batch_size, self.hidden_size], device); + + LstmState::new(cell, hidden) + } +} + +/// Stacked LSTM implementation supporting multiple layers +/// Each layer processes the output of the previous layer +#[derive(Module, Debug)] +pub struct StackedLstm { + pub layers: Vec>, +} + +#[derive(Config, Debug)] +pub struct StackedLstmConfig { + pub input_size: usize, + pub hidden_size: usize, + pub num_layers: usize, + pub dropout: f64, +} + +impl StackedLstmConfig { + pub fn init(&self, device: &B::Device) -> StackedLstm { + let mut layers: Vec> = vec![]; + // Create list of LSTM cells, one for each layer + for i in 0..self.num_layers { + if i == 0 { + if i < self.num_layers - 1 { + layers.push( + LstmCellConfig::new(self.input_size, self.hidden_size, self.dropout) + .init(device), + ); + } else { + // No dropout on last layer + layers.push( + LstmCellConfig::new(self.input_size, self.hidden_size, 0.0).init(device), + ); + } + } else if i < self.num_layers - 1 { + layers.push( + LstmCellConfig::new(self.hidden_size, self.hidden_size, self.dropout) + .init(device), + ); + } else { + // No dropout on last layer + layers.push( + LstmCellConfig::new(self.hidden_size, self.hidden_size, 0.0).init(device), + ); + } + } + StackedLstm { layers } + } +} + +impl StackedLstm { + /// Process input sequence through stacked LSTM layers. + /// + /// Args: + /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// states: Optional initial states for each layer + /// + /// Returns: + /// Tuple of (output, states) where output has shape (batch_size, seq_length, hidden_size) + /// and states is a vector of length num_layers, both cell and hidden state in each element have shape (batch_size, hidden_size) + pub fn forward( + &self, + x: Tensor, + states: Option>>, + ) -> (Tensor, Vec>) { + let [batch_size, seq_length, _] = x.dims(); + let device = x.device(); + + let mut states = match states { + None => { + let mut temp: Vec> = vec![]; + for layer in self.layers.iter() { + temp.push(layer.init_state(batch_size, &device)); + } + temp + } + _ => states.unwrap(), + }; + + let mut layer_outputs = vec![]; + for t in 0..seq_length { + let mut input_t = x + .clone() + .slice([None, Some((t as i64, t as i64 + 1)), None]) + .squeeze::<2>(1); + for (i, lstm_cell) in self.layers.iter().enumerate() { + let mut state: LstmState = + LstmState::new(states[i].cell.clone(), states[i].hidden.clone()); + state = lstm_cell.forward(input_t, state); + input_t = state.hidden.clone(); + states[i] = state; + } + layer_outputs.push(input_t); + } + + // Stack output along sequence dimension + let output = Tensor::stack(layer_outputs, 1); + + (output, states) + } +} + +/// Complete LSTM network with bidirectional support. +/// +/// In bidirectional mode: +/// - Forward LSTM processes sequence from left to right +/// - Backward LSTM processes sequence from right to left +/// - Outputs are concatenated for final prediction +#[derive(Module, Debug)] +pub struct LstmNetwork { + // Forward direction LSTM + pub stacked_lstm: StackedLstm, + // Optional backward direction LSTM for bidirectional processing + pub reverse_lstm: Option>, + pub dropout: Dropout, + pub fc: Linear, +} + +#[derive(Config, Debug)] +pub struct LstmNetworkConfig { + #[config(default = 1)] + pub input_size: usize, // Single feature (number sequence) + #[config(default = 32)] + pub hidden_size: usize, // Size of LSTM hidden state + #[config(default = 2)] + pub num_layers: usize, // Number of LSTM layers + #[config(default = 1)] + pub output_size: usize, // Predict one number + #[config(default = 0.1)] + pub dropout: f64, + #[config(default = true)] + pub bidirectional: bool, // Use bidirectional LSTM +} + +impl LstmNetworkConfig { + pub fn init(&self, device: &B::Device) -> LstmNetwork { + // Forward direction LSTM + let stacked_lstm = StackedLstmConfig::new( + self.input_size, + self.hidden_size, + self.num_layers, + self.dropout, + ) + .init(device); + + // Optional backward direction LSTM for bidirectional processing + let (reverse_lstm, hidden_size) = if self.bidirectional { + let lstm = StackedLstmConfig::new( + self.input_size, + self.hidden_size, + self.num_layers, + self.dropout, + ) + .init(device); + (Some(lstm), 2 * self.hidden_size) + } else { + (None, self.hidden_size) + }; + + let fc = LinearConfig::new(hidden_size, self.output_size).init(device); + let dropout = DropoutConfig::new(self.dropout).init(); + + LstmNetwork { + stacked_lstm, + reverse_lstm, + dropout, + fc, + } + } +} + +impl LstmNetwork { + /// Forward pass of the network. + /// + /// For bidirectional processing: + /// 1. Process sequence normally with forward LSTM + /// 2. Process reversed sequence with backward LSTM + /// 3. Concatenate both outputs + /// 4. Apply final linear transformation + /// + /// Args: + /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// states: Optional initial states + /// + /// Returns: + /// Output tensor of shape (batch_size, output_size) + pub fn forward(&self, x: Tensor, states: Option>>) -> Tensor { + let seq_length = x.dims()[1] as i64; + // Forward direction + let (mut output, _states) = self.stacked_lstm.forward(x.clone(), states); + + output = match &self.reverse_lstm { + Some(reverse_lstm) => { + //Process sequence in reverse direction + let (mut reverse_output, _states) = reverse_lstm.forward(x.flip([1]), None); + // Flip back to align with forward sequence + reverse_output = reverse_output.flip([1]); + // Concatenate forward and backward outputs along the feature dimension + output = Tensor::cat(vec![output, reverse_output], 2); + output + } + None => output, + }; + + // Apply dropout before final layer + output = self.dropout.forward(output); + // Use final timestep output for prediction + self.fc.forward( + output + .slice([None, Some((seq_length - 1, seq_length)), None]) + .squeeze::<2>(1), + ) + } +} diff --git a/examples/modern-lstm/src/training.rs b/examples/modern-lstm/src/training.rs new file mode 100644 index 0000000000..9f6af81328 --- /dev/null +++ b/examples/modern-lstm/src/training.rs @@ -0,0 +1,131 @@ +use crate::dataset::{ + SequenceBatcher, SequenceDataset, NOISE_LEVEL, NUM_SEQUENCES, RANDOM_SEED, SEQ_LENGTH, +}; +use crate::model::{LstmNetwork, LstmNetworkConfig}; +use burn::{ + data::dataloader::DataLoaderBuilder, + module::AutodiffModule, + nn::loss::{MseLoss, Reduction::Mean}, + optim::{AdamConfig, GradientsParams, Optimizer}, + prelude::*, + record::CompactRecorder, + tensor::backend::AutodiffBackend, +}; + +#[derive(Config)] +pub struct TrainingConfig { + pub model: LstmNetworkConfig, + pub optimizer: AdamConfig, + + #[config(default = 30)] + pub num_epochs: usize, + #[config(default = 32)] + pub batch_size: usize, + #[config(default = 2)] + pub num_workers: usize, + #[config(default = 1e-3)] + pub lr: f64, +} + +// Create the directory to save the model and model config +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Save training config + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + B::seed(RANDOM_SEED); + + // Create the model and optimizer + let mut model = config.model.init::(&device); + let mut optim = config.optimizer.init::>(); + + // Create the batcher + let batcher_train = SequenceBatcher::::new(device.clone()); + let batcher_valid = SequenceBatcher::::new(device.clone()); + + // Create the dataloaders + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(RANDOM_SEED) + .num_workers(config.num_workers) + .build(SequenceDataset::new(NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL)); + + let dataloader_valid = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(RANDOM_SEED) + .num_workers(config.num_workers) + // 20% size of training + .build(SequenceDataset::new( + NUM_SEQUENCES / 5, + SEQ_LENGTH, + NOISE_LEVEL, + )); + + let train_num_items = dataloader_train.num_items(); + let valid_num_items = dataloader_valid.num_items(); + + println!("Starting training..."); + // Iterate over our training for X epochs + for epoch in 1..config.num_epochs + 1 { + // Initialize the training and validation metrics at the start of each epoch + let mut train_losses = vec![]; + let mut train_loss = 0.0; + let mut valid_losses = vec![]; + let mut valid_loss = 0.0; + + // Implement our training loop + for batch in dataloader_train.iter() { + let output = model.forward(batch.sequences, None); + let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); + train_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; + + // Gradients for the current backward pass + let grads = loss.backward(); + // Gradients linked to each parameter of the model + let grads = GradientsParams::from_grads(grads, &model); + // Update the model using the optimizer + model = optim.step(config.lr, model, grads); + } + + // The averaged train loss per epoch + let avg_train_loss = train_loss / train_num_items as f32; + train_losses.push(avg_train_loss); + + // Get the model without autodiff + let valid_model = model.valid(); + + // Implement our validation loop + for batch in dataloader_valid.iter() { + let output = valid_model.forward(batch.sequences, None); + let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); + valid_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; + } + // The averaged train loss per epoch + let avg_valid_loss = valid_loss / valid_num_items as f32; + valid_losses.push(avg_valid_loss); + + // Display the averaged training and validataion metrics every 10 epochs + if (epoch + 1) % 5 == 0 { + println!( + "Epoch {}/{}, Avg Loss {:.4}, Avg Val Loss: {:.4}", + epoch + 1, + config.num_epochs, + avg_train_loss, + avg_valid_loss, + ); + } + } + + // Save the trained model + model + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Trained model should be saved successfully"); +} diff --git a/examples/pytorch-import/Cargo.toml b/examples/pytorch-import/Cargo.toml index a7b3305689..dd2b56e92d 100644 --- a/examples/pytorch-import/Cargo.toml +++ b/examples/pytorch-import/Cargo.toml @@ -4,7 +4,7 @@ edition = "2021" license = "MIT OR Apache-2.0" name = "pytorch-import" publish = false -version = "0.16.0" +version = "0.17.0" [dependencies] burn = { path = "../../crates/burn", features = [ diff --git a/examples/pytorch-import/model/Cargo.toml b/examples/pytorch-import/model/Cargo.toml index 894ac7e48f..f2678bfcbc 100644 --- a/examples/pytorch-import/model/Cargo.toml +++ b/examples/pytorch-import/model/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "model" -version = "0.5.0" +version = "0.6.0" edition = "2021" [dependencies] diff --git a/examples/raspberry-pi-pico/Cargo.lock b/examples/raspberry-pi-pico/Cargo.lock index 2cbc8fb721..a2f5e866d3 100644 --- a/examples/raspberry-pi-pico/Cargo.lock +++ b/examples/raspberry-pi-pico/Cargo.lock @@ -286,7 +286,7 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "burn-train", @@ -294,7 +294,7 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -305,7 +305,7 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-tensor", "candle-core", @@ -315,7 +315,7 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.16.0" +version = "0.17.0" dependencies = [ "cubecl-common", "data-encoding", @@ -326,7 +326,7 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bincode", "burn-autodiff", @@ -357,7 +357,7 @@ dependencies = [ [[package]] name = "burn-cuda" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -371,7 +371,7 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "csv", "derive-new", @@ -395,7 +395,7 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.16.0" +version = "0.17.0" dependencies = [ "derive-new", "proc-macro2", @@ -405,7 +405,7 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -418,7 +418,7 @@ dependencies = [ [[package]] name = "burn-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "candle-core", @@ -441,7 +441,7 @@ dependencies = [ [[package]] name = "burn-jit" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-fusion", @@ -461,7 +461,7 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -478,7 +478,7 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-tensor", "half", @@ -489,7 +489,7 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "bytemuck", @@ -507,7 +507,7 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "crossterm", @@ -525,7 +525,7 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -2959,7 +2959,7 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "onnx-ir" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bytemuck", "half", diff --git a/examples/raspberry-pi-pico/src/bin/main.rs b/examples/raspberry-pi-pico/src/bin/main.rs index 1b7f6acdf0..a502a8193e 100644 --- a/examples/raspberry-pi-pico/src/bin/main.rs +++ b/examples/raspberry-pi-pico/src/bin/main.rs @@ -10,7 +10,7 @@ use embassy_rp as _; use embedded_alloc::Heap; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; #[global_allocator] static HEAP: Heap = Heap::empty(); @@ -25,7 +25,7 @@ async fn main(_spawner: Spawner) { } // Get a default device for the backend - let device = BackendDeice::default(); + let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); @@ -47,7 +47,7 @@ async fn main(_spawner: Spawner) { } } -fn run_model<'a>(model: &Model, device: &BackendDeice, input: f32) -> Tensor { +fn run_model<'a>(model: &Model, device: &BackendDevice, input: f32) -> Tensor { // Define the tensor let input = Tensor::::from_floats([[input]], &device); diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 5d06497e08..f9f80bdb8d 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -7,12 +7,12 @@ publish = false version.workspace = true [features] -default = ["wgpu"] -cuda-jit = ["burn/cuda-jit"] -wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +default = ["webgpu"] +cuda = ["burn/cuda"] +webgpu = ["burn/webgpu"] +vulkan = ["burn/vulkan"] ndarray = ["burn/ndarray"] [dependencies] cfg-if = { workspace = true } -burn = { path = "../../crates/burn", version = "0.16.0", features = ["server"] } +burn = { path = "../../crates/burn", version = "0.17.0", features = ["server"] } diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs index 92cba57a2a..014a5e2cf5 100644 --- a/examples/server/src/lib.rs +++ b/examples/server/src/lib.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "141"] + pub fn start() { let port = std::env::var("REMOTE_BACKEND_PORT") .map(|port| match port.parse::() { @@ -9,10 +11,12 @@ pub fn start() { cfg_if::cfg_if! { if #[cfg(feature = "ndarray")]{ burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "cuda-jit")]{ - burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "wgpu")] { - burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "cuda")]{ + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "webgpu")] { + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "vulkan")] { + burn::server::start::(Default::default(), port); } else { panic!("No backend selected, can't start server on port {port}"); } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 4ec5d7c89a..043c61672d 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -16,10 +16,10 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +vulkan = ["wgpu", "burn/vulkan"] remote = ["burn/remote"] -cuda-jit = ["burn/cuda-jit"] -hip-jit = ["burn/hip-jit"] +cuda = ["burn/cuda"] +hip = ["burn/hip"] [dependencies] # Burn diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index 8bc611361f..9d62606706 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -102,6 +102,6 @@ cd burn # Use the --release flag to really speed up training. # AG News -cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset -cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset +cargo run --example ag-news-train --release --features cuda # Train on the ag news dataset +cargo run --example ag-news-infer --release --features cuda # Run inference on the ag news dataset ``` diff --git a/examples/text-classification/examples/ag-news-infer.rs b/examples/text-classification/examples/ag-news-infer.rs index 9af5c6c6eb..77626e0b60 100644 --- a/examples/text-classification/examples/ag-news-infer.rs +++ b/examples/text-classification/examples/ag-news-infer.rs @@ -81,13 +81,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{cuda_jit::CudaDevice, CudaJit}; + use burn::backend::{cuda::CudaDevice, Cuda}; pub fn run() { - launch::>(CudaDevice::default()); + launch::>(CudaDevice::default()); } } @@ -105,6 +105,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index bf12a0b6d9..927c190b2c 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "256"] + use burn::{ nn::transformer::TransformerEncoderConfig, optim::{decay::WeightDecayConfig, AdamConfig}, @@ -101,23 +103,23 @@ mod remote { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, CudaJit}; + use burn::backend::{Autodiff, Cuda}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } -#[cfg(feature = "hip-jit")] -mod hip_jit { +#[cfg(feature = "hip")] +mod hip { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, HipJit}; + use burn::backend::{Autodiff, Hip}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } @@ -135,10 +137,10 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); - #[cfg(feature = "hip-jit")] - hip_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); + #[cfg(feature = "hip")] + hip::run(); #[cfg(feature = "remote")] remote::run(); } diff --git a/examples/text-classification/examples/db-pedia-infer.rs b/examples/text-classification/examples/db-pedia-infer.rs index 490ed3b97e..027eb76122 100644 --- a/examples/text-classification/examples/db-pedia-infer.rs +++ b/examples/text-classification/examples/db-pedia-infer.rs @@ -1,6 +1,6 @@ use text_classification::DbPediaDataset; -use burn::tensor::backend::AutodiffBackend; +use burn::tensor::backend::Backend; #[cfg(not(feature = "f16"))] #[allow(dead_code)] @@ -8,7 +8,7 @@ type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; -pub fn launch(device: B::Device) { +pub fn launch(device: B::Device) { text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", @@ -34,24 +34,18 @@ pub fn launch(device: B::Device) { feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::{ - ndarray::{NdArray, NdArrayDevice}, - Autodiff, - }; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(NdArrayDevice::Cpu); + launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; @@ -61,35 +55,29 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>>(device); + launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::{ - tch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::tch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(LibTorchDevice::Cpu); + launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::{ - wgpu::{Wgpu, WgpuDevice}, - Autodiff, - }; + use burn::backend::wgpu::{Wgpu, WgpuDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(WgpuDevice::default()); + launch::>(WgpuDevice::default()); } } diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml new file mode 100644 index 0000000000..d6ee6345b1 --- /dev/null +++ b/examples/wgan/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "wgan" +version = "0.1.0" +edition = "2021" + +[features] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] +cuda = ["burn/cuda"] + +[dependencies] +burn = { path = "../../crates/burn", features=["train", "vision"] } +image = { workspace = true } diff --git a/examples/wgan/README.md b/examples/wgan/README.md new file mode 100644 index 0000000000..0828145f61 --- /dev/null +++ b/examples/wgan/README.md @@ -0,0 +1,40 @@ +# Wasserstein Generative Adversarial Network + +A burn implementation of examplar WGAN model to generate MNIST digits inspired by +[the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html). +Please note that better performance maybe gained by adopting a convolution layer in +[some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch). + +## Usage + + +## Training + +```sh +# Cuda backend +cargo run --example wgan-mnist --release --features cuda + +# Wgpu backend +cargo run --example wgan-mnist --release --features wgpu + +# Tch GPU backend +export TORCH_CUDA_VERSION=cu121 # Set the cuda version +cargo run --example wgan-mnist --release --features tch-gpu + +# Tch CPU backend +cargo run --example wgan-mnist --release --features tch-cpu + +# NdArray backend (CPU) +cargo run --example wgan-mnist --release --features ndarray # f32 - single thread +cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas +cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib +``` + + +### Generating + +To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. + +```sh +cargo run --example wgan-generate --release --features cuda +``` diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs new file mode 100644 index 0000000000..1d0a4fd87d --- /dev/null +++ b/examples/wgan/examples/wgan-generate.rs @@ -0,0 +1,86 @@ +use burn::tensor::backend::Backend; + +pub fn launch(device: B::Device) { + wgan::infer::generate::("/tmp/wgan-mnist", device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + + use crate::launch; + + pub fn run() { + launch::(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + launch::(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::wgpu::Wgpu; + + pub fn run() { + launch::(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::Cuda; + + pub fn run() { + launch::(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/wgan/examples/wgan-mnist.rs b/examples/wgan/examples/wgan-mnist.rs new file mode 100644 index 0000000000..787acfec94 --- /dev/null +++ b/examples/wgan/examples/wgan-mnist.rs @@ -0,0 +1,107 @@ +use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend}; + +use wgan::{model::ModelConfig, training::TrainingConfig}; + +pub fn launch(device: B::Device) { + let config = TrainingConfig::new( + ModelConfig::new(), + RmsPropConfig::new() + .with_alpha(0.99) + .with_momentum(0.0) + .with_epsilon(0.00000008) + .with_weight_decay(None) + .with_centered(false), + ); + + wgan::training::train::("/tmp/wgan-mnist", config, device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::{cuda::CudaDevice, Autodiff, Cuda}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/wgan/src/dataset.rs b/examples/wgan/src/dataset.rs new file mode 100644 index 0000000000..46848d4ffb --- /dev/null +++ b/examples/wgan/src/dataset.rs @@ -0,0 +1,49 @@ +use burn::{ + data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, + prelude::*, +}; + +#[derive(Clone, Debug)] +pub struct MnistBatcher { + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct MnistBatch { + pub images: Tensor, + pub targets: Tensor, +} + +impl MnistBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +impl Batcher> for MnistBatcher { + fn batch(&self, items: Vec) -> MnistBatch { + let images = items + .iter() + .map(|item| TensorData::from(item.image)) + .map(|data| Tensor::::from_data(data.convert::(), &self.device)) + .map(|tensor| tensor.reshape([1, 28, 28])) + // Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example + .map(|tensor| ((tensor / 255) - 0.5) / 0.5) + .collect(); + + let targets = items + .iter() + .map(|item| { + Tensor::::from_data( + TensorData::from([(item.label as i64).elem::()]), + &self.device, + ) + }) + .collect(); + + let images = Tensor::stack(images, 0); + let targets = Tensor::cat(targets, 0); + + MnistBatch { images, targets } + } +} diff --git a/examples/wgan/src/infer.rs b/examples/wgan/src/infer.rs new file mode 100644 index 0000000000..25ca984feb --- /dev/null +++ b/examples/wgan/src/infer.rs @@ -0,0 +1,41 @@ +use crate::training::{save_image, TrainingConfig}; +use burn::{ + prelude::*, + record::{CompactRecorder, Recorder}, + tensor::Distribution, +}; + +pub fn generate(artifact_dir: &str, device: B::Device) { + // Loading model + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/generator").into(), &device) + .expect("Trained model should exist; run train first"); + let (mut generator, _) = config.model.init::(&device); + generator = generator.load_record(record); + + // Get a batch of noise + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + let fake_images = generator.forward(noise); // [batch_size, channesl*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + save_image::(fake_images, 5, format!("{artifact_dir}/fake_image.png")).unwrap(); +} diff --git a/examples/wgan/src/lib.rs b/examples/wgan/src/lib.rs new file mode 100644 index 0000000000..021f62278a --- /dev/null +++ b/examples/wgan/src/lib.rs @@ -0,0 +1,4 @@ +pub mod dataset; +pub mod infer; +pub mod model; +pub mod training; diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs new file mode 100644 index 0000000000..755d8e9e1d --- /dev/null +++ b/examples/wgan/src/model.rs @@ -0,0 +1,157 @@ +use burn::{ + module::{Module, ModuleMapper, ParamId}, + nn::BatchNorm, + prelude::*, + tensor::backend::AutodiffBackend, +}; + +/// Layer block of generator model +#[derive(Module, Debug)] +pub struct LayerBlock { + fc: nn::Linear, + bn: nn::BatchNorm, + leakyrelu: nn::LeakyRelu, +} + +impl LayerBlock { + pub fn new(input: usize, output: usize, device: &B::Device) -> Self { + let fc = nn::LinearConfig::new(input, output) + .with_bias(true) + .init(device); + let bn: BatchNorm = nn::BatchNormConfig::new(output) + .with_epsilon(0.8) + .init(device); + let leakyrelu = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + + Self { fc, bn, leakyrelu } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.fc.forward(input); // output: [Batch, x] + let output = self.bn.forward(output); // output: [Batch, x] + + self.leakyrelu.forward(output) // output: [Batch, x] + } +} + +/// Generator model +#[derive(Module, Debug)] +pub struct Generator { + layer1: LayerBlock, + layer2: LayerBlock, + layer3: LayerBlock, + layer4: LayerBlock, + fc: nn::Linear, + tanh: nn::Tanh, +} + +impl Generator { + /// Applies the forward pass on the input tensor by specified order + pub fn forward(&self, noise: Tensor) -> Tensor { + let output = self.layer1.forward(noise); + let output = self.layer2.forward(output); + let output = self.layer3.forward(output); + let output = self.layer4.forward(output); + let output = self.fc.forward(output); + + self.tanh.forward(output) // [batch_size, channels*height*width] + } +} + +/// Discriminator model +#[derive(Module, Debug)] +pub struct Discriminator { + fc1: nn::Linear, + leakyrelu1: nn::LeakyRelu, + fc2: nn::Linear, + leakyrelu2: nn::LeakyRelu, + fc3: nn::Linear, +} + +impl Discriminator { + /// Applies the forward pass on the input tensor by specified order. + /// The input image shape is [batch, channels, height, width] + pub fn forward(&self, images: Tensor) -> Tensor { + // Full connection for each batch + let output = images.flatten(1, 3); // output: [batch, channels*height*width] + let output = self.fc1.forward(output); // output: [batch, 512] + let output = self.leakyrelu1.forward(output); // output: [batch, 512] + let output = self.fc2.forward(output); // output: [batch, 256] + let output = self.leakyrelu2.forward(output); // output: [batch, 256] + + self.fc3.forward(output) // output: [batch, 1] + } +} + +// Use model config to construct a generative and adversarial model +#[derive(Config, Debug)] +pub struct ModelConfig { + /// Dimensionality of the latent space + #[config(default = 100)] + pub latent_dim: usize, + #[config(default = 28)] + pub image_size: usize, + #[config(default = 1)] + pub channels: usize, +} + +impl ModelConfig { + /// Initialize the generator and discriminator models based on the config. + pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { + // Construct the initialized generator + let layer1 = LayerBlock::new(self.latent_dim, 128, device); + let layer2 = LayerBlock::new(128, 256, device); + let layer3 = LayerBlock::new(256, 512, device); + let layer4 = LayerBlock::new(512, 1024, device); + let fc = nn::LinearConfig::new(1024, self.channels * self.image_size * self.image_size) + .with_bias(true) + .init(device); + + let generator = Generator { + layer1, + layer2, + layer3, + layer4, + fc, + tanh: nn::Tanh::new(), + }; + + // Construct the initialized discriminator + let fc1 = nn::LinearConfig::new(self.channels * self.image_size * self.image_size, 512) + .init(device); + let leakyrelu1 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc2 = nn::LinearConfig::new(512, 256).init(device); + let leakyrelu2 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc3 = nn::LinearConfig::new(256, 1).init(device); + + let discriminator = Discriminator { + fc1, + leakyrelu1, + fc2, + leakyrelu2, + fc3, + }; + + (generator, discriminator) + } +} + +/// Clip module mapper to clip all module parameters between a range of values +#[derive(Module, Clone, Debug)] +pub struct Clip { + pub min: f32, + pub max: f32, +} + +impl ModuleMapper for Clip { + fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { + let is_require_grad = tensor.is_require_grad(); + + let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max)); + + if is_require_grad { + tensor = tensor.require_grad(); + } + tensor + } +} diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs new file mode 100644 index 0000000000..25fbef21c1 --- /dev/null +++ b/examples/wgan/src/training.rs @@ -0,0 +1,211 @@ +use crate::dataset::MnistBatcher; +use crate::model::{Clip, ModelConfig}; +use burn::optim::{GradientsParams, Optimizer, RmsPropConfig}; +use burn::{ + data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, + prelude::*, + record::CompactRecorder, + tensor::{backend::AutodiffBackend, Distribution}, +}; +use image::{buffer::ConvertBuffer, error::ImageResult, Rgb32FImage, RgbImage}; +use std::path::Path; + +#[derive(Config)] +pub struct TrainingConfig { + pub model: ModelConfig, + pub optimizer: RmsPropConfig, + + #[config(default = 200)] + pub num_epochs: usize, + #[config(default = 512)] + pub batch_size: usize, + #[config(default = 8)] + pub num_workers: usize, + #[config(default = 5)] + pub seed: u64, + #[config(default = 3e-4)] + pub lr: f64, + + /// Number of training steps for discriminator before generator is trained per iteration + #[config(default = 5)] + pub num_critic: usize, + /// Lower and upper clip value for disc. weights + #[config(default = 0.01)] + pub clip_value: f32, + /// Save a sample of images every `sample_interval` epochs + #[config(default = 10)] + pub sample_interval: usize, +} + +// Create the directory to save the model and model config +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +/// Save the generated images +// The images format is [B, H, W, C] +pub fn save_image>( + images: Tensor, + nrow: u32, + path: Q, +) -> ImageResult<()> { + let ncol = (images.dims()[0] as f32 / nrow as f32).ceil() as u32; + + let width = images.dims()[2] as u32; + let height = images.dims()[1] as u32; + + // Supports both 1 and 3 channels image + let channels = match images.dims()[3] { + 1 => 3, + 3 => 1, + _ => panic!("Wrong channels number"), + }; + + let mut imgbuf = RgbImage::new(nrow * width, ncol * height); + // Write images into a nrow*ncol grid layout + for row in 0..nrow { + for col in 0..ncol { + let image: Tensor = images + .clone() + .slice((row * nrow + col) as usize..(row * nrow + col + 1) as usize) + .squeeze(0); + // The Rgb32 should be in range 0.0-1.0 + let image = image.into_data().iter::().collect::>(); + // Supports both 1 and 3 channels image + let image = image + .into_iter() + .flat_map(|n| std::iter::repeat(n).take(channels)) + .collect(); + + let image = Rgb32FImage::from_vec(width, height, image).unwrap(); + let image: RgbImage = image.convert(); + for (x, y, pixel) in image.enumerate_pixels() { + imgbuf.put_pixel(row * width + x, col * height + y, *pixel); + } + } + } + imgbuf.save(path) +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Create the Clip module mapper + let mut clip = Clip { + min: -config.clip_value, + max: config.clip_value, + }; + + // Save training config + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + B::seed(config.seed); + + // Create the model and optimizer + let (mut generator, mut discriminator) = config.model.init::(&device); + let mut optimizer_g = config.optimizer.init(); + let mut optimizer_d = config.optimizer.init(); + + // Create the dataset batcher + let batcher_train = MnistBatcher::::new(device.clone()); + + // Create the dataloaders + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::train()); + + // Iterate over our training for X epochs + for epoch in 0..config.num_epochs { + // Implement our training loop + for (iteration, batch) in dataloader_train.iter().enumerate() { + // Generate a batch of fake images from noise (standarded normal distribution) + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + // datach: do not update gerenator, only discriminator is updated + let fake_images = generator.forward(noise.clone()).detach(); // [batch_size, channels*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss + let loss_d = -discriminator.forward(batch.images).mean() + + discriminator.forward(fake_images.clone()).mean(); + + // Gradients for the current backward pass + let grads = loss_d.backward(); + // Gradients linked to each parameter of the discriminator + let grads = GradientsParams::from_grads(grads, &discriminator); + // Update the discriminator using the optimizer + discriminator = optimizer_d.step(config.lr, discriminator, grads); + // Clip parameters (weights) of discriminator + discriminator = discriminator.map(&mut clip); + + // Train the generator every num_critic iterations + if iteration % config.num_critic == 0 { + // Generate a batch of images again without detaching + let critic_fake_images = generator.forward(noise.clone()); + let critic_fake_images = critic_fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss. Minimize it to make the fake images as truth + let loss_g = -discriminator.forward(critic_fake_images).mean(); + + let grads = loss_g.backward(); + let grads = GradientsParams::from_grads(grads, &generator); + generator = optimizer_g.step(config.lr, generator, grads); + + // Print the progression + let batch_num = (dataloader_train.num_items() as f32 / config.batch_size as f32) + .ceil() as usize; + println!( + "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}]", + epoch + 1, + config.num_epochs, + iteration, + batch_num, + loss_d.into_scalar(), + loss_g.into_scalar() + ); + } + // If at save interval => save the first 25 generated images + if epoch % config.sample_interval == 0 && iteration == 0 { + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() + - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5/255.0 to the images, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + let path = format!("{artifact_dir}/image-{}.png", epoch); + save_image::(fake_images, 5, path).unwrap(); + } + } + } + + // Save the trained models + generator + .save_file(format!("{artifact_dir}/generator"), &CompactRecorder::new()) + .expect("Generator should be saved successfully"); + discriminator + .save_file( + format!("{artifact_dir}/discriminator"), + &CompactRecorder::new(), + ) + .expect("Discriminator should be saved successfully"); +} diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index ce796eb7b1..63ac5e4c70 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xtask" -version = "1.1.0" +version = "1.2.0" edition = "2021" license = "MIT OR Apache-2.0" diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 47e50f80ed..5b94b2909e 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -83,7 +83,7 @@ pub(crate) fn handle_command( vec!["--features", "test-wgpu-spirv"], None, None, - "std wgpu-spirv", + "std vulkan", )?; } }