diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 8b8916c631..de956c243e 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -6,6 +6,31 @@ on:
- "v*"
jobs:
+ publish-burn-router:
+ uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
+ with:
+ crate: burn-router
+ needs:
+ - publish-burn-common
+ - publish-burn-tensor
+ # dev dependencies
+ - publish-burn-autodiff
+ - publish-burn-ndarray
+ - publish-burn-wgpu
+ secrets:
+ CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
+
+ publish-burn-remote:
+ uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
+ with:
+ crate: burn-remote
+ needs:
+ - publish-burn-common
+ - publish-burn-tensor
+ - publish-burn-router
+ secrets:
+ CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
+
publish-burn-derive:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
@@ -162,6 +187,7 @@ jobs:
- publish-burn-tch
- publish-burn-ndarray
- publish-burn-candle
+ - publish-burn-remote
with:
crate: burn-core
secrets:
diff --git a/Cargo.lock b/Cargo.lock
index 7b555924fc..c9ce522699 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -41,7 +41,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
- "getrandom",
+ "getrandom 0.2.15",
"once_cell",
"version_check",
"zerocopy",
@@ -62,21 +62,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
-[[package]]
-name = "alloc-no-stdlib"
-version = "2.0.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3"
-
-[[package]]
-name = "alloc-stdlib"
-version = "0.2.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece"
-dependencies = [
- "alloc-no-stdlib",
-]
-
[[package]]
name = "allocator-api2"
version = "0.2.21"
@@ -139,19 +124,20 @@ dependencies = [
[[package]]
name = "anstyle-wincon"
-version = "3.0.6"
+version = "3.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125"
+checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e"
dependencies = [
"anstyle",
+ "once_cell",
"windows-sys 0.59.0",
]
[[package]]
name = "anyhow"
-version = "1.0.94"
+version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7"
+checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04"
[[package]]
name = "arbitrary"
@@ -188,7 +174,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -206,12 +192,6 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76"
-[[package]]
-name = "arrayref"
-version = "0.3.9"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
-
[[package]]
name = "arrayvec"
version = "0.7.6"
@@ -269,34 +249,25 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "async-trait"
-version = "0.1.83"
+version = "0.1.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
+checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
-]
-
-[[package]]
-name = "atoi"
-version = "2.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528"
-dependencies = [
- "num-traits",
+ "syn 2.0.98",
]
[[package]]
name = "atoi_simd"
-version = "0.15.6"
+version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9"
+checksum = "4790f9e8961209112beb783d85449b508673cf4a6a419c8449b210743ac4dbe9"
[[package]]
name = "atomic-waker"
@@ -310,17 +281,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a"
-[[package]]
-name = "atty"
-version = "0.2.14"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
-dependencies = [
- "hermit-abi 0.1.19",
- "libc",
- "winapi",
-]
-
[[package]]
name = "autocfg"
version = "1.4.0"
@@ -352,19 +312,19 @@ dependencies = [
[[package]]
name = "axum"
-version = "0.7.9"
+version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
+checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8"
dependencies = [
- "async-trait",
"axum-core",
"base64 0.22.1",
"bytes",
+ "form_urlencoded",
"futures-util",
- "http 1.2.0",
- "http-body 1.0.1",
+ "http",
+ "http-body",
"http-body-util",
- "hyper 1.5.2",
+ "hyper",
"hyper-util",
"itoa",
"matchit",
@@ -378,9 +338,9 @@ dependencies = [
"serde_path_to_error",
"serde_urlencoded",
"sha1",
- "sync_wrapper 1.0.2",
+ "sync_wrapper",
"tokio",
- "tokio-tungstenite 0.24.0",
+ "tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@@ -389,20 +349,19 @@ dependencies = [
[[package]]
name = "axum-core"
-version = "0.4.5"
+version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
+checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733"
dependencies = [
- "async-trait",
"bytes",
"futures-util",
- "http 1.2.0",
- "http-body 1.0.1",
+ "http",
+ "http-body",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
- "sync_wrapper 1.0.2",
+ "sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
@@ -410,31 +369,31 @@ dependencies = [
[[package]]
name = "backend-comparison"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"arboard",
"burn",
"burn-common",
- "clap 4.5.23",
+ "chrono",
+ "clap",
"colored",
"cubecl",
"derive-new 0.7.0",
"dirs",
- "github-device-flow",
"half",
"indicatif",
"log",
"os_info",
"percent-encoding",
"rand",
- "reqwest 0.12.12",
+ "reqwest",
"rstest",
"serde",
"serde_json",
"serial_test",
"strum",
"strum_macros",
- "sysinfo 0.32.1",
+ "sysinfo",
"tracing-subscriber",
"wgpu",
"wsl",
@@ -461,12 +420,6 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
-[[package]]
-name = "base64"
-version = "0.21.7"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
-
[[package]]
name = "base64"
version = "0.22.1"
@@ -528,9 +481,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
-version = "2.6.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
+checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
dependencies = [
"serde",
]
@@ -541,19 +494,6 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2"
-[[package]]
-name = "blake3"
-version = "1.5.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e"
-dependencies = [
- "arrayref",
- "arrayvec",
- "cc",
- "cfg-if",
- "constant_time_eq 0.3.1",
-]
-
[[package]]
name = "blas-src"
version = "0.10.0"
@@ -589,32 +529,11 @@ dependencies = [
"objc2",
]
-[[package]]
-name = "brotli"
-version = "6.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b"
-dependencies = [
- "alloc-no-stdlib",
- "alloc-stdlib",
- "brotli-decompressor",
-]
-
-[[package]]
-name = "brotli-decompressor"
-version = "4.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362"
-dependencies = [
- "alloc-no-stdlib",
- "alloc-stdlib",
-]
-
[[package]]
name = "bstr"
-version = "1.11.1"
+version = "1.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "786a307d683a5bf92e6fd5fd69a7eb613751668d1d8d67d802846dfe367c62c8"
+checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0"
dependencies = [
"memchr",
"serde",
@@ -628,13 +547,13 @@ checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b"
[[package]]
name = "bumpalo"
-version = "3.16.0"
+version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
+checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
[[package]]
name = "burn"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-core",
"burn-train",
@@ -642,7 +561,7 @@ dependencies = [
[[package]]
name = "burn-autodiff"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-common",
"burn-tensor",
@@ -654,7 +573,7 @@ dependencies = [
[[package]]
name = "burn-candle"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-autodiff",
"burn-tch",
@@ -666,14 +585,14 @@ dependencies = [
[[package]]
name = "burn-common"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
- "cubecl-common 0.4.0",
+ "cubecl-common",
"dashmap",
- "getrandom",
+ "getrandom 0.2.15",
"indicatif",
"rayon",
- "reqwest 0.12.12",
+ "reqwest",
"serde",
"tokio",
"web-time",
@@ -681,7 +600,7 @@ dependencies = [
[[package]]
name = "burn-core"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"ahash",
"bincode",
@@ -713,13 +632,13 @@ dependencies = [
"serde_json",
"spin",
"tempfile",
- "thiserror 2.0.9",
+ "thiserror 2.0.11",
"uuid",
]
[[package]]
name = "burn-cuda"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-fusion",
"burn-jit",
@@ -734,7 +653,7 @@ dependencies = [
[[package]]
name = "burn-dataset"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-common",
"csv",
@@ -761,22 +680,22 @@ dependencies = [
"strum",
"strum_macros",
"tempfile",
- "thiserror 2.0.9",
+ "thiserror 2.0.11",
]
[[package]]
name = "burn-derive"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"derive-new 0.7.0",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "burn-fusion"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-common",
"burn-tensor",
@@ -790,7 +709,7 @@ dependencies = [
[[package]]
name = "burn-hip"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-fusion",
"burn-jit",
@@ -805,9 +724,10 @@ dependencies = [
[[package]]
name = "burn-import"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
+ "burn-ndarray",
"candle-core",
"derive-new 0.7.0",
"half",
@@ -821,8 +741,8 @@ dependencies = [
"rust-format",
"serde",
"serde_json",
- "syn 2.0.95",
- "thiserror 2.0.9",
+ "syn 2.0.98",
+ "thiserror 2.0.11",
"tracing-core",
"tracing-subscriber",
"zip 2.2.2",
@@ -830,7 +750,7 @@ dependencies = [
[[package]]
name = "burn-jit"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-autodiff",
"burn-common",
@@ -856,7 +776,7 @@ dependencies = [
[[package]]
name = "burn-ndarray"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"atomic_float",
"blas-src",
@@ -876,7 +796,7 @@ dependencies = [
[[package]]
name = "burn-no-std-tests"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"burn-ndarray",
@@ -885,12 +805,11 @@ dependencies = [
[[package]]
name = "burn-remote"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"async-channel",
"axum",
"burn-common",
- "burn-remote",
"burn-router",
"burn-tensor",
"derive-new 0.7.0",
@@ -900,14 +819,14 @@ dependencies = [
"serde",
"serde_bytes",
"tokio",
- "tokio-tungstenite 0.26.1",
+ "tokio-tungstenite",
"tracing-core",
"tracing-subscriber",
]
[[package]]
name = "burn-router"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-autodiff",
"burn-common",
@@ -921,7 +840,7 @@ dependencies = [
[[package]]
name = "burn-tch"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-autodiff",
"burn-tensor",
@@ -934,7 +853,7 @@ dependencies = [
[[package]]
name = "burn-tensor"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"bincode",
"burn-common",
@@ -955,7 +874,7 @@ dependencies = [
[[package]]
name = "burn-tensor-testgen"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"proc-macro2",
"quote",
@@ -963,7 +882,7 @@ dependencies = [
[[package]]
name = "burn-train"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"async-channel",
"burn-core",
@@ -974,7 +893,7 @@ dependencies = [
"ratatui",
"rstest",
"serde",
- "sysinfo 0.32.1",
+ "sysinfo",
"systemstat",
"tracing-appender",
"tracing-core",
@@ -983,7 +902,7 @@ dependencies = [
[[package]]
name = "burn-wgpu"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn-fusion",
"burn-jit",
@@ -1004,13 +923,13 @@ dependencies = [
[[package]]
name = "bytemuck_derive"
-version = "1.8.0"
+version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec"
+checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -1027,9 +946,12 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]]
name = "bytes"
-version = "1.9.0"
+version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b"
+checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9"
+dependencies = [
+ "serde",
+]
[[package]]
name = "bytesize"
@@ -1060,9 +982,9 @@ dependencies = [
[[package]]
name = "candle-core"
-version = "0.8.1"
+version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d1e306c8a4276ba57ce9fac76d823cc8c8a7fca14bf222ac20ad8b12c4273152"
+checksum = "855dfedff437d2681d68e1f34ae559d88b0dd84aa5a6b63f2c8e75ebdd875bbf"
dependencies = [
"accelerate-src",
"byteorder",
@@ -1072,7 +994,7 @@ dependencies = [
"gemm",
"half",
"libc",
- "memmap2 0.9.5",
+ "memmap2",
"metal 0.27.0",
"num-traits",
"num_cpus",
@@ -1090,18 +1012,18 @@ dependencies = [
[[package]]
name = "candle-kernels"
-version = "0.8.1"
+version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cbd8ea6588f3c6286ea89a52dad3365f0536fd0b71e729fa998cc2347f1df3b6"
+checksum = "53343628fa470b7075c28c589b98735b4220b464e37ddbb8e117040e199f4787"
dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-metal-kernels"
-version = "0.8.1"
+version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cbc6621c7e2202f4f129bcc3185c2c6d4fa2fc6b8f3f2b07eaf7c06042910c83"
+checksum = "50fa64274a009a5d95c542b10bf3a4ea809bd394654c6ae99233bcc35b3a33ef"
dependencies = [
"metal 0.27.0",
"once_cell",
@@ -1135,9 +1057,9 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.2.4"
+version = "1.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf"
+checksum = "e4730490333d58093109dc02c23174c3f4d490998c3fed3cc8e82d57afedb9cf"
dependencies = [
"jobserver",
"libc",
@@ -1160,12 +1082,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
-[[package]]
-name = "cfg_aliases"
-version = "0.1.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
-
[[package]]
name = "cfg_aliases"
version = "0.2.1"
@@ -1182,16 +1098,15 @@ dependencies = [
"iana-time-zone",
"js-sys",
"num-traits",
- "serde",
"wasm-bindgen",
"windows-targets 0.52.6",
]
[[package]]
name = "chrono-tz"
-version = "0.8.6"
+version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e"
+checksum = "9c6ac4f2c0bf0f44e9161aec9675e1050aa4a530663c4a9e37e108fa948bca9f"
dependencies = [
"chrono",
"chrono-tz-build",
@@ -1200,42 +1115,14 @@ dependencies = [
[[package]]
name = "chrono-tz-build"
-version = "0.2.1"
+version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f"
+checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7"
dependencies = [
"parse-zoneinfo",
- "phf",
"phf_codegen",
]
-[[package]]
-name = "ciborium"
-version = "0.2.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
-dependencies = [
- "ciborium-io",
- "ciborium-ll",
- "serde",
-]
-
-[[package]]
-name = "ciborium-io"
-version = "0.2.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
-
-[[package]]
-name = "ciborium-ll"
-version = "0.2.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
-dependencies = [
- "ciborium-io",
- "half",
-]
-
[[package]]
name = "cipher"
version = "0.4.4"
@@ -1248,75 +1135,36 @@ dependencies = [
[[package]]
name = "clap"
-version = "3.2.25"
+version = "4.5.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
-dependencies = [
- "atty",
- "bitflags 1.3.2",
- "clap_derive 3.2.25",
- "clap_lex 0.2.4",
- "indexmap 1.9.3",
- "once_cell",
- "strsim 0.10.0",
- "termcolor",
- "textwrap",
-]
-
-[[package]]
-name = "clap"
-version = "4.5.23"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84"
+checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796"
dependencies = [
"clap_builder",
- "clap_derive 4.5.18",
+ "clap_derive",
]
[[package]]
name = "clap_builder"
-version = "4.5.23"
+version = "4.5.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838"
+checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7"
dependencies = [
"anstream",
"anstyle",
- "clap_lex 0.7.4",
- "strsim 0.11.1",
-]
-
-[[package]]
-name = "clap_derive"
-version = "3.2.25"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008"
-dependencies = [
- "heck 0.4.1",
- "proc-macro-error",
- "proc-macro2",
- "quote",
- "syn 1.0.109",
+ "clap_lex",
+ "strsim",
]
[[package]]
name = "clap_derive"
-version = "4.5.18"
+version = "4.5.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
+checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c"
dependencies = [
- "heck 0.5.0",
+ "heck",
"proc-macro2",
"quote",
- "syn 2.0.95",
-]
-
-[[package]]
-name = "clap_lex"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
-dependencies = [
- "os_str_bytes",
+ "syn 2.0.98",
]
[[package]]
@@ -1336,9 +1184,9 @@ dependencies = [
[[package]]
name = "cmake"
-version = "0.1.52"
+version = "0.1.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e"
+checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6"
dependencies = [
"cc",
]
@@ -1389,9 +1237,9 @@ dependencies = [
[[package]]
name = "compact_str"
-version = "0.8.0"
+version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644"
+checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32"
dependencies = [
"castaway",
"cfg-if",
@@ -1456,16 +1304,6 @@ dependencies = [
"libc",
]
-[[package]]
-name = "core-foundation"
-version = "0.10.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63"
-dependencies = [
- "core-foundation-sys",
- "libc",
-]
-
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@@ -1479,7 +1317,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081"
dependencies = [
"bitflags 1.3.2",
- "core-foundation 0.9.4",
+ "core-foundation",
"core-graphics-types",
"foreign-types 0.5.0",
"libc",
@@ -1492,15 +1330,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf"
dependencies = [
"bitflags 1.3.2",
- "core-foundation 0.9.4",
+ "core-foundation",
"libc",
]
[[package]]
name = "cpufeatures"
-version = "0.2.16"
+version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3"
+checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
dependencies = [
"libc",
]
@@ -1578,7 +1416,7 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"crossterm_winapi",
"mio",
"parking_lot 0.12.3",
@@ -1599,9 +1437,9 @@ dependencies = [
[[package]]
name = "crunchy"
-version = "0.2.2"
+version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
+checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
[[package]]
name = "crypto-common"
@@ -1636,46 +1474,33 @@ dependencies = [
[[package]]
name = "cubecl"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
"cubecl-core",
"cubecl-cuda",
"cubecl-hip",
"cubecl-linalg",
- "cubecl-runtime 0.4.0",
+ "cubecl-reduce",
+ "cubecl-runtime",
"cubecl-wgpu",
"half",
]
[[package]]
name = "cubecl-common"
-version = "0.3.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "51d402af454241d28d303a4cf4d2a861fae18404d65964c31934f746a40a6cf4"
-dependencies = [
- "derive-new 0.6.0",
- "embassy-futures",
- "futures-lite",
- "getrandom",
- "log",
- "portable-atomic",
- "rand",
- "serde",
- "spin",
- "web-time",
-]
-
-[[package]]
-name = "cubecl-common"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
+ "bytemuck",
"derive-new 0.6.0",
+ "derive_more 1.0.0",
"embassy-futures",
"futures-lite",
- "getrandom",
+ "getrandom 0.2.15",
+ "half",
"log",
+ "num-traits",
"portable-atomic",
"rand",
"serde",
@@ -1685,13 +1510,15 @@ dependencies = [
[[package]]
name = "cubecl-core"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
+ "bitflags 2.8.0",
"bytemuck",
- "cubecl-common 0.4.0",
+ "cubecl-common",
+ "cubecl-ir",
"cubecl-macros",
- "cubecl-runtime 0.4.0",
+ "cubecl-runtime",
"derive-new 0.6.0",
"derive_more 1.0.0",
"half",
@@ -1704,13 +1531,13 @@ dependencies = [
[[package]]
name = "cubecl-cpp"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
"bytemuck",
- "cubecl-common 0.4.0",
+ "cubecl-common",
"cubecl-core",
- "cubecl-runtime 0.4.0",
+ "cubecl-runtime",
"derive-new 0.6.0",
"half",
"log",
@@ -1718,14 +1545,14 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
"bytemuck",
- "cubecl-common 0.4.0",
+ "cubecl-common",
"cubecl-core",
"cubecl-cpp",
- "cubecl-runtime 0.4.0",
+ "cubecl-runtime",
"cudarc",
"derive-new 0.6.0",
"half",
@@ -1734,15 +1561,15 @@ dependencies = [
[[package]]
name = "cubecl-hip"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
"bytemuck",
- "cubecl-common 0.4.0",
+ "cubecl-common",
"cubecl-core",
"cubecl-cpp",
"cubecl-hip-sys",
- "cubecl-runtime 0.4.0",
+ "cubecl-runtime",
"derive-new 0.6.0",
"half",
"log",
@@ -1751,86 +1578,104 @@ dependencies = [
[[package]]
name = "cubecl-hip-sys"
-version = "6.3.0"
+version = "6.3.1001"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9974218b3ff1f1e7b2f11ce254fd90b3ebcc2af6b4d084f7f6a0c351fb16112c"
+checksum = "c7e92df7f9feff6a469932fc4d4b349d28000af9e6f34e583eb4f8df70038d48"
dependencies = [
"libc",
]
+[[package]]
+name = "cubecl-ir"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
+dependencies = [
+ "cubecl-common",
+ "cubecl-macros-internal",
+ "derive_more 1.0.0",
+ "float-ord",
+ "fnv",
+ "half",
+ "hashbrown 0.14.5",
+ "num-traits",
+ "portable-atomic",
+ "serde",
+ "variadics_please",
+]
+
[[package]]
name = "cubecl-linalg"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
"bytemuck",
"cubecl-core",
- "cubecl-runtime 0.4.0",
+ "cubecl-runtime",
"half",
"serde",
]
[[package]]
name = "cubecl-macros"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
- "cubecl-common 0.4.0",
+ "cubecl-common",
"darling",
"derive-new 0.6.0",
"ident_case",
"prettyplease",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
+]
+
+[[package]]
+name = "cubecl-macros-internal"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
+dependencies = [
+ "darling",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.98",
]
[[package]]
name = "cubecl-opt"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
- "cubecl-common 0.4.0",
- "cubecl-core",
+ "cubecl-common",
+ "cubecl-ir",
"float-ord",
"log",
"num",
"petgraph",
"smallvec",
"stable-vec",
+ "type-map",
]
[[package]]
-name = "cubecl-runtime"
-version = "0.3.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3468467f412dff4bbf97fb5061a3557445f017299e2fb73ef7b96c6cdb799bc3"
+name = "cubecl-reduce"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
- "async-channel",
- "async-lock",
- "cfg_aliases 0.2.1",
- "cubecl-common 0.3.0",
- "derive-new 0.6.0",
- "dirs",
- "hashbrown 0.14.5",
- "log",
- "md5",
- "sanitize-filename 0.5.0",
- "serde",
- "serde_json",
- "spin",
- "wasm-bindgen-futures",
+ "cubecl-core",
+ "cubecl-runtime",
+ "num-traits",
]
[[package]]
name = "cubecl-runtime"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
"async-channel",
"async-lock",
- "cfg_aliases 0.2.1",
- "cubecl-common 0.4.0",
+ "cfg_aliases",
+ "cubecl-common",
"derive-new 0.6.0",
"dirs",
"hashbrown 0.14.5",
@@ -1840,18 +1685,20 @@ dependencies = [
"serde",
"serde_json",
"spin",
+ "variadics_please",
"wasm-bindgen-futures",
]
[[package]]
name = "cubecl-spirv"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
- "cubecl-common 0.4.0",
+ "bitflags 2.8.0",
+ "cubecl-common",
"cubecl-core",
"cubecl-opt",
- "cubecl-runtime 0.4.0",
+ "cubecl-runtime",
"half",
"hashbrown 0.14.5",
"rspirv",
@@ -1859,17 +1706,17 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
-version = "0.4.0"
-source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf"
+version = "0.5.0"
+source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b"
dependencies = [
"ash",
"async-channel",
"bytemuck",
"cfg-if",
- "cfg_aliases 0.2.1",
- "cubecl-common 0.4.0",
+ "cfg_aliases",
+ "cubecl-common",
"cubecl-core",
- "cubecl-runtime 0.4.0",
+ "cubecl-runtime",
"cubecl-spirv",
"derive-new 0.6.0",
"hashbrown 0.14.5",
@@ -1880,9 +1727,9 @@ dependencies = [
[[package]]
name = "cudarc"
-version = "0.12.2"
+version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25"
+checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384"
dependencies = [
"half",
"libloading",
@@ -1890,17 +1737,17 @@ dependencies = [
[[package]]
name = "custom-csv-dataset"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"csv",
- "reqwest 0.12.12",
+ "reqwest",
"serde",
]
[[package]]
name = "custom-cubecl-kernel"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"burn-jit",
@@ -1913,7 +1760,7 @@ dependencies = [
[[package]]
name = "custom-image-dataset"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"flate2",
@@ -1922,7 +1769,7 @@ dependencies = [
[[package]]
name = "custom-renderer"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"bytemuck",
@@ -1934,7 +1781,7 @@ dependencies = [
[[package]]
name = "custom-training-loop"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"bytemuck",
@@ -1946,7 +1793,7 @@ dependencies = [
[[package]]
name = "custom-wgpu-kernel"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"bytemuck",
@@ -1976,8 +1823,8 @@ dependencies = [
"ident_case",
"proc-macro2",
"quote",
- "strsim 0.11.1",
- "syn 2.0.95",
+ "strsim",
+ "syn 2.0.98",
]
[[package]]
@@ -1988,7 +1835,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2007,9 +1854,9 @@ dependencies = [
[[package]]
name = "data-encoding"
-version = "2.6.0"
+version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
+checksum = "0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f"
[[package]]
name = "deflate64"
@@ -2034,7 +1881,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2045,7 +1892,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2056,7 +1903,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2077,7 +1924,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2087,7 +1934,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2098,7 +1945,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2118,7 +1965,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
"unicode-xid",
]
@@ -2174,15 +2021,9 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
-[[package]]
-name = "doc-comment"
-version = "0.3.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
-
[[package]]
name = "document-features"
version = "0.2.10"
@@ -2204,9 +2045,9 @@ dependencies = [
[[package]]
name = "dyn-clone"
-version = "1.0.17"
+version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
+checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35"
[[package]]
name = "dyn-stack"
@@ -2223,9 +2064,6 @@ name = "either"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
-dependencies = [
- "serde",
-]
[[package]]
name = "embassy-futures"
@@ -2254,10 +2092,10 @@ version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc"
dependencies = [
- "heck 0.5.0",
+ "heck",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2269,14 +2107,14 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "env_filter"
-version = "0.1.2"
+version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab"
+checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
dependencies = [
"log",
"regex",
@@ -2284,9 +2122,9 @@ dependencies = [
[[package]]
name = "env_logger"
-version = "0.11.5"
+version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d"
+checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
dependencies = [
"anstream",
"anstyle",
@@ -2331,9 +2169,9 @@ checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c"
[[package]]
name = "event-listener"
-version = "5.3.1"
+version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba"
+checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae"
dependencies = [
"concurrent-queue",
"parking",
@@ -2367,9 +2205,9 @@ dependencies = [
[[package]]
name = "fake"
-version = "3.0.1"
+version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "661cb0601b5f4050d1e65452c5b0ea555c0b3e88fb5ed7855906adc6c42523ef"
+checksum = "aef603df4ba9adbca6a332db7da6f614f21eafefbaf8e087844e452fdec152d0"
dependencies = [
"deunicode",
"rand",
@@ -2388,10 +2226,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
-name = "fast-float"
-version = "0.2.0"
+name = "fast-float2"
+version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c"
+checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55"
[[package]]
name = "faster-hex"
@@ -2468,9 +2306,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
-version = "0.1.3"
+version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
+checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
[[package]]
name = "foreign-types"
@@ -2499,7 +2337,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2523,16 +2361,6 @@ dependencies = [
"percent-encoding",
]
-[[package]]
-name = "fs4"
-version = "0.9.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e8c6b3bd49c37d2aa3f3f2220233b29a7cd23f79d1fe70e5337d25fb390793de"
-dependencies = [
- "rustix",
- "windows-sys 0.52.0",
-]
-
[[package]]
name = "futures"
version = "0.3.31"
@@ -2583,9 +2411,9 @@ checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-lite"
-version = "2.5.0"
+version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cef40d21ae2c515b51041df9ed313ed21e572df340ea58a922a0aefe7e8891a1"
+checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532"
dependencies = [
"fastrand",
"futures-core",
@@ -2602,7 +2430,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -2788,10 +2616,22 @@ dependencies = [
"cfg-if",
"js-sys",
"libc",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
+[[package]]
+name = "getrandom"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "wasi 0.13.3+wasi-0.2.2",
+ "windows-targets 0.52.6",
+]
+
[[package]]
name = "gif"
version = "0.13.1"
@@ -2808,20 +2648,6 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
-[[package]]
-name = "github-device-flow"
-version = "0.2.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "98852ab71f5613dac02a0d1b41f3ffaf993b69449904dd13a10575612a56074d"
-dependencies = [
- "chrono",
- "clap 3.2.25",
- "reqwest 0.11.27",
- "serde",
- "serde_derive",
- "serde_json",
-]
-
[[package]]
name = "gix-features"
version = "0.39.1"
@@ -2836,9 +2662,9 @@ dependencies = [
[[package]]
name = "gix-fs"
-version = "0.12.0"
+version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "34740384d8d763975858fa2c176b68652a6fcc09f616e24e3ce967b0d370e4d8"
+checksum = "3b3d4fac505a621f97e5ce2c69fdc425742af00c0920363ca4074f0eb48b1db9"
dependencies = [
"fastrand",
"gix-features",
@@ -2852,7 +2678,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b5eccc17194ed0e67d49285e4853307e4147e95407f91c1c3e4a13ba9f4e4ce"
dependencies = [
"faster-hex",
- "thiserror 2.0.9",
+ "thiserror 2.0.11",
]
[[package]]
@@ -2873,15 +2699,15 @@ dependencies = [
[[package]]
name = "gix-trace"
-version = "0.1.11"
+version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "04bdde120c29f1fc23a24d3e115aeeea3d60d8e65bab92cc5f9d90d9302eb952"
+checksum = "7c396a2036920c69695f760a65e7f2677267ccf483f25046977d87e4cb2665f7"
[[package]]
name = "gix-utils"
-version = "0.1.13"
+version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ba427e3e9599508ed98a6ddf8ed05493db114564e338e41f6a996d2e4790335f"
+checksum = "ff08f24e03ac8916c478c8419d7d3c33393da9bb41fa4c24455d5406aeefd35f"
dependencies = [
"fastrand",
"unicode-normalization",
@@ -2900,9 +2726,9 @@ dependencies = [
[[package]]
name = "glob"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
+checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "globset"
@@ -2923,16 +2749,16 @@ version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"ignore",
"walkdir",
]
[[package]]
name = "glow"
-version = "0.14.2"
+version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d51fa363f025f5c111e03f13eda21162faeacb6911fe8caa0c0349f9cf0c4483"
+checksum = "c5e5ea60d70410161c8bf5da3fdfeaa1c72ed2c15f8bbb9d19fe3a4fad085f08"
dependencies = [
"js-sys",
"slotmap",
@@ -2942,9 +2768,9 @@ dependencies = [
[[package]]
name = "glutin_wgl_sys"
-version = "0.6.0"
+version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0a4e1951bbd9434a81aa496fe59ccc2235af3820d27b85f9314e279609211e2c"
+checksum = "2c4ee00b289aba7a9e5306d57c2d05499b2e5dc427f84ac708bd2c090212cf3e"
dependencies = [
"gl_generator",
]
@@ -2955,7 +2781,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"gpu-alloc-types",
]
@@ -2965,7 +2791,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
]
[[package]]
@@ -2982,13 +2808,13 @@ dependencies = [
[[package]]
name = "gpu-descriptor"
-version = "0.3.0"
+version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9c08c1f623a8d0b722b8b99f821eb0ba672a1618f0d3b16ddbee1cedd2dd8557"
+checksum = "dcf29e94d6d243368b7a56caa16bc213e4f9f8ed38c4d9557069527b5d5281ca"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"gpu-descriptor-types",
- "hashbrown 0.14.5",
+ "hashbrown 0.15.2",
]
[[package]]
@@ -2997,37 +2823,18 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
]
[[package]]
name = "guide"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"log",
"serde",
]
-[[package]]
-name = "h2"
-version = "0.3.26"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8"
-dependencies = [
- "bytes",
- "fnv",
- "futures-core",
- "futures-sink",
- "futures-util",
- "http 0.2.12",
- "indexmap 2.7.0",
- "slab",
- "tokio",
- "tokio-util",
- "tracing",
-]
-
[[package]]
name = "h2"
version = "0.4.7"
@@ -3039,8 +2846,8 @@ dependencies = [
"fnv",
"futures-core",
"futures-sink",
- "http 1.2.0",
- "indexmap 2.7.0",
+ "http",
+ "indexmap",
"slab",
"tokio",
"tokio-util",
@@ -3062,22 +2869,6 @@ dependencies = [
"serde",
]
-[[package]]
-name = "halfbrown"
-version = "0.2.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f"
-dependencies = [
- "hashbrown 0.14.5",
- "serde",
-]
-
-[[package]]
-name = "hashbrown"
-version = "0.12.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
-
[[package]]
name = "hashbrown"
version = "0.13.2"
@@ -3121,27 +2912,12 @@ dependencies = [
"hashbrown 0.14.5",
]
-[[package]]
-name = "heck"
-version = "0.4.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
-
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
-[[package]]
-name = "hermit-abi"
-version = "0.1.19"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
-dependencies = [
- "libc",
-]
-
[[package]]
name = "hermit-abi"
version = "0.3.9"
@@ -3201,17 +2977,6 @@ version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
-[[package]]
-name = "http"
-version = "0.2.12"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
-dependencies = [
- "bytes",
- "fnv",
- "itoa",
-]
-
[[package]]
name = "http"
version = "1.2.0"
@@ -3223,17 +2988,6 @@ dependencies = [
"itoa",
]
-[[package]]
-name = "http-body"
-version = "0.4.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
-dependencies = [
- "bytes",
- "http 0.2.12",
- "pin-project-lite",
-]
-
[[package]]
name = "http-body"
version = "1.0.1"
@@ -3241,7 +2995,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
- "http 1.2.0",
+ "http",
]
[[package]]
@@ -3252,16 +3006,16 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f"
dependencies = [
"bytes",
"futures-util",
- "http 1.2.0",
- "http-body 1.0.1",
+ "http",
+ "http-body",
"pin-project-lite",
]
[[package]]
name = "httparse"
-version = "1.9.5"
+version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946"
+checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a"
[[package]]
name = "httpdate"
@@ -3277,40 +3031,16 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
-version = "0.14.32"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
-dependencies = [
- "bytes",
- "futures-channel",
- "futures-core",
- "futures-util",
- "h2 0.3.26",
- "http 0.2.12",
- "http-body 0.4.6",
- "httparse",
- "httpdate",
- "itoa",
- "pin-project-lite",
- "socket2",
- "tokio",
- "tower-service",
- "tracing",
- "want",
-]
-
-[[package]]
-name = "hyper"
-version = "1.5.2"
+version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0"
+checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
- "h2 0.4.7",
- "http 1.2.0",
- "http-body 1.0.1",
+ "h2",
+ "http",
+ "http-body",
"httparse",
"httpdate",
"itoa",
@@ -3322,35 +3052,21 @@ dependencies = [
[[package]]
name = "hyper-rustls"
-version = "0.27.3"
+version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333"
+checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
- "http 1.2.0",
- "hyper 1.5.2",
+ "http",
+ "hyper",
"hyper-util",
"rustls",
- "rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
-[[package]]
-name = "hyper-tls"
-version = "0.5.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
-dependencies = [
- "bytes",
- "hyper 0.14.32",
- "native-tls",
- "tokio",
- "tokio-native-tls",
-]
-
[[package]]
name = "hyper-tls"
version = "0.6.0"
@@ -3359,7 +3075,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
- "hyper 1.5.2",
+ "hyper",
"hyper-util",
"native-tls",
"tokio",
@@ -3376,9 +3092,9 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
- "http 1.2.0",
- "http-body 1.0.1",
- "hyper 1.5.2",
+ "http",
+ "http-body",
+ "hyper",
"pin-project-lite",
"socket2",
"tokio",
@@ -3524,7 +3240,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -3595,13 +3311,12 @@ dependencies = [
[[package]]
name = "image-classification-web"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"burn-candle",
"burn-import",
"console_error_panic_hook",
- "cubecl-runtime 0.3.0",
"js-sys",
"log",
"serde",
@@ -3615,9 +3330,9 @@ dependencies = [
[[package]]
name = "image-webp"
-version = "0.2.0"
+version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f"
+checksum = "b77d01e822461baa8409e156015a1d91735549f0f2c17691bd2d996bef238f7f"
dependencies = [
"byteorder-lite",
"quick-error",
@@ -3631,19 +3346,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
[[package]]
name = "indexmap"
-version = "1.9.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
-dependencies = [
- "autocfg",
- "hashbrown 0.12.3",
-]
-
-[[package]]
-name = "indexmap"
-version = "2.7.0"
+version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f"
+checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
dependencies = [
"equivalent",
"hashbrown 0.15.2",
@@ -3652,9 +3357,9 @@ dependencies = [
[[package]]
name = "indicatif"
-version = "0.17.9"
+version = "0.17.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281"
+checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
dependencies = [
"console",
"number_prefix",
@@ -3680,16 +3385,15 @@ dependencies = [
[[package]]
name = "instability"
-version = "0.3.3"
+version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b829f37dead9dc39df40c2d3376c179fdfd2ac771f53f55d3c30dc096a3c0c6e"
+checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d"
dependencies = [
"darling",
"indoc",
- "pretty_assertions",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -3709,14 +3413,14 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "ipnet"
-version = "2.10.1"
+version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
+checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
@@ -3757,12 +3461,6 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
-[[package]]
-name = "itoap"
-version = "1.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8"
-
[[package]]
name = "jni-sys"
version = "0.3.0"
@@ -3786,9 +3484,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0"
[[package]]
name = "js-sys"
-version = "0.3.76"
+version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7"
+checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
dependencies = [
"once_cell",
"wasm-bindgen",
@@ -3825,15 +3523,15 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
-version = "0.2.168"
+version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d"
+checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libfuzzer-sys"
-version = "0.4.8"
+version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa"
+checksum = "cf78f52d400cf2d84a3a973a78a592b4adc535739e0a5597a0da6f0c357adc75"
dependencies = [
"arbitrary",
"cc",
@@ -3846,7 +3544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
- "windows-targets 0.52.6",
+ "windows-targets 0.48.5",
]
[[package]]
@@ -3861,7 +3559,7 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"libc",
"redox_syscall 0.5.8",
]
@@ -3879,9 +3577,9 @@ dependencies = [
[[package]]
name = "linux-raw-sys"
-version = "0.4.14"
+version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
+checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "litemap"
@@ -3913,9 +3611,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e"
[[package]]
name = "log"
-version = "0.4.22"
+version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
+checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]]
name = "loop9"
@@ -3937,9 +3635,9 @@ dependencies = [
[[package]]
name = "lz4"
-version = "1.28.0"
+version = "1.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725"
+checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4"
dependencies = [
"lz4-sys",
]
@@ -4000,9 +3698,9 @@ dependencies = [
[[package]]
name = "matchit"
-version = "0.7.3"
+version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
+checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrixmultiply"
@@ -4027,16 +3725,6 @@ dependencies = [
"rayon",
]
-[[package]]
-name = "md-5"
-version = "0.10.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf"
-dependencies = [
- "cfg-if",
- "digest",
-]
-
[[package]]
name = "md5"
version = "0.7.0"
@@ -4049,15 +3737,6 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
-[[package]]
-name = "memmap2"
-version = "0.7.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6"
-dependencies = [
- "libc",
-]
-
[[package]]
name = "memmap2"
version = "0.9.5"
@@ -4069,21 +3748,27 @@ dependencies = [
]
[[package]]
-name = "memoffset"
-version = "0.9.1"
+name = "metal"
+version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
+checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25"
dependencies = [
- "autocfg",
+ "bitflags 2.8.0",
+ "block",
+ "core-graphics-types",
+ "foreign-types 0.5.0",
+ "log",
+ "objc",
+ "paste",
]
[[package]]
name = "metal"
-version = "0.27.0"
+version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25"
+checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block",
"core-graphics-types",
"foreign-types 0.5.0",
@@ -4094,11 +3779,11 @@ dependencies = [
[[package]]
name = "metal"
-version = "0.29.0"
+version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
+checksum = "f569fb946490b5743ad69813cb19629130ce9374034abe31614a36402d18f99e"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block",
"core-graphics-types",
"foreign-types 0.5.0",
@@ -4121,9 +3806,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
-version = "0.8.2"
+version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394"
+checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924"
dependencies = [
"adler2",
"simd-adler32",
@@ -4137,13 +3822,13 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
dependencies = [
"libc",
"log",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
[[package]]
name = "mnist"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"log",
@@ -4152,11 +3837,10 @@ dependencies = [
[[package]]
name = "mnist-inference-web"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"console_error_panic_hook",
- "cubecl-runtime 0.3.0",
"js-sys",
"serde",
"wasm-bindgen",
@@ -4165,12 +3849,23 @@ dependencies = [
[[package]]
name = "model"
-version = "0.5.0"
+version = "0.6.0"
dependencies = [
"burn",
"burn-import",
]
+[[package]]
+name = "modern-lstm"
+version = "0.1.0"
+dependencies = [
+ "burn",
+ "polars",
+ "rand",
+ "rand_distr",
+ "serde",
+]
+
[[package]]
name = "monostate"
version = "0.1.13"
@@ -4189,55 +3884,34 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
-]
-
-[[package]]
-name = "multiversion"
-version = "0.7.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142"
-dependencies = [
- "multiversion-macros",
- "target-features",
-]
-
-[[package]]
-name = "multiversion-macros"
-version = "0.7.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90"
-dependencies = [
- "proc-macro2",
- "quote",
- "syn 1.0.109",
- "target-features",
+ "syn 2.0.98",
]
[[package]]
name = "naga"
-version = "23.1.0"
+version = "24.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f"
+checksum = "e380993072e52eef724eddfcde0ed013b0c023c3f0417336ed041aa9f076994e"
dependencies = [
"arrayvec",
"bit-set",
- "bitflags 2.6.0",
- "cfg_aliases 0.1.1",
+ "bitflags 2.8.0",
+ "cfg_aliases",
"codespan-reporting",
"hexf-parse",
- "indexmap 2.7.0",
+ "indexmap",
"log",
- "rustc-hash 1.1.0",
- "spirv",
+ "rustc-hash",
+ "spirv 0.3.0+sdk-1.3.268.0",
+ "strum",
"termcolor",
- "thiserror 1.0.69",
+ "thiserror 2.0.11",
"unicode-xid",
]
[[package]]
name = "named-tensor"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"serde",
@@ -4245,9 +3919,9 @@ dependencies = [
[[package]]
name = "native-tls"
-version = "0.2.12"
+version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
+checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c"
dependencies = [
"libc",
"log",
@@ -4255,7 +3929,7 @@ dependencies = [
"openssl-probe",
"openssl-sys",
"schannel",
- "security-framework 2.11.1",
+ "security-framework",
"security-framework-sys",
"tempfile",
]
@@ -4413,7 +4087,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -4463,7 +4137,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
- "hermit-abi 0.3.9",
+ "hermit-abi",
"libc",
]
@@ -4485,7 +4159,7 @@ dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -4509,7 +4183,7 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c9bff0aa1d48904a1385ea2a8b97576fbdcbc9a3cfccd0d31fe978e1c4038c5"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"libloading",
"nvml-wrapper-sys",
"static_assertions",
@@ -4558,7 +4232,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block2",
"libc",
"objc2",
@@ -4574,7 +4248,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block2",
"objc2",
"objc2-foundation",
@@ -4594,9 +4268,9 @@ dependencies = [
[[package]]
name = "objc2-encode"
-version = "4.0.3"
+version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7891e71393cd1f227313c9379a26a584ff3d7e6e7159e988851f0934c993f0f8"
+checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
[[package]]
name = "objc2-foundation"
@@ -4604,7 +4278,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block2",
"libc",
"objc2",
@@ -4616,7 +4290,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block2",
"objc2",
"objc2-foundation",
@@ -4628,7 +4302,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block2",
"objc2",
"objc2-foundation",
@@ -4646,43 +4320,13 @@ dependencies = [
[[package]]
name = "object"
-version = "0.36.5"
+version = "0.36.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e"
+checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87"
dependencies = [
"memchr",
]
-[[package]]
-name = "object_store"
-version = "0.10.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3"
-dependencies = [
- "async-trait",
- "base64 0.22.1",
- "bytes",
- "chrono",
- "futures",
- "humantime",
- "hyper 1.5.2",
- "itertools 0.13.0",
- "md-5",
- "parking_lot 0.12.3",
- "percent-encoding",
- "quick-xml",
- "rand",
- "reqwest 0.12.12",
- "ring",
- "serde",
- "serde_json",
- "snafu",
- "tokio",
- "tracing",
- "url",
- "walkdir",
-]
-
[[package]]
name = "once_cell"
version = "1.20.2"
@@ -4713,7 +4357,7 @@ dependencies = [
[[package]]
name = "onnx-inference"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"burn-import",
@@ -4722,7 +4366,7 @@ dependencies = [
[[package]]
name = "onnx-ir"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"bytemuck",
"half",
@@ -4739,7 +4383,7 @@ dependencies = [
[[package]]
name = "onnx-tests"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"burn-import",
@@ -4750,16 +4394,16 @@ dependencies = [
[[package]]
name = "openblas-build"
-version = "0.10.10"
+version = "0.10.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8ca8f8c64eb5b43f5538059ccbc71391420bba14d987d7e8ab99ed62ed33e26b"
+checksum = "b8140c0c1afaf88d2d30c48abad86b3bdd2334d691e08f7325a960d784240647"
dependencies = [
"anyhow",
"cc",
"flate2",
"native-tls",
"tar",
- "thiserror 2.0.9",
+ "thiserror 2.0.11",
"ureq",
]
@@ -4777,11 +4421,11 @@ dependencies = [
[[package]]
name = "openssl"
-version = "0.10.68"
+version = "0.10.70"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5"
+checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
@@ -4798,20 +4442,20 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "openssl-probe"
-version = "0.1.5"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
+checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
-version = "0.9.104"
+version = "0.9.105"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741"
+checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc"
dependencies = [
"cc",
"libc",
@@ -4825,23 +4469,26 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
+[[package]]
+name = "ordered-float"
+version = "4.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
+dependencies = [
+ "num-traits",
+]
+
[[package]]
name = "os_info"
-version = "3.9.0"
+version = "3.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e5ca711d8b83edbb00b44d504503cd247c9c0bd8b0fa2694f2a1a3d8165379ce"
+checksum = "6e6520c8cc998c5741ee68ec1dc369fc47e5f0ea5320018ecf2a1ccd6328f48b"
dependencies = [
"log",
"serde",
"windows-sys 0.52.0",
]
-[[package]]
-name = "os_str_bytes"
-version = "6.6.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
-
[[package]]
name = "overload"
version = "0.1.1"
@@ -4963,23 +4610,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
- "indexmap 2.7.0",
+ "indexmap",
]
[[package]]
name = "phf"
-version = "0.11.2"
+version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
+checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
dependencies = [
"phf_shared",
]
[[package]]
name = "phf_codegen"
-version = "0.11.2"
+version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a"
+checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a"
dependencies = [
"phf_generator",
"phf_shared",
@@ -4987,9 +4634,9 @@ dependencies = [
[[package]]
name = "phf_generator"
-version = "0.11.2"
+version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0"
+checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
dependencies = [
"phf_shared",
"rand",
@@ -4997,18 +4644,18 @@ dependencies = [
[[package]]
name = "phf_shared"
-version = "0.11.2"
+version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b"
+checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
dependencies = [
"siphasher",
]
[[package]]
name = "pin-project-lite"
-version = "0.2.15"
+version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff"
+checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
@@ -5033,9 +4680,9 @@ dependencies = [
[[package]]
name = "png"
-version = "0.17.15"
+version = "0.17.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b67582bd5b65bdff614270e2ea89a1cf15bef71245cc1e5f7ea126977144211d"
+checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526"
dependencies = [
"bitflags 1.3.2",
"crc32fast",
@@ -5046,11 +4693,11 @@ dependencies = [
[[package]]
name = "polars"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f65c6aa86d991a64c95416a61202f7952da2f8cccefa448f9a23c1b8f2301ecc"
+checksum = "72571dde488ecccbe799798bf99ab7308ebdb7cf5d95bcc498dbd5a132f0da4d"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
"polars-arrow",
"polars-core",
"polars-error",
@@ -5066,12 +4713,11 @@ dependencies = [
[[package]]
name = "polars-arrow"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "87dbb24d29ddea5abb73d7954df8b8d3d4bb7f02a3e5c96d1519cdad9e816a3d"
+checksum = "6611c758d52e799761cc25900666b71552e6c929d88052811bc9daad4b3321a8"
dependencies = [
"ahash",
- "atoi",
"atoi_simd",
"bytemuck",
"chrono",
@@ -5079,21 +4725,16 @@ dependencies = [
"dyn-clone",
"either",
"ethnum",
- "fast-float",
- "getrandom",
+ "getrandom 0.2.15",
"hashbrown 0.15.2",
"itoa",
- "itoap",
"lz4",
- "multiversion",
"num-traits",
"parking_lot 0.12.3",
"polars-arrow-format",
"polars-error",
"polars-schema",
"polars-utils",
- "ryu",
- "serde",
"simdutf8",
"streaming-iterator",
"strength_reduce",
@@ -5114,28 +4755,33 @@ dependencies = [
[[package]]
name = "polars-compute"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cbdb1071147452a4c4b25560f23d2fbaffef255b04757291131b22fc2c0d35b2"
+checksum = "332f2547dbb27599a8ffe68e56159f5996ba03d1dad0382ccb62c109ceacdeb6"
dependencies = [
+ "atoi_simd",
"bytemuck",
+ "chrono",
"either",
+ "fast-float2",
+ "itoa",
"num-traits",
"polars-arrow",
"polars-error",
"polars-utils",
+ "ryu",
"strength_reduce",
"version_check",
]
[[package]]
name = "polars-core"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dd5df9b55e614088a3270b06f8649dce76537c268d6b1ca4d9c37008b2be5949"
+checksum = "796d06eae7e6e74ed28ea54a8fccc584ebac84e6cf0e1e9ba41ffc807b169a01"
dependencies = [
"ahash",
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"bytemuck",
"chrono",
"chrono-tz",
@@ -5143,7 +4789,8 @@ dependencies = [
"either",
"hashbrown 0.14.5",
"hashbrown 0.15.2",
- "indexmap 2.7.0",
+ "indexmap",
+ "itoa",
"num-traits",
"once_cell",
"polars-arrow",
@@ -5156,35 +4803,32 @@ dependencies = [
"rand_distr",
"rayon",
"regex",
- "serde",
- "serde_json",
"strum_macros",
- "thiserror 1.0.69",
+ "thiserror 2.0.11",
"version_check",
"xxhash-rust",
]
[[package]]
name = "polars-error"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4643898a644f30c83737db85f942f8c8956b0c11190b39afec745218eae1746b"
+checksum = "19d6529cae0d1db5ed690e47de41fac9b35ae0c26d476830c2079f130887b847"
dependencies = [
- "object_store",
"polars-arrow-format",
"regex",
"simdutf8",
- "thiserror 1.0.69",
+ "thiserror 2.0.11",
]
[[package]]
name = "polars-expr"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ea1b431ed816cba1120cff200f06b962748001bbb2e615ce53cfbbdf701cc136"
+checksum = "c8e639991a8ad4fb12880ab44bcc3cf44a5703df003142334d9caf86d77d77e7"
dependencies = [
"ahash",
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"hashbrown 0.15.2",
"num-traits",
"once_cell",
@@ -5203,80 +4847,50 @@ dependencies = [
[[package]]
name = "polars-io"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b2fab2c016635cb416b49461fd6419b0208c6c13a4fd065bd65e4a87dbb66314"
+checksum = "719a77e94480f6be090512da196e378cbcbeb3584c6fe1134c600aee906e38ab"
dependencies = [
"ahash",
"async-trait",
"atoi_simd",
- "blake3",
"bytes",
"chrono",
- "fast-float",
- "fs4",
+ "fast-float2",
"futures",
"glob",
"hashbrown 0.15.2",
"home",
"itoa",
"memchr",
- "memmap2 0.7.1",
+ "memmap2",
"num-traits",
- "object_store",
"once_cell",
"percent-encoding",
"polars-arrow",
"polars-core",
"polars-error",
- "polars-json",
"polars-parquet",
"polars-schema",
"polars-time",
"polars-utils",
- "pyo3",
"rayon",
"regex",
- "reqwest 0.12.12",
"ryu",
- "serde",
- "serde_json",
- "simd-json",
"simdutf8",
"tokio",
"tokio-util",
- "url",
-]
-
-[[package]]
-name = "polars-json"
-version = "0.44.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d5c8c057ef04feaf34b6ce52096bdea3a766fa4725f50442078c8a4ee86397bf"
-dependencies = [
- "ahash",
- "chrono",
- "fallible-streaming-iterator",
- "hashbrown 0.15.2",
- "indexmap 2.7.0",
- "itoa",
- "num-traits",
- "polars-arrow",
- "polars-error",
- "polars-utils",
- "ryu",
- "simd-json",
- "streaming-iterator",
]
[[package]]
name = "polars-lazy"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4a8ca74f42e7b47cad241b36b98d991cc7fbb51b8d0695a055eb937588d1f310"
+checksum = "a0a731a672dfc8ac38c1f73c9a4b2ae38d2fc8ac363bfb64c5f3a3e072ffc5ad"
dependencies = [
"ahash",
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
+ "chrono",
"memchr",
"once_cell",
"polars-arrow",
@@ -5296,32 +4910,28 @@ dependencies = [
[[package]]
name = "polars-mem-engine"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7a32614e5b52c9b83856d80c7e2880b79d83055bfd59969bd1d0b148f9cfdc7a"
+checksum = "33442189bcbf2e2559aa7914db3835429030a13f4f18e43af5fba9d1b018cf12"
dependencies = [
- "futures",
- "memmap2 0.7.1",
+ "memmap2",
"polars-arrow",
"polars-core",
"polars-error",
"polars-expr",
"polars-io",
- "polars-json",
"polars-ops",
"polars-plan",
"polars-time",
"polars-utils",
- "pyo3",
"rayon",
- "tokio",
]
[[package]]
name = "polars-ops"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "035c800fbe5bbd820afeb8313713ed345853bb014e0f821a4025d40cf0d60e1a"
+checksum = "cbb83218b0c216104f0076cd1a005128be078f958125f3d59b094ee73d78c18e"
dependencies = [
"ahash",
"argminmax",
@@ -5332,9 +4942,10 @@ dependencies = [
"either",
"hashbrown 0.15.2",
"hex",
- "indexmap 2.7.0",
+ "indexmap",
"memchr",
"num-traits",
+ "once_cell",
"polars-arrow",
"polars-compute",
"polars-core",
@@ -5344,39 +4955,33 @@ dependencies = [
"rayon",
"regex",
"regex-syntax 0.8.5",
- "serde",
"strum_macros",
+ "unicode-normalization",
"unicode-reverse",
"version_check",
]
[[package]]
name = "polars-parquet"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "91dcf1d9f048079376949eaf2e24e240b313ff4a102fb83b57c9a5f807cdca52"
+checksum = "5c60ee85535590a38db6c703a21be4cb25342e40f573f070d1e16f9d84a53ac7"
dependencies = [
"ahash",
"async-stream",
"base64 0.22.1",
- "brotli",
"bytemuck",
"ethnum",
- "flate2",
"futures",
"hashbrown 0.15.2",
- "lz4",
"num-traits",
"polars-arrow",
"polars-compute",
"polars-error",
"polars-parquet-format",
"polars-utils",
- "serde",
"simdutf8",
- "snap",
"streaming-decompression",
- "zstd 0.13.2",
]
[[package]]
@@ -5391,15 +4996,16 @@ dependencies = [
[[package]]
name = "polars-pipe"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "05936f2b3981eecb2fe74d8ef092bb75a93d2a056b3e4f339f4ac20c71c9e331"
+checksum = "42d238fb76698f56e51ddfa89b135e4eda56a4767c6e8859eed0ab78386fcd52"
dependencies = [
"crossbeam-channel",
"crossbeam-queue",
"enum_dispatch",
"hashbrown 0.15.2",
"num-traits",
+ "once_cell",
"polars-arrow",
"polars-compute",
"polars-core",
@@ -5416,75 +5022,69 @@ dependencies = [
[[package]]
name = "polars-plan"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "23de436f33f4d1134c58f24e7059a221b957ec20730807e0ef0c80c8e4b3d06a"
+checksum = "4f03533a93aa66127fcb909a87153a3c7cfee6f0ae59f497e73d7736208da54c"
dependencies = [
"ahash",
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"bytemuck",
"bytes",
"chrono",
"chrono-tz",
- "ciborium",
"either",
- "futures",
"hashbrown 0.15.2",
- "memmap2 0.7.1",
+ "memmap2",
"num-traits",
"once_cell",
"percent-encoding",
"polars-arrow",
+ "polars-compute",
"polars-core",
"polars-io",
- "polars-json",
"polars-ops",
- "polars-parquet",
"polars-time",
"polars-utils",
- "pyo3",
"rayon",
"recursive",
"regex",
- "serde",
"strum_macros",
"version_check",
]
[[package]]
name = "polars-row"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3823d3de3e614509bba6929798f1f3d5ae05c1cdfc4eb7029d2ec6ad77201da2"
+checksum = "6bf47f7409f8e75328d7d034be390842924eb276716d0458607be0bddb8cc839"
dependencies = [
+ "bitflags 2.8.0",
"bytemuck",
"polars-arrow",
+ "polars-compute",
"polars-error",
"polars-utils",
]
[[package]]
name = "polars-schema"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d88667f770291cefa2e8cd366a54f29dc6fe362e9a263914c903db411a58ac1d"
+checksum = "416621ae82b84466cf4ff36838a9b0aeb4a67e76bd3065edc8c9cb7da19b1bc7"
dependencies = [
- "indexmap 2.7.0",
+ "indexmap",
"polars-error",
"polars-utils",
- "serde",
"version_check",
]
[[package]]
name = "polars-sql"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "69451f08363bb497407f6ebebe00bc01972a51716d20d115b75f9b5326f1f3c8"
+checksum = "edaab553b90aa4d6743bb538978e1982368acb58a94408d7dd3299cad49c7083"
dependencies = [
"hex",
- "once_cell",
- "polars-arrow",
"polars-core",
"polars-error",
"polars-lazy",
@@ -5493,22 +5093,22 @@ dependencies = [
"polars-time",
"polars-utils",
"rand",
+ "regex",
"serde",
- "serde_json",
"sqlparser",
]
[[package]]
name = "polars-stream"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "188622b0a4bc4530cf91a288134254ffa065d18932e261075377914225e757c2"
+checksum = "498997b656c779610c1496b3d96a59fe569ef22a5b81ccfe5325cb3df8dff2fd"
dependencies = [
"atomic-waker",
"crossbeam-deque",
"crossbeam-utils",
"futures",
- "memmap2 0.7.1",
+ "memmap2",
"parking_lot 0.12.3",
"pin-project-lite",
"polars-core",
@@ -5516,6 +5116,7 @@ dependencies = [
"polars-expr",
"polars-io",
"polars-mem-engine",
+ "polars-ops",
"polars-parquet",
"polars-plan",
"polars-utils",
@@ -5529,49 +5130,50 @@ dependencies = [
[[package]]
name = "polars-time"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "90f36e4d6b19f2c406faea585b9a1814f422fc5b310f65ccf8a55216df0754ef"
+checksum = "d192efbdab516d28b3fab1709a969e3385bd5cda050b7c9aa9e2502a01fda879"
dependencies = [
- "atoi",
+ "atoi_simd",
"bytemuck",
"chrono",
"chrono-tz",
"now",
+ "num-traits",
"once_cell",
"polars-arrow",
+ "polars-compute",
"polars-core",
"polars-error",
"polars-ops",
"polars-utils",
+ "rayon",
"regex",
- "serde",
"strum_macros",
]
[[package]]
name = "polars-utils"
-version = "0.44.2"
+version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "96186b70bda00c90b5027bf2f69193c5c40571e80d3e8ec505c22cdc8e3e39aa"
+checksum = "a8f6c8166a4a7fbc15b87c81645ed9e1f0651ff2e8c96cafc40ac5bf43441a10"
dependencies = [
"ahash",
"bytemuck",
"bytes",
"compact_str",
"hashbrown 0.15.2",
- "indexmap 2.7.0",
+ "indexmap",
"libc",
- "memmap2 0.7.1",
+ "memmap2",
"num-traits",
"once_cell",
"polars-error",
- "pyo3",
- "raw-cpuid 11.2.0",
+ "rand",
+ "raw-cpuid 11.3.0",
"rayon",
- "serde",
"stacker",
- "sysinfo 0.31.4",
+ "sysinfo",
"version_check",
]
@@ -5580,6 +5182,9 @@ name = "portable-atomic"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6"
+dependencies = [
+ "serde",
+]
[[package]]
name = "portable-atomic-util"
@@ -5623,12 +5228,12 @@ dependencies = [
[[package]]
name = "prettyplease"
-version = "0.2.25"
+version = "0.2.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033"
+checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac"
dependencies = [
"proc-macro2",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -5640,35 +5245,11 @@ dependencies = [
"toml_edit",
]
-[[package]]
-name = "proc-macro-error"
-version = "1.0.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
-dependencies = [
- "proc-macro-error-attr",
- "proc-macro2",
- "quote",
- "syn 1.0.109",
- "version_check",
-]
-
-[[package]]
-name = "proc-macro-error-attr"
-version = "1.0.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
-dependencies = [
- "proc-macro2",
- "quote",
- "version_check",
-]
-
[[package]]
name = "proc-macro2"
-version = "1.0.92"
+version = "1.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
+checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99"
dependencies = [
"unicode-ident",
]
@@ -5689,7 +5270,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30"
dependencies = [
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -5726,7 +5307,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322330e133eab455718444b4e033ebfac7c6528972c784fcde28d2cc783c6257"
dependencies = [
"anyhow",
- "indexmap 2.7.0",
+ "indexmap",
"log",
"protobuf",
"protobuf-support",
@@ -5765,72 +5346,9 @@ dependencies = [
"reborrow",
]
-[[package]]
-name = "pyo3"
-version = "0.21.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8"
-dependencies = [
- "cfg-if",
- "indoc",
- "libc",
- "memoffset",
- "parking_lot 0.12.3",
- "portable-atomic",
- "pyo3-build-config",
- "pyo3-ffi",
- "pyo3-macros",
- "unindent",
-]
-
-[[package]]
-name = "pyo3-build-config"
-version = "0.21.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50"
-dependencies = [
- "once_cell",
- "target-lexicon",
-]
-
-[[package]]
-name = "pyo3-ffi"
-version = "0.21.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403"
-dependencies = [
- "libc",
- "pyo3-build-config",
-]
-
-[[package]]
-name = "pyo3-macros"
-version = "0.21.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c"
-dependencies = [
- "proc-macro2",
- "pyo3-macros-backend",
- "quote",
- "syn 2.0.95",
-]
-
-[[package]]
-name = "pyo3-macros-backend"
-version = "0.21.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c"
-dependencies = [
- "heck 0.4.1",
- "proc-macro2",
- "pyo3-build-config",
- "quote",
- "syn 2.0.95",
-]
-
[[package]]
name = "pytorch-import"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"burn-import",
@@ -5839,7 +5357,7 @@ dependencies = [
[[package]]
name = "pytorch-tests"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"burn-autodiff",
@@ -5864,68 +5382,6 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
-[[package]]
-name = "quick-xml"
-version = "0.36.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe"
-dependencies = [
- "memchr",
- "serde",
-]
-
-[[package]]
-name = "quinn"
-version = "0.11.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
-dependencies = [
- "bytes",
- "pin-project-lite",
- "quinn-proto",
- "quinn-udp",
- "rustc-hash 2.1.0",
- "rustls",
- "socket2",
- "thiserror 2.0.9",
- "tokio",
- "tracing",
-]
-
-[[package]]
-name = "quinn-proto"
-version = "0.11.9"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
-dependencies = [
- "bytes",
- "getrandom",
- "rand",
- "ring",
- "rustc-hash 2.1.0",
- "rustls",
- "rustls-pki-types",
- "slab",
- "thiserror 2.0.9",
- "tinyvec",
- "tracing",
- "web-time",
-]
-
-[[package]]
-name = "quinn-udp"
-version = "0.5.8"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527"
-dependencies = [
- "cfg_aliases 0.2.1",
- "libc",
- "once_cell",
- "socket2",
- "tracing",
- "windows-sys 0.59.0",
-]
-
[[package]]
name = "quote"
version = "1.0.38"
@@ -5984,7 +5440,7 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
]
[[package]]
@@ -5999,9 +5455,9 @@ dependencies = [
[[package]]
name = "range-alloc"
-version = "0.1.3"
+version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab"
+checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde"
[[package]]
name = "ratatui"
@@ -6009,7 +5465,7 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"cassowary",
"compact_str",
"crossterm",
@@ -6086,11 +5542,11 @@ dependencies = [
[[package]]
name = "raw-cpuid"
-version = "11.2.0"
+version = "11.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0"
+checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
]
[[package]]
@@ -6159,7 +5615,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b"
dependencies = [
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -6177,7 +5633,7 @@ version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
]
[[package]]
@@ -6186,31 +5642,11 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
"libredox",
"thiserror 1.0.69",
]
-[[package]]
-name = "ref-cast"
-version = "1.0.23"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931"
-dependencies = [
- "ref-cast-impl",
-]
-
-[[package]]
-name = "ref-cast-impl"
-version = "1.0.23"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6"
-dependencies = [
- "proc-macro2",
- "quote",
- "syn 2.0.95",
-]
-
[[package]]
name = "regex"
version = "1.11.1"
@@ -6267,46 +5703,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832"
-[[package]]
-name = "reqwest"
-version = "0.11.27"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
-dependencies = [
- "base64 0.21.7",
- "bytes",
- "encoding_rs",
- "futures-core",
- "futures-util",
- "h2 0.3.26",
- "http 0.2.12",
- "http-body 0.4.6",
- "hyper 0.14.32",
- "hyper-tls 0.5.0",
- "ipnet",
- "js-sys",
- "log",
- "mime",
- "native-tls",
- "once_cell",
- "percent-encoding",
- "pin-project-lite",
- "rustls-pemfile 1.0.4",
- "serde",
- "serde_json",
- "serde_urlencoded",
- "sync_wrapper 0.1.2",
- "system-configuration 0.5.1",
- "tokio",
- "tokio-native-tls",
- "tower-service",
- "url",
- "wasm-bindgen",
- "wasm-bindgen-futures",
- "web-sys",
- "winreg",
-]
-
[[package]]
name = "reqwest"
version = "0.12.12"
@@ -6319,13 +5715,13 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
- "h2 0.4.7",
- "http 1.2.0",
- "http-body 1.0.1",
+ "h2",
+ "http",
+ "http-body",
"http-body-util",
- "hyper 1.5.2",
+ "hyper",
"hyper-rustls",
- "hyper-tls 0.6.0",
+ "hyper-tls",
"hyper-util",
"ipnet",
"js-sys",
@@ -6335,26 +5731,19 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
- "quinn",
- "rustls",
- "rustls-native-certs 0.8.1",
- "rustls-pemfile 2.2.0",
- "rustls-pki-types",
+ "rustls-pemfile",
"serde",
"serde_json",
"serde_urlencoded",
- "sync_wrapper 1.0.2",
- "system-configuration 0.6.1",
+ "sync_wrapper",
+ "system-configuration",
"tokio",
"tokio-native-tls",
- "tokio-rustls",
- "tokio-util",
"tower",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
- "wasm-streams",
"web-sys",
"windows-registry",
]
@@ -6376,7 +5765,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if",
- "getrandom",
+ "getrandom 0.2.15",
"libc",
"spin",
"untrusted",
@@ -6407,12 +5796,11 @@ dependencies = [
[[package]]
name = "rspirv"
-version = "0.12.0+sdk-1.3.268.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d"
+version = "0.12.0+sdk-1.3.296.0"
+source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e"
dependencies = [
- "rustc-hash 1.1.0",
- "spirv",
+ "rustc-hash",
+ "spirv 0.3.0+sdk-1.3.296.0",
]
[[package]]
@@ -6441,7 +5829,7 @@ dependencies = [
"regex",
"relative-path",
"rustc_version",
- "syn 2.0.95",
+ "syn 2.0.98",
"unicode-ident",
]
@@ -6451,7 +5839,7 @@ version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",
@@ -6481,12 +5869,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
-[[package]]
-name = "rustc-hash"
-version = "2.1.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497"
-
[[package]]
name = "rustc_version"
version = "0.4.1"
@@ -6498,11 +5880,11 @@ dependencies = [
[[package]]
name = "rustix"
-version = "0.38.42"
+version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85"
+checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"errno",
"libc",
"linux-raw-sys",
@@ -6511,9 +5893,9 @@ dependencies = [
[[package]]
name = "rustls"
-version = "0.23.20"
+version = "0.23.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b"
+checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7"
dependencies = [
"log",
"once_cell",
@@ -6531,31 +5913,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5"
dependencies = [
"openssl-probe",
- "rustls-pemfile 2.2.0",
- "rustls-pki-types",
- "schannel",
- "security-framework 2.11.1",
-]
-
-[[package]]
-name = "rustls-native-certs"
-version = "0.8.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3"
-dependencies = [
- "openssl-probe",
+ "rustls-pemfile",
"rustls-pki-types",
"schannel",
- "security-framework 3.1.0",
-]
-
-[[package]]
-name = "rustls-pemfile"
-version = "1.0.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
-dependencies = [
- "base64 0.21.7",
+ "security-framework",
]
[[package]]
@@ -6569,12 +5930,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
-version = "1.10.1"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37"
-dependencies = [
- "web-time",
-]
+checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
[[package]]
name = "rustls-webpki"
@@ -6589,15 +5947,15 @@ dependencies = [
[[package]]
name = "rustversion"
-version = "1.0.18"
+version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248"
+checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4"
[[package]]
name = "ryu"
-version = "1.0.18"
+version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
+checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]]
name = "safetensors"
@@ -6649,9 +6007,9 @@ dependencies = [
[[package]]
name = "scc"
-version = "2.2.6"
+version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "94b13f8ea6177672c49d12ed964cca44836f59621981b04a3e26b87e675181de"
+checksum = "28e1c91382686d21b5ac7959341fcb9780fa7c03773646995a87c950fa7be640"
dependencies = [
"sdd",
]
@@ -6692,21 +6050,8 @@ version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
- "bitflags 2.6.0",
- "core-foundation 0.9.4",
- "core-foundation-sys",
- "libc",
- "security-framework-sys",
-]
-
-[[package]]
-name = "security-framework"
-version = "3.1.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "81d3f8c9bfcc3cbb6b0179eb57042d75b1582bdc65c3cb95f3fa999509c03cbc"
-dependencies = [
- "bitflags 2.6.0",
- "core-foundation 0.10.0",
+ "bitflags 2.8.0",
+ "core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
@@ -6714,9 +6059,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
-version = "2.13.0"
+version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5"
+checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
@@ -6724,9 +6069,9 @@ dependencies = [
[[package]]
name = "semver"
-version = "1.0.24"
+version = "1.0.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba"
+checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03"
[[package]]
name = "seq-macro"
@@ -6771,14 +6116,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "serde_json"
-version = "1.0.134"
+version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d"
+checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949"
dependencies = [
"itoa",
"memchr",
@@ -6849,12 +6194,12 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "server"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"cfg-if",
@@ -6933,23 +6278,6 @@ version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
-[[package]]
-name = "simd-json"
-version = "0.14.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "aa2bcf6c6e164e81bc7a5d49fc6988b3d515d9e8c07457d7b74ffb9324b9cd40"
-dependencies = [
- "ahash",
- "getrandom",
- "halfbrown",
- "once_cell",
- "ref-cast",
- "serde",
- "serde_json",
- "simdutf8",
- "value-trait",
-]
-
[[package]]
name = "simd_helpers"
version = "0.1.0"
@@ -6967,7 +6295,7 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
[[package]]
name = "simple-regression"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"log",
@@ -6978,9 +6306,9 @@ dependencies = [
[[package]]
name = "siphasher"
-version = "0.3.11"
+version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
+checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]]
name = "slab"
@@ -7006,34 +6334,6 @@ version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
-[[package]]
-name = "snafu"
-version = "0.7.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6"
-dependencies = [
- "doc-comment",
- "snafu-derive",
-]
-
-[[package]]
-name = "snafu-derive"
-version = "0.7.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf"
-dependencies = [
- "heck 0.4.1",
- "proc-macro2",
- "quote",
- "syn 1.0.109",
-]
-
-[[package]]
-name = "snap"
-version = "1.1.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b"
-
[[package]]
name = "socket2"
version = "0.5.8"
@@ -7060,7 +6360,15 @@ version = "0.3.0+sdk-1.3.268.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
+]
+
+[[package]]
+name = "spirv"
+version = "0.3.0+sdk-1.3.296.0"
+source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e"
+dependencies = [
+ "bitflags 2.8.0",
]
[[package]]
@@ -7077,9 +6385,9 @@ dependencies = [
[[package]]
name = "sqlparser"
-version = "0.49.0"
+version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e"
+checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8"
dependencies = [
"log",
]
@@ -7139,12 +6447,6 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
-[[package]]
-name = "strsim"
-version = "0.10.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
-
[[package]]
name = "strsim"
version = "0.11.1"
@@ -7166,11 +6468,11 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
- "heck 0.5.0",
+ "heck",
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -7186,27 +6488,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
- "quote",
"unicode-ident",
]
[[package]]
name = "syn"
-version = "2.0.95"
+version = "2.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a"
+checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
-[[package]]
-name = "sync_wrapper"
-version = "0.1.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
-
[[package]]
name = "sync_wrapper"
version = "1.0.2"
@@ -7224,7 +6519,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -7233,7 +6528,7 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"byteorder",
"enum-as-inner",
"libc",
@@ -7243,22 +6538,9 @@ dependencies = [
[[package]]
name = "sysinfo"
-version = "0.31.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "355dbe4f8799b304b05e1b0f05fc59b2a18d36645cf169607da45bde2f69a1be"
-dependencies = [
- "core-foundation-sys",
- "libc",
- "memchr",
- "ntapi",
- "windows 0.57.0",
-]
-
-[[package]]
-name = "sysinfo"
-version = "0.32.1"
+version = "0.33.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af"
+checksum = "4fc858248ea01b66f19d8e8a6d55f41deaf91e9d495246fd01368d99935c6c01"
dependencies = [
"core-foundation-sys",
"libc",
@@ -7269,36 +6551,15 @@ dependencies = [
"windows 0.57.0",
]
-[[package]]
-name = "system-configuration"
-version = "0.5.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
-dependencies = [
- "bitflags 1.3.2",
- "core-foundation 0.9.4",
- "system-configuration-sys 0.5.0",
-]
-
[[package]]
name = "system-configuration"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
- "bitflags 2.6.0",
- "core-foundation 0.9.4",
- "system-configuration-sys 0.6.0",
-]
-
-[[package]]
-name = "system-configuration-sys"
-version = "0.5.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
-dependencies = [
- "core-foundation-sys",
- "libc",
+ "bitflags 2.8.0",
+ "core-foundation",
+ "system-configuration-sys",
]
[[package]]
@@ -7318,7 +6579,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349"
dependencies = [
"cfg-expr",
- "heck 0.5.0",
+ "heck",
"pkg-config",
"toml",
"version-compare",
@@ -7349,12 +6610,6 @@ dependencies = [
"xattr",
]
-[[package]]
-name = "target-features"
-version = "0.1.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5"
-
[[package]]
name = "target-lexicon"
version = "0.12.16"
@@ -7380,12 +6635,13 @@ dependencies = [
[[package]]
name = "tempfile"
-version = "3.14.0"
+version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c"
+checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91"
dependencies = [
"cfg-if",
"fastrand",
+ "getrandom 0.3.1",
"once_cell",
"rustix",
"windows-sys 0.59.0",
@@ -7402,7 +6658,7 @@ dependencies = [
[[package]]
name = "text-classification"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"derive-new 0.7.0",
@@ -7412,7 +6668,7 @@ dependencies = [
[[package]]
name = "text-generation"
-version = "0.16.0"
+version = "0.17.0"
dependencies = [
"burn",
"derive-new 0.7.0",
@@ -7442,12 +6698,6 @@ dependencies = [
"rgb",
]
-[[package]]
-name = "textwrap"
-version = "0.16.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9"
-
[[package]]
name = "thiserror"
version = "1.0.69"
@@ -7459,11 +6709,11 @@ dependencies = [
[[package]]
name = "thiserror"
-version = "2.0.9"
+version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc"
+checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
dependencies = [
- "thiserror-impl 2.0.9",
+ "thiserror-impl 2.0.11",
]
[[package]]
@@ -7474,18 +6724,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
name = "thiserror-impl"
-version = "2.0.9"
+version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4"
+checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -7563,9 +6813,9 @@ dependencies = [
[[package]]
name = "tinyvec"
-version = "1.8.0"
+version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938"
+checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
@@ -7585,7 +6835,7 @@ dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
- "getrandom",
+ "getrandom 0.2.15",
"hf-hub",
"itertools 0.12.1",
"lazy_static",
@@ -7610,9 +6860,9 @@ dependencies = [
[[package]]
name = "tokio"
-version = "1.42.0"
+version = "1.43.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551"
+checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
dependencies = [
"backtrace",
"bytes",
@@ -7626,13 +6876,13 @@ dependencies = [
[[package]]
name = "tokio-macros"
-version = "2.4.0"
+version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
+checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -7655,18 +6905,6 @@ dependencies = [
"tokio",
]
-[[package]]
-name = "tokio-tungstenite"
-version = "0.24.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
-dependencies = [
- "futures-util",
- "log",
- "tokio",
- "tungstenite 0.24.0",
-]
-
[[package]]
name = "tokio-tungstenite"
version = "0.26.1"
@@ -7676,7 +6914,7 @@ dependencies = [
"futures-util",
"log",
"tokio",
- "tungstenite 0.26.1",
+ "tungstenite",
]
[[package]]
@@ -7715,11 +6953,11 @@ dependencies = [
[[package]]
name = "toml_edit"
-version = "0.22.22"
+version = "0.22.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
+checksum = "02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee"
dependencies = [
- "indexmap 2.7.0",
+ "indexmap",
"serde",
"serde_spanned",
"toml_datetime",
@@ -7750,7 +6988,7 @@ dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
- "sync_wrapper 1.0.2",
+ "sync_wrapper",
"tokio",
"tower-layer",
"tower-service",
@@ -7776,7 +7014,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58fccce80a2ef6bc32a512514a53cf853d438a44abaea286a4acb0c9f8566860"
dependencies = [
"anyhow",
- "clap 4.5.23",
+ "clap",
"derive_more 0.99.18",
"env_logger",
"log",
@@ -7796,7 +7034,7 @@ checksum = "5a3a646485f7cd8f580749ab94718ad3d344bcc0cc5b0fefe43c15fdd898bb96"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -7831,7 +7069,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -7881,38 +7119,29 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
-version = "0.24.0"
+version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
+checksum = "413083a99c579593656008130e29255e54dcaae495be556cc26888f211648c24"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
- "http 1.2.0",
+ "http",
"httparse",
"log",
"rand",
"sha1",
- "thiserror 1.0.69",
+ "thiserror 2.0.11",
"utf-8",
]
[[package]]
-name = "tungstenite"
-version = "0.26.1"
+name = "type-map"
+version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "413083a99c579593656008130e29255e54dcaae495be556cc26888f211648c24"
+checksum = "deb68604048ff8fa93347f02441e4487594adc20bb8a084f9e564d2b827a0a9f"
dependencies = [
- "byteorder",
- "bytes",
- "data-encoding",
- "http 1.2.0",
- "httparse",
- "log",
- "rand",
- "sha1",
- "thiserror 2.0.9",
- "utf-8",
+ "rustc-hash",
]
[[package]]
@@ -7965,9 +7194,9 @@ dependencies = [
[[package]]
name = "unicode-ident"
-version = "1.0.14"
+version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
+checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
[[package]]
name = "unicode-normalization"
@@ -8037,12 +7266,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
-[[package]]
-name = "unindent"
-version = "0.2.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
-
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -8061,7 +7284,7 @@ dependencies = [
"native-tls",
"once_cell",
"rustls",
- "rustls-native-certs 0.7.3",
+ "rustls-native-certs",
"rustls-pki-types",
"serde",
"serde_json",
@@ -8106,11 +7329,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
-version = "1.11.0"
+version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a"
+checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
"rand",
]
@@ -8127,20 +7350,19 @@ dependencies = [
[[package]]
name = "valuable"
-version = "0.1.0"
+version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
+checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
-name = "value-trait"
-version = "0.10.1"
+name = "variadics_please"
+version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187"
+checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c"
dependencies = [
- "float-cmp",
- "halfbrown",
- "itoa",
- "ryu",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.98",
]
[[package]]
@@ -8186,36 +7408,46 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
+[[package]]
+name = "wasi"
+version = "0.13.3+wasi-0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2"
+dependencies = [
+ "wit-bindgen-rt",
+]
+
[[package]]
name = "wasm-bindgen"
-version = "0.2.99"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396"
+checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
dependencies = [
"cfg-if",
"once_cell",
+ "rustversion",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
-version = "0.2.99"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79"
+checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
dependencies = [
"bumpalo",
"log",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
-version = "0.4.49"
+version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2"
+checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if",
"js-sys",
@@ -8226,9 +7458,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
-version = "0.2.99"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe"
+checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
@@ -8236,22 +7468,25 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
-version = "0.2.99"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2"
+checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
-version = "0.2.99"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6"
+checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
+dependencies = [
+ "unicode-ident",
+]
[[package]]
name = "wasm-logger"
@@ -8264,19 +7499,6 @@ dependencies = [
"web-sys",
]
-[[package]]
-name = "wasm-streams"
-version = "0.4.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
-dependencies = [
- "futures-util",
- "js-sys",
- "wasm-bindgen",
- "wasm-bindgen-futures",
- "web-sys",
-]
-
[[package]]
name = "wasm-timer"
version = "0.2.5"
@@ -8294,9 +7516,9 @@ dependencies = [
[[package]]
name = "web-sys"
-version = "0.3.76"
+version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc"
+checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
@@ -8314,9 +7536,9 @@ dependencies = [
[[package]]
name = "webpki-roots"
-version = "0.26.7"
+version = "0.26.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e"
+checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9"
dependencies = [
"rustls-pki-types",
]
@@ -8327,14 +7549,23 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
+[[package]]
+name = "wgan"
+version = "0.1.0"
+dependencies = [
+ "burn",
+ "image",
+]
+
[[package]]
name = "wgpu"
-version = "23.0.1"
+version = "24.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "80f70000db37c469ea9d67defdc13024ddf9a5f1b89cb2941b812ad7cde1735a"
+checksum = "47f55718f85c2fa756edffa0e7f0e0a60aba463d1362b57e23123c58f035e4b6"
dependencies = [
"arrayvec",
- "cfg_aliases 0.1.1",
+ "bitflags 2.8.0",
+ "cfg_aliases",
"document-features",
"js-sys",
"log",
@@ -8354,43 +7585,43 @@ dependencies = [
[[package]]
name = "wgpu-core"
-version = "23.0.1"
+version = "24.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a"
+checksum = "82a39b8842dc9ffcbe34346e3ab6d496b32a47f6497e119d762c97fcaae3cb37"
dependencies = [
"arrayvec",
"bit-vec",
- "bitflags 2.6.0",
- "cfg_aliases 0.1.1",
+ "bitflags 2.8.0",
+ "cfg_aliases",
"document-features",
- "indexmap 2.7.0",
+ "indexmap",
"log",
"naga",
"once_cell",
"parking_lot 0.12.3",
"profiling",
"raw-window-handle",
- "rustc-hash 1.1.0",
+ "rustc-hash",
"smallvec",
- "thiserror 1.0.69",
+ "thiserror 2.0.11",
"wgpu-hal",
"wgpu-types",
]
[[package]]
name = "wgpu-hal"
-version = "23.0.1"
+version = "24.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "89364b8a0b211adc7b16aeaf1bd5ad4a919c1154b44c9ce27838213ba05fd821"
+checksum = "5a782e5056b060b0b4010881d1decddd059e44f2ecd01e2db2971b48ad3627e5"
dependencies = [
"android_system_properties",
"arrayvec",
"ash",
"bit-set",
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"block",
"bytemuck",
- "cfg_aliases 0.1.1",
+ "cfg_aliases",
"core-graphics-types",
"glow",
"glutin_wgl_sys",
@@ -8402,19 +7633,20 @@ dependencies = [
"libc",
"libloading",
"log",
- "metal 0.29.0",
+ "metal 0.31.0",
"naga",
"ndk-sys",
"objc",
"once_cell",
+ "ordered-float",
"parking_lot 0.12.3",
"profiling",
"range-alloc",
"raw-window-handle",
"renderdoc-sys",
- "rustc-hash 1.1.0",
+ "rustc-hash",
"smallvec",
- "thiserror 1.0.69",
+ "thiserror 2.0.11",
"wasm-bindgen",
"web-sys",
"wgpu-types",
@@ -8424,12 +7656,13 @@ dependencies = [
[[package]]
name = "wgpu-types"
-version = "23.0.0"
+version = "24.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068"
+checksum = "50ac044c0e76c03a0378e7786ac505d010a873665e2d51383dcff8dd227dc69c"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.8.0",
"js-sys",
+ "log",
"web-sys",
]
@@ -8538,7 +7771,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -8549,7 +7782,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -8560,7 +7793,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -8571,7 +7804,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -8763,21 +7996,20 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
-version = "0.6.20"
+version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
+checksum = "7e49d2d35d3fad69b39b94139037ecfb4f359f08958b9c11e7315ce770462419"
dependencies = [
"memchr",
]
[[package]]
-name = "winreg"
-version = "0.50.0"
+name = "wit-bindgen-rt"
+version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1"
+checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
dependencies = [
- "cfg-if",
- "windows-sys 0.48.0",
+ "bitflags 2.8.0",
]
[[package]]
@@ -8789,7 +8021,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -8829,9 +8061,9 @@ checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d"
[[package]]
name = "xattr"
-version = "1.3.1"
+version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f"
+checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909"
dependencies = [
"libc",
"linux-raw-sys",
@@ -8840,13 +8072,13 @@ dependencies = [
[[package]]
name = "xml-rs"
-version = "0.8.24"
+version = "0.8.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ea8b391c9a790b496184c29f7f93b9ed5b16abb306c05415b68bcc16e4d06432"
+checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4"
[[package]]
name = "xtask"
-version = "1.1.0"
+version = "1.2.0"
dependencies = [
"log",
"rstest",
@@ -8856,9 +8088,9 @@ dependencies = [
[[package]]
name = "xxhash-rust"
-version = "0.8.12"
+version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984"
+checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3"
[[package]]
name = "yansi"
@@ -8886,7 +8118,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
"synstructure",
]
@@ -8908,7 +8140,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -8928,7 +8160,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
"synstructure",
]
@@ -8949,7 +8181,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -8971,7 +8203,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.95",
+ "syn 2.0.98",
]
[[package]]
@@ -9004,7 +8236,7 @@ dependencies = [
"crc32fast",
"crossbeam-utils",
"displaydoc",
- "indexmap 2.7.0",
+ "indexmap",
"num_enum",
"thiserror 1.0.69",
]
@@ -9025,13 +8257,13 @@ dependencies = [
"displaydoc",
"flate2",
"hmac",
- "indexmap 2.7.0",
+ "indexmap",
"lzma-rs",
"memchr",
"pbkdf2 0.12.2",
"rand",
"sha1",
- "thiserror 2.0.9",
+ "thiserror 2.0.11",
"time",
"zeroize",
"zopfli",
diff --git a/Cargo.toml b/Cargo.toml
index e72651e70c..169d668aa8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -23,18 +23,18 @@ exclude = [
edition = "2021"
license = "MIT OR Apache-2.0"
readme = "README.md"
-version = "0.16.0"
+version = "0.17.0"
[workspace.dependencies]
atomic_float = "1"
bytemuck = "1.21.0"
candle-core = { version = "0.8" }
-clap = { version = "4.5.23", features = ["derive"] }
+clap = { version = "4.5.27", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
csv = "1.3.1"
dashmap = "6.1.0"
-data-encoding = { version = "2.6.0", default-features = false, features = [
+data-encoding = { version = "2.7.0", default-features = false, features = [
"alloc",
] }
dirs = "5.0.1"
@@ -47,16 +47,16 @@ globwalk = "0.9.1"
hashbrown = "0.15.2"
hound = "3.5.1"
image = "0.25.5"
-indicatif = "0.17.9"
+indicatif = "0.17.11"
js-sys = "0.3.72"
libm = "0.2.11"
-log = { default-features = false, version = "0.4.22" }
+log = { default-features = false, version = "0.4.25" }
md5 = "0.7.0"
paste = "1"
percent-encoding = "2.3.1"
-polars = { version = "0.44.2", features = ["lazy"] }
+polars = { version = "0.46.0", features = ["lazy"] }
pretty_assertions = "1.4.1"
-proc-macro2 = "1.0.92"
+proc-macro2 = "1.0.93"
protobuf = "3.7.1"
protobuf-codegen = "3.7.1"
quote = "1.0.38"
@@ -84,7 +84,7 @@ strum = "0.26.3"
strum_macros = "0.26.4"
syn = { version = "2.0.95", features = ["full", "extra-traits"] }
tempfile = "3.14.0"
-thiserror = "2.0.9"
+thiserror = "2.0.11"
tokio = { version = "1.42.0", features = ["rt", "macros"] }
tracing-appender = "0.2.3"
tracing-core = "0.1.33"
@@ -101,11 +101,11 @@ ratatui = "0.29.0"
# WGPU stuff
text_placeholder = "0.5.1"
-wgpu = "23.0.0"
+wgpu = "24.0.1"
# Benchmarks and Burnbench
arboard = "3.4.1"
-github-device-flow = "0.2.0"
+chrono = "0.4.39"
os_info = "3.9.0"
wsl = "0.1.0"
@@ -140,12 +140,12 @@ serde = { version = "1.0.217", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed
-serde_json = { version = "1.0.134", default-features = false }
-uuid = { version = "1.11.0", default-features = false }
+serde_json = { version = "1.0.137", default-features = false }
+uuid = { version = "1.12.1", default-features = false }
-libc = "0.2.168"
+libc = "0.2.169"
nvml-wrapper = "0.10.0"
-sysinfo = "0.32.1"
+sysinfo = "0.33.1"
systemstat = "0.2.3"
tch = "0.15.0"
@@ -153,14 +153,14 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
### For the main burn branch. ###
-cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "34af9342a2b4f8dcf1b0047afbea0f26405b92cf" }
-cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "34af9342a2b4f8dcf1b0047afbea0f26405b92cf" }
+cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
+cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
-# cubecl = { version = "0.3.0", default-features = false }
-# cubecl-common = { version = "0.3.0", default-features = false }
+# cubecl = { version = "0.4.0", default-features = false }
+# cubecl-common = { version = "0.4.0", default-features = false }
### For xtask crate ###
tracel-xtask = { version = "=1.1.8" }
diff --git a/NOTICES.md b/NOTICES.md
index a11156a772..0f559e27a4 100644
--- a/NOTICES.md
+++ b/NOTICES.md
@@ -9,7 +9,7 @@ repository copied or derived from.
License: BSD 3-Clause License
-Copyright (c) 2017,
+Copyright (c) 2017,
All rights reserved.
Redistribution and use in source and binary forms, with or without
@@ -572,4 +572,75 @@ SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-DEALINGS IN THE SOFTWARE.
\ No newline at end of file
+DEALINGS IN THE SOFTWARE.
+
+## github-device-flow
+
+**Source**:
+- Part of: https://github.com/jakewilkins/gh-device-flow/blob/main/src/lib.rs
+- https://github.com/jakewilkins/gh-device-flow/blob/main/src/util.rs
+
+MIT License
+
+Copyright (c) 2022 Jake Wilkins
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
+## ICU
+
+UNICODE LICENSE V3
+
+COPYRIGHT AND PERMISSION NOTICE
+
+Copyright © 2016-2024 Unicode, Inc.
+
+NOTICE TO USER: Carefully read the following legal agreement. BY
+DOWNLOADING, INSTALLING, COPYING OR OTHERWISE USING DATA FILES, AND/OR
+SOFTWARE, YOU UNEQUIVOCALLY ACCEPT, AND AGREE TO BE BOUND BY, ALL OF THE
+TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE, DO NOT
+DOWNLOAD, INSTALL, COPY, DISTRIBUTE OR USE THE DATA FILES OR SOFTWARE.
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of data files and any associated documentation (the "Data Files") or
+software and any associated documentation (the "Software") to deal in the
+Data Files or Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, and/or sell
+copies of the Data Files or Software, and to permit persons to whom the
+Data Files or Software are furnished to do so, provided that either (a)
+this copyright and permission notice appear with all copies of the Data
+Files or Software, or (b) this copyright and permission notice appear in
+associated Documentation.
+
+THE DATA FILES AND SOFTWARE ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
+KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF
+THIRD PARTY RIGHTS.
+
+IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE
+BE LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES,
+OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
+WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,
+ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE DATA
+FILES OR SOFTWARE.
+
+Except as contained in this notice, the name of a copyright holder shall
+not be used in advertising or otherwise to promote the sale, use or other
+dealings in these Data Files or Software without prior written
+authorization of the copyright holder.
diff --git a/README.md b/README.md
index a0780dcc16..951b2a9f24 100644
--- a/README.md
+++ b/README.md
@@ -567,6 +567,8 @@ Additional examples:
sample.
- [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the
DbPedia dataset.
+- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits
+ based on MNIST.
For more practical insights, you can clone the repository and run any of them directly on your
computer!
@@ -619,19 +621,20 @@ leads to more reliable, bug-free solutions built faster (after some practice
> **Deprecation Note**
Since `0.14.0`, the internal structure for tensor data has changed. The
-> previous `Data` struct is being deprecated in favor of the new `TensorData` struct, which allows
-> for more flexibility by storing the underlying data as bytes and keeping the data type as a field.
-> If you are using `Data` in your code, make sure to switch to `TensorData`.
+> previous `Data` struct was deprecated and officially removed since `0.17.0` in favor of the new
+> `TensorData` struct, which allows for more flexibility by storing the underlying data as bytes and
+> keeping the data type as a field. If you are using `Data` in your code, make sure to switch to
+> `TensorData`.
@@ -640,8 +643,9 @@ Loading Model Records From Previous Versions ⚠️
-In the event that you are trying to load a model record saved in a previous version, make sure to
-enable the `record-backward-compat` feature flag.
+In the event that you are trying to load a model record saved in a version older than `0.14.0`, make
+sure to use a compatible version (`0.14`, `0.15` or `0.16`) with the `record-backward-compat`
+feature flag.
```
features = [..., "record-backward-compat"]
@@ -650,13 +654,14 @@ features = [..., "record-backward-compat"]
Otherwise, the record won't be deserialized correctly and you will get an error message. This error
will also point you to the backward compatible feature flag.
-The backward compatibility is maintained for deserialization when loading records. Therefore, as
-soon as you have saved the record again it will be saved according to the new structure and you
-won't need the backward compatible feature flag anymore.
+The backward compatibility was maintained for deserialization when loading records. Therefore, as
+soon as you have saved the record again it will be saved according to the new structure and you can
+upgrade back to the current version
Please note that binary formats are not backward compatible. Thus, you will need to load your record
in a previous version and save it in any of the other self-describing record format (e.g., using the
-`NamedMpkFileRecorder`) before using the new version with the `record-backward-compat` feature flag.
+`NamedMpkFileRecorder`) before using a compatible version (as described) with the
+`record-backward-compat` feature flag.
diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml
index ee5f0bd8a2..821d189fe0 100644
--- a/backend-comparison/Cargo.toml
+++ b/backend-comparison/Cargo.toml
@@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
-cuda-jit = ["burn/cuda-jit"]
-cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
+cuda = ["burn/cuda"]
+cuda-fusion = ["cuda", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
-hip-jit = ["burn/hip-jit"]
+hip = ["burn/hip"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
@@ -27,20 +27,20 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
-wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"]
+wgpu-spirv = ["burn/vulkan", "burn/autotune"]
wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"]
[dependencies]
arboard = { workspace = true }
burn = { path = "../crates/burn", default-features = false }
-burn-common = { path = "../crates/burn-common", version = "0.16.0" }
+burn-common = { path = "../crates/burn-common", version = "0.17.0" }
clap = { workspace = true }
colored = { workspace = true }
+chrono = { workspace = true }
cubecl = { workspace = true, features = ["wgpu"], default-features = true }
derive-new = { workspace = true }
dirs = { workspace = true }
-github-device-flow = { workspace = true }
half = { workspace = true }
indicatif = { workspace = true }
os_info = { workspace = true }
@@ -124,6 +124,10 @@ path = "benches/resnet.rs"
harness = false
name = "autodiff"
+[[bench]]
+harness = false
+name = "reduce"
+
[[bin]]
name = "burnbench"
path = "src/bin/burnbench.rs"
diff --git a/backend-comparison/README.md b/backend-comparison/README.md
index ba7042bbc1..6f4b547a22 100644
--- a/backend-comparison/README.md
+++ b/backend-comparison/README.md
@@ -57,6 +57,7 @@ Available Benchmarks:
- conv-transpose3d
- conv2d
- conv3d
+- reduce
```
#### Run benchmarks
diff --git a/backend-comparison/benches/matmul_fused.rs b/backend-comparison/benches/matmul_fused.rs
index 375be97b4e..fbec64c648 100644
--- a/backend-comparison/benches/matmul_fused.rs
+++ b/backend-comparison/benches/matmul_fused.rs
@@ -1,5 +1,9 @@
use backend_comparison::persistence::save;
-use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor};
+use burn::tensor::{
+ activation::{gelu, relu},
+ backend::Backend,
+ Distribution, Shape, Tensor,
+};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
@@ -14,7 +18,7 @@ impl Benchmark for MatmulBenchmark {
type Args = (Tensor, Tensor, Tensor);
fn name(&self) -> String {
- "matmul_bias_relu".into()
+ "matmul_relu_bias_gelu".into()
}
fn shapes(&self) -> Vec> {
@@ -23,7 +27,7 @@ impl Benchmark for MatmulBenchmark {
fn execute(&self, (lhs, rhs, bias): Self::Args) {
let bias = bias.unsqueeze();
- relu(lhs.matmul(rhs) + bias);
+ gelu(relu(lhs.matmul(rhs)) + bias);
}
fn prepare(&self) -> Self::Args {
diff --git a/backend-comparison/benches/reduce.rs b/backend-comparison/benches/reduce.rs
new file mode 100644
index 0000000000..df365f2306
--- /dev/null
+++ b/backend-comparison/benches/reduce.rs
@@ -0,0 +1,102 @@
+use backend_comparison::persistence::save;
+use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
+use burn_common::benchmark::{run_benchmark, Benchmark};
+
+enum Instruction {
+ ArgMin(usize),
+ SumDim(usize),
+ Sum,
+}
+
+struct ReduceBenchmark {
+ instruction: Instruction,
+ shape: Shape,
+ device: B::Device,
+ tensor: Tensor,
+}
+
+impl ReduceBenchmark {
+ pub fn new(instruction: Instruction, device: B::Device) -> Self {
+ let shape = Shape::new([4096, 512, 64]);
+ let tensor = Tensor::random(shape.clone(), Distribution::Default, &device);
+ Self {
+ instruction,
+ shape,
+ device,
+ tensor,
+ }
+ }
+}
+
+impl Benchmark for ReduceBenchmark {
+ type Args = ();
+
+ fn prepare(&self) -> Self::Args {}
+
+ fn execute(&self, _: Self::Args) {
+ match self.instruction {
+ Instruction::ArgMin(axis) => {
+ self.tensor.clone().argmin(axis);
+ }
+ Instruction::SumDim(axis) => {
+ self.tensor.clone().sum_dim(axis);
+ }
+ Instruction::Sum => {
+ self.tensor.clone().sum();
+ }
+ }
+ }
+
+ fn name(&self) -> String {
+ match self.instruction {
+ Instruction::ArgMin(axis) => format!("reduce-argmin-{axis}"),
+ Instruction::SumDim(axis) => format!("reduce-sum-{axis}"),
+ Instruction::Sum => String::from("reduce-sum-full"),
+ }
+ }
+
+ fn sync(&self) {
+ B::sync(&self.device)
+ }
+
+ fn shapes(&self) -> Vec> {
+ vec![self.shape.dims.clone()]
+ }
+}
+
+#[allow(dead_code)]
+fn bench(
+ device: &B::Device,
+ feature_name: &str,
+ url: Option<&str>,
+ token: Option<&str>,
+) {
+ let mut benchmarks = Vec::new();
+
+ for axis in 0..3 {
+ benchmarks.push(ReduceBenchmark::::new(
+ Instruction::ArgMin(axis),
+ device.clone(),
+ ));
+
+ benchmarks.push(ReduceBenchmark::::new(
+ Instruction::SumDim(axis),
+ device.clone(),
+ ));
+ }
+
+ benchmarks.push(ReduceBenchmark::::new(Instruction::Sum, device.clone()));
+
+ save::(
+ benchmarks.into_iter().map(run_benchmark).collect(),
+ device,
+ feature_name,
+ url,
+ token,
+ )
+ .unwrap();
+}
+
+fn main() {
+ backend_comparison::bench_on_backend!();
+}
diff --git a/backend-comparison/src/burnbenchapp/auth.rs b/backend-comparison/src/burnbenchapp/auth/base.rs
similarity index 96%
rename from backend-comparison/src/burnbenchapp/auth.rs
rename to backend-comparison/src/burnbenchapp/auth/base.rs
index 3e9470b2bd..f956b0a204 100644
--- a/backend-comparison/src/burnbenchapp/auth.rs
+++ b/backend-comparison/src/burnbenchapp/auth/base.rs
@@ -1,6 +1,5 @@
use arboard::Clipboard;
use burn::serde::{Deserialize, Serialize};
-use github_device_flow::{self, DeviceFlow};
use reqwest;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
@@ -64,7 +63,7 @@ pub(crate) fn get_tokens() -> Option {
pub(crate) fn get_username(access_token: &str) -> Option {
let client = reqwest::blocking::Client::new();
let response = client
- .get(format!("{}users/me", super::USER_BENCHMARK_SERVER_URL))
+ .get(format!("{}users/me", USER_BENCHMARK_SERVER_URL))
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(
@@ -77,7 +76,7 @@ pub(crate) fn get_username(access_token: &str) -> Option {
}
fn auth() -> Option {
- let mut flow = match DeviceFlow::start(CLIENT_ID, None) {
+ let mut flow = match DeviceFlow::start(CLIENT_ID, None, None) {
Ok(flow) => flow,
Err(e) => {
eprintln!("Error authenticating: {}", e);
@@ -134,7 +133,7 @@ fn verify_tokens(tokens: &Tokens) -> bool {
)
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.send();
- response.map_or(false, |resp| resp.status().is_success())
+ response.is_ok_and(|resp| resp.status().is_success())
}
fn refresh_tokens(tokens: &Tokens) -> Option {
@@ -142,10 +141,7 @@ fn refresh_tokens(tokens: &Tokens) -> Option {
println!("Refreshing token...");
let client = reqwest::blocking::Client::new();
let response = client
- .post(format!(
- "{}auth/refresh-token",
- super::USER_BENCHMARK_SERVER_URL
- ))
+ .post(format!("{}auth/refresh-token", USER_BENCHMARK_SERVER_URL))
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(
@@ -189,6 +185,8 @@ fn save_tokens(tokens: &Tokens) {
#[cfg(test)]
use serial_test::serial;
+use crate::burnbenchapp::{auth::github_device_flow::DeviceFlow, USER_BENCHMARK_SERVER_URL};
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs b/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs
new file mode 100644
index 0000000000..55aa00f73e
--- /dev/null
+++ b/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs
@@ -0,0 +1,232 @@
+// Initially from: https://github.com/jakewilkins/gh-device-flow
+use std::collections::HashMap;
+use std::{fmt, result::Result, thread, time};
+
+use chrono::offset::Utc;
+use chrono::{DateTime, Duration};
+use serde::{Deserialize, Serialize};
+
+pub fn credential_error(msg: String) -> DeviceFlowError {
+ DeviceFlowError::GitHubError(msg)
+}
+
+pub fn send_request(
+ device_flow: &mut DeviceFlow,
+ url: String,
+ body: String,
+) -> Option> {
+ let client = reqwest::blocking::Client::new();
+ let response_struct = client
+ .post(&url)
+ .header("Accept", "application/json")
+ .body(body)
+ .send();
+
+ match response_struct {
+ Ok(resp) => match resp.json::>() {
+ Ok(hm) => Some(hm),
+ Err(err) => {
+ device_flow.state = DeviceFlowState::Failure(err.into());
+ None
+ }
+ },
+ Err(err) => {
+ device_flow.state = DeviceFlowState::Failure(err.into());
+ None
+ }
+ }
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize)]
+pub struct Credential {
+ pub token: String,
+ pub expiry: String,
+ pub refresh_token: String,
+}
+
+impl Credential {
+ fn empty() -> Credential {
+ Credential {
+ token: String::new(),
+ expiry: String::new(),
+ refresh_token: String::new(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum DeviceFlowError {
+ HttpError(String),
+ GitHubError(String),
+}
+
+impl fmt::Display for DeviceFlowError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string),
+ DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string),
+ }
+ }
+}
+
+impl std::error::Error for DeviceFlowError {}
+
+impl From for DeviceFlowError {
+ fn from(e: reqwest::Error) -> Self {
+ DeviceFlowError::HttpError(format!("{:?}", e))
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum DeviceFlowState {
+ Pending,
+ Processing(time::Duration),
+ Success(Credential),
+ Failure(DeviceFlowError),
+}
+
+#[derive(Clone)]
+pub struct DeviceFlow {
+ pub host: String,
+ pub client_id: String,
+ pub scope: String,
+ pub user_code: Option,
+ pub device_code: Option,
+ pub verification_uri: Option,
+ pub state: DeviceFlowState,
+}
+
+const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
+
+impl DeviceFlow {
+ pub fn new(client_id: &str, maybe_host: Option<&str>, scope: Option<&str>) -> Self {
+ Self {
+ client_id: String::from(client_id),
+ scope: match scope {
+ Some(string) => String::from(string),
+ None => String::new(),
+ },
+ host: match maybe_host {
+ Some(string) => String::from(string),
+ None => String::from("github.com"),
+ },
+ user_code: None,
+ device_code: None,
+ verification_uri: None,
+ state: DeviceFlowState::Pending,
+ }
+ }
+
+ pub fn start(
+ client_id: &str,
+ maybe_host: Option<&str>,
+ scope: Option<&str>,
+ ) -> Result {
+ let mut flow = DeviceFlow::new(client_id, maybe_host, scope);
+
+ flow.setup();
+
+ match flow.state {
+ DeviceFlowState::Processing(_) => Ok(flow.to_owned()),
+ DeviceFlowState::Failure(err) => Err(err),
+ _ => Err(credential_error(
+ "Something truly unexpected happened".into(),
+ )),
+ }
+ }
+
+ pub fn setup(&mut self) {
+ let body = format!("client_id={}&scope={}", &self.client_id, &self.scope);
+ let entry_url = format!("https://{}/login/device/code", &self.host);
+
+ if let Some(res) = send_request(self, entry_url, body) {
+ if res.contains_key("error") && res.contains_key("error_description") {
+ self.state = DeviceFlowState::Failure(credential_error(
+ res["error_description"].as_str().unwrap().into(),
+ ))
+ } else if res.contains_key("error") {
+ self.state = DeviceFlowState::Failure(credential_error(format!(
+ "Error response: {:?}",
+ res["error"].as_str().unwrap()
+ )))
+ } else {
+ self.user_code = Some(String::from(res["user_code"].as_str().unwrap()));
+ self.device_code = Some(String::from(res["device_code"].as_str().unwrap()));
+ self.verification_uri =
+ Some(String::from(res["verification_uri"].as_str().unwrap()));
+ self.state = DeviceFlowState::Processing(FIVE_SECONDS);
+ }
+ };
+ }
+
+ pub fn poll(&mut self, iterations: u32) -> Result {
+ for count in 0..iterations {
+ self.update();
+
+ if let DeviceFlowState::Processing(interval) = self.state {
+ if count == iterations {
+ return Err(credential_error("Max poll iterations reached".into()));
+ }
+
+ thread::sleep(interval);
+ } else {
+ break;
+ }
+ }
+
+ match &self.state {
+ DeviceFlowState::Success(cred) => Ok(cred.to_owned()),
+ DeviceFlowState::Failure(err) => Err(err.to_owned()),
+ _ => Err(credential_error(
+ "Unable to fetch credential, sorry :/".into(),
+ )),
+ }
+ }
+
+ pub fn update(&mut self) {
+ let poll_url = format!("https://{}/login/oauth/access_token", self.host);
+ let poll_payload = format!(
+ "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
+ self.client_id,
+ &self.device_code.clone().unwrap()
+ );
+
+ if let Some(res) = send_request(self, poll_url, poll_payload) {
+ if res.contains_key("error") {
+ match res["error"].as_str().unwrap() {
+ "authorization_pending" => {}
+ "slow_down" => {
+ if let DeviceFlowState::Processing(current_interval) = self.state {
+ self.state =
+ DeviceFlowState::Processing(current_interval + FIVE_SECONDS);
+ };
+ }
+ other_reason => {
+ self.state = DeviceFlowState::Failure(credential_error(format!(
+ "Error checking for token: {}",
+ other_reason
+ )));
+ }
+ }
+ } else {
+ let mut this_credential = Credential::empty();
+ this_credential.token = res["access_token"].as_str().unwrap().to_string();
+
+ if let Some(expires_in) = res.get("expires_in") {
+ this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
+ this_credential.refresh_token =
+ res["refresh_token"].as_str().unwrap().to_string();
+ }
+
+ self.state = DeviceFlowState::Success(this_credential);
+ }
+ }
+ }
+}
+
+fn calculate_expiry(expires_in: i64) -> String {
+ let expires_in = Duration::seconds(expires_in);
+ let mut expiry: DateTime = Utc::now();
+ expiry += expires_in;
+ expiry.to_rfc3339()
+}
diff --git a/backend-comparison/src/burnbenchapp/auth/mod.rs b/backend-comparison/src/burnbenchapp/auth/mod.rs
new file mode 100644
index 0000000000..7e1e4539d7
--- /dev/null
+++ b/backend-comparison/src/burnbenchapp/auth/mod.rs
@@ -0,0 +1,4 @@
+mod base;
+pub(crate) mod github_device_flow;
+
+pub(crate) use base::*;
diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs
index 83c5060a6b..4fb31edab8 100644
--- a/backend-comparison/src/burnbenchapp/base.rs
+++ b/backend-comparison/src/burnbenchapp/base.rs
@@ -62,6 +62,13 @@ enum BackendValues {
CandleCuda,
#[strum(to_string = "candle-metal")]
CandleMetal,
+ #[strum(to_string = "cuda")]
+ Cuda,
+ #[strum(to_string = "cuda-fusion")]
+ CudaFusion,
+ #[cfg(target_os = "linux")]
+ #[strum(to_string = "hip")]
+ Hip,
#[strum(to_string = "ndarray")]
Ndarray,
#[strum(to_string = "ndarray-blas-accelerate")]
@@ -82,13 +89,6 @@ enum BackendValues {
WgpuSpirv,
#[strum(to_string = "wgpu-spirv-fusion")]
WgpuSpirvFusion,
- #[strum(to_string = "cuda-jit")]
- CudaJit,
- #[strum(to_string = "cuda-jit-fusion")]
- CudaJitFusion,
- #[cfg(target_os = "linux")]
- #[strum(to_string = "hip-jit")]
- HipJit,
}
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
@@ -123,6 +123,8 @@ enum BenchmarkValues {
Conv2d,
#[strum(to_string = "conv3d")]
Conv3d,
+ #[strum(to_string = "reduce")]
+ Reduce,
}
pub fn execute() {
diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs
index 03e2d70444..b3351e9dd5 100644
--- a/backend-comparison/src/lib.rs
+++ b/backend-comparison/src/lib.rs
@@ -54,6 +54,9 @@ fn update_panic_hook() {
#[macro_export]
macro_rules! bench_on_backend {
() => {
+ $crate::bench_on_backend!(bench)
+ };
+ ($fn_name:ident) => {
use std::env;
backend_comparison::init_log().unwrap();
@@ -88,25 +91,25 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-spirv";
#[cfg(feature = "wgpu-spirv-fusion")]
let feature_name = "wgpu-spirv-fusion";
- #[cfg(feature = "cuda-jit")]
- let feature_name = "cuda-jit";
- #[cfg(feature = "cuda-jit-fusion")]
- let feature_name = "cuda-jit-fusion";
- #[cfg(feature = "hip-jit")]
- let feature_name = "hip-jit";
+ #[cfg(feature = "cuda")]
+ let feature_name = "cuda";
+ #[cfg(feature = "cuda-fusion")]
+ let feature_name = "cuda-fusion";
+ #[cfg(feature = "hip")]
+ let feature_name = "hip";
#[cfg(any(feature = "wgpu"))]
{
use burn::backend::wgpu::{Wgpu, WgpuDevice};
- bench::>(&WgpuDevice::default(), feature_name, url, token);
+ $fn_name::>(&WgpuDevice::default(), feature_name, url, token);
}
#[cfg(any(feature = "wgpu-spirv"))]
{
use burn::backend::wgpu::{Wgpu, WgpuDevice};
- bench::>(&WgpuDevice::default(), feature_name, url, token);
+ $fn_name::>(&WgpuDevice::default(), feature_name, url, token);
}
#[cfg(feature = "tch-gpu")]
@@ -117,7 +120,7 @@ macro_rules! bench_on_backend {
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;
- bench::>(&device, feature_name, url, token);
+ $fn_name::>(&device, feature_name, url, token);
}
#[cfg(feature = "tch-cpu")]
@@ -125,7 +128,7 @@ macro_rules! bench_on_backend {
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
let device = LibTorchDevice::Cpu;
- bench::(&device, feature_name, url, token);
+ $fn_name::(&device, feature_name, url, token);
}
#[cfg(any(
@@ -139,7 +142,7 @@ macro_rules! bench_on_backend {
use burn::backend::NdArray;
let device = NdArrayDevice::Cpu;
- bench::(&device, feature_name, url, token);
+ $fn_name::(&device, feature_name, url, token);
}
#[cfg(feature = "candle-cpu")]
@@ -148,7 +151,7 @@ macro_rules! bench_on_backend {
use burn::backend::Candle;
let device = CandleDevice::Cpu;
- bench::(&device, feature_name, url, token);
+ $fn_name::(&device, feature_name, url, token);
}
#[cfg(feature = "candle-cuda")]
@@ -157,7 +160,7 @@ macro_rules! bench_on_backend {
use burn::backend::Candle;
let device = CandleDevice::cuda(0);
- bench::(&device, feature_name, url, token);
+ $fn_name::(&device, feature_name, url, token);
}
#[cfg(feature = "candle-metal")]
@@ -166,21 +169,21 @@ macro_rules! bench_on_backend {
use burn::backend::Candle;
let device = CandleDevice::metal(0);
- bench::(&device, feature_name, url, token);
+ $fn_name::(&device, feature_name, url, token);
}
- #[cfg(feature = "cuda-jit")]
+ #[cfg(feature = "cuda")]
{
- use burn::backend::cuda_jit::{Cuda, CudaDevice};
+ use burn::backend::cuda::{Cuda, CudaDevice};
- bench::>(&CudaDevice::default(), feature_name, url, token);
+ $fn_name::>(&CudaDevice::default(), feature_name, url, token);
}
- #[cfg(feature = "hip-jit")]
+ #[cfg(feature = "hip")]
{
- use burn::backend::hip_jit::{Hip, HipDevice};
+ use burn::backend::hip::{Hip, HipDevice};
- bench::>(&HipDevice::default(), feature_name, url, token);
+ $fn_name::>(&HipDevice::default(), feature_name, url, token);
}
};
}
diff --git a/backend-comparison/src/persistence/system_info.rs b/backend-comparison/src/persistence/system_info.rs
index 287b629c21..3fe24bc955 100644
--- a/backend-comparison/src/persistence/system_info.rs
+++ b/backend-comparison/src/persistence/system_info.rs
@@ -38,7 +38,7 @@ impl BenchmarkSystemInfo {
fn enumerate_cpus() -> Vec {
let system = sysinfo::System::new_with_specifics(
- sysinfo::RefreshKind::new().with_cpu(sysinfo::CpuRefreshKind::everything()),
+ sysinfo::RefreshKind::nothing().with_cpu(sysinfo::CpuRefreshKind::everything()),
);
let cpu_names: HashSet = system
.cpus()
diff --git a/burn-book/src/advanced/no-std.md b/burn-book/src/advanced/no-std.md
index 7689d25354..e55afc904d 100644
--- a/burn-book/src/advanced/no-std.md
+++ b/burn-book/src/advanced/no-std.md
@@ -23,7 +23,7 @@ Some other dependencies have to be added
```toml
[dependencies]
embedded-alloc = "0.5.1" # Only if there is no default allocator for your chip
-burn = { version = "0.16", default-features = false, features = ["ndarray"] } # Backend must be ndarray
+burn = { version = "0.17", default-features = false, features = ["ndarray"] } # Backend must be ndarray
[build-dependencies]
burn-import = { version = "0.14" } # Used to auto generate the rust code to import the model
@@ -68,7 +68,7 @@ We are using ndarray, so we just need to define the NdArray backend as usual
use burn::{backend::NdArray, tensor::Tensor};
type Backend = NdArray;
-type BackendDeice = ::Device;
+type BackendDevice = ::Device;
```
Then inside the `main` function add
@@ -76,7 +76,7 @@ Then inside the `main` function add
use your_model::Model;
// Get a default device for the backend
-let device = BackendDeice::default();
+let device = BackendDevice::default();
// Create a new model and load the state
let model: Model = Model::default();
diff --git a/burn-book/src/basic-workflow/README.md b/burn-book/src/basic-workflow/README.md
index 5b32591a58..8515d73d2c 100644
--- a/burn-book/src/basic-workflow/README.md
+++ b/burn-book/src/basic-workflow/README.md
@@ -14,7 +14,7 @@ automatically add the missing imports as you add the code snippets to your code.
Be sure to checkout the git branch corresponding to the version of Burn you are using to follow
this guide.
-The current version of Burn is `0.16` and the corresponding branch to checkout is `main`.
+The current version of Burn is `0.17` and the corresponding branch to checkout is `main`.
The code for this demo can be executed from Burn's base directory using the command:
diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md
index adce46b297..ac4b16dbce 100644
--- a/burn-book/src/basic-workflow/model.md
+++ b/burn-book/src/basic-workflow/model.md
@@ -20,7 +20,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
-burn = { version = "~0.16", features = ["train", "wgpu", "vision"] }
+burn = { version = "~0.17", features = ["train", "wgpu", "vision"] }
```
Our goal will be to create a basic convolutional neural network used for image classification. We
diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md
index e029aca708..e5dd4eaae9 100644
--- a/burn-book/src/building-blocks/metric.md
+++ b/burn-book/src/building-blocks/metric.md
@@ -4,11 +4,12 @@ When working with the learner, you have the option to record metrics that will b
throughout the training process. We currently offer a restricted range of metrics.
| Metric | Description |
-|------------------|---------------------------------------------------------|
+| ---------------- | ------------------------------------------------------- |
| Accuracy | Calculate the accuracy in percentage |
| TopKAccuracy | Calculate the top-k accuracy in percentage |
| Precision | Calculate precision in percentage |
| Recall | Calculate recall in percentage |
+| FBetaScore | Calculate Fβ score in percentage |
| AUROC | Calculate the area under curve of ROC in percentage |
| Loss | Output the loss used for the backward pass |
| CPU Temperature | Fetch the temperature of CPUs |
diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md
index 0f5aca7f24..9598d6e39e 100644
--- a/burn-book/src/building-blocks/module.md
+++ b/burn-book/src/building-blocks/module.md
@@ -294,3 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules.
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |
+| `PoissonNllLoss` | `nn.PoissonNLLLoss` |
diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md
index fc4833c3c4..c12bb82c00 100644
--- a/burn-book/src/building-blocks/tensor.md
+++ b/burn-book/src/building-blocks/tensor.md
@@ -131,47 +131,47 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
-| Burn | PyTorch Equivalent |
-| ------------------------------------- | ------------------------------------------------------------------------- |
-| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
-| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
-| `Tensor::from_primitive(primitive)` | N/A |
-| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
-| `tensor.all()` | `tensor.all()` |
-| `tensor.all_dim(dim)` | `tensor.all(dim)` |
-| `tensor.any()` | `tensor.any()` |
-| `tensor.any_dim(dim)` | `tensor.any(dim)` |
-| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
-| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
-| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
-| `tensor.device()` | `tensor.device` |
-| `tensor.dtype()` | `tensor.dtype` |
-| `tensor.dims()` | `tensor.size()` |
-| `tensor.equal(other)` | `x == y` |
-| `tensor.expand(shape)` | `tensor.expand(shape)` |
-| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
-| `tensor.flip(axes)` | `tensor.flip(axes)` |
-| `tensor.into_data()` | N/A |
-| `tensor.into_primitive()` | N/A |
-| `tensor.into_scalar()` | `tensor.item()` |
-| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
-| `tensor.not_equal(other)` | `x != y` |
-| `tensor.permute(axes)` | `tensor.permute(axes)` |
-| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
-| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
-| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
-| `tensor.reshape(shape)` | `tensor.view(shape)` |
-| `tensor.shape()` | `tensor.shape` |
-| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
-| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
-| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
-| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
-| `tensor.to_data()` | N/A |
-| `tensor.to_device(device)` | `tensor.to(device)` |
-| `tensor.transpose()` | `tensor.T` |
-| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
-| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
-| `tensor.unsqueeze_dims(dims)` | N/A |
+| Burn | PyTorch Equivalent |
+| ------------------------------------------- | ------------------------------------------------------------------------- |
+| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
+| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
+| `Tensor::from_primitive(primitive)` | N/A |
+| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
+| `tensor.all()` | `tensor.all()` |
+| `tensor.all_dim(dim)` | `tensor.all(dim)` |
+| `tensor.any()` | `tensor.any()` |
+| `tensor.any_dim(dim)` | `tensor.any(dim)` |
+| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
+| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
+| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
+| `tensor.device()` | `tensor.device` |
+| `tensor.dtype()` | `tensor.dtype` |
+| `tensor.dims()` | `tensor.size()` |
+| `tensor.equal(other)` | `x == y` |
+| `tensor.expand(shape)` | `tensor.expand(shape)` |
+| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
+| `tensor.flip(axes)` | `tensor.flip(axes)` |
+| `tensor.into_data()` | N/A |
+| `tensor.into_primitive()` | N/A |
+| `tensor.into_scalar()` | `tensor.item()` |
+| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
+| `tensor.not_equal(other)` | `x != y` |
+| `tensor.permute(axes)` | `tensor.permute(axes)` |
+| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
+| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
+| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
+| `tensor.reshape(shape)` | `tensor.view(shape)` |
+| `tensor.shape()` | `tensor.shape` |
+| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
+| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
+| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
+| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
+| `tensor.to_data()` | N/A |
+| `tensor.to_device(device)` | `tensor.to(device)` |
+| `tensor.transpose()` | `tensor.T` |
+| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
+| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
+| `tensor.unsqueeze_dims(dims)` | N/A |
### Numeric Operations
@@ -229,6 +229,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.neg()` or `-tensor` | `-tensor` |
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.ones_like()` | `torch.ones_like(tensor)` |
+| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` |
+| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A |
| `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
@@ -257,33 +259,32 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
Those operations are only available for `Float` tensors.
-| Burn API | PyTorch Equivalent |
-| --------------------------------------------- | ---------------------------------- |
-| `Tensor::one_hot(index, num_classes, device)` | N/A |
-| `tensor.cast(dtype)` | `tensor.to(dtype)` |
-| `tensor.ceil()` | `tensor.ceil()` |
-| `tensor.cos()` | `tensor.cos()` |
-| `tensor.erf()` | `tensor.erf()` |
-| `tensor.exp()` | `tensor.exp()` |
-| `tensor.floor()` | `tensor.floor()` |
-| `tensor.from_floats(floats, device)` | N/A |
-| `tensor.from_full_precision(tensor)` | N/A |
-| `tensor.int()` | Similar to `tensor.to(torch.long)` |
-| `tensor.log()` | `tensor.log()` |
-| `tensor.log1p()` | `tensor.log1p()` |
-| `tensor.matmul(other)` | `tensor.matmul(other)` |
-| `tensor.random(shape, distribution, device)` | N/A |
-| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
-| `tensor.recip()` | `tensor.reciprocal()` |
-| `tensor.round()` | `tensor.round()` |
-| `tensor.sin()` | `tensor.sin()` |
-| `tensor.sqrt()` | `tensor.sqrt()` |
-| `tensor.tanh()` | `tensor.tanh()` |
-| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
-| `tensor.var(dim)` | `tensor.var(dim)` |
-| `tensor.var_bias(dim)` | N/A |
-| `tensor.var_mean(dim)` | N/A |
-| `tensor.var_mean_bias(dim)` | N/A |
+| Burn API | PyTorch Equivalent |
+| -------------------------------------------- | ---------------------------------- |
+| `tensor.cast(dtype)` | `tensor.to(dtype)` |
+| `tensor.ceil()` | `tensor.ceil()` |
+| `tensor.cos()` | `tensor.cos()` |
+| `tensor.erf()` | `tensor.erf()` |
+| `tensor.exp()` | `tensor.exp()` |
+| `tensor.floor()` | `tensor.floor()` |
+| `tensor.from_floats(floats, device)` | N/A |
+| `tensor.from_full_precision(tensor)` | N/A |
+| `tensor.int()` | Similar to `tensor.to(torch.long)` |
+| `tensor.log()` | `tensor.log()` |
+| `tensor.log1p()` | `tensor.log1p()` |
+| `tensor.matmul(other)` | `tensor.matmul(other)` |
+| `tensor.random(shape, distribution, device)` | N/A |
+| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
+| `tensor.recip()` | `tensor.reciprocal()` |
+| `tensor.round()` | `tensor.round()` |
+| `tensor.sin()` | `tensor.sin()` |
+| `tensor.sqrt()` | `tensor.sqrt()` |
+| `tensor.tanh()` | `tensor.tanh()` |
+| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
+| `tensor.var(dim)` | `tensor.var(dim)` |
+| `tensor.var_bias(dim)` | N/A |
+| `tensor.var_mean(dim)` | N/A |
+| `tensor.var_mean_bias(dim)` | N/A |
### Int Operations
@@ -293,11 +294,21 @@ Those operations are only available for `Int` tensors.
| ------------------------------------------------ | ------------------------------------------------------- |
| `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` |
| `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
+| `tensor.bitwise_and(other)` | `torch.bitwise_and(tensor, other)` |
+| `tensor.bitwise_and_scalar(scalar)` | `torch.bitwise_and(tensor, scalar)` |
+| `tensor.bitwise_not()` | `torch.bitwise_not(tensor)` |
+| `tensor.bitwise_left_shift(other)` | `torch.bitwise_left_shift(tensor, other)` |
+| `tensor.bitwise_left_shift_scalar(scalar)` | `torch.bitwise_left_shift(tensor, scalar)` |
+| `tensor.bitwise_right_shift(other)` | `torch.bitwise_right_shift(tensor, other)` |
+| `tensor.bitwise_right_shift_scalar(scalar)` | `torch.bitwise_right_shift(tensor, scalar)` |
+| `tensor.bitwise_or(other)` | `torch.bitwise_or(tensor, other)` |
+| `tensor.bitwise_or_scalar(scalar)` | `torch.bitwise_or(tensor, scalar)` |
+| `tensor.bitwise_xor(other)` | `torch.bitwise_xor(tensor, other)` |
+| `tensor.bitwise_xor_scalar(scalar)` | `torch.bitwise_xor(tensor, scalar)` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
| `tensor.cartesian_grid(shape, device)` | N/A |
-| `tensor.one_hot(num_classes)` | N/A |
### Bool Operations
@@ -329,7 +340,7 @@ strategies.
| Burn API | PyTorch Equivalent |
| ------------------------------------------------ | -------------------------------------------------- |
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
-| `activation::hard_sigmoid(tensor, alpha, beta) | `nn.functional.hardsigmoid(tensor)` |
+| `activation::hard_sigmoid(tensor, alpha, beta)` | `nn.functional.hardsigmoid(tensor)` |
| `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` |
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
diff --git a/burn-book/src/examples.md b/burn-book/src/examples.md
index c9703a4389..2b083b6fbe 100644
--- a/burn-book/src/examples.md
+++ b/burn-book/src/examples.md
@@ -85,6 +85,7 @@ The following additional examples are currently available if you want to check t
| [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/pytorch-import) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. |
| [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. |
| [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. |
+| [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan) | Trains a WGAN model to generate new handwritten digits based on MNIST. |
For more information on each example, see their respective `README.md` file. Be sure to check out
the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date
diff --git a/burn-book/src/import/onnx-model.md b/burn-book/src/import/onnx-model.md
index 05b9d5de81..9b3b7917fd 100644
--- a/burn-book/src/import/onnx-model.md
+++ b/burn-book/src/import/onnx-model.md
@@ -74,7 +74,7 @@ First, add the `burn-import` crate to your `Cargo.toml`:
```toml
[build-dependencies]
-burn-import = "~0.16"
+burn-import = "~0.17"
```
Then, in your `build.rs` file:
diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md
index 5c17eee3e9..1f584cdc9f 100644
--- a/burn-book/src/import/pytorch-model.md
+++ b/burn-book/src/import/pytorch-model.md
@@ -162,17 +162,13 @@ struct NetConfig {
n_head: usize,
n_layer: usize,
d_model: usize,
- // Candle's pickle has a bug with float serialization
- // https://github.com/huggingface/candle/issues/1729
- // some_float: f64,
+ some_float: f64,
some_int: i32,
some_bool: bool,
some_str: String,
some_list_int: Vec,
some_list_str: Vec,
- // Candle's pickle has a bug with float serialization
- // https://github.com/huggingface/candle/issues/1729
- // some_list_float: Vec,
+ some_list_float: Vec,
some_dict: HashMap,
}
diff --git a/burn-book/src/saving-and-loading.md b/burn-book/src/saving-and-loading.md
index 13a96cc94d..24b52dd22a 100644
--- a/burn-book/src/saving-and-loading.md
+++ b/burn-book/src/saving-and-loading.md
@@ -4,7 +4,7 @@ Saving your trained machine learning model is quite easy, no matter the output f
mentioned in the [Record](./building-blocks/record.md) section, different formats are supported to
serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` which uses the
[MessagePack](https://msgpack.org/) binary serialization format with the help of
-[smp_serde](https://docs.rs/rmp-serde/).
+[rmp_serde](https://docs.rs/rmp-serde/).
```rust, ignore
// Save model in MessagePack format with full precision
@@ -22,7 +22,7 @@ Now that you have a trained model saved to your disk, you can easily load it in
```rust, ignore
// Load model in full precision from MessagePack file
let recorder = NamedMpkFileRecorder::::new();
-model
+model = model
.load_file(model_path, &recorder, device)
.expect("Should be able to load the model weights from the provided file");
```
diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml
index 2144d46885..df7040f835 100644
--- a/crates/burn-autodiff/Cargo.toml
+++ b/crates/burn-autodiff/Cargo.toml
@@ -18,16 +18,16 @@ std = []
async = [] # Require std
[dependencies]
-burn-common = { path = "../burn-common", version = "0.16.0" }
-burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false }
-burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true }
+burn-common = { path = "../burn-common", version = "0.17.0", default-features = false }
+burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false }
+burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true }
derive-new = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
[dev-dependencies]
-burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [
+burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [
"export_tests",
] }
diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs
index ffcc522051..8203d16212 100644
--- a/crates/burn-autodiff/src/ops/int_tensor.rs
+++ b/crates/burn-autodiff/src/ops/int_tensor.rs
@@ -352,4 +352,48 @@ impl IntTensorOps for Autodiff {
fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor {
B::int_argsort(tensor, dim, descending)
}
+
+ fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ B::bitwise_and(lhs, rhs)
+ }
+
+ fn bitwise_and_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor {
+ B::bitwise_and_scalar(lhs, rhs)
+ }
+
+ fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ B::bitwise_or(lhs, rhs)
+ }
+
+ fn bitwise_or_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor {
+ B::bitwise_or_scalar(lhs, rhs)
+ }
+
+ fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ B::bitwise_xor(lhs, rhs)
+ }
+
+ fn bitwise_xor_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor {
+ B::bitwise_xor_scalar(lhs, rhs)
+ }
+
+ fn bitwise_not(tensor: IntTensor) -> IntTensor {
+ B::bitwise_not(tensor)
+ }
+
+ fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ B::bitwise_left_shift(lhs, rhs)
+ }
+
+ fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor {
+ B::bitwise_left_shift_scalar(lhs, rhs)
+ }
+
+ fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ B::bitwise_right_shift(lhs, rhs)
+ }
+
+ fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor {
+ B::bitwise_right_shift_scalar(lhs, rhs)
+ }
}
diff --git a/crates/burn-candle/Cargo.toml b/crates/burn-candle/Cargo.toml
index 62af31d5fb..65fbf416ca 100644
--- a/crates/burn-candle/Cargo.toml
+++ b/crates/burn-candle/Cargo.toml
@@ -21,17 +21,17 @@ accelerate = ["candle-core/accelerate"]
[dependencies]
derive-new = { workspace = true }
-burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false }
+burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false }
half = { workspace = true }
candle-core = { workspace = true }
[dev-dependencies]
-burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [
+burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [
"export_tests",
] }
-burn-tch = { path = "../burn-tch", version = "0.16.0", default-features = false, features = [
+burn-tch = { path = "../burn-tch", version = "0.17.0", default-features = false, features = [
] }
-burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [
+burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [
"export_tests",
] }
diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs
index 6e3586506b..67328c36f3 100644
--- a/crates/burn-candle/src/ops/int_tensor.rs
+++ b/crates/burn-candle/src/ops/int_tensor.rs
@@ -373,6 +373,50 @@ impl IntTensorOps for Candle, rhs: IntTensor) -> IntTensor {
+ unimplemented!("bitwise_and is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor {
+ unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ unimplemented!("bitwise_or is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor {
+ unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ unimplemented!("bitwise_xor is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor {
+ unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_not(tensor: IntTensor) -> IntTensor {
+ unimplemented!("bitwise_not is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor {
+ unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor {
+ unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor");
+ }
+
+ fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor {
+ unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor");
+ }
+
fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor {
super::base::cumsum(tensor, dim)
}
diff --git a/crates/burn-common/src/lib.rs b/crates/burn-common/src/lib.rs
index 77faa0195d..efe3b6d7d2 100644
--- a/crates/burn-common/src/lib.rs
+++ b/crates/burn-common/src/lib.rs
@@ -11,6 +11,9 @@ pub mod id;
pub use cubecl_common::*;
+#[cfg(feature = "rayon")]
+pub use rayon;
+
extern crate alloc;
/// Network utilities.
diff --git a/crates/burn-common/src/parallel.rs b/crates/burn-common/src/parallel.rs
index 93c0da7d2d..969683f9e7 100644
--- a/crates/burn-common/src/parallel.rs
+++ b/crates/burn-common/src/parallel.rs
@@ -1,51 +1,90 @@
/// Macro for running a function in parallel.
+#[cfg(feature = "rayon")]
#[macro_export(local_inner_macros)]
macro_rules! run_par {
(
$func:expr
) => {{
- #[cfg(feature = "rayon")]
- use rayon::prelude::*;
+ use $crate::rayon::prelude::*;
- #[cfg(feature = "rayon")]
#[allow(clippy::redundant_closure_call)]
- let output = rayon::scope(|_| $func());
+ $crate::rayon::scope(|_| $func())
+ }};
+}
- #[cfg(not(feature = "rayon"))]
- let output = $func();
+/// Macro for running a function in parallel.
+#[cfg(not(feature = "rayon"))]
+#[macro_export(local_inner_macros)]
+macro_rules! run_par {
+ (
+ $func:expr
+ ) => {{
+ $func()
+ }};
+}
- output
+/// Macro for iterating in parallel.
+#[cfg(not(feature = "rayon"))]
+#[macro_export(local_inner_macros)]
+macro_rules! iter_par {
+ (
+ $iter:expr
+ ) => {{
+ $iter
}};
}
/// Macro for iterating in parallel.
+#[cfg(feature = "rayon")]
#[macro_export(local_inner_macros)]
macro_rules! iter_par {
(
$iter:expr
) => {{
- #[cfg(feature = "rayon")]
- let output = $iter.into_par_iter();
+ $iter.into_par_iter()
+ }};
+}
- #[cfg(not(feature = "rayon"))]
- let output = $iter;
+/// Macro for iterating in parallel.
+#[cfg(feature = "rayon")]
+#[macro_export(local_inner_macros)]
+macro_rules! iter_slice_par {
+ (
+ $slice:expr
+ ) => {{
+ $slice.into_par_iter()
+ }};
+}
- output
+/// Macro for iterating in parallel.
+#[cfg(not(feature = "rayon"))]
+#[macro_export(local_inner_macros)]
+macro_rules! iter_slice_par {
+ (
+ $slice:expr
+ ) => {{
+ $slice.iter()
}};
}
/// Macro for iterating over a range in parallel.
+#[cfg(feature = "rayon")]
#[macro_export(local_inner_macros)]
macro_rules! iter_range_par {
(
$start:expr, $end:expr
) => {{
- #[cfg(feature = "rayon")]
- let output = ($start..$end).into_par_iter();
-
- #[cfg(not(feature = "rayon"))]
- let output = ($start..$end);
+ ($start..$end).into_par_iter()
+ }};
+}
- output
+/// Macro for iterating over a range in parallel.
+#[cfg(not(feature = "rayon"))]
+#[macro_export(local_inner_macros)]
+macro_rules! iter_range_par {
+ (
+ $start:expr, $end:expr
+ ) => {{
+ ($start..$end)
}};
}
diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml
index e63af0fba5..423dc784d8 100644
--- a/crates/burn-core/Cargo.toml
+++ b/crates/burn-core/Cargo.toml
@@ -36,8 +36,8 @@ doc = [
"ndarray",
"tch",
"wgpu",
- "cuda-jit",
- "hip-jit",
+ "cuda",
+ "hip",
"audio",
"vision",
"autodiff",
@@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"]
## Backend features
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
-autotune = ["burn-wgpu?/autotune"]
+autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"]
blas-netlib = ["burn-ndarray?/blas-netlib"]
metal = ["burn-candle?/metal"]
openblas = ["burn-ndarray?/blas-openblas"]
@@ -100,12 +100,13 @@ template = ["burn-wgpu?/template"]
candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
-cuda-jit = ["burn-cuda"]
-hip-jit = ["burn-hip"]
+cuda = ["burn-cuda"]
+hip = ["burn-hip"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
-wgpu-spirv = ["wgpu", "burn-wgpu/spirv"]
+vulkan = ["wgpu", "burn-wgpu/vulkan"]
+webgpu = ["wgpu", "burn-wgpu/webgpu"]
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]
@@ -113,37 +114,34 @@ record-item-custom-serde = ["thiserror", "regex"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
-# Backwards compatibility with previous serialized data format.
-record-backward-compat = []
-
-test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
-test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray.
+test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray.
+test-hip = ["hip"] # To use hip during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-wgpu-spirv = [
"test-wgpu",
- "wgpu-spirv",
+ "vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
-burn-common = { path = "../burn-common", version = "0.16.0", default-features = false }
-burn-dataset = { path = "../burn-dataset", version = "0.16.0", optional = true, default-features = false }
-burn-derive = { path = "../burn-derive", version = "0.16.0" }
-burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false }
+burn-common = { path = "../burn-common", version = "0.17.0", default-features = false }
+burn-dataset = { path = "../burn-dataset", version = "0.17.0", optional = true, default-features = false }
+burn-derive = { path = "../burn-derive", version = "0.17.0" }
+burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false }
# Backends
-burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", optional = true }
-burn-candle = { path = "../burn-candle", version = "0.16.0", optional = true }
-burn-cuda = { path = "../burn-cuda", version = "0.16.0", optional = true, default-features = false }
-burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-features = false }
-burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false }
-burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true }
-burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true }
-burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
-burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
+burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true }
+burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true }
+burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false }
+burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false }
+burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true, default-features = false }
+burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true }
+burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true }
+burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true }
+burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false }
data-encoding = { workspace = true }
uuid = { workspace = true }
@@ -173,13 +171,13 @@ thiserror = { workspace = true, optional = true }
portable-atomic-util = { workspace = true }
[dev-dependencies]
-burn-dataset = { path = "../burn-dataset", version = "0.16.0", features = [
+burn-dataset = { path = "../burn-dataset", version = "0.17.0", features = [
"fake",
] }
tempfile = { workspace = true }
-burn-autodiff = { path = "../burn-autodiff", version = "0.16.0" }
-burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false }
+burn-autodiff = { path = "../burn-autodiff", version = "0.17.0" }
+burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false }
[package.metadata.docs.rs]
features = ["doc"]
diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs
index bd4c959302..31ac3a8c41 100644
--- a/crates/burn-core/src/backend.rs
+++ b/crates/burn-core/src/backend.rs
@@ -21,11 +21,17 @@ pub use burn_wgpu as wgpu;
#[cfg(feature = "wgpu")]
pub use burn_wgpu::Wgpu;
-#[cfg(feature = "cuda-jit")]
-pub use burn_cuda as cuda_jit;
+#[cfg(feature = "webgpu")]
+pub use burn_wgpu::WebGpu;
-#[cfg(feature = "cuda-jit")]
-pub use burn_cuda::Cuda as CudaJit;
+#[cfg(feature = "vulkan")]
+pub use burn_wgpu::Vulkan;
+
+#[cfg(feature = "cuda")]
+pub use burn_cuda as cuda;
+
+#[cfg(feature = "cuda")]
+pub use burn_cuda::Cuda;
#[cfg(feature = "candle")]
pub use burn_candle as candle;
@@ -33,11 +39,11 @@ pub use burn_candle as candle;
#[cfg(feature = "candle")]
pub use burn_candle::Candle;
-#[cfg(feature = "hip-jit")]
-pub use burn_hip as hip_jit;
+#[cfg(feature = "hip")]
+pub use burn_hip as hip;
-#[cfg(feature = "hip-jit")]
-pub use burn_hip::Hip as HipJit;
+#[cfg(feature = "hip")]
+pub use burn_hip::Hip;
#[cfg(feature = "tch")]
pub use burn_tch as libtorch;
diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs
index f554518430..ade8d64db7 100644
--- a/crates/burn-core/src/lib.rs
+++ b/crates/burn-core/src/lib.rs
@@ -1,6 +1,7 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
+#![recursion_limit = "135"]
//! The core crate of Burn.
diff --git a/crates/burn-core/src/nn/conv/checks.rs b/crates/burn-core/src/nn/conv/checks.rs
index cd346163ad..36932621f1 100644
--- a/crates/burn-core/src/nn/conv/checks.rs
+++ b/crates/burn-core/src/nn/conv/checks.rs
@@ -9,3 +9,14 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
);
}
}
+
+// https://github.com/tracel-ai/burn/issues/2676
+/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+/// size is not supported as it will not produce the same output size.
+pub(crate) fn check_same_padding_support(kernel_size: &[usize]) {
+ for k in kernel_size.iter() {
+ if k % 2 == 0 {
+ unimplemented!("Same padding with an even kernel size is not supported");
+ }
+ }
+}
diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs
index 0b64eab324..c3f61a6b07 100644
--- a/crates/burn-core/src/nn/conv/conv1d.rs
+++ b/crates/burn-core/src/nn/conv/conv1d.rs
@@ -28,6 +28,10 @@ pub struct Conv1dConfig {
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
+ ///
+ /// ### Warning
+ /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+ /// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
/// If bias should be added to the output.
@@ -87,6 +91,9 @@ impl Conv1dConfig {
/// Initialize a new [conv1d](Conv1d) module.
pub fn init(&self, device: &B::Device) -> Conv1d {
checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
+ if self.padding == PaddingConfig1d::Same {
+ checks::check_same_padding_support(&[self.kernel_size]);
+ }
let shape = [
self.channels_out,
@@ -175,6 +182,14 @@ mod tests {
.assert_approx_eq(&TensorData::zeros::(conv.weight.shape()), 3);
}
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let device = Default::default();
+ let config = Conv1dConfig::new(5, 5, 4).with_padding(PaddingConfig1d::Same);
+ let _ = config.init::(&device);
+ }
+
#[test]
fn display() {
let config = Conv1dConfig::new(5, 5, 5);
diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs
index 73be36d357..72c00187be 100644
--- a/crates/burn-core/src/nn/conv/conv2d.rs
+++ b/crates/burn-core/src/nn/conv/conv2d.rs
@@ -30,6 +30,10 @@ pub struct Conv2dConfig {
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
+ ///
+ /// ### Warning
+ /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+ /// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// If bias should be added to the output.
@@ -68,6 +72,9 @@ impl Conv2dConfig {
/// Initialize a new [conv2d](Conv2d) module.
pub fn init(&self, device: &B::Device) -> Conv2d {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
+ if self.padding == PaddingConfig2d::Same {
+ checks::check_same_padding_support(&self.kernel_size);
+ }
let shape = [
self.channels[1],
@@ -228,6 +235,14 @@ mod tests {
let _ = config.init::(&device);
}
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let device = Default::default();
+ let config = Conv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same);
+ let _ = config.init::(&device);
+ }
+
#[test]
fn display() {
let config = Conv2dConfig::new([5, 1], [5, 5]);
diff --git a/crates/burn-core/src/nn/conv/conv3d.rs b/crates/burn-core/src/nn/conv/conv3d.rs
index de7fb1ce2b..0b5d530c5a 100644
--- a/crates/burn-core/src/nn/conv/conv3d.rs
+++ b/crates/burn-core/src/nn/conv/conv3d.rs
@@ -68,6 +68,9 @@ impl Conv3dConfig {
/// Initialize a new [conv3d](Conv3d) module.
pub fn init(&self, device: &B::Device) -> Conv3d {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
+ if self.padding == PaddingConfig3d::Same {
+ checks::check_same_padding_support(&self.kernel_size);
+ }
let shape = [
self.channels[1],
@@ -228,6 +231,14 @@ mod tests {
assert_eq!(config.initializer, init);
}
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let device = Default::default();
+ let config = Conv3dConfig::new([4, 4], [2, 2, 2]).with_padding(PaddingConfig3d::Same);
+ let _ = config.init::(&device);
+ }
+
#[test]
fn display() {
let config = Conv3dConfig::new([5, 1], [5, 5, 5]);
diff --git a/crates/burn-core/src/nn/conv/deform_conv2d.rs b/crates/burn-core/src/nn/conv/deform_conv2d.rs
index 03becd9d4e..2baff11d07 100644
--- a/crates/burn-core/src/nn/conv/deform_conv2d.rs
+++ b/crates/burn-core/src/nn/conv/deform_conv2d.rs
@@ -33,6 +33,10 @@ pub struct DeformConv2dConfig {
#[config(default = "1")]
pub offset_groups: usize,
/// The padding configuration.
+ ///
+ /// ### Warning
+ /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+ /// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// If bias should be added to the output.
@@ -73,6 +77,9 @@ impl DeformConv2dConfig {
/// Initialize a new [DeformConv2d](DeformConv2d) module.
pub fn init(&self, device: &B::Device) -> DeformConv2d {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups);
+ if self.padding == PaddingConfig2d::Same {
+ checks::check_same_padding_support(&self.kernel_size);
+ }
let shape = [
self.channels[1],
@@ -250,6 +257,14 @@ mod tests {
let _ = config.init::(&device);
}
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let device = Default::default();
+ let config = DeformConv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same);
+ let _ = config.init::(&device);
+ }
+
#[test]
fn display() {
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs
index d03e95c1f3..79fc12ecbf 100644
--- a/crates/burn-core/src/nn/dropout.rs
+++ b/crates/burn-core/src/nn/dropout.rs
@@ -30,6 +30,12 @@ pub struct Dropout {
impl DropoutConfig {
/// Initialize a new [dropout](Dropout) module.
pub fn init(&self) -> Dropout {
+ if self.prob < 0.0 || self.prob > 1.0 {
+ panic!(
+ "Dropout probability should be between 0 and 1, but got {}",
+ self.prob
+ );
+ }
Dropout { prob: self.prob }
}
}
@@ -108,4 +114,11 @@ mod tests {
assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}");
}
+
+ #[test]
+ #[should_panic = "Dropout probability should be between 0 and 1,"]
+ fn dropout_prob_invalid() {
+ let config = DropoutConfig::new(-10.);
+ let _layer = config.init();
+ }
}
diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs
index f645c84fd9..54b80f4f60 100644
--- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs
+++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs
@@ -118,9 +118,9 @@ impl BinaryCrossEntropyLoss {
(targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
} else {
// - (target * log(input) + (1 - target) * log(1 - input))
- (targets_float.clone() * logits.clone().log()
- + (targets_float.neg() + 1.) * (logits.neg() + 1.).log())
- .neg()
+ // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
+ (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
+ - targets_float * logits.log().clamp_min(-100.0)
};
if let Some(weights) = &self.weights {
@@ -171,6 +171,38 @@ mod tests {
use crate::tensor::{activation::sigmoid, TensorData};
use crate::TestBackend;
+ #[test]
+ fn test_binary_cross_entropy_preds_all_correct() {
+ let device = Default::default();
+ let preds = Tensor::::from_floats([1.0, 0.0, 1.0, 0.0], &device);
+ let targets =
+ Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device);
+
+ let loss_actual = BinaryCrossEntropyLossConfig::new()
+ .init(&device)
+ .forward(preds, targets)
+ .into_data();
+
+ let loss_expected = TensorData::from([0.000]);
+ loss_actual.assert_approx_eq(&loss_expected, 3);
+ }
+
+ #[test]
+ fn test_binary_cross_entropy_preds_all_incorrect() {
+ let device = Default::default();
+ let preds = Tensor::::from_floats([0.0, 1.0, 0.0, 1.0], &device);
+ let targets =
+ Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device);
+
+ let loss_actual = BinaryCrossEntropyLossConfig::new()
+ .init(&device)
+ .forward(preds, targets)
+ .into_data();
+
+ let loss_expected = TensorData::from([100.000]); // clamped value
+ loss_actual.assert_approx_eq(&loss_expected, 3);
+ }
+
#[test]
fn test_binary_cross_entropy() {
// import torch
diff --git a/crates/burn-core/src/nn/loss/mod.rs b/crates/burn-core/src/nn/loss/mod.rs
index cca7b4541b..475364e63b 100644
--- a/crates/burn-core/src/nn/loss/mod.rs
+++ b/crates/burn-core/src/nn/loss/mod.rs
@@ -2,10 +2,12 @@ mod binary_cross_entropy;
mod cross_entropy;
mod huber;
mod mse;
+mod poisson;
mod reduction;
pub use binary_cross_entropy::*;
pub use cross_entropy::*;
pub use huber::*;
pub use mse::*;
+pub use poisson::*;
pub use reduction::*;
diff --git a/crates/burn-core/src/nn/loss/poisson.rs b/crates/burn-core/src/nn/loss/poisson.rs
new file mode 100644
index 0000000000..3cc989ad8e
--- /dev/null
+++ b/crates/burn-core/src/nn/loss/poisson.rs
@@ -0,0 +1,390 @@
+use core::f32::consts::PI;
+
+use crate as burn;
+use crate::module::{Content, DisplaySettings, ModuleDisplay};
+use crate::tensor::backend::Backend;
+use crate::tensor::Tensor;
+use crate::{config::Config, module::Module};
+
+use super::Reduction;
+
+/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance.
+///
+/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss
+/// behavior, such as whether the input is in log-space, whether to include the Stirling
+/// approximation term, and a small epsilon value to avoid numerical instability.
+#[derive(Config, Debug)]
+pub struct PoissonNllLossConfig {
+ /// If `true`, the predictions are expected to be in log-space.
+ ///
+ /// When `log_input` is `true`, the loss is computed as:
+ /// ```text
+ /// L(predictions, target) = exp(predictions) - target * predictions
+ /// ```
+ /// When `log_input` is `false`, the loss is computed as:
+ /// ```text
+ /// L(predictions, target) = predictions - target * log(predictions + eps)
+ /// ```
+ #[config(default = true)]
+ pub log_input: bool,
+ /// Whether to compute the full loss, including the Stirling approximation term.
+ ///
+ /// When `full` is `true`, the Stirling approximation term is added to the loss:
+ /// ```text
+ /// target * log(target) - target + 0.5 * log(2 * PI * target)
+ /// ```
+ #[config(default = false)]
+ pub full: bool,
+ /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
+ ///
+ /// This epsilon value is added to the predictions to ensure numerical stability
+ /// when computing the logarithm.
+ #[config(default = 1e-8)]
+ pub eps: f64,
+}
+
+impl PoissonNllLossConfig {
+ /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration.
+ ///
+ /// # Panics
+ /// - Panics if `eps` is not a positive number.
+ pub fn init(&self) -> PoissonNllLoss {
+ self.assertions();
+ PoissonNllLoss {
+ log_input: self.log_input,
+ full: self.full,
+ eps: self.eps,
+ }
+ }
+
+ /// Validates the configuration parameters.
+ ///
+ /// # Panics
+ /// - Panics if `eps` is not a positive number.
+ fn assertions(&self) {
+ assert!(
+ self.eps > 0.,
+ "eps for PoissonNllLoss must be a positive number."
+ );
+ }
+}
+
+/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target.
+///
+/// This loss function is used when the target values are assumed to follow a Poisson distribution.
+/// The loss is defined as:
+/// ```text
+/// target ~ Poisson(input)
+/// L(predictions, target) = predictions - target * log(predictions) + log(target!)
+/// ```
+/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula.
+/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss.
+///
+/// For more details, see:
+///
+#[derive(Module, Debug, Clone)]
+#[module(custom_display)]
+pub struct PoissonNllLoss {
+ /// If `true`, the predictions are expected to be in log-space.
+ pub log_input: bool,
+ /// Whether to compute the full loss, including the Stirling approximation term.
+ pub full: bool,
+ /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.
+ pub eps: f64,
+}
+
+impl ModuleDisplay for PoissonNllLoss {
+ fn custom_settings(&self) -> Option {
+ DisplaySettings::new()
+ .with_new_line_after_attribute(false)
+ .optional()
+ }
+
+ fn custom_content(&self, content: Content) -> Option {
+ content
+ .add("log_input", &self.log_input)
+ .add("full", &self.full)
+ .add("eps", &self.eps)
+ .optional()
+ }
+}
+
+impl PoissonNllLoss {
+ /// Computes the loss element-wise for the given predictions and targets, then reduces
+ /// the result to a single loss value.
+ ///
+ /// # Arguments
+ /// - `predictions`: The predicted values.
+ /// - `targets`: The target values.
+ /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`.
+ ///
+ /// # Shapes
+ /// - `predictions`: `[...dims]`
+ /// - `targets`: `[...dims]`
+ /// - `output`: `[1]`
+ ///
+ /// # Panics
+ /// - Panics if the shapes of `predictions` and `targets` do not match.
+ /// - Panics if any target value is negative.
+ /// - Panics if `log_input` is `false` and any prediction value is negative.
+ pub fn forward(
+ &self,
+ predictions: Tensor,
+ targets: Tensor,
+ reduction: Reduction,
+ ) -> Tensor {
+ let loss = self.forward_no_reduction(predictions, targets);
+ match reduction {
+ Reduction::Mean | Reduction::Auto => loss.mean(),
+ Reduction::Sum => loss.sum(),
+ }
+ }
+
+ /// Computes the loss element-wise for the given predictions and targets without reduction.
+ ///
+ /// # Arguments
+ /// - `predictions`: The predicted values.
+ /// - `targets`: The target values.
+ ///
+ /// # Shapes
+ /// - `predictions`: `[...dims]`
+ /// - `targets`: `[...dims]`
+ /// - `output`: `[...dims]`
+ ///
+ /// # Panics
+ /// - Panics if the shapes of `predictions` and `targets` do not match.
+ /// - Panics if any target value is negative.
+ /// - Panics if `log_input` is `false` and any prediction value is negative.
+ pub fn forward_no_reduction(
+ &self,
+ predictions: Tensor,
+ targets: Tensor,
+ ) -> Tensor {
+ self.assertions(&predictions, &targets);
+ let mut loss;
+ if self.log_input {
+ loss = predictions.clone().exp() - targets.clone() * predictions;
+ } else {
+ loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();
+ }
+ if self.full {
+ let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()
+ + (targets.clone() * 2. * PI).log() * 0.5;
+ loss = loss
+ + log_stirling_term
+ .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());
+ }
+ loss
+ }
+
+ /// Validates the input tensors for the loss computation.
+ ///
+ /// # Panics
+ /// - Panics if the shapes of `predictions` and `targets` do not match.
+ /// - Panics if any target value is negative.
+ /// - Panics if `log_input` is `false` and any prediction value is negative.
+ fn assertions(
+ &self,
+ predictions: &Tensor,
+ targets: &Tensor,
+ ) {
+ let predictions_dims = predictions.dims();
+ let targets_dims = targets.dims();
+ assert!(
+ predictions_dims == targets_dims,
+ "Shape of targets ({:?}) should correspond to outer shape of predictions ({:?}).",
+ targets_dims,
+ predictions_dims
+ );
+ assert!(
+ targets.clone().greater_equal_elem(0.).all().into_scalar(),
+ "All the values of `targets` must be non-negative."
+ );
+ if !self.log_input {
+ assert!(
+ predictions.clone().greater_equal_elem(0.).all().into_scalar(),
+ "When `log_input` is `false`, all the values of `predictions` must be non-negative."
+ );
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::tensor::TensorData;
+ use crate::TestBackend;
+ type TestTensor = Tensor;
+
+ #[test]
+ fn test_poisson_nll_loss() {
+ let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
+ let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
+
+ let device = Default::default();
+
+ let predictions = TestTensor::<1>::from_data(predictions, &device);
+ let targets = TestTensor::<1>::from_data(targets, &device);
+
+ let poisson = PoissonNllLossConfig::new().init();
+
+ let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
+ let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
+ let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
+
+ let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);
+ loss_no_reduction.into_data().assert_approx_eq(&expected, 5);
+
+ let expected = TensorData::from([21.0321]);
+ loss.into_data().assert_approx_eq(&expected, 5);
+
+ let expected = TensorData::from([126.1929]);
+ loss_sum.into_data().assert_approx_eq(&expected, 5);
+ }
+
+ #[test]
+ fn test_poisson_nll_loss_no_log_input() {
+ let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);
+ let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);
+
+ let device = Default::default();
+
+ let predictions = TestTensor::<1>::from_data(predictions, &device);
+ let targets = TestTensor::<1>::from_data(targets, &device);
+
+ let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
+
+ let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());
+
+ let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);
+ loss_no_reduction.into_data().assert_approx_eq(&expected, 5);
+ }
+
+ #[test]
+ fn test_poisson_nll_loss_full() {
+ let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
+ let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
+
+ let device = Default::default();
+
+ let predictions = TestTensor::<1>::from_data(predictions, &device);
+ let targets = TestTensor::<1>::from_data(targets, &device);
+
+ let poisson = PoissonNllLossConfig::new().with_full(true).init();
+
+ let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
+ let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
+ let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);
+
+ let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);
+ loss_no_reduction.into_data().assert_approx_eq(&expected, 5);
+
+ let expected = TensorData::from([21.9920]);
+ loss.into_data().assert_approx_eq(&expected, 5);
+
+ let expected = TensorData::from([131.9518]);
+ loss_sum.into_data().assert_approx_eq(&expected, 5);
+ }
+
+ #[cfg(feature = "std")]
+ #[test]
+ fn test_poisson_nll_loss_gradients() {
+ type TestAutodiffTensor = Tensor;
+
+ let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);
+ let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);
+
+ let device = Default::default();
+
+ let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();
+ let predictions2 = predictions1.clone();
+ let targets = TestAutodiffTensor::from_data(targets, &device);
+
+ let poisson = PoissonNllLossConfig::new().with_full(false).init();
+ let poisson_full = PoissonNllLossConfig::new().with_full(true).init();
+
+ let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
+ let loss_full_sum =
+ poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);
+
+ let grads = loss_sum.backward();
+ let grads_full = loss_full_sum.backward();
+
+ let grads_predictions1 = predictions1.grad(&grads).unwrap();
+ let grads_predictions2 = predictions2.grad(&grads_full).unwrap();
+
+ let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);
+
+ grads_predictions1
+ .into_data()
+ .assert_approx_eq(&expected, 5);
+ grads_predictions2
+ .into_data()
+ .assert_approx_eq(&expected, 5);
+ }
+
+ #[test]
+ #[should_panic = "eps for PoissonNllLoss must be a positive number."]
+ fn test_negative_eps() {
+ let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
+ }
+
+ #[test]
+ #[should_panic = "All the values of `targets` must be non-negative."]
+ fn test_targets_with_negative_values() {
+ let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);
+ let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);
+
+ let device = Default::default();
+
+ let predictions = TestTensor::<1>::from_data(predictions, &device);
+ let targets = TestTensor::<1>::from_data(targets, &device);
+
+ let poisson = PoissonNllLossConfig::new().init();
+
+ let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
+ }
+
+ #[test]
+ #[should_panic = "Shape of targets"]
+ fn test_shape_tensors() {
+ let predictions = TensorData::from([0., 1., 2.]);
+ let targets = TensorData::from([0., 1.]);
+
+ let device = Default::default();
+
+ let predictions = TestTensor::<1>::from_data(predictions, &device);
+ let targets = TestTensor::<1>::from_data(targets, &device);
+
+ let poisson = PoissonNllLossConfig::new().init();
+
+ let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
+ }
+
+ #[test]
+ #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."]
+ fn test_exp_predictions_non_negative() {
+ let predictions = TensorData::from([0.3, -0.1, 0.4]);
+ let targets = TensorData::from([0., 1., 0.]);
+
+ let device = Default::default();
+
+ let predictions = TestTensor::<1>::from_data(predictions, &device);
+ let targets = TestTensor::<1>::from_data(targets, &device);
+
+ let poisson = PoissonNllLossConfig::new().with_log_input(false).init();
+
+ let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
+ }
+
+ #[test]
+ fn display() {
+ let config = PoissonNllLossConfig::new();
+ let loss = config.init();
+
+ assert_eq!(
+ alloc::format!("{}", loss),
+ "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
+ );
+ }
+}
diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs
index 949160fd5b..24ec8ff972 100644
--- a/crates/burn-core/src/nn/pool/avg_pool1d.rs
+++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs
@@ -1,4 +1,5 @@
use crate as burn;
+use crate::nn::conv::checks::check_same_padding_support;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
@@ -18,6 +19,10 @@ pub struct AvgPool1dConfig {
#[config(default = "1")]
pub stride: usize,
/// The padding configuration.
+ ///
+ /// ### Warning
+ /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+ /// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
/// If the padding is counted in the denominator when computing the average.
@@ -36,10 +41,6 @@ pub struct AvgPool1dConfig {
/// legitimate values, and they contribute to the denominator
/// when calculating the average. This is equivalent to
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
-///
-/// TODO: Add support for `count_include_pad=False`, see
-/// [Issue 636](https://github.com/tracel-ai/burn/issues/636)
-
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AvgPool1d {
@@ -73,6 +74,9 @@ impl ModuleDisplay for AvgPool1d {
impl AvgPool1dConfig {
/// Initialize a new [avg pool 1d](AvgPool1d) module.
pub fn init(&self) -> AvgPool1d {
+ if self.padding == PaddingConfig1d::Same {
+ check_same_padding_support(&[self.kernel_size]);
+ }
AvgPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
@@ -111,6 +115,13 @@ impl AvgPool1d {
mod tests {
use super::*;
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let config = AvgPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
+ let _ = config.init();
+ }
+
#[test]
fn display() {
let config = AvgPool1dConfig::new(3);
diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs
index 6c6ffc87ed..343d59922b 100644
--- a/crates/burn-core/src/nn/pool/avg_pool2d.rs
+++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs
@@ -1,4 +1,5 @@
use crate as burn;
+use crate::nn::conv::checks::check_same_padding_support;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
@@ -18,6 +19,10 @@ pub struct AvgPool2dConfig {
#[config(default = "[1, 1]")]
pub strides: [usize; 2],
/// The padding configuration.
+ ///
+ /// ### Warning
+ /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+ /// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// If the padding is counted in the denominator when computing the average.
@@ -36,9 +41,6 @@ pub struct AvgPool2dConfig {
/// legitimate values, and they contribute to the denominator
/// when calculating the average. This is equivalent to
/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
-///
-/// TODO: Add support for `count_include_pad=False`, see
-/// [Issue 636](https://github.com/tracel-ai/burn/issues/636)
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct AvgPool2d {
@@ -72,6 +74,9 @@ impl ModuleDisplay for AvgPool2d {
impl AvgPool2dConfig {
/// Initialize a new [avg pool 2d](AvgPool2d) module.
pub fn init(&self) -> AvgPool2d {
+ if self.padding == PaddingConfig2d::Same {
+ check_same_padding_support(&self.kernel_size);
+ }
AvgPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
@@ -110,6 +115,13 @@ impl AvgPool2d {
mod tests {
use super::*;
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let config = AvgPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same);
+ let _ = config.init();
+ }
+
#[test]
fn display() {
let config = AvgPool2dConfig::new([3, 3]);
diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs
index 5be363e908..71041e6155 100644
--- a/crates/burn-core/src/nn/pool/max_pool1d.rs
+++ b/crates/burn-core/src/nn/pool/max_pool1d.rs
@@ -1,4 +1,5 @@
use crate as burn;
+use crate::nn::conv::checks::check_same_padding_support;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
@@ -18,6 +19,10 @@ pub struct MaxPool1dConfig {
#[config(default = "1")]
pub stride: usize,
/// The padding configuration.
+ ///
+ /// ### Warning
+ /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+ /// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig1d::Valid")]
pub padding: PaddingConfig1d,
/// The dilation.
@@ -61,6 +66,9 @@ impl ModuleDisplay for MaxPool1d {
impl MaxPool1dConfig {
/// Initialize a new [max pool 1d](MaxPool1d) module.
pub fn init(&self) -> MaxPool1d {
+ if self.padding == PaddingConfig1d::Same {
+ check_same_padding_support(&[self.kernel_size]);
+ }
MaxPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
@@ -93,6 +101,13 @@ impl MaxPool1d {
mod tests {
use super::*;
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let config = MaxPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
+ let _ = config.init();
+ }
+
#[test]
fn display() {
let config = MaxPool1dConfig::new(3);
diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs
index ab9c60d276..3eb94f5db5 100644
--- a/crates/burn-core/src/nn/pool/max_pool2d.rs
+++ b/crates/burn-core/src/nn/pool/max_pool2d.rs
@@ -1,4 +1,5 @@
use crate as burn;
+use crate::nn::conv::checks::check_same_padding_support;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
@@ -18,6 +19,10 @@ pub struct MaxPool2dConfig {
#[config(default = "[1, 1]")]
pub strides: [usize; 2],
/// The padding configuration.
+ ///
+ /// ### Warning
+ /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
+ /// size is not supported as it will not produce the same output size.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// The dilation.
@@ -61,6 +66,9 @@ impl ModuleDisplay for MaxPool2d {
impl MaxPool2dConfig {
/// Initialize a new [max pool 2d](MaxPool2d) module.
pub fn init(&self) -> MaxPool2d {
+ if self.padding == PaddingConfig2d::Same {
+ check_same_padding_support(&self.kernel_size);
+ }
MaxPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
@@ -93,6 +101,13 @@ impl MaxPool2d {
mod tests {
use super::*;
+ #[test]
+ #[should_panic = "Same padding with an even kernel size is not supported"]
+ fn same_with_even_kernel_is_invalid() {
+ let config = MaxPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same);
+ let _ = config.init();
+ }
+
#[test]
fn display() {
let config = MaxPool2dConfig::new([3, 3]);
diff --git a/crates/burn-core/src/nn/rnn/gru.rs b/crates/burn-core/src/nn/rnn/gru.rs
index c66ad631b6..e2f8b2425e 100644
--- a/crates/burn-core/src/nn/rnn/gru.rs
+++ b/crates/burn-core/src/nn/rnn/gru.rs
@@ -20,6 +20,21 @@ pub struct GruConfig {
pub d_hidden: usize,
/// If a bias should be applied during the Gru transformation.
pub bias: bool,
+ /// If reset gate should be applied after weight multiplication.
+ ///
+ /// This configuration option controls how the reset gate is applied to the hidden state.
+ /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for
+ /// Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by
+ /// the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU).
+ /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine
+ /// Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication.
+ ///
+ /// The differing implementations can give slightly different numerical results and have different efficiencies. For more
+ /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs).
+ ///
+ /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`).
+ #[config(default = "true")]
+ pub reset_after: bool,
/// Gru initializer
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
@@ -41,6 +56,8 @@ pub struct Gru {
pub new_gate: GateController,
/// The size of the hidden state.
pub d_hidden: usize,
+ /// If reset gate should be applied after weight multiplication.
+ pub reset_after: bool,
}
impl ModuleDisplay for Gru {
@@ -58,6 +75,7 @@ impl ModuleDisplay for Gru {
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
+ .add("reset_after", &self.reset_after)
.optional()
}
}
@@ -94,86 +112,92 @@ impl GruConfig {
reset_gate,
new_gate,
d_hidden: self.d_hidden,
+ reset_after: self.reset_after,
}
}
}
impl Gru {
/// Applies the forward pass on the input tensor. This GRU implementation
- /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size].
+ /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`.
///
- /// # Shapes
+ /// # Parameters
/// - batched_input: `[batch_size, sequence_length, input_size]`.
- /// - state: An optional tensor representing an initial cell state with the same dimensions
- /// as batched_input. If none is provided, one will be generated.
- /// - output: `[batch_size, sequence_length, hidden_size]`.
+ /// - state: An optional tensor representing an initial cell state with dimensions
+ /// `[batch_size, hidden_size]`. If none is provided, an empty state will be used.
+ ///
+ /// # Returns
+ /// - output: `[batch_size, sequence_length, hidden_size]`
pub fn forward(
&self,
batched_input: Tensor,
- state: Option>,
+ state: Option>,
) -> Tensor {
+ let device = batched_input.device();
let [batch_size, seq_length, _] = batched_input.shape().dims();
- let mut hidden_state = match state {
+ let mut batched_hidden_state =
+ Tensor::empty([batch_size, seq_length, self.d_hidden], &device);
+
+ let mut hidden_t = match state {
Some(state) => state,
- None => Tensor::zeros(
- [batch_size, seq_length, self.d_hidden],
- &batched_input.device(),
- ),
+ None => Tensor::zeros([batch_size, self.d_hidden], &device),
};
- for (t, (input_t, hidden_t)) in batched_input
- .iter_dim(1)
- .zip(hidden_state.clone().iter_dim(1))
- .enumerate()
- {
+ for (t, input_t) in batched_input.iter_dim(1).enumerate() {
let input_t = input_t.squeeze(1);
- let hidden_t = hidden_t.squeeze(1);
// u(pdate)g(ate) tensors
- let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
+ let biased_ug_input_sum =
+ self.gate_product(&input_t, &hidden_t, None, &self.update_gate);
let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t)
// r(eset)g(ate) tensors
- let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate);
+ let biased_rg_input_sum =
+ self.gate_product(&input_t, &hidden_t, None, &self.reset_gate);
let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t)
- let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate
// n(ew)g(ate) tensor
- let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate);
+ let biased_ng_input_sum = if self.reset_after {
+ self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)
+ } else {
+ let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate
+ self.gate_product(&input_t, &reset_t, None, &self.new_gate)
+ };
let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t)
// calculate linear interpolation between previous hidden state and candidate state:
// g(t) * (1 - z(t)) + z(t) * hidden_t
- let state_vector = candidate_state
+ hidden_t = candidate_state
.clone()
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
+ update_values.clone().mul(hidden_t);
- let current_shape = state_vector.shape().dims;
- let unsqueezed_shape = [current_shape[0], 1, current_shape[1]];
- let reshaped_state_vector = state_vector.reshape(unsqueezed_shape);
- hidden_state = hidden_state.slice_assign(
+ let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1);
+
+ batched_hidden_state = batched_hidden_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
- reshaped_state_vector,
+ unsqueezed_hidden_state,
);
}
- hidden_state
+ batched_hidden_state
}
/// Helper function for performing weighted matrix product for a gate and adds
- /// bias, if any.
+ /// bias, if any, and optionally applies reset to hidden state.
///
- /// Mathematically, performs `Wx*X + Wh*H + b`, where:
+ /// Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where:
/// Wx = weight matrix for the connection to input vector X
/// Wh = weight matrix for the connection to hidden state H
/// X = input vector
/// H = hidden state
/// b = bias terms
+ /// r = reset state
fn gate_product(
&self,
input: &Tensor,
hidden: &Tensor,
+ reset: Option<&Tensor>,
gate: &GateController,
) -> Tensor {
let input_product = input.clone().matmul(gate.input_transform.weight.val());
@@ -190,13 +214,29 @@ impl Gru {
.as_ref()
.map(|bias_param| bias_param.val());
- match (input_bias, hidden_bias) {
- (Some(input_bias), Some(hidden_bias)) => {
+ match (input_bias, hidden_bias, reset) {
+ (Some(input_bias), Some(hidden_bias), Some(r)) => {
+ input_product
+ + input_bias.unsqueeze()
+ + r.clone().mul(hidden_product + hidden_bias.unsqueeze())
+ }
+ (Some(input_bias), Some(hidden_bias), None) => {
input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
}
- (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
- (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
- (None, None) => input_product + hidden_product,
+ (Some(input_bias), None, Some(r)) => {
+ input_product + input_bias.unsqueeze() + r.clone().mul(hidden_product)
+ }
+ (Some(input_bias), None, None) => {
+ input_product + input_bias.unsqueeze() + hidden_product
+ }
+ (None, Some(hidden_bias), Some(r)) => {
+ input_product + r.clone().mul(hidden_product + hidden_bias.unsqueeze())
+ }
+ (None, Some(hidden_bias), None) => {
+ input_product + hidden_product + hidden_bias.unsqueeze()
+ }
+ (None, None, Some(r)) => input_product + r.clone().mul(hidden_product),
+ (None, None, None) => input_product + hidden_product,
}
}
}
@@ -207,29 +247,16 @@ mod tests {
use crate::tensor::{Distribution, TensorData};
use crate::{module::Param, nn::LinearRecord, TestBackend};
- /// Test forward pass with simple input vector.
- ///
- /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125
- /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150
- /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699
- ///
- /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341
- #[test]
- fn tests_forward_single_input_single_feature() {
- TestBackend::seed(0);
- let config = GruConfig::new(1, 1, false);
- let device = Default::default();
- let mut gru = config.init::(&device);
-
- fn create_gate_controller(
+ fn init_gru(reset_after: bool, device: &B::Device) -> Gru {
+ fn create_gate_controller(
weights: f32,
biases: f32,
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
- device: &::Device,
- ) -> GateController {
+ device: &B::Device,
+ ) -> GateController {
let record_1 = LinearRecord {
weight: Param::from_data(TensorData::from([[weights]]), device),
bias: Some(Param::from_data(TensorData::from([biases]), device)),
@@ -248,6 +275,9 @@ mod tests {
)
}
+ let config = GruConfig::new(1, 1, false).with_reset_after(reset_after);
+ let mut gru = config.init::(device);
+
gru.update_gate = create_gate_controller(
0.5,
0.0,
@@ -255,7 +285,7 @@ mod tests {
1,
false,
Initializer::XavierNormal { gain: 1.0 },
- &device,
+ device,
);
gru.reset_gate = create_gate_controller(
0.6,
@@ -264,7 +294,7 @@ mod tests {
1,
false,
Initializer::XavierNormal { gain: 1.0 },
- &device,
+ device,
);
gru.new_gate = create_gate_controller(
0.7,
@@ -273,18 +303,72 @@ mod tests {
1,
false,
Initializer::XavierNormal { gain: 1.0 },
- &device,
+ device,
);
+ gru
+ }
+
+ /// Test forward pass with simple input vector.
+ ///
+ /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125
+ /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150
+ /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699
+ ///
+ /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341
+ #[test]
+ fn tests_forward_single_input_single_feature() {
+ TestBackend::seed(0);
+ let device = Default::default();
+ let mut gru = init_gru::(false, &device);
let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device);
+ let expected = TensorData::from([[0.034]]);
+ // Reset gate applied to hidden state before the matrix multiplication
+ let state = gru.forward(input.clone(), None);
+
+ let output = state
+ .select(0, Tensor::arange(0..1, &device))
+ .squeeze::<2>(0);
+
+ output.to_data().assert_approx_eq(&expected, 3);
+
+ // Reset gate applied to hidden state after the matrix multiplication
+ gru.reset_after = true; // override forward behavior
+ let state = gru.forward(input, None);
+
+ let output = state
+ .select(0, Tensor::arange(0..1, &device))
+ .squeeze::<2>(0);
+
+ output.to_data().assert_approx_eq(&expected, 3);
+ }
+
+ #[test]
+ fn tests_forward_seq_len_3() {
+ TestBackend::seed(0);
+ let device = Default::default();
+ let mut gru = init_gru::(true, &device);
+
+ let input =
+ Tensor::::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
+ let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]);
+
+ let result = gru.forward(input.clone(), None);
+ let output = result
+ .select(0, Tensor::arange(0..1, &device))
+ .squeeze::<2>(0);
+
+ output.to_data().assert_approx_eq(&expected, 3);
+
+ // Reset gate applied to hidden state before the matrix multiplication
+ gru.reset_after = false; // override forward behavior
let state = gru.forward(input, None);
let output = state
.select(0, Tensor::arange(0..1, &device))
.squeeze::<2>(0);
- let expected = TensorData::from([[0.034]]);
output.to_data().assert_approx_eq(&expected, 3);
}
@@ -308,7 +392,7 @@ mod tests {
assert_eq!(
alloc::format!("{}", layer),
- "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}"
+ "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}"
);
}
}
diff --git a/crates/burn-core/src/record/primitive.rs b/crates/burn-core/src/record/primitive.rs
index 9dd921e824..2f9fa3e83c 100644
--- a/crates/burn-core/src/record/primitive.rs
+++ b/crates/burn-core/src/record/primitive.rs
@@ -5,9 +5,7 @@ use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
use super::{PrecisionSettings, Record};
use crate::module::{Param, ParamId};
-#[allow(deprecated)]
-use burn_tensor::DataSerialize;
-use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor};
+use burn_tensor::{backend::Backend, Bool, Int, Tensor};
use hashbrown::HashMap;
use serde::{
@@ -143,23 +141,6 @@ where
}
}
-#[allow(deprecated)]
-impl Record for DataSerialize
-where
- E: Element,
- B: Backend,
-{
- type Item = DataSerialize;
-
- fn into_item(self) -> Self::Item {
- self.convert()
- }
-
- fn from_item(item: Self::Item, _device: &B::Device) -> Self {
- item.convert()
- }
-}
-
/// (De)serialize parameters into a clean format.
#[derive(new, Debug, Clone, Serialize, Deserialize)]
pub struct ParamSerde {
diff --git a/crates/burn-core/src/record/tensor.rs b/crates/burn-core/src/record/tensor.rs
index ab6f448b7e..a07453bcba 100644
--- a/crates/burn-core/src/record/tensor.rs
+++ b/crates/burn-core/src/record/tensor.rs
@@ -4,20 +4,7 @@ use super::{PrecisionSettings, Record};
use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData};
use serde::{Deserialize, Serialize};
-#[cfg(not(feature = "record-backward-compat"))]
use alloc::format;
-#[cfg(feature = "record-backward-compat")]
-use burn_tensor::DataSerialize;
-
-/// Versioned serde data deserialization to maintain backward compatibility between formats.
-#[cfg(feature = "record-backward-compat")]
-#[allow(deprecated)]
-#[derive(Serialize, Deserialize)]
-#[serde(untagged)]
-enum TensorDataSerde {
- V1(DataSerialize),
- V2(TensorData),
-}
/// Deserialize the value into [`TensorData`].
fn deserialize_data<'de, E, De>(deserializer: De) -> Result
@@ -25,31 +12,18 @@ where
E: Element + Deserialize<'de>,
De: serde::Deserializer<'de>,
{
- #[cfg(feature = "record-backward-compat")]
- {
- let data = match TensorDataSerde::::deserialize(deserializer)? {
- TensorDataSerde::V1(data) => data.into_tensor_data(),
- // NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16
- TensorDataSerde::V2(data) => data.convert::(),
- };
- Ok(data)
- }
-
- #[cfg(not(feature = "record-backward-compat"))]
- {
- let data = TensorData::deserialize(deserializer).map_err(|e| {
- serde::de::Error::custom(format!(
- "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag. Once you have saved the record in the new format, you can disable the feature flag.\n",
- e
- ))
- })?;
- let data = if let DType::QFloat(_) = data.dtype {
- data // do not convert quantized tensors
- } else {
- data.convert::()
- };
- Ok(data)
- }
+ let data = TensorData::deserialize(deserializer).map_err(|e| {
+ serde::de::Error::custom(format!(
+ "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n",
+ e
+ ))
+ })?;
+ let data = if let DType::QFloat(_) = data.dtype {
+ data // do not convert quantized tensors
+ } else {
+ data.convert::()
+ };
+ Ok(data)
}
/// This struct implements serde to lazily serialize and deserialize a float tensor
diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml
index c366386b0e..1a92e695b2 100644
--- a/crates/burn-cuda/Cargo.toml
+++ b/crates/burn-cuda/Cargo.toml
@@ -19,9 +19,9 @@ fusion = ["burn-fusion", "burn-jit/fusion"]
std = ["burn-jit/std", "cubecl/std"]
[dependencies]
-burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
-burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false }
-burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
+burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true }
+burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false }
+burn-tensor = { path = "../burn-tensor", version = "0.17.0", features = [
"cubecl-cuda",
] }
cubecl = { workspace = true, features = ["cuda"] }
@@ -34,7 +34,7 @@ log = { workspace = true }
[dev-dependencies]
-burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [
+burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [
"export_tests",
] }
paste = { workspace = true }
diff --git a/crates/burn-dataset/Cargo.toml b/crates/burn-dataset/Cargo.toml
index 0237765973..c7ddbebc41 100644
--- a/crates/burn-dataset/Cargo.toml
+++ b/crates/burn-dataset/Cargo.toml
@@ -30,7 +30,7 @@ __sqlite-shared = [
dataframe = ["dep:polars"]
[dependencies]
-burn-common = { path = "../burn-common", version = "0.16.0", optional = true, features = [
+burn-common = { path = "../burn-common", version = "0.17.0", optional = true, features = [
"network",
] }
csv = { workspace = true }
diff --git a/crates/burn-dataset/src/dataset/dataframe.rs b/crates/burn-dataset/src/dataset/dataframe.rs
index 023b357454..c851e8a3e3 100644
--- a/crates/burn-dataset/src/dataset/dataframe.rs
+++ b/crates/burn-dataset/src/dataset/dataframe.rs
@@ -269,20 +269,20 @@ mod tests {
}
fn create_test_dataframe() -> DataFrame {
- let s0 = Column::Series(Series::new("int32".into(), &[1i32, 2i32, 3i32]));
- let s1 = Column::Series(Series::new("bool".into(), &[true, false, true]));
- let s2 = Column::Series(Series::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]));
- let s3 = Column::Series(Series::new("string".into(), &["Boo", "Boo2", "Boo3"]));
- let s6 = Column::Series(Series::new("int16".into(), &[1i16, 2i16, 3i16]));
- let s8 = Column::Series(Series::new("uint32".into(), &[1u32, 2u32, 3u32]));
- let s9 = Column::Series(Series::new("uint64".into(), &[1u64, 2u64, 3u64]));
- let s10 = Column::Series(Series::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]));
- let s11 = Column::Series(Series::new("int64".into(), &[1i64, 2i64, 3i64]));
- let s12 = Column::Series(Series::new("int8".into(), &[1i8, 2i8, 3i8]));
+ let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]);
+ let s1 = Column::new("bool".into(), &[true, false, true]);
+ let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]);
+ let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]);
+ let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]);
+ let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]);
+ let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]);
+ let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]);
+ let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]);
+ let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]);
let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]];
- let s13 = Column::Series(Series::new("binary".into(), binary_data));
+ let s13 = Column::new("binary".into(), binary_data);
DataFrame::new(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap()
}
diff --git a/crates/burn-fusion/Cargo.toml b/crates/burn-fusion/Cargo.toml
index eb4296097b..1f2f785940 100644
--- a/crates/burn-fusion/Cargo.toml
+++ b/crates/burn-fusion/Cargo.toml
@@ -17,8 +17,8 @@ std = ["serde/std"]
doc = ["default"]
[dependencies]
-burn-tensor = { path = "../burn-tensor", version = "0.16.0" }
-burn-common = { path = "../burn-common", version = "0.16.0" }
+burn-tensor = { path = "../burn-tensor", version = "0.17.0" }
+burn-common = { path = "../burn-common", version = "0.17.0" }
hashbrown = { workspace = true }
derive-new = {workspace = true }
spin = { workspace = true }
diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs
index baa5169db3..658907bf3e 100644
--- a/crates/burn-fusion/src/ops/boolean.rs
+++ b/crates/burn-fusion/src/ops/boolean.rs
@@ -1,5 +1,6 @@
use burn_tensor::{
ops::{binary_ops_shape, FloatTensor, IntTensor},
+ repr::{FromDataOperationDescription, TensorDescription},
DType, Element, TensorData,
};
use std::marker::PhantomData;
@@ -24,15 +25,32 @@ use burn_tensor::{
impl BoolTensorOps for Fusion {
fn bool_empty(shape: Shape, device: &Device) -> BoolTensor {
+ #[derive(new)]
+ struct EmptyOps {
+ desc: TensorDescription,
+ device: Device,
+ }
+
+ impl Operation for EmptyOps {
+ fn execute(self: Box, handles: &mut HandleContainer) {
+ let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device);
+ handles.register_bool_tensor::(&self.desc.id, output);
+ }
+ }
+
+ let stream = StreamId::current();
let client = get_client::(&device.clone());
- let tensor = B::bool_empty(shape.clone(), device);
+ let out = client.tensor_uninitialized(shape.dims.clone(), DType::Bool);
- client.register_tensor(
- B::bool_tensor_handle(tensor),
- shape.dims,
- StreamId::current(),
- DType::Bool,
- )
+ let desc = out.to_description_out();
+
+ client.register(
+ vec![stream],
+ OperationDescription::BaseBool(BaseOperationDescription::Empty(desc.clone())),
+ EmptyOps::::new(desc, device.clone()),
+ );
+
+ out
}
async fn bool_into_data(tensor: BoolTensor) -> TensorData {
@@ -40,16 +58,35 @@ impl BoolTensorOps for Fusion {
}
fn bool_from_data(data: burn_tensor::TensorData, device: &Device) -> BoolTensor {
+ #[derive(new)]
+ struct FromDataOps {
+ desc: FromDataOperationDescription,
+ device: Device,
+ }
+
+ impl Operation for FromDataOps {
+ fn execute(self: Box, handles: &mut HandleContainer) {
+ let output = B::bool_from_data(self.desc.data, &self.device);
+ handles.register_bool_tensor::(&self.desc.out.id, output);
+ }
+ }
+
+ let stream = StreamId::current();
let client = get_client::(&device.clone());
- let tensor = B::bool_from_data(data, device);
- let shape = burn_tensor::TensorMetadata::shape(&tensor);
-
- client.register_tensor(
- B::bool_tensor_handle(tensor),
- shape.dims,
- StreamId::current(),
- DType::Bool,
- )
+ let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool);
+
+ let desc = FromDataOperationDescription {
+ out: out.to_description_out(),
+ data,
+ };
+
+ client.register(
+ vec![stream],
+ OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())),
+ FromDataOps::::new(desc, device.clone()),
+ );
+
+ out
}
fn bool_into_int(tensor: BoolTensor) -> IntTensor {
diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs
index 41d691272f..afbe6ae77c 100644
--- a/crates/burn-fusion/src/ops/float.rs
+++ b/crates/burn-fusion/src/ops/float.rs
@@ -16,16 +16,35 @@ use std::{marker::PhantomData, ops::Range};
impl