From 7350209da3736a89523e70e82cc54f60d5a0c6df Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 27 Sep 2024 11:30:16 -0600 Subject: [PATCH 1/5] higher order tensor --- src/tensor/extension.rs | 111 ++++++++++++++++++++++++++++++++++++++++ src/tensor/mod.rs | 2 + 2 files changed, 113 insertions(+) create mode 100644 src/tensor/extension.rs diff --git a/src/tensor/extension.rs b/src/tensor/extension.rs new file mode 100644 index 0000000..4154ce4 --- /dev/null +++ b/src/tensor/extension.rs @@ -0,0 +1,111 @@ +use core::{ + marker::PhantomData, + ops::{Add, Mul}, +}; + +use super::{Tensor, V}; + +trait TensorProduct { + type T1: TensorProduct; + type T2: TensorProduct; + + fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self; + + fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2; + + fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1; +} + +impl + Mul + Copy> TensorProduct + for V +{ + type T1 = Self; + type T2 = V<1, F>; // Scalar ring + + fn tensor_product(tensor_1: Self::T1, _tensor_2: Self::T2) -> Self { + tensor_1 + } + + fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 { + let val = self + .0 + .iter() + .zip(tensor_1.0.iter()) + .fold(F::default(), |acc, (a, b)| acc + (*a * *b)); + V([val]) + } + + fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 { + *self * tensor_2.0[0] + } +} + +impl + Mul + Copy> + TensorProduct for Tensor +where + [(); M * N]:, +{ + type T1 = V; + type T2 = V; + + fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self { + todo!() + } + + fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 { + todo!() + } + + fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 { + todo!() + } +} + +#[derive(Clone)] +pub struct HigherTensor, T2: TensorProduct, F> { + tensor_1: T1, + tensor_2: T2, + _p: PhantomData, +} + +impl, T2: TensorProduct, F> TensorProduct for HigherTensor { + type T1 = T1; + type T2 = T2; + + fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self { + todo!() + } + + fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 { + todo!() + } + + fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 { + todo!() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn intro() { + let tensor_1 = Tensor::<3, 2, f64> { + coefficients: V::<2>(V::<3>([1, 2, 3])), + }; + let tensor_2 = tensor_1.clone(); + + let tensor = HigherTensor { + tensor_1, + tensor_2, + _p: PhantomData, + }; + + let nested_tensor = HigherTensor { + tensor_1: tensor.clone(), + tensor_2: tensor.clone(), + _p: PhantomData, + }; + } +} diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 44299e5..d501b1f 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -2,8 +2,10 @@ use core::ops::AddAssign; use super::*; +pub mod extension; pub mod macros; +#[derive(Clone)] pub struct Tensor where [(); M * N]:, From 05865a89709d9c9ac782a2b95a4ee350b621aeaa Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 29 Sep 2024 11:07:12 -0400 Subject: [PATCH 2/5] save --- src/tensor/extension.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tensor/extension.rs b/src/tensor/extension.rs index 4154ce4..fa3312b 100644 --- a/src/tensor/extension.rs +++ b/src/tensor/extension.rs @@ -1,3 +1,5 @@ +//! Something cool to do here with recursion + use core::{ marker::PhantomData, ops::{Add, Mul}, From 4b97e93a017fb494e02dfae6280c60dc4b781c89 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 21 Dec 2024 13:26:43 -0700 Subject: [PATCH 3/5] remove: tons of junk --- Cargo.lock | 261 ---------------------------------------- Cargo.toml | 8 -- src/coproduct.rs | 93 -------------- src/lib.rs | 68 +---------- src/module.rs | 73 +++++++++++ src/product.rs | 68 ----------- src/tensor.rs | 111 +++++++++++++++++ src/tensor/extension.rs | 113 ----------------- src/tensor/macros.rs | 205 ------------------------------- src/tensor/mod.rs | 161 ------------------------- src/unique_coproduct.rs | 33 ----- 11 files changed, 190 insertions(+), 1004 deletions(-) delete mode 100644 src/coproduct.rs create mode 100644 src/module.rs delete mode 100644 src/product.rs create mode 100644 src/tensor.rs delete mode 100644 src/tensor/extension.rs delete mode 100644 src/tensor/macros.rs delete mode 100644 src/tensor/mod.rs delete mode 100644 src/unique_coproduct.rs diff --git a/Cargo.lock b/Cargo.lock index ddd7ce2..1ab5361 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,267 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "aho-corasick" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" -dependencies = [ - "memchr", -] - -[[package]] -name = "anstream" -version = "0.6.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb" -dependencies = [ - "anstyle", - "anstyle-parse", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "utf8parse", -] - -[[package]] -name = "anstyle" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" - -[[package]] -name = "anstyle-parse" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" -dependencies = [ - "utf8parse", -] - -[[package]] -name = "anstyle-query" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" -dependencies = [ - "windows-sys", -] - -[[package]] -name = "anstyle-wincon" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" -dependencies = [ - "anstyle", - "windows-sys", -] - -[[package]] -name = "colorchoice" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" - -[[package]] -name = "env_filter" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" -dependencies = [ - "log", - "regex", -] - -[[package]] -name = "env_logger" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" -dependencies = [ - "anstream", - "anstyle", - "env_filter", - "humantime", - "log", -] - [[package]] name = "extensor" version = "0.1.1" -dependencies = [ - "env_logger", - "extensor-macros", - "log", -] - -[[package]] -name = "extensor-macros" -version = "0.1.0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "humantime" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" - -[[package]] -name = "log" -version = "0.4.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" - -[[package]] -name = "memchr" -version = "2.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" - -[[package]] -name = "proc-macro2" -version = "1.0.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "regex" -version = "1.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" - -[[package]] -name = "syn" -version = "2.0.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "utf8parse" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets", -] - -[[package]] -name = "windows-targets" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" diff --git a/Cargo.toml b/Cargo.toml index 2bdee11..b583646 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,3 @@ keywords = ["tensor", "algebra", "math", "macros"] license-file = "LICENSE" readme = "README.md" rust-version = "1.80.0" - -[dependencies] -extensor-macros = { path = "macros/", version = "0.1.0" } - - -[dev-dependencies] -log = "0.4.21" -env_logger = "0.11.3" diff --git a/src/coproduct.rs b/src/coproduct.rs deleted file mode 100644 index 8ceaa5a..0000000 --- a/src/coproduct.rs +++ /dev/null @@ -1,93 +0,0 @@ -use super::*; - -pub trait Coproduct -where - Self: Sized, -{ - type X; - type Y; - - fn construct(x: Option, y: Option) -> Self; - - #[allow(non_snake_case)] - fn iota_X(x: Option) -> Self { - Self::construct(x, None) - } - - #[allow(non_snake_case)] - fn iota_Y(y: Option) -> Self { - Self::construct(None, y) - } - - #[allow(non_snake_case)] - fn get_X_via_tag(&self) -> Option; - - #[allow(non_snake_case)] - fn get_Y_via_tag(&self) -> Option; - - #[allow(non_snake_case)] - fn f>( - &self, - f_X: impl Fn(Option) -> Z, - f_Y: impl Fn(Option) -> Z, - ) -> Z { - f_X(self.get_X_via_tag()) + f_Y(self.get_Y_via_tag()) - } -} - -pub struct DirectSum { - v: Option>, - w: Option>, -} - -impl Coproduct for DirectSum -where - F: Copy, -{ - type X = V; - type Y = V; - - fn construct(v: Option, w: Option) -> Self { - assert!(v.is_some() || w.is_some()); - DirectSum { v, w } - } - - fn get_X_via_tag(&self) -> Option { - self.v - } - - fn get_Y_via_tag(&self) -> Option { - self.w - } -} - -impl Add for DirectSum -where - F: Add + Default + Copy, -{ - type Output = Self; - fn add(self, other: DirectSum) -> Self::Output { - DirectSum::construct( - self.v - .zip(other.v) - .map(|(v, other_v)| v + other_v) - .or(self.v) - .or(other.v), - self.w - .zip(other.w) - .map(|(w, other_w)| w + other_w) - .or(self.w) - .or(other.w), - ) - } -} - -impl Mul for DirectSum -where - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - DirectSum::construct(self.v.map(|v| v * scalar), self.w.map(|w| w * scalar)) - } -} diff --git a/src/lib.rs b/src/lib.rs index ea73703..2991bca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,69 +4,13 @@ #![no_std] use core::{ - fmt::{Debug, Formatter, Result}, - ops::{Add, Mul}, + fmt::Debug, + ops::{Add, Mul, Neg}, }; -use coproduct::{Coproduct, DirectSum}; -use product::{DirectProduct, ProductType}; +#[cfg(test)] +#[macro_use] +extern crate std; -pub mod coproduct; -pub mod product; +pub mod module; pub mod tensor; -pub mod unique_coproduct; - -#[derive(Copy, Clone, Debug)] -pub struct V([F; M]); - -impl Default for V -where - F: Default + Copy, -{ - fn default() -> Self { - V([F::default(); M]) - } -} - -impl + Default + Copy> Add for V { - type Output = Self; - fn add(self, other: V) -> Self::Output { - let mut sum = V::default(); - for i in 0..M { - sum.0[i] = self.0[i] + other.0[i]; - } - sum - } -} - -impl + Default + Copy> Mul for V { - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - let mut scalar_multiple = V::default(); - for i in 0..M { - scalar_multiple.0[i] = scalar * self.0[i]; - } - scalar_multiple - } -} - -impl From> for DirectProduct -where - F: Add + Default + Copy, -{ - fn from(sum: DirectSum) -> DirectProduct { - DirectProduct::construct( - sum.get_X_via_tag().unwrap_or_default(), - sum.get_Y_via_tag().unwrap_or_default(), - ) - } -} - -impl From> for DirectSum -where - F: Add + Default + Copy, -{ - fn from(prod: DirectProduct) -> DirectSum { - DirectSum::iota_X(Some(prod.pi_X())) + DirectSum::iota_Y(Some(prod.pi_Y())) - } -} diff --git a/src/module.rs b/src/module.rs new file mode 100644 index 0000000..8ad5158 --- /dev/null +++ b/src/module.rs @@ -0,0 +1,73 @@ +use core::ops::Div; + +use super::*; + +pub trait Module: + Add + Neg + Mul + Copy +{ + type Ring: Add + Neg + Mul + Default + Copy; +} + +pub trait VectorSpace: Module +where + Self::Ring: Div, +{ +} + +#[derive(Copy, Clone, Debug)] +pub struct Vector(pub [F; M]); + +impl Default for Vector +where + F: Default + Copy, +{ + fn default() -> Self { + Self([F::default(); M]) + } +} + +impl + Default + Copy> Add for Vector { + type Output = Self; + fn add(self, other: Self) -> Self::Output { + let mut sum = Self::default(); + for i in 0..M { + sum.0[i] = self.0[i] + other.0[i]; + } + sum + } +} + +impl + Default + Copy> Neg for Vector { + type Output = Self; + fn neg(self) -> Self::Output { + let mut neg = Self::default(); + for i in 0..M { + neg.0[i] = -self.0[i]; + } + neg + } +} + +impl + Default + Copy> Mul for Vector { + type Output = Self; + fn mul(self, scalar: F) -> Self::Output { + let mut scalar_multiple = Self::default(); + for i in 0..M { + scalar_multiple.0[i] = scalar * self.0[i]; + } + scalar_multiple + } +} + +impl + Neg + Mul + Default + Copy> Module + for Vector +{ + type Ring = F; +} + +impl< + const M: usize, + F: Add + Neg + Mul + Div + Default + Copy, + > VectorSpace for Vector +{ +} diff --git a/src/product.rs b/src/product.rs deleted file mode 100644 index f6e734a..0000000 --- a/src/product.rs +++ /dev/null @@ -1,68 +0,0 @@ -use super::*; - -pub trait ProductType -where - Self: Sized, -{ - type X; - type Y; - - fn construct(x: Self::X, y: Self::Y) -> Self; - - #[allow(non_snake_case)] - fn pi_X(&self) -> Self::X; - - #[allow(non_snake_case)] - fn pi_Y(&self) -> Self::Y; - - #[allow(non_snake_case)] - fn f(z: &Z, f_X: impl Fn(&Z) -> Self::X, f_Y: impl Fn(&Z) -> Self::Y) -> Self { - Self::construct(f_X(z), f_Y(z)) - } -} - -#[derive(Copy, Clone)] -pub struct DirectProduct { - v: V, - w: V, -} - -impl ProductType for DirectProduct -where - F: Copy, -{ - type X = V; - type Y = V; - - fn construct(v: Self::X, w: Self::Y) -> Self { - DirectProduct { v, w } - } - - fn pi_X(&self) -> Self::X { - self.v - } - - fn pi_Y(&self) -> Self::Y { - self.w - } -} - -impl Add for DirectProduct -where - F: Add + Default + Copy, -{ - type Output = Self; - fn add(self, other: DirectProduct) -> Self::Output { - DirectProduct::construct(self.pi_X() + other.pi_X(), self.pi_Y() + other.pi_Y()) - } -} - -impl Mul for DirectProduct -where - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - DirectProduct::construct(self.pi_X() * scalar, self.pi_Y() * scalar) - } -} diff --git a/src/tensor.rs b/src/tensor.rs new file mode 100644 index 0000000..843a8f3 --- /dev/null +++ b/src/tensor.rs @@ -0,0 +1,111 @@ +use module::Module; + +use super::*; + +#[derive(Copy, Clone, Debug)] +pub struct TensorProduct +where + A: Module, + B: Module, +{ + a: A, + b: B, +} + +impl Add for TensorProduct +where + A: Module, + B: Module, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + a: self.a + rhs.a, + b: self.b + rhs.b, + } + } +} + +impl Neg for TensorProduct +where + A: Module, + B: Module, +{ + type Output = Self; + + fn neg(self) -> Self::Output { + Self { + a: -self.a, + b: -self.b, + } + } +} + +impl Mul<::Ring> for TensorProduct +where + A: Module, + B: Module, +{ + type Output = Self; + + fn mul(self, rhs: ::Ring) -> Self::Output { + Self { + a: self.a * rhs, + b: self.b * rhs, + } + } +} + +impl Module for TensorProduct +where + A: Module, + B: Module, +{ + type Ring = A::Ring; +} + +impl TensorProduct +where + A: Module + Copy, + B: Module + Copy, +{ + pub const fn new(a: A, b: B) -> Self { + Self { a, b } + } + + pub const fn append + Copy>( + self, + c: C, + ) -> TensorProduct> { + let a = self.a; + let b = self.b; + let prod = TensorProduct { a: b, b: c }; + TensorProduct { a, b: prod } + } +} + +#[cfg(test)] +mod tests { + use module::Vector; + + use super::*; + + #[test] + fn intro() { + let a = Vector::<1, f64>::default(); + let b = Vector::<2, f64>::default(); + let c = Vector::<3, f64>::default(); + let tensor = TensorProduct::new(a, b); + let tensor = tensor.append(c); + + let a = Vector::<1, f64>::default(); + let b = Vector::<2, f64>::default(); + let c = Vector::<3, f64>::default(); + let tensor2 = TensorProduct::new(a, b); + let tensor2 = tensor2.append(c); + + let added = tensor + tensor2; + dbg!(added); + } +} diff --git a/src/tensor/extension.rs b/src/tensor/extension.rs deleted file mode 100644 index fa3312b..0000000 --- a/src/tensor/extension.rs +++ /dev/null @@ -1,113 +0,0 @@ -//! Something cool to do here with recursion - -use core::{ - marker::PhantomData, - ops::{Add, Mul}, -}; - -use super::{Tensor, V}; - -trait TensorProduct { - type T1: TensorProduct; - type T2: TensorProduct; - - fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self; - - fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2; - - fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1; -} - -impl + Mul + Copy> TensorProduct - for V -{ - type T1 = Self; - type T2 = V<1, F>; // Scalar ring - - fn tensor_product(tensor_1: Self::T1, _tensor_2: Self::T2) -> Self { - tensor_1 - } - - fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 { - let val = self - .0 - .iter() - .zip(tensor_1.0.iter()) - .fold(F::default(), |acc, (a, b)| acc + (*a * *b)); - V([val]) - } - - fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 { - *self * tensor_2.0[0] - } -} - -impl + Mul + Copy> - TensorProduct for Tensor -where - [(); M * N]:, -{ - type T1 = V; - type T2 = V; - - fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self { - todo!() - } - - fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 { - todo!() - } - - fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 { - todo!() - } -} - -#[derive(Clone)] -pub struct HigherTensor, T2: TensorProduct, F> { - tensor_1: T1, - tensor_2: T2, - _p: PhantomData, -} - -impl, T2: TensorProduct, F> TensorProduct for HigherTensor { - type T1 = T1; - type T2 = T2; - - fn tensor_product(tensor_1: Self::T1, tensor_2: Self::T2) -> Self { - todo!() - } - - fn multilinear_t1(&self, tensor_1: Self::T1) -> Self::T2 { - todo!() - } - - fn multilinear_t2(&self, tensor_2: Self::T2) -> Self::T1 { - todo!() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn intro() { - let tensor_1 = Tensor::<3, 2, f64> { - coefficients: V::<2>(V::<3>([1, 2, 3])), - }; - let tensor_2 = tensor_1.clone(); - - let tensor = HigherTensor { - tensor_1, - tensor_2, - _p: PhantomData, - }; - - let nested_tensor = HigherTensor { - tensor_1: tensor.clone(), - tensor_2: tensor.clone(), - _p: PhantomData, - }; - } -} diff --git a/src/tensor/macros.rs b/src/tensor/macros.rs deleted file mode 100644 index ed2ea32..0000000 --- a/src/tensor/macros.rs +++ /dev/null @@ -1,205 +0,0 @@ -use super::*; - -// TODO: Could probably just assign a valence to the tensors and use N0, N1, N2, -// etc. as dims - -#[macro_export] -macro_rules! tensor { - ($name:ident, $($consts:ident),+) => { - #[derive(extensor_macros::MultilinearMap)] - pub struct $name<$(const $consts: usize),+, F> - where F: Default + Copy + AddAssign + Mul, - { - pub coefficients: coeff_builder!($($consts),+; F), - } - - impl<$(const $consts: usize),+, F: Default + Copy + AddAssign + Mul> Default for $name<$($consts),+, F> { - fn default() -> Self { - let coefficients = ::default(); - $name { coefficients } - } - - } - - impl<$(const $consts: usize),+, F> Debug for $name<$($consts),+, F> - where - F: Default + Copy + Debug + AddAssign + Mul, - { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - f.debug_struct(stringify!($name)) - .field("coefficients", &self.coefficients) - .finish() - } - } - - impl<$(const $consts: usize),+, F> Add for $name<$($consts),+, F> - where - F: Add + Copy + Default + AddAssign + Mul, - { - type Output = Self; - - fn add(self, other: Self) -> Self::Output { - let mut result = Self::default(); - add_tensors!(result.coefficients, self.coefficients, other.coefficients; $($consts),+); - result - } - } - - impl<$(const $consts: usize),+, F> Mul for $name<$($consts),+, F> - where - F: Mul + Copy + Default + AddAssign, - { - type Output = Self; - - fn mul(self, scalar: F) -> Self::Output { - let mut result = Self::default(); - scalar_mul_tensor!(result.coefficients, self.coefficients, scalar; $($consts),+); - result - } - } - } -} - -macro_rules! coeff_builder { - ($const:ident; $expr:ty) => { - V<$const, $expr> - }; - ($const:ident, $($rest:ident),+; $expr:ty) => { - V<$const, coeff_builder!($($rest),+; $expr)> - }; -} - -macro_rules! def_builder { - ($const:ident; $expr:ty) => { - V::<$const, $expr> - }; - ($const:ident, $($rest:ident),+; $expr:ty) => { - V::<$const, def_builder!($($rest),+; $expr)> - }; -} - -macro_rules! add_tensors { - ($result:expr, $self:expr, $other:expr; $const:ident) => { - for i in 0..$const { - $result.0[i] = $self.0[i] + $other.0[i]; - } - }; - ($result:expr, $self:expr, $other:expr; $const:ident, $($rest:ident),+) => { - for i in 0..$const { - add_tensors!($result.0[i], $self.0[i], $other.0[i]; $($rest),+); - } - }; -} - -macro_rules! scalar_mul_tensor { - ($result:expr, $self:expr, $scalar:expr; $const:ident) => { - for i in 0..$const { - $result.0[i] = $self.0[i] * $scalar; - } - }; - ($result:expr, $self:expr, $scalar:expr; $const:ident, $($rest:ident),+) => { - for i in 0..$const { - scalar_mul_tensor!($result.0[i], $self.0[i], $scalar; $($rest),+); - } - }; -} - -tensor!(TensorTester, M, N, P); - -#[cfg(test)] -mod tests { - - use super::*; - tensor!(Tensor2, M, N); - - tensor!(Tensor3, M, N, P); - - use log::{debug, info}; - - fn log() { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); - } - - #[test] - fn create_arbitrary_tensor() { - // log(); - let tensor = Tensor2::<2, 3, f64>::default(); - debug!("{:?}", tensor.coefficients); - - let tensor = Tensor3::<2, 3, 4, f64>::default(); - debug!("{:?}", tensor.coefficients); - } - - #[test] - fn add_tensors() { - // log(); - let mut tensor1 = Tensor2::<2, 3, f64>::default(); - for i in 0..2 { - for j in 0..3 { - tensor1.coefficients.0[i].0[j] = (i + j) as f64; - } - } - debug!("tensor1: {:?}", tensor1.coefficients); - let mut tensor2 = Tensor2::<2, 3, f64>::default(); - for i in 0..2 { - for j in 0..3 { - tensor2.coefficients.0[i].0[j] = i as f64 - j as f64; - } - } - debug!("tensor2: {:?}", tensor2.coefficients); - let tensor3 = tensor1 + tensor2; - info!("output: {:?}", tensor3.coefficients); - } - - #[test] - fn scalar_mul_tensor() { - // log(); - let mut tensor1 = Tensor2::<2, 3, f64>::default(); - for i in 0..2 { - for j in 0..3 { - tensor1.coefficients.0[i].0[j] = (i + j) as f64; - } - } - debug!("tensor1: {:?}", tensor1.coefficients); - let scalar = 2.0; - let tensor2 = tensor1 * scalar; - info!("output: {:?}", tensor2.coefficients); - } - - #[test] - fn multilinear_map() { - log(); - // / 1 0 0 \ - // tensor = \ 0 1 0 / - let mut tensor = Tensor2::<2, 3, f64>::default(); - tensor.coefficients.0[0].0[0] = 1.0; - tensor.coefficients.0[1].0[1] = 1.0; - debug!("tensor: {:?}", tensor); - - // / -1 \ - // v_0 = \ 1 / - let mut v_0 = V::default(); - v_0.0[0] = -1.0; - v_0.0[1] = 1.0; - debug!("v_0: {:?}", v_0); - - // / 1 \ - // | 2 | - // v_1 = \ 3 / - let mut v_1 = V::default(); - v_1.0[0] = 1.0; - v_1.0[1] = 2.0; - v_1.0[2] = 3.0; - debug!("v_1: {:?}", v_1); - - // / 1 \ - // tensor.map(_,v_1) = \ 2 / - // - // then the next is: - // / 1 \ - // tensor.map(v_0, v_1) = < -1 1 > \ 2 / = -1 + 2 = 1 - let output = tensor.multilinear_map(v_0, v_1); - info!("output: {:?}", output); - assert_eq!(output, 1.0); - } -} diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs deleted file mode 100644 index d501b1f..0000000 --- a/src/tensor/mod.rs +++ /dev/null @@ -1,161 +0,0 @@ -use core::ops::AddAssign; - -use super::*; - -pub mod extension; -pub mod macros; - -#[derive(Clone)] -pub struct Tensor -where - [(); M * N]:, -{ - /// This set up makes the first index into `coefficients` the "rows" and - /// second the "columns" - coefficients: V>, -} - -impl Default for Tensor -where - [(); M * N]:, - F: Default + Copy, -{ - fn default() -> Self { - let coefficients = V::>::default(); - Tensor { coefficients } - } -} - -impl Tensor -where - [(); M * N]:, - F: Mul + Default + Copy, -{ - pub fn tensor_product(v: [V; P], w: [V; P]) -> Tensor { - let mut tensor = Tensor::default(); - - for p in 0..P { - for i in 0..M { - for j in 0..N { - tensor.coefficients.0[i].0[j] = v[p].0[i] * w[p].0[j]; - } - } - } - tensor - } -} - -impl Add for Tensor -where - [(); M * N]:, - F: Add + Copy + Default, -{ - type Output = Self; - fn add(self, other: Tensor) -> Self::Output { - let mut tensor = Tensor::default(); - for i in 0..M { - for j in 0..N { - tensor.coefficients.0[i].0[j] = - self.coefficients.0[i].0[j] + other.coefficients.0[i].0[j]; - } - } - tensor - } -} - -impl Mul for Tensor -where - [(); M * N]:, - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - let mut tensor = Tensor::default(); - for i in 0..M { - for j in 0..N { - tensor.coefficients.0[i].0[j] = self.coefficients.0[i].0[j] * scalar; - } - } - tensor - } -} - -/// Below are more features of tensor that we can define for free! - -impl Tensor -where - [(); M * N]:, - F: Add + Mul + AddAssign + Default + Copy, -{ - pub fn bilinear_map(&self, v: V, w: V) -> F { - let mut sum = F::default(); - for i in 0..M { - for j in 0..N { - sum += v.0[j] * self.coefficients.0[i].0[j] * w.0[i]; - } - } - sum - } - - /// Here, for each choice of `w`, we get a distinct linear functional on `V` - /// that utilizes the tensor product. - #[allow(non_snake_case)] - pub fn get_functional_on_V(&self, w: V) -> impl Fn(V) -> F + '_ { - move |v| self.bilinear_map(v, w) - } - - /// Here, for each choice of `v`, we get a distinct linear functional on `W` - /// that utilizes the tensor product. - #[allow(non_snake_case)] - pub fn get_functional_on_W(&self, v: V) -> impl Fn(V) -> F + '_ { - move |w| self.bilinear_map(v, w) - } - - /// Matrix multiplication acting from the left :) - #[allow(non_snake_case)] - pub fn linear_map_V_to_W(&self, v: V) -> V { - let mut w = V([F::default(); N]); - for j in 0..N { - for i in 0..M { - w.0[j] += self.coefficients.0[i].0[j] * v.0[j]; - } - } - w - } - - /// Matrix multiplication acting from the right :) - #[allow(non_snake_case)] - pub fn linear_map_W_to_V(&self, w: V) -> V { - let mut v = V([F::default(); M]); - for j in 0..N { - for i in 0..M { - v.0[j] += self.coefficients.0[i].0[j] * w.0[i]; - } - } - v - } -} - -/// This implementation makes `Tensor` an "Algebra" :) -/// In other words, we can multiply M x N matrices with N x P matrices to get an -/// M x P matrix. -impl Mul> for Tensor -where - [(); M * N]:, - [(); N * P]:, - F: Add + AddAssign + Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, other: Tensor) -> Self::Output { - let mut product = Tensor::default(); - for i in 0..N { - for k in 0..P { - for j in 0..M { - product.coefficients.0[j].0[k] += - self.coefficients.0[i].0[j] * other.coefficients.0[j].0[k]; - } - } - } - product - } -} diff --git a/src/unique_coproduct.rs b/src/unique_coproduct.rs deleted file mode 100644 index 69dc666..0000000 --- a/src/unique_coproduct.rs +++ /dev/null @@ -1,33 +0,0 @@ -use super::*; - -pub enum UniqueDirectSum { - V(V), - W(V), -} - -impl Add for UniqueDirectSum -where - F: Add + Default + Copy, -{ - type Output = Self; - fn add(self, other: UniqueDirectSum) -> Self::Output { - match (self, other) { - (UniqueDirectSum::V(v), UniqueDirectSum::V(w)) => UniqueDirectSum::V(V::add(v, w)), - (UniqueDirectSum::W(v), UniqueDirectSum::W(w)) => UniqueDirectSum::W(V::add(v, w)), - _ => panic!("Cannot add V and W with Rust `UniqueDirectSum`!"), - } - } -} - -impl Mul for UniqueDirectSum -where - F: Mul + Default + Copy, -{ - type Output = Self; - fn mul(self, scalar: F) -> Self::Output { - match self { - UniqueDirectSum::V(v) => UniqueDirectSum::V(v * scalar), - UniqueDirectSum::W(w) => UniqueDirectSum::W(w * scalar), - } - } -} From 3743d8f3610dc575d89d724b3aa25911f5fd201c Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 21 Dec 2024 14:33:16 -0700 Subject: [PATCH 4/5] cleanup --- src/lib.rs | 2 +- src/module.rs | 32 ++++++++++++++++++++++++++- src/tensor.rs | 60 +++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 76 insertions(+), 18 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2991bca..ca36054 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![allow(unstable_features)] #![allow(incomplete_features)] -#![feature(generic_const_exprs)] +#![feature(specialization)] #![no_std] use core::{ diff --git a/src/module.rs b/src/module.rs index 8ad5158..110c4e9 100644 --- a/src/module.rs +++ b/src/module.rs @@ -1,4 +1,4 @@ -use core::ops::Div; +use core::{marker::PhantomData, ops::Div}; use super::*; @@ -71,3 +71,33 @@ impl< > VectorSpace for Vector { } + +#[derive(Clone, Copy, Default)] +pub struct TrivialModule { + pub(crate) _r: PhantomData, +} + +impl Module for TrivialModule { + type Ring = R; +} + +impl Add for TrivialModule { + type Output = Self; + fn add(self, _: Self) -> Self::Output { + Self { _r: PhantomData } + } +} + +impl Neg for TrivialModule { + type Output = Self; + fn neg(self) -> Self::Output { + Self { _r: PhantomData } + } +} + +impl Mul for TrivialModule { + type Output = Self; + fn mul(self, _: R) -> Self::Output { + Self { _r: PhantomData } + } +} diff --git a/src/tensor.rs b/src/tensor.rs index 843a8f3..5a07276 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,9 +1,11 @@ -use module::Module; +use core::marker::PhantomData; + +use module::{Module, TrivialModule, Vector}; use super::*; #[derive(Copy, Clone, Debug)] -pub struct TensorProduct +pub struct Tensor where A: Module, B: Module, @@ -12,7 +14,7 @@ where b: B, } -impl Add for TensorProduct +impl Add for Tensor where A: Module, B: Module, @@ -27,7 +29,7 @@ where } } -impl Neg for TensorProduct +impl Neg for Tensor where A: Module, B: Module, @@ -42,7 +44,7 @@ where } } -impl Mul<::Ring> for TensorProduct +impl Mul<::Ring> for Tensor where A: Module, B: Module, @@ -57,7 +59,7 @@ where } } -impl Module for TensorProduct +impl Module for Tensor where A: Module, B: Module, @@ -65,23 +67,41 @@ where type Ring = A::Ring; } -impl TensorProduct +impl Tensor> +where + A: Module + Copy, +{ + pub const fn append_trivial + Copy>(self, b: B) -> Tensor { + let a = self.a; + Tensor { a, b } + } +} + +impl Tensor where A: Module + Copy, B: Module + Copy, { - pub const fn new(a: A, b: B) -> Self { + pub const fn product(a: A, b: B) -> Self { Self { a, b } } - pub const fn append + Copy>( - self, - c: C, - ) -> TensorProduct> { + pub const fn append + Copy>(self, c: C) -> Tensor> { let a = self.a; let b = self.b; - let prod = TensorProduct { a: b, b: c }; - TensorProduct { a, b: prod } + let prod = Tensor { a: b, b: c }; + Tensor { a, b: prod } + } +} + +impl + Neg + Mul + Default + Copy> + From> for Tensor, TrivialModule> +{ + fn from(value: Vector) -> Self { + Self { + a: value, + b: TrivialModule { _r: PhantomData }, + } } } @@ -96,16 +116,24 @@ mod tests { let a = Vector::<1, f64>::default(); let b = Vector::<2, f64>::default(); let c = Vector::<3, f64>::default(); - let tensor = TensorProduct::new(a, b); + let tensor = Tensor::product(a, b); let tensor = tensor.append(c); let a = Vector::<1, f64>::default(); let b = Vector::<2, f64>::default(); let c = Vector::<3, f64>::default(); - let tensor2 = TensorProduct::new(a, b); + let tensor2 = Tensor::product(a, b); let tensor2 = tensor2.append(c); let added = tensor + tensor2; dbg!(added); } + + #[test] + fn terminal() { + let a = Vector::<1, f64>::default(); + let b = Vector::<2, f64>::default(); + let tensor = Tensor::from(a); + let tensor = tensor.append_trivial(b); + } } From 2218e05ed07697c5b4d44f4a0d2d160f9cf28e21 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 21 Dec 2024 17:41:08 -0700 Subject: [PATCH 5/5] feat: init --- Cargo.lock | 9 +++++ Cargo.toml | 3 ++ src/algebra/mod.rs | 39 ++++++++++++++++++++++ src/algebra/quadratic_form.rs | 63 +++++++++++++++++++++++++++++++++++ src/lib.rs | 6 +++- src/module.rs | 42 +++++++++++++---------- src/tensor.rs | 22 ++++++------ 7 files changed, 155 insertions(+), 29 deletions(-) create mode 100644 src/algebra/mod.rs create mode 100644 src/algebra/quadratic_form.rs diff --git a/Cargo.lock b/Cargo.lock index 1ab5361..05768c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "const-default" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b396d1f76d455557e1218ec8066ae14bba60b4b36ecd55577ba979f5db7ecaa" + [[package]] name = "extensor" version = "0.1.1" +dependencies = [ + "const-default", +] diff --git a/Cargo.toml b/Cargo.toml index b583646..3aac495 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,3 +8,6 @@ keywords = ["tensor", "algebra", "math", "macros"] license-file = "LICENSE" readme = "README.md" rust-version = "1.80.0" + +[dependencies] +const-default = "1.0" diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs new file mode 100644 index 0000000..a63b93d --- /dev/null +++ b/src/algebra/mod.rs @@ -0,0 +1,39 @@ +use core::marker::PhantomData; + +use module::{Module, Vector}; +use tensor::Tensor; + +use super::*; + +pub mod quadratic_form; + +// #[derive(Clone, Copy)] +// pub struct TensorAlgebra(PhantomData); + +// impl Module for TensorAlgebra { +// type Ring = M::Ring; +// } + +// impl Add for TensorAlgebra { +// type Output = TensorAlgebra; + +// fn add(self, rhs: Self) -> Self::Output { +// todo!() +// } +// } + +// impl Neg for TensorAlgebra { +// type Output = TensorAlgebra; + +// fn neg(self) -> Self::Output { +// todo!() +// } +// } + +// impl Mul for TensorAlgebra { +// type Output = TensorAlgebra; + +// fn mul(self, rhs: M::Ring) -> Self::Output { +// todo!() +// } +// } diff --git a/src/algebra/quadratic_form.rs b/src/algebra/quadratic_form.rs new file mode 100644 index 0000000..06b8b7c --- /dev/null +++ b/src/algebra/quadratic_form.rs @@ -0,0 +1,63 @@ +use core::ops::AddAssign; + +use module::{Module, Vector}; + +use super::*; + +pub struct SymmetricIndex +where + [(); D * (D + 1) / 2]:, +{ + // Store only upper triangular indices + // Length is (D * (D + 1)) / 2 + indices: [(usize, usize); (D * (D + 1)) / 2], +} + +pub struct QuadraticForm { + eigenbasis: [Vector; D], + eigenvalues: [F; D], +} + +impl + AddAssign> QuadraticForm { + pub const fn new_diagonal(eigenvalues: [F; D]) -> Self { + let mut eigenbasis = [Vector::DEFAULT; D]; + let mut i = 0; + while i < D { + let mut j = 0; + while j < D { + if i == j { + eigenbasis[i].0[i] = eigenvalues[i]; + } + j += 1; + } + i += 1; + } + + Self { + eigenbasis, + eigenvalues, + } + } + + // TODO: Make `const` if we ever get a const `Mul` + pub fn eval(&self, lhs: Vector, rhs: Vector) -> F { + let mut sum = F::DEFAULT; + let mut i = 0; + while i < D { + // Project vectors onto eigenbasis + let mut lhs_comp = F::DEFAULT; + let mut rhs_comp = F::DEFAULT; + let mut j = 0; + while j < D { + lhs_comp += lhs.0[j] * self.eigenbasis[i].0[j]; + rhs_comp += rhs.0[j] * self.eigenbasis[i].0[j]; + j += 1; + } + + // Multiply by eigenvalue and add to sum + sum += lhs_comp * rhs_comp * self.eigenvalues[i]; + i += 1; + } + sum + } +} diff --git a/src/lib.rs b/src/lib.rs index ca36054..1f4617e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![allow(unstable_features)] #![allow(incomplete_features)] -#![feature(specialization)] +#![feature(generic_const_exprs)] +#![feature(const_trait_impl)] #![no_std] use core::{ @@ -8,9 +9,12 @@ use core::{ ops::{Add, Mul, Neg}, }; +use const_default::ConstDefault; + #[cfg(test)] #[macro_use] extern crate std; +pub mod algebra; pub mod module; pub mod tensor; diff --git a/src/module.rs b/src/module.rs index 110c4e9..b99dc11 100644 --- a/src/module.rs +++ b/src/module.rs @@ -5,7 +5,7 @@ use super::*; pub trait Module: Add + Neg + Mul + Copy { - type Ring: Add + Neg + Mul + Default + Copy; + type Ring: Add + Neg + Mul + Default + ConstDefault + Copy; } pub trait VectorSpace: Module @@ -17,30 +17,32 @@ where #[derive(Copy, Clone, Debug)] pub struct Vector(pub [F; M]); -impl Default for Vector +impl ConstDefault for Vector where - F: Default + Copy, + F: ConstDefault + Copy, { - fn default() -> Self { - Self([F::default(); M]) - } + const DEFAULT: Self = Self([F::DEFAULT; M]); } -impl + Default + Copy> Add for Vector { +impl + ConstDefault + Copy> const Add + for Vector +{ type Output = Self; fn add(self, other: Self) -> Self::Output { - let mut sum = Self::default(); - for i in 0..M { + let mut sum = Self::DEFAULT; + let mut i = 0; + while i < M { sum.0[i] = self.0[i] + other.0[i]; + i += 1; } sum } } -impl + Default + Copy> Neg for Vector { +impl + ConstDefault + Copy> Neg for Vector { type Output = Self; fn neg(self) -> Self::Output { - let mut neg = Self::default(); + let mut neg = Self::DEFAULT; for i in 0..M { neg.0[i] = -self.0[i]; } @@ -48,26 +50,30 @@ impl + Default + Copy> Neg for Vector { } } -impl + Default + Copy> Mul for Vector { +impl + ConstDefault + Copy> Mul for Vector { type Output = Self; fn mul(self, scalar: F) -> Self::Output { - let mut scalar_multiple = Self::default(); - for i in 0..M { + let mut scalar_multiple = Self::DEFAULT; + let mut i = 0; + while i < M { scalar_multiple.0[i] = scalar * self.0[i]; + i += 1; } scalar_multiple } } -impl + Neg + Mul + Default + Copy> Module - for Vector +impl< + const M: usize, + F: Add + Neg + Mul + Default + ConstDefault + Copy, + > Module for Vector { type Ring = F; } impl< const M: usize, - F: Add + Neg + Mul + Div + Default + Copy, + F: Add + Neg + Mul + Div + Default + ConstDefault + Copy, > VectorSpace for Vector { } @@ -77,7 +83,7 @@ pub struct TrivialModule { pub(crate) _r: PhantomData, } -impl Module for TrivialModule { +impl Module for TrivialModule { type Ring = R; } diff --git a/src/tensor.rs b/src/tensor.rs index 5a07276..e419d00 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -94,8 +94,10 @@ where } } -impl + Neg + Mul + Default + Copy> - From> for Tensor, TrivialModule> +impl< + const M: usize, + F: Add + Neg + Mul + Default + ConstDefault + Copy, + > From> for Tensor, TrivialModule> { fn from(value: Vector) -> Self { Self { @@ -113,15 +115,15 @@ mod tests { #[test] fn intro() { - let a = Vector::<1, f64>::default(); - let b = Vector::<2, f64>::default(); - let c = Vector::<3, f64>::default(); + let a = Vector::<1, f64>::DEFAULT; + let b = Vector::<2, f64>::DEFAULT; + let c = Vector::<3, f64>::DEFAULT; let tensor = Tensor::product(a, b); let tensor = tensor.append(c); - let a = Vector::<1, f64>::default(); - let b = Vector::<2, f64>::default(); - let c = Vector::<3, f64>::default(); + let a = Vector::<1, f64>::DEFAULT; + let b = Vector::<2, f64>::DEFAULT; + let c = Vector::<3, f64>::DEFAULT; let tensor2 = Tensor::product(a, b); let tensor2 = tensor2.append(c); @@ -131,8 +133,8 @@ mod tests { #[test] fn terminal() { - let a = Vector::<1, f64>::default(); - let b = Vector::<2, f64>::default(); + let a = Vector::<1, f64>::DEFAULT; + let b = Vector::<2, f64>::DEFAULT; let tensor = Tensor::from(a); let tensor = tensor.append_trivial(b); }