diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 62c35b64d3..78feeab2ff 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -29,7 +29,7 @@ impl Benchmark for MatmulBenchmark { } fn execute(&self, (lhs, rhs): Self::Args) { - lhs.clone().matmul(rhs.clone()); + lhs.clone().transpose().matmul(rhs.clone()); } fn prepare(&self) -> Self::Args { @@ -56,7 +56,7 @@ fn bench( let m = 256; let k = 1024; let n = 256; - let shape_lhs = [batch_size, m, k].into(); + let shape_lhs = [batch_size, k, m].into(); let shape_rhs = [batch_size, k, n].into(); let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()); diff --git a/crates/burn-cube-macros/src/analyzer.rs b/crates/burn-cube-macros/src/analyzer.rs index a180fe6257..98355408a7 100644 --- a/crates/burn-cube-macros/src/analyzer.rs +++ b/crates/burn-cube-macros/src/analyzer.rs @@ -133,10 +133,14 @@ impl VariableAnalyzer { self.find_occurrences_in_expr(&expr.cond, depth); self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth); if let Some((_, expr)) = &expr.else_branch { - if let syn::Expr::Block(expr_block) = &**expr { - self.find_occurrences_in_stmts(&expr_block.block.stmts, depth); - } else { - // Unsupported: handled in codegen. + match &**expr { + syn::Expr::Block(expr_block) => { + self.find_occurrences_in_stmts(&expr_block.block.stmts, depth); + } + syn::Expr::If(expr) => { + self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth); + } + _ => unreachable!(), } } } diff --git a/crates/burn-cube-macros/src/codegen_function/branch.rs b/crates/burn-cube-macros/src/codegen_function/branch.rs index 8def970f64..e8d132d3a0 100644 --- a/crates/burn-cube-macros/src/codegen_function/branch.rs +++ b/crates/burn-cube-macros/src/codegen_function/branch.rs @@ -120,6 +120,7 @@ pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream { /// if cond {...} /// if cond {...} else {...} /// if Comptime::get(...) {...} [else {...}] +/// if Comptime::get(...) {...} [else if Comptime::get(...) {...}]* [else {...}] pub(crate) fn codegen_if( expr_if: &syn::ExprIf, loop_level: usize, @@ -135,19 +136,19 @@ pub(crate) fn codegen_if( let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_tracker); if let Some((_, expr)) = &expr_if.else_branch { - if let syn::Expr::Block(expr_block) = &**expr { - let else_block = codegen_block(&expr_block.block, loop_level + 1, variable_tracker); + let else_block = match &**expr { + syn::Expr::Block(expr_block) => { + codegen_block(&expr_block.block, loop_level + 1, variable_tracker) + } - quote::quote! { - let _cond = #cond; - burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block); + syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level + 1, variable_tracker), + _ => unreachable!(), + }; + quote::quote! { + { + let _cond = #cond; + burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block); } - } else { - syn::Error::new_spanned( - expr, - "Unsupported: only `else` block is allowed after an `if` statement.", - ) - .into_compile_error() } } else { quote::quote! { diff --git a/crates/burn-cube/src/compute/builder.rs b/crates/burn-cube/src/compute/builder.rs index 067b87d30d..5664a9f619 100644 --- a/crates/burn-cube/src/compute/builder.rs +++ b/crates/burn-cube/src/compute/builder.rs @@ -39,6 +39,26 @@ impl KernelBuilder { self.context.scalar(index, elem) } + /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion. + pub fn output_tensor(&mut self, item: Item) -> ExpandElement { + self.outputs.push(OutputInfo::Array { item }); + let variable = self.context.output(self.num_output, item); + self.num_output += 1; + + variable + } + + /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion. + pub fn input_tensor(&mut self, item: Item) -> ExpandElement { + self.inputs.push(InputInfo::Array { + item, + visibility: Visibility::Read, + }); + let variable = self.context.input(self.num_input, item); + self.num_input += 1; + variable + } + /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion. pub fn output_array(&mut self, item: Item) -> ExpandElement { self.outputs.push(OutputInfo::Array { item }); diff --git a/crates/burn-cube/src/frontend/context.rs b/crates/burn-cube/src/frontend/context.rs index 3b632a6754..ef0378f218 100644 --- a/crates/burn-cube/src/frontend/context.rs +++ b/crates/burn-cube/src/frontend/context.rs @@ -124,6 +124,10 @@ impl CubeContext { } } + pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement { + ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size)) + } + /// Obtain the index-th input pub fn input(&mut self, index: u16, item: Item) -> ExpandElement { ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item)) diff --git a/crates/burn-cube/src/frontend/element/array.rs b/crates/burn-cube/src/frontend/element/array.rs index cce27a54aa..c845e2f84a 100644 --- a/crates/burn-cube/src/frontend/element/array.rs +++ b/crates/burn-cube/src/frontend/element/array.rs @@ -6,14 +6,17 @@ use crate::{ ir::{Item, Vectorization}, unexpanded, KernelSettings, Runtime, }; +use crate::{ + frontend::{indexation::Index, CubeContext}, + prelude::{assign, index, index_assign, Comptime}, +}; use super::{ - ArgSettings, CubePrimitive, ExpandElementTyped, Init, LaunchArg, LaunchArgExpand, TensorHandle, - UInt, + ArgSettings, CubePrimitive, ExpandElement, ExpandElementTyped, Init, LaunchArg, + LaunchArgExpand, TensorHandle, UInt, }; /// A contiguous array of elements. -#[derive(new)] pub struct Array { _val: PhantomData, } @@ -22,6 +25,77 @@ impl CubeType for Array { type ExpandType = ExpandElementTyped>; } +impl Array { + pub fn new(_size: S) -> Self { + Array { _val: PhantomData } + } + + pub fn new_expand( + context: &mut CubeContext, + size: S, + ) -> ::ExpandType { + let size = size.value(); + let size = match size { + crate::ir::Variable::ConstantScalar(val, _) => val as u32, + _ => panic!("Array need constant initialization value"), + }; + context + .create_local_array(Item::new(T::as_elem()), size) + .into() + } + + pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { + Array { _val: PhantomData } + } + + pub fn vectorized_expand( + context: &mut CubeContext, + size: S, + vectorization_factor: UInt, + ) -> ::ExpandType { + let size = size.value(); + let size = match size { + crate::ir::Variable::ConstantScalar(val, _) => val as u32, + _ => panic!("Shared memory need constant initialization value"), + }; + context + .create_local_array( + Item::vectorized(T::as_elem(), vectorization_factor.val as u8), + size, + ) + .into() + } + + pub fn to_vectorized(self, _vectorization_factor: Comptime) -> T { + unexpanded!() + } +} + +impl ExpandElementTyped> { + pub fn to_vectorized_expand( + self, + context: &mut CubeContext, + vectorization_factor: UInt, + ) -> ExpandElement { + let factor = vectorization_factor.val; + let var = self.expand.clone(); + let mut new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8)); + if vectorization_factor.val == 1 { + let element = index::expand(context, self.clone(), 0u32); + assign::expand(context, element, new_var.clone()); + } else { + for i in 0..factor { + let element = index::expand(context, self.expand.clone(), i); + new_var = index_assign::expand(context, new_var, i, element); + } + } + new_var + } +} + +impl CubeType for &Array { + type ExpandType = ExpandElementTyped>; +} impl Init for ExpandElementTyped> { fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { // The type can't be deeply cloned/copied. diff --git a/crates/burn-cube/src/frontend/element/float.rs b/crates/burn-cube/src/frontend/element/float.rs index 15292fa710..83b8b23b19 100644 --- a/crates/burn-cube/src/frontend/element/float.rs +++ b/crates/burn-cube/src/frontend/element/float.rs @@ -26,6 +26,7 @@ pub trait Float: + Erf + Recip + core::ops::Index + + core::ops::IndexMut { fn new(val: f32) -> Self; fn new_expand(context: &mut CubeContext, val: f32) -> ::ExpandType; @@ -35,6 +36,11 @@ pub trait Float: val: f32, vectorization: UInt, ) -> ::ExpandType; + fn vectorized_empty(vectorization: UInt) -> Self; + fn vectorized_empty_expand( + context: &mut CubeContext, + vectorization: UInt, + ) -> ::ExpandType; } macro_rules! impl_float { @@ -101,6 +107,21 @@ macro_rules! impl_float { new_var } } + + fn vectorized_empty(vectorization: UInt) -> Self { + Self::vectorized(0., vectorization) + } + + fn vectorized_empty_expand( + context: &mut CubeContext, + vectorization: UInt, + ) -> ::ExpandType { + if vectorization.val == 1 { + Self::new_expand(context, 0.) + } else { + context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)) + } + } } impl core::ops::Index for $type { @@ -111,6 +132,12 @@ macro_rules! impl_float { } } + impl core::ops::IndexMut for $type { + fn index_mut(&mut self, _index: UInt) -> &mut Self::Output { + unexpanded!() + } + } + impl LaunchArgExpand for $type { fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement { assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); diff --git a/crates/burn-cube/src/frontend/element/shared_memory.rs b/crates/burn-cube/src/frontend/element/shared_memory.rs index 9c9f40d5d2..92df54b00d 100644 --- a/crates/burn-cube/src/frontend/element/shared_memory.rs +++ b/crates/burn-cube/src/frontend/element/shared_memory.rs @@ -5,7 +5,7 @@ use crate::{ ir::Item, }; -use super::{ExpandElement, Init}; +use super::{ExpandElement, Init, UInt}; #[derive(Clone, Copy)] pub struct SharedMemory { @@ -49,4 +49,24 @@ impl SharedMemory { }; context.create_shared(Item::new(T::as_elem()), size) } + + pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { + SharedMemory { _val: PhantomData } + } + + pub fn vectorized_expand( + context: &mut CubeContext, + size: S, + vectorization_factor: UInt, + ) -> ::ExpandType { + let size = size.value(); + let size = match size { + crate::ir::Variable::ConstantScalar(val, _) => val as u32, + _ => panic!("Shared memory need constant initialization value"), + }; + context.create_shared( + Item::vectorized(T::as_elem(), vectorization_factor.val as u8), + size, + ) + } } diff --git a/crates/burn-cube/src/frontend/operation/binary.rs b/crates/burn-cube/src/frontend/operation/binary.rs index 8e743a4011..136ac60a2d 100644 --- a/crates/burn-cube/src/frontend/operation/binary.rs +++ b/crates/burn-cube/src/frontend/operation/binary.rs @@ -51,8 +51,8 @@ pub mod sub { impl core::ops::Sub for $type { type Output = Self; - fn sub(self, _rhs: Self) -> Self::Output { - unexpanded!() + fn sub(self, rhs: Self) -> Self::Output { + (self.val - rhs.val).into() } } }; @@ -83,8 +83,8 @@ pub mod mul { impl core::ops::Mul for $type { type Output = Self; - fn mul(self, _rhs: Self) -> Self::Output { - unexpanded!() + fn mul(self, rhs: Self) -> Self::Output { + (self.val * rhs.val).into() } } }; @@ -115,8 +115,8 @@ pub mod div { impl core::ops::Div for $type { type Output = Self; - fn div(self, _rhs: Self) -> Self::Output { - unexpanded!() + fn div(self, rhs: Self) -> Self::Output { + (self.val / rhs.val).into() } } }; diff --git a/crates/burn-cube/tests/error/if_else_if.rs b/crates/burn-cube/tests/error/if_else_if.rs deleted file mode 100644 index 01042252de..0000000000 --- a/crates/burn-cube/tests/error/if_else_if.rs +++ /dev/null @@ -1,10 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -fn range(x: UInt, y: UInt) { - if x == y { - } else if x != y { - } -} - -fn main() {} diff --git a/crates/burn-cube/tests/error/if_else_if.stderr b/crates/burn-cube/tests/error/if_else_if.stderr deleted file mode 100644 index 338ab405e1..0000000000 --- a/crates/burn-cube/tests/error/if_else_if.stderr +++ /dev/null @@ -1,7 +0,0 @@ -error: Unsupported: only `else` block is allowed after an `if` statement. - --> tests/error/if_else_if.rs:6:12 - | -6 | } else if x != y { - | ____________^ -7 | | } - | |_____^ diff --git a/crates/burn-cube/tests/frontend/array.rs b/crates/burn-cube/tests/frontend/array.rs index c1bf3068f3..093b2e7c81 100644 --- a/crates/burn-cube/tests/frontend/array.rs +++ b/crates/burn-cube/tests/frontend/array.rs @@ -1,5 +1,27 @@ use burn_cube::prelude::*; +#[cube] +fn array_read_write(array_size: Comptime) { + let mut array = Array::::new(array_size); + array[0] = T::from_int(3); + let _ = array[0]; +} + +#[cube] +fn array_to_vectorized_variable() -> T { + let mut array = Array::::new(2); + array[0] = T::from_int(0); + array[1] = T::from_int(1); + array.to_vectorized(Comptime::new(UInt::new(2))) +} + +#[cube] +fn array_of_one_to_vectorized_variable() -> T { + let mut array = Array::::new(1); + array[0] = T::from_int(3); + array.to_vectorized(Comptime::new(UInt::new(1))) +} + #[cube] fn array_add_assign_simple(array: &mut Array) { array[UInt::new(1)] += UInt::new(1); @@ -17,6 +39,19 @@ mod tests { ir::{Elem, Item, Variable}, }; + type ElemType = F32; + + #[test] + fn cube_support_array() { + let mut context = CubeContext::root(); + + array_read_write_expand::(&mut context, 512); + assert_eq!( + format!("{:?}", context.into_scope().operations), + inline_macro_ref_read_write() + ) + } + #[test] fn array_add_assign() { let mut context = CubeContext::root(); @@ -31,6 +66,48 @@ mod tests { ); } + #[test] + fn cube_array_to_vectorized() { + let mut context = CubeContext::root(); + + array_to_vectorized_variable_expand::(&mut context); + assert_eq!( + format!("{:?}", context.into_scope().operations), + inline_macro_ref_to_vectorized() + ); + } + + #[test] + fn cube_array_of_one_to_vectorized() { + let mut context = CubeContext::root(); + + array_of_one_to_vectorized_variable_expand::(&mut context); + assert_eq!( + format!("{:?}", context.into_scope().operations), + inline_macro_ref_one_to_vectorized() + ); + } + + fn inline_macro_ref_read_write() -> String { + let context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + + let mut scope = context.into_scope(); + let var = scope.create_local(item); + let pos: Variable = 0u32.into(); + + // Create + let array = scope.create_local_array(item, 512); + + // Write + cpa!(scope, array[pos] = 3.0_f32); + + // Read + cpa!(scope, var = array[pos]); + + format!("{:?}", scope.operations) + } + #[test] fn array_add_assign_expr() { let mut context = CubeContext::root(); @@ -62,6 +139,46 @@ mod tests { format!("{:?}", scope.operations) } + fn inline_macro_ref_to_vectorized() -> String { + let context = CubeContext::root(); + let scalar_item = Item::new(ElemType::as_elem()); + let vectorized_item = Item::vectorized(ElemType::as_elem(), 2); + + let mut scope = context.into_scope(); + let pos0: Variable = 0u32.into(); + let pos1: Variable = 1u32.into(); + let array = scope.create_local_array(scalar_item, 2); + cpa!(scope, array[pos0] = 0.0_f32); + cpa!(scope, array[pos1] = 1.0_f32); + + let vectorized_var = scope.create_local(vectorized_item); + let tmp = scope.create_local(scalar_item); + cpa!(scope, tmp = array[pos0]); + cpa!(scope, vectorized_var[pos0] = tmp); + cpa!(scope, tmp = array[pos1]); + cpa!(scope, vectorized_var[pos1] = tmp); + + format!("{:?}", scope.operations) + } + + fn inline_macro_ref_one_to_vectorized() -> String { + let context = CubeContext::root(); + let scalar_item = Item::new(ElemType::as_elem()); + let unvectorized_item = Item::new(ElemType::as_elem()); + + let mut scope = context.into_scope(); + let pos0: Variable = 0u32.into(); + let array = scope.create_local_array(scalar_item, 1); + cpa!(scope, array[pos0] = 3.0_f32); + + let unvectorized_var = scope.create_local(unvectorized_item); + let tmp = scope.create_local(scalar_item); + cpa!(scope, tmp = array[pos0]); + cpa!(scope, unvectorized_var = tmp); + + format!("{:?}", scope.operations) + } + fn inline_macro_array_add_assign_expr() -> String { let context = CubeContext::root(); diff --git a/crates/burn-cube/tests/frontend/comptime.rs b/crates/burn-cube/tests/frontend/comptime.rs index d4ef987876..e0790a75ec 100644 --- a/crates/burn-cube/tests/frontend/comptime.rs +++ b/crates/burn-cube/tests/frontend/comptime.rs @@ -21,6 +21,55 @@ pub fn comptime_if_else(lhs: T, cond: Comptime) { } } +#[cube] +#[allow(clippy::collapsible_else_if)] +pub fn comptime_else_then_if(lhs: T, cond1: Comptime, cond2: Comptime) { + if Comptime::get(cond1) { + let _ = lhs + T::from_int(4); + } else { + if Comptime::get(cond2) { + let _ = lhs + T::from_int(5); + } else { + let _ = lhs - T::from_int(6); + } + } +} + +#[cube] +pub fn comptime_elsif(lhs: T, cond1: Comptime, cond2: Comptime) { + if Comptime::get(cond1) { + let _ = lhs + T::from_int(4); + } else if Comptime::get(cond2) { + let _ = lhs + T::from_int(5); + } else { + let _ = lhs - T::from_int(6); + } +} + +#[cube] +pub fn comptime_elsif_with_runtime1(lhs: T, comptime_cond: Comptime) { + let runtime_cond = lhs >= T::from_int(2); + if Comptime::get(comptime_cond) { + let _ = lhs + T::from_int(4); + } else if runtime_cond { + let _ = lhs + T::from_int(5); + } else { + let _ = lhs - T::from_int(6); + } +} + +#[cube] +pub fn comptime_elsif_with_runtime2(lhs: T, comptime_cond: Comptime) { + let runtime_cond = lhs >= T::from_int(2); + if runtime_cond { + let _ = lhs + T::from_int(4); + } else if Comptime::get(comptime_cond) { + let _ = lhs + T::from_int(5); + } else { + let _ = lhs - T::from_int(6); + } +} + #[cube] pub fn comptime_if_expr(lhs: T, x: Comptime, y: Comptime) { let y2 = x + y; @@ -62,7 +111,7 @@ mod tests { use burn_cube::{ cpa, frontend::{CubeContext, CubePrimitive, F32}, - ir::{Item, Variable}, + ir::{Elem, Item, Variable}, }; type ElemType = F32; @@ -96,6 +145,7 @@ mod tests { inline_macro_ref_comptime(true) ); } + #[test] fn cube_comptime_else_test() { let mut context = CubeContext::root(); @@ -111,6 +161,58 @@ mod tests { ); } + #[test] + fn cube_comptime_elsif_test() { + for cond1 in [false, true] { + for cond2 in [false, true] { + let mut context1 = CubeContext::root(); + let lhs = context1.create_local(Item::new(ElemType::as_elem())); + comptime_else_then_if_expand::(&mut context1, lhs, cond1, cond2); + let scope1 = context1.into_scope(); + + let mut context2 = CubeContext::root(); + let lhs = context2.create_local(Item::new(ElemType::as_elem())); + comptime_elsif_expand::(&mut context2, lhs, cond1, cond2); + let scope2 = context2.into_scope(); + + assert_eq!( + format!("{:?}", scope1.operations), + format!("{:?}", scope2.operations), + ); + } + } + } + + #[test] + fn cube_comptime_elsif_runtime1_test() { + for cond in [false, true] { + let mut context = CubeContext::root(); + let lhs = context.create_local(Item::new(ElemType::as_elem())); + comptime_elsif_with_runtime1_expand::(&mut context, lhs, cond); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_elsif_runtime1(cond) + ); + } + } + + #[test] + fn cube_comptime_elsif_runtime2_test() { + for cond in [false, true] { + let mut context = CubeContext::root(); + let lhs = context.create_local(Item::new(ElemType::as_elem())); + comptime_elsif_with_runtime2_expand::(&mut context, lhs, cond); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_elsif_runtime2(cond) + ); + } + } + #[test] fn cube_comptime_map_bool_test() { let mut context1 = CubeContext::root(); @@ -170,4 +272,52 @@ mod tests { format!("{:?}", scope.operations) } + + fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String { + let mut context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let runtime_cond = scope.create_local(Item::new(Elem::Bool)); + let y = scope.create_local(item); + cpa!(scope, runtime_cond = x >= 2.0f32); + + if comptime_cond { + cpa!(scope, y = x + 4.0f32); + } else { + cpa!(&mut scope, if(runtime_cond).then(|scope| { + cpa!(scope, y = x + 5.0f32); + }).else(|scope| { + cpa!(scope, y = x - 6.0f32); + })); + }; + + format!("{:?}", scope.operations) + } + + fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String { + let mut context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let runtime_cond = scope.create_local(Item::new(Elem::Bool)); + let y = scope.create_local(item); + cpa!(scope, runtime_cond = x >= 2.0f32); + + cpa!(&mut scope, if(runtime_cond).then(|scope| { + cpa!(scope, y = x + 4.0f32); + }).else(|scope| { + if comptime_cond { + cpa!(scope, y = x + 5.0f32); + } else { + cpa!(scope, y = x - 6.0f32); + } + })); + + format!("{:?}", scope.operations) + } } diff --git a/crates/burn-cube/tests/frontend/for_loop.rs b/crates/burn-cube/tests/frontend/for_loop.rs index d102b11de1..75e888109a 100644 --- a/crates/burn-cube/tests/frontend/for_loop.rs +++ b/crates/burn-cube/tests/frontend/for_loop.rs @@ -17,10 +17,7 @@ pub fn for_loop(mut lhs: Array, rhs: F, end: UInt, unroll: Comptime } mod tests { - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; + use burn_cube::{cpa, ir::Item}; use super::*; @@ -29,7 +26,7 @@ mod tests { let mut context = CubeContext::root(); let unroll = true; - let lhs = context.create_local(Item::new(ElemType::as_elem())); + let lhs = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); let rhs = context.create_local(Item::new(ElemType::as_elem())); let end = 4u32.into(); @@ -44,7 +41,7 @@ mod tests { let mut context = CubeContext::root(); let unroll = false; - let lhs = context.create_local(Item::new(ElemType::as_elem())); + let lhs = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); let rhs = context.create_local(Item::new(ElemType::as_elem())); let end = 4u32.into(); @@ -55,15 +52,13 @@ mod tests { } fn inline_macro_ref(unroll: bool) -> String { - let mut context = CubeContext::root(); + let context = CubeContext::root(); let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local(item); - let rhs = context.create_local(item); - let lhs: Variable = lhs.into(); - let rhs: Variable = rhs.into(); - let end = 4u32; let mut scope = context.into_scope(); + let lhs = scope.create_local_array(item, 4u32); + let rhs = scope.create_local(item); + let end = 4u32; // Kernel let tmp1 = scope.create_local(item); diff --git a/crates/burn-cube/tests/frontend/if.rs b/crates/burn-cube/tests/frontend/if.rs index 0703786416..200113d56c 100644 --- a/crates/burn-cube/tests/frontend/if.rs +++ b/crates/burn-cube/tests/frontend/if.rs @@ -24,6 +24,17 @@ pub fn if_then_else(lhs: F) { } } +#[cube] +pub fn elsif(lhs: F) { + if lhs < F::new(0.) { + let _ = lhs + F::new(2.); + } else if lhs > F::new(0.) { + let _ = lhs + F::new(1.); + } else { + let _ = lhs + F::new(0.); + } +} + mod tests { use burn_cube::{ cpa, @@ -62,6 +73,18 @@ mod tests { ); } + #[test] + fn cube_elsif_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::new(ElemType::as_elem())); + + elsif_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif()); + } + fn inline_macro_ref_if() -> String { let mut context = CubeContext::root(); let item = Item::new(ElemType::as_elem()); @@ -99,4 +122,30 @@ mod tests { format!("{:?}", scope.operations) } + + fn inline_macro_ref_elsif() -> String { + let mut context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond1 = scope.create_local(Item::new(Elem::Bool)); + let lhs: Variable = lhs.into(); + let y = scope.create_local(item); + let cond2 = scope.create_local(Item::new(Elem::Bool)); + + cpa!(scope, cond1 = lhs < 0f32); + cpa!(&mut scope, if(cond1).then(|scope| { + cpa!(scope, y = lhs + 2.0f32); + }).else(|mut scope|{ + cpa!(scope, cond2 = lhs > 0f32); + cpa!(&mut scope, if(cond2).then(|scope| { + cpa!(scope, y = lhs + 1.0f32); + }).else(|scope|{ + cpa!(scope, y = lhs + 0.0f32); + })); + })); + + format!("{:?}", scope.operations) + } } diff --git a/crates/burn-cube/tests/frontend/trait.rs b/crates/burn-cube/tests/frontend/trait.rs index 2bbd064c2a..65fd55df54 100644 --- a/crates/burn-cube/tests/frontend/trait.rs +++ b/crates/burn-cube/tests/frontend/trait.rs @@ -3,13 +3,9 @@ use burn_cube::prelude::*; /// Traits used in Cube kernels must expose an _expand variant /// for all their methods. However, one does not need to provide its /// implementation, see examples below. +#[cube] trait Strategy { fn operation(input_1: T, input_2: T) -> T; - fn operation_expand( - context: &mut CubeContext, - input_1: ::ExpandType, - input_2: ::ExpandType, - ) -> ::ExpandType; } struct AddStrategy; @@ -21,40 +17,19 @@ fn add_strategy_operation(input_1: T, input_2: T) -> T { input_1 + input_2 } +#[cube] impl Strategy for AddStrategy { - /// Here we link the trait's method to the cube function fn operation(input_1: T, input_2: T) -> T { - add_strategy_operation(input_1, input_2) - } - - /// Here we link the trait's expanded method to the cube expanded function - fn operation_expand( - context: &mut CubeContext, - input_1: ::ExpandType, - input_2: ::ExpandType, - ) -> ::ExpandType { - add_strategy_operation_expand::(context, input_1, input_2) + add_strategy_operation::(input_1, input_2) } } struct SubStrategy; #[cube] -fn sub_strategy_operation(input_1: T, input_2: T) -> T { - input_1 - input_2 -} - impl Strategy for SubStrategy { fn operation(input_1: T, input_2: T) -> T { - sub_strategy_operation(input_1, input_2) - } - - fn operation_expand( - context: &mut CubeContext, - input_1: ::ExpandType, - input_2: ::ExpandType, - ) -> ::ExpandType { - sub_strategy_operation_expand::(context, input_1, input_2) + input_1 - input_2 } } diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 0e0257186e..4b914636cd 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,88 +1,12 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; -use burn_cube::{prelude::*, Compiler}; +use burn_cube::prelude::*; use burn_tensor::Shape; -use std::cmp::{max, min}; use super::{ - init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded, + config::Tiling2dConfig, init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, + matmul_tiling_2d_cube, matmul_tiling_2d_padded, }; -#[derive(Debug, Clone)] -/// Tiling 2D parameters -pub struct Tiling2dConfig { - /// Number of invocations in x - pub grid_x: usize, - /// Number of invocations in y - pub grid_y: usize, - /// Block size along dimension of lhs - pub block_size_m: usize, - /// Block size along common dimension - pub block_size_k: usize, - /// Block size along dimension of rhs - pub block_size_n: usize, - /// Tile size along dimension of lhs - pub tile_size_m: usize, - /// Tile size along dimension of rhs - pub tile_size_n: usize, - /// Loop unrolling - pub unroll: bool, -} - -impl Tiling2dConfig { - #[allow(unused, clippy::too_many_arguments)] - fn new( - grid_x: usize, - grid_y: usize, - block_size_m: usize, - block_size_k: usize, - block_size_n: usize, - tile_size_m: usize, - tile_size_n: usize, - unroll: bool, - ) -> Self { - assert!(grid_x == f32::ceil(block_size_m as f32 / tile_size_m as f32) as usize); - assert!(grid_y == f32::ceil(block_size_n as f32 / tile_size_n as f32) as usize); - assert!( - block_size_k <= min(block_size_m, block_size_n), - "Not enough invocations to fill shared memory" - ); - assert!( - block_size_k * max(block_size_m, block_size_n) - <= ::max_shared_memory_size(), - "Shared memory limit will be busted. " - ); - assert!( - block_size_m % tile_size_m == 0 && block_size_n % tile_size_n == 0, - "Tile size must divide block size in m and n dimensions" - ); - Self { - grid_x, - grid_y, - block_size_m, - block_size_k, - block_size_n, - tile_size_m, - tile_size_n, - unroll, - } - } -} - -impl Default for Tiling2dConfig { - fn default() -> Self { - Self { - grid_x: 16, - grid_y: 16, - block_size_m: 64, - block_size_k: 32, - block_size_n: 64, - tile_size_m: 4, - tile_size_n: 4, - unroll: false, - } - } -} - /// The strategy to be used when launching a matmul kernel. pub enum MatmulStrategy { /// A simple kernel will be used with memory coalescing optimization. @@ -99,6 +23,8 @@ pub enum MatmulStrategy { #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. Autotune, + /// A tiling 2d kernel with everything vectorized, and comptime bound checks + Tiling2dCube(Tiling2dConfig), } #[allow(clippy::derivable_impls)] // Necessary otherwise the feature flags dont' work. @@ -109,13 +35,6 @@ impl Default for MatmulStrategy { } } -#[cfg(not(feature = "autotune"))] -impl Default for MatmulStrategy { - fn default() -> Self { - MatmulStrategy::Tiling2d(Tiling2dConfig::default()) - } -} - /// Launch a matmul kernel using the given strategy. pub fn matmul( lhs: JitTensor, @@ -135,6 +54,10 @@ pub fn matmul( let out = init_matmul_output(&lhs, &rhs); matmul_tiling_2d_padded(lhs, rhs, out, config) } + MatmulStrategy::Tiling2dCube(config) => { + let out = init_matmul_output(&lhs, &rhs); + matmul_tiling_2d_cube(lhs, rhs, out, config) + } #[cfg(feature = "autotune")] MatmulStrategy::Autotune => matmul_autotune(lhs, rhs), } @@ -159,20 +82,3 @@ pub(crate) fn simple_cube_count( CubeCount::Static(cubes_x, cubes_y, num_iter as u32) } - -pub(crate) fn tiling2d_launch_options( - output_shape: &Shape, - config: Tiling2dConfig, -) -> CubeCount { - let num_rows = output_shape.dims[D - 2]; - let num_cols = output_shape.dims[D - 1]; - - let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; - let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output_shape.dims[i]; - } - - CubeCount::Static(cubes_x, cubes_y, num_iter as u32) -} diff --git a/crates/burn-jit/src/kernel/matmul/config.rs b/crates/burn-jit/src/kernel/matmul/config.rs new file mode 100644 index 0000000000..22b15f5f1f --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/config.rs @@ -0,0 +1,130 @@ +use burn_cube::{ + compute::CubeCount, + frontend::{CubeContext, Init, UInt}, + ir::CubeDim, +}; +use burn_tensor::Shape; + +use crate::JitRuntime; + +#[derive(Debug, Clone)] +/// Tiling 2D parameters +pub struct Tiling2dConfig { + /// Block size along dimension of lhs + pub block_size_m: usize, + /// Block size along common dimension + pub block_size_k: usize, + /// Block size along dimension of rhs + pub block_size_n: usize, + /// Tile size and shared memory vectorization + pub tile_size: usize, + /// Loop unrolling + pub unroll: bool, +} + +impl Default for Tiling2dConfig { + fn default() -> Self { + Self { + block_size_m: 64, + block_size_k: 32, + block_size_n: 64, + tile_size: 4, + unroll: false, + } + } +} + +impl Init for CubeTiling2dConfig { + fn init(self, _context: &mut CubeContext) -> Self { + self + } +} + +#[derive(Debug, Clone, Copy)] +/// Tiling 2D parameters +pub struct CubeTiling2dConfig { + /// Block size along dimension of lhs + pub block_size_m: UInt, + /// Block size along common dimension + pub block_size_k: UInt, + /// Block size along dimension of rhs + pub block_size_n: UInt, + /// Loop unrolling for inner compute loop. Probably slower + pub unroll_compute: bool, + /// Loop unrolling for all loops related to vectorization/tile size. Probably faster + pub unroll_tile: bool, + /// Bounds must be checked on lhs dimension + pub check_m_bounds: bool, + /// Bounds must be checked on common dimension + pub check_k_bounds: bool, + /// Bounds must be checked on rhs dimension + pub check_n_bounds: bool, + /// Tile size. Should correspond to vectorization of inputs/outputs/shared memory + pub tile_size: UInt, + /// Lhs is transposed in global memory + pub lhs_transposed: bool, + /// Rhs is transposed in global memory + pub rhs_transposed: bool, +} + +impl CubeTiling2dConfig { + pub fn new( + config: &Tiling2dConfig, + m: usize, + k: usize, + n: usize, + lhs_transposed: bool, + rhs_transposed: bool, + ) -> Self { + assert!( + config.block_size_k <= config.block_size_m + && config.block_size_k <= config.block_size_n, + "Larger block size in k than m or n results in unfilled shared memory." + ); + assert!( + config.block_size_m % config.tile_size == 0 + && config.block_size_k % config.tile_size == 0 + && config.block_size_n % config.tile_size == 0, + "Tiling 2d algorithm assumes tile size divides block size perfectly. " + ); + + CubeTiling2dConfig { + block_size_m: UInt::new(config.block_size_m as u32), + block_size_k: UInt::new(config.block_size_k as u32), + block_size_n: UInt::new(config.block_size_n as u32), + unroll_compute: config.unroll, + unroll_tile: true, + check_m_bounds: m % config.block_size_m != 0, + check_k_bounds: k % config.block_size_k != 0, + check_n_bounds: n % config.block_size_n != 0, + tile_size: UInt::new(config.tile_size as u32), + lhs_transposed, + rhs_transposed, + } + } +} + +pub fn tiling2d_cube_count( + output_shape: &Shape, + config: &Tiling2dConfig, +) -> CubeCount { + let num_rows = output_shape.dims[D - 2]; + let num_cols = output_shape.dims[D - 1]; + + let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; + let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32; + let mut num_iter = 1; + for i in 0..D - 2 { + num_iter *= output_shape.dims[i]; + } + + CubeCount::Static(cubes_x, cubes_y, num_iter as u32) +} + +pub fn tiling2d_cube_dim(config: &Tiling2dConfig) -> CubeDim { + CubeDim::new( + (config.block_size_m / config.tile_size) as u32, + (config.block_size_n / config.tile_size) as u32, + 1, + ) +} diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index 324827ea41..e1e4e30daa 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,6 +1,12 @@ mod base; +mod config; mod simple; mod tiling2d; +#[cfg(not(feature = "export_tests"))] +mod tiling2d_cube; +#[cfg(feature = "export_tests")] +/// Tiling 2d cube functions +pub mod tiling2d_cube; mod tiling2d_shader; mod tune; @@ -19,4 +25,6 @@ pub mod padding; #[cfg(not(feature = "export_tests"))] mod padding; +pub use config::Tiling2dConfig; pub use tiling2d::*; +pub use tiling2d_cube::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index dea73cd020..cb284b8fe0 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -1,6 +1,6 @@ use burn_cube::{ frontend::TensorHandle, - ir::{BinaryOperator, CubeDim, Elem, FloatKind, KernelDefinition, Scope, Variable, Visibility}, + ir::{BinaryOperator, Elem, FloatKind, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -8,18 +8,19 @@ use burn_tensor::{Element, Shape}; use crate::{ element::JitElement, - kernel::{into_contiguous, Kernel}, - tensor::JitTensor, + kernel::{into_contiguous, matmul::config::tiling2d_cube_dim, Kernel}, + tensor::{JitTensor, MatrixLayout}, JitRuntime, }; use std::marker::PhantomData; use super::{ + config::tiling2d_cube_count, padding::{crop, pad_round, PaddingOutput}, - shape_out, tiling2d_launch_options, + shape_out, tiling2d_shader::MatmulTiling2dShader, - Tiling2dConfig, }; +use crate::kernel::matmul::config::Tiling2dConfig; #[derive(new, Debug)] struct MatmulTiling2dEagerKernel { @@ -68,11 +69,7 @@ impl Kernel for MatmulTiling2dEagerKernel { scope, }; - let settings = KernelSettings::default().cube_dim(CubeDim::new( - self.config.grid_x as u32, - self.config.grid_y as u32, - 1, - )); + let settings = KernelSettings::default().cube_dim(tiling2d_cube_dim(&self.config)); KernelIntegrator::new(info).integrate(settings) } @@ -99,14 +96,16 @@ pub fn matmul_tiling_2d( let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), bounds_check_required); let client = lhs.client.clone(); - let lhs = match lhs.batch_swapped_with_row_col() { - true => into_contiguous(lhs), - false => lhs, - }; - let rhs = match rhs.batch_swapped_with_row_col() { - true => into_contiguous(rhs), - false => rhs, + let check_layout = |tensor: JitTensor| match tensor.matrix_layout() { + MatrixLayout::Contiguous => (tensor, false), + MatrixLayout::MildlyPermuted { + transposed, + batch_swap: _, + } => (tensor, transposed), + MatrixLayout::HighlyPermuted => (into_contiguous(tensor), false), }; + let (lhs, _lhs_transposed) = check_layout(lhs); + let (rhs, _rhs_transposed) = check_layout(rhs); Execution::start(kernel, client) .inputs(&[ @@ -118,8 +117,8 @@ pub fn matmul_tiling_2d( &out.strides, &out.shape.dims, )]) - .execute(CubeCountSettings::Custom(tiling2d_launch_options::( - &out.shape, config, + .execute(CubeCountSettings::Custom(tiling2d_cube_count::( + &out.shape, &config, ))); out @@ -141,14 +140,18 @@ pub fn matmul_tiling_2d_padded(lhs, config.block_size_m, config.block_size_k); let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + PaddingOutput::Unchanged(tensor) + if tensor.matrix_layout() == MatrixLayout::HighlyPermuted => + { into_contiguous(tensor) } _ => round_lhs.into_tensor(), }; let round_rhs = pad_round::(rhs, config.block_size_k, config.block_size_n); let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => { + PaddingOutput::Unchanged(tensor) + if tensor.matrix_layout() == MatrixLayout::HighlyPermuted => + { into_contiguous(tensor) } _ => round_rhs.into_tensor(), @@ -175,9 +178,9 @@ pub fn matmul_tiling_2d_padded( + .execute(CubeCountSettings::Custom(tiling2d_cube_count::( &rounded_output.shape, - config, + &config, ))); crop(rounded_output, out) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs new file mode 100644 index 0000000000..6268da11a5 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs @@ -0,0 +1,156 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::config::CubeTiling2dConfig; + +use super::block_loop::{block_loop, block_loop_expand}; + +#[cube(launch)] +#[allow(unused_mut)] +fn tiling2d_cube( + lhs: &Tensor, + rhs: &Tensor, + out: &mut Tensor, + config: Comptime, +) { + let dims = get_dims::(lhs, rhs); + let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config); + let offsets = calculate_batch_offsets::(lhs, rhs, out, CUBE_POS_Z); + let shared_memories = make_shared_memories::(config); + block_loop::( + lhs, + rhs, + out, + coordinates, + offsets, + shared_memories, + config, + dims, + ); +} + +#[derive(CubeType, Copy, Clone)] +/// Information available at runtime only +/// Strides assume contiguous +pub(crate) struct Dimensions { + pub m: UInt, + pub k: UInt, + pub n: UInt, +} + +#[derive(CubeType, Copy, Clone)] +pub(crate) struct SharedMemories { + pub lhs: SharedMemory, + pub rhs: SharedMemory, +} + +#[derive(CubeType, Copy, Clone)] +/// Number of elements in previous batches +/// Not divided by vectorization facto +pub(crate) struct BatchOffsets { + pub lhs: UInt, + pub rhs: UInt, + pub out: UInt, +} + +#[derive(CubeType, Copy, Clone)] +pub(crate) struct Coordinates { + pub unit_row: UInt, + pub unit_col: UInt, + pub skip_row: UInt, + pub skip_col: UInt, +} + +#[cube] +fn get_dims(lhs: &Tensor, rhs: &Tensor) -> Dimensions { + let rank = lhs.rank(); + let first_dim = rank - UInt::new(2); + let second_dim = rank - UInt::new(1); + let m = lhs.shape(first_dim); + let k = lhs.shape(second_dim); + let n = rhs.shape(second_dim); + + Dimensions { m, k, n } +} + +#[cube] +fn calculate_coordinates( + cube_pos_x: UInt, + cube_pos_y: UInt, + unit_pos: UInt, + config: Comptime, +) -> Coordinates { + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let block_size_n = Comptime::map(config, |c| c.block_size_n); + let tile_size = Comptime::map(config, |c| c.tile_size); + + let n_units_per_row = ((Comptime::runtime(block_size_n) - UInt::new(1)) + / Comptime::runtime(tile_size)) + + UInt::new(1); + + // Cube offset + let skip_row = cube_pos_x * Comptime::runtime(block_size_m); + let skip_col = cube_pos_y * Comptime::runtime(block_size_n); + + // Position of the first element of the unit, relative to the cube + let unit_row = (unit_pos / n_units_per_row) * Comptime::runtime(tile_size); + let unit_col = (unit_pos % n_units_per_row) * Comptime::runtime(tile_size); + + Coordinates { + unit_row, + unit_col, + skip_row, + skip_col, + } +} + +#[cube] +#[allow(unused_mut)] +fn calculate_batch_offsets( + lhs: &Tensor, + rhs: &Tensor, + out: &Tensor, + batch_number: UInt, +) -> BatchOffsets { + let rank = out.rank(); + + let dim_m = lhs.shape(rank - UInt::new(2)); + let dim_n = rhs.shape(rank - UInt::new(1)); + + // Batch offset for output + let mut offset_out = dim_m * dim_n * batch_number; + let mut offset_lhs = UInt::new(0); + let mut offset_rhs = UInt::new(0); + + // Batch offset for lhs, rhs + for b in range(0u32, rank - UInt::new(2), Comptime::new(false)) { + let tmp = offset_out / out.stride(b); + offset_lhs += tmp % lhs.shape(b) * lhs.stride(b); + offset_rhs += tmp % rhs.shape(b) * rhs.stride(b); + } + + BatchOffsets { + lhs: offset_lhs, + rhs: offset_rhs, + out: offset_out, + } +} + +#[cube] +fn make_shared_memories(config: Comptime) -> SharedMemories { + let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let block_size_k = Comptime::map(config, |c| c.block_size_k); + let block_size_n = Comptime::map(config, |c| c.block_size_n); + + let lhs = SharedMemory::::vectorized( + Comptime::get(block_size_k * block_size_m / tile_size), + Comptime::get(tile_size), + ); + + let rhs = SharedMemory::::vectorized( + Comptime::get(block_size_k * block_size_n / tile_size), + Comptime::get(tile_size), + ); + + SharedMemories { lhs, rhs } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs new file mode 100644 index 0000000000..123f991fe1 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs @@ -0,0 +1,82 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::config::CubeTiling2dConfig; + +use super::{ + base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, + compute_loop::{compute_loop, compute_loop_expand}, + load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand}, + tile::{loader::TileLoader, writer::TileWriter}, + write_output::{write_to_output, write_to_output_expand}, +}; + +#[cube] +pub(crate) fn block_loop( + lhs: &Tensor, + rhs: &Tensor, + out: &mut Tensor, + coordinates: Coordinates, + offsets: BatchOffsets, + shared: SharedMemories, + config: Comptime, + dims: Dimensions, +) { + let block_size_k = Comptime::map(config, |c| c.block_size_k); + let mut results = init_results::(config); + + let n_loops = calculate_n_loops::(lhs.shape(lhs.rank() - UInt::new(1)), config); + + for k in range(0u32, n_loops, Comptime::new(false)) { + let k = k * Comptime::runtime(block_size_k); + + load_to_shared_memories::>( + lhs, + rhs, + coordinates, + k, + offsets, + shared, + config, + dims, + ); + + sync_units(); + + compute_loop::(coordinates, shared.lhs, shared.rhs, &mut results, config); + + sync_units(); + } + + write_to_output::>(out, &results, coordinates, offsets.out, dims, config); +} + +#[cube] +fn init_results(config: Comptime) -> Array { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + + let mut results = Array::::new(Comptime::get(tile_size * tile_size)); + for i in range(0u32, Comptime::get(tile_size * tile_size), unroll) { + results[i] = F::new(0.); + } + + results +} + +#[cube] +#[allow(unused_assignments)] +fn calculate_n_loops(dim_k: UInt, config: Comptime) -> UInt { + let block_size_k = Comptime::map(config, |c| c.block_size_k); + let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); + + let mut n_loops = UInt::new(0); // TODO support syntax let x = if ... else ... + if Comptime::get(check_k_bounds) { + n_loops = UInt::cast_from(F::ceil( + F::cast_from(dim_k) / F::cast_from(Comptime::runtime(block_size_k)), + )); + } else { + n_loops = dim_k / Comptime::runtime(block_size_k); + } + + n_loops +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs new file mode 100644 index 0000000000..082421b252 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs @@ -0,0 +1,155 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::config::CubeTiling2dConfig; + +use super::{ + base::Coordinates, + outer_product::{tile_outer_product, tile_outer_product_expand}, +}; + +#[cube] +#[allow(unused_mut)] +pub(crate) fn compute_loop( + coordinates: Coordinates, + shared_lhs: SharedMemory, + shared_rhs: SharedMemory, + results: &mut Array, + config: Comptime, +) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); + let block_size_n = Comptime::map(config, |c| c.block_size_n); + let unroll = Comptime::map(config, |c| c.unroll_compute); + + let unit_row = coordinates.unit_row; + let unit_col = coordinates.unit_col; + + for dot_index in range(0u32, block_size_k, unroll) { + let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m)) + / Comptime::runtime(tile_size)]; + let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n)) + / Comptime::runtime(tile_size)]; + + tile_outer_product::(register_m, register_n, results, config); + } +} + +#[cfg(feature = "export_tests")] +/// Compute loop exported tests +pub mod tests { + use crate::{ + kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::test_utils::{ + assert_equals, create_empty, make_config, range_tensor, range_tensor_transposed, + TILE_SIZE, + }, + }, + JitRuntime, + }; + + use super::{super::base::CoordinatesExpand, *}; + + #[cube(launch)] + fn compute_loop_test( + lhs: Tensor, + rhs: Tensor, + unit_row: UInt, + unit_col: UInt, + results: &mut Array, + config: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let block_size_k = Comptime::map(config, |c| c.block_size_m); + let block_size_n = Comptime::map(config, |c| c.block_size_m); + let sm_size_lhs = block_size_m * block_size_k / tile_size; + let sm_size_rhs = block_size_n * block_size_k / tile_size; + + // Shared memories are not launchable, so we launch with tensor and convert to shared memory + let mut shared_lhs = + SharedMemory::::vectorized(Comptime::get(sm_size_lhs), Comptime::get(tile_size)); + for i in range(0u32, lhs.len(), Comptime::new(false)) { + shared_lhs[i] = lhs[i]; + } + + let mut shared_rhs = + SharedMemory::::vectorized(Comptime::get(sm_size_rhs), Comptime::get(tile_size)); + for i in range(0u32, rhs.len(), Comptime::new(false)) { + shared_rhs[i] = rhs[i]; + } + + for i in range(0u32, 16u32, Comptime::new(false)) { + results[i] = F::new(0.); + } + + let coordinates = Coordinates { + unit_row, + unit_col, + skip_row: UInt::new(0), + skip_col: UInt::new(0), + }; + + compute_loop(coordinates, shared_lhs, shared_rhs, results, config) + } + + /// Exported test + pub fn compute_loop_unit_test(device: &R::Device) { + let lhs = range_tensor::(8, 8, device); + let rhs = range_tensor::(8, 8, device); + let results = create_empty::(TILE_SIZE, TILE_SIZE, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + const SOME_DIM: usize = 12; + let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); + + compute_loop_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape.dims), + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), + ScalarArg::new(0), + ScalarArg::new(0), + ArrayArg::new(&results, 1), + config, + ); + + let expected = &[ + 8960.0, 9184.0, 9408.0, 9632.0, 9184.0, 9416.0, 9648.0, 9880.0, 9408.0, 9648.0, 9888.0, + 10128.0, 9632.0, 9880.0, 10128.0, 10376.0, + ]; + assert_equals::(results, expected, device); + } + + /// Exported test + pub fn compute_loop_unit_offset_test(device: &R::Device) { + let lhs = range_tensor_transposed::(8, 4, device); + let rhs = range_tensor::(4, 8, device); + let results = create_empty::(TILE_SIZE, TILE_SIZE, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(4, 8, 4); + + compute_loop_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape.dims), + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), + ScalarArg::new(4), + ScalarArg::new(4), + ArrayArg::new(&results, 1), + config, + ); + + let expected = &[ + 1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0, + 1978.0, 1928.0, 2046.0, 2164.0, 2282.0, + ]; + assert_equals::(results, expected, device); + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs new file mode 100644 index 0000000000..b196ff4f4d --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs @@ -0,0 +1,98 @@ +use std::cmp::max; + +use burn_cube::{frontend::TensorArg, Compiler}; + +use crate::{ + kernel::{ + into_contiguous, + matmul::{ + config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig}, + Tiling2dConfig, + }, + }, + tensor::{JitTensor, MatrixLayout}, + FloatElement, JitRuntime, +}; + +use super::base::tiling2d_cube_launch; + +/// Matrix multiplication using tiling 2d algorithm +pub fn matmul_tiling_2d_cube( + lhs: JitTensor, + rhs: JitTensor, + out: JitTensor, + config: Tiling2dConfig, +) -> JitTensor { + assert!( + config.block_size_k * max(config.block_size_m, config.block_size_n) + <= ::max_shared_memory_size(), + "Shared memory limit will be busted. " + ); + + let m = lhs.shape.dims[D - 2]; + let k = lhs.shape.dims[D - 1]; + let n = rhs.shape.dims[D - 1]; + + let client = lhs.client.clone(); + + let check_layout = |tensor: JitTensor| match tensor.matrix_layout() { + MatrixLayout::Contiguous => (tensor, false), + MatrixLayout::MildlyPermuted { + transposed, + batch_swap: _, + } => (tensor, transposed), + MatrixLayout::HighlyPermuted => (into_contiguous(tensor), false), + }; + let (lhs, lhs_transposed) = check_layout(lhs); + let (rhs, rhs_transposed) = check_layout(rhs); + + let vectorization = |shape: usize| { + [4, 2] + .into_iter() + .filter(|v| shape % v == 0) + .map(|v| v as u8) + .next() + .unwrap_or(1) + }; + + let lhs_vectorization = match lhs_transposed { + true => vectorization(m), + false => 1, + }; + let rhs_vectorization = match rhs_transposed { + true => 1, + false => vectorization(n), + }; + let out_vectorization = vectorization(n); + + let cube_count = tiling2d_cube_count::(&out.shape, &config); + let cube_dim = tiling2d_cube_dim(&config); + let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed); + + tiling2d_cube_launch::( + client, + cube_count, + cube_dim, + TensorArg::vectorized( + lhs_vectorization, + &lhs.handle, + &lhs.strides, + &lhs.shape.dims, + ), + TensorArg::vectorized( + rhs_vectorization, + &rhs.handle, + &rhs.strides, + &rhs.shape.dims, + ), + TensorArg::vectorized( + out_vectorization, + &out.handle, + &out.strides, + &out.shape.dims, + ), + cube_config, + ); + + out +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs new file mode 100644 index 0000000000..aae7e8804d --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs @@ -0,0 +1,723 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::config::CubeTiling2dConfig; + +use super::{ + base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, + tile::block_io::{ + base::BlockLoader, horizontal_block_check::HorizontalCheckBlockIO, + unchecked_block::UncheckedBlockIO, vertical_block_check::VerticalCheckBlockIO, + whole_block_check::WholeCheckBlockIO, + }, +}; + +#[derive(CubeType)] +#[allow(dead_code)] +pub(crate) struct LoadInfo { + pub coordinates: Coordinates, + pub k: UInt, + pub batch_offset: UInt, + pub shared_memory: SharedMemory, + pub config: Comptime, + pub dims: Dimensions, +} + +#[cube] +pub(crate) trait Loader: Sync + Send + 'static { + fn load_lhs_plain>(lhs: &Tensor, load_info: LoadInfo); + fn load_lhs_transposed>(lhs: &Tensor, load_info: LoadInfo); + fn load_rhs_plain>(rhs: &Tensor, load_info: LoadInfo); + fn load_rhs_transposed>(rhs: &Tensor, load_info: LoadInfo); +} + +#[cube] +pub(crate) fn load_to_shared_memories>( + lhs: &Tensor, + rhs: &Tensor, + coordinates: Coordinates, + k: UInt, + offsets: BatchOffsets, + shared: SharedMemories, + config: Comptime, + dims: Dimensions, +) { + let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed); + let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed); + + let lhs_load_info = LoadInfo { + coordinates, + k, + batch_offset: offsets.lhs, + shared_memory: shared.lhs, + config, + dims, + }; + let rhs_load_info = LoadInfo { + coordinates, + k, + batch_offset: offsets.rhs, + shared_memory: shared.rhs, + config, + dims, + }; + + // Lhs must be loaded as transposed. If it already is transposed in global memory, we load as plain. + if Comptime::get(lhs_transposed) { + load_lhs_plain::(lhs, lhs_load_info, config); + } else { + load_lhs_transposed::(lhs, lhs_load_info, config); + } + + // Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back. + if Comptime::get(rhs_transposed) { + load_rhs_transposed::(rhs, rhs_load_info, config); + } else { + load_rhs_plain::(rhs, rhs_load_info, config); + } +} + +#[cube] +fn load_lhs_transposed>( + lhs: &Tensor, + load_info: LoadInfo, + config: Comptime, +) { + let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); + let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); + + if Comptime::get(check_m_bounds) { + if Comptime::get(check_k_bounds) { + L::load_lhs_transposed::(lhs, load_info); + } else { + L::load_lhs_transposed::(lhs, load_info); + } + } else if Comptime::get(check_k_bounds) { + L::load_lhs_transposed::(lhs, load_info); + } else { + L::load_lhs_transposed::(lhs, load_info); + } +} + +#[cube] +fn load_lhs_plain>( + lhs: &Tensor, + load_info: LoadInfo, + config: Comptime, +) { + let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); + let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); + + if Comptime::get(check_k_bounds) { + if Comptime::get(check_m_bounds) { + L::load_lhs_plain::(lhs, load_info); + } else { + L::load_lhs_plain::(lhs, load_info); + } + } else if Comptime::get(check_m_bounds) { + L::load_lhs_plain::(lhs, load_info); + } else { + L::load_lhs_plain::(lhs, load_info); + } +} + +#[cube] +fn load_rhs_transposed>( + rhs: &Tensor, + load_info: LoadInfo, + config: Comptime, +) { + let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); + let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); + + if Comptime::get(check_n_bounds) { + if Comptime::get(check_k_bounds) { + L::load_rhs_transposed::(rhs, load_info); + } else { + L::load_rhs_transposed::(rhs, load_info); + } + } else if Comptime::get(check_k_bounds) { + L::load_rhs_transposed::(rhs, load_info); + } else { + L::load_rhs_transposed::(rhs, load_info); + } +} + +#[cube] +fn load_rhs_plain>( + rhs: &Tensor, + load_info: LoadInfo, + config: Comptime, +) { + let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); + let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); + + if Comptime::get(check_k_bounds) { + if Comptime::get(check_n_bounds) { + L::load_rhs_plain::(rhs, load_info); + } else { + L::load_rhs_plain::(rhs, load_info); + } + } else if Comptime::get(check_n_bounds) { + L::load_rhs_plain::(rhs, load_info); + } else { + L::load_rhs_plain::(rhs, load_info); + } +} + +#[cfg(feature = "export_tests")] +/// Exported tests for loading to shared memory +pub mod tests { + use crate::kernel::matmul::tiling2d_cube::load_shared_memory::LoadInfoExpand; + use crate::kernel::matmul::tiling2d_cube::test_utils::{ + assert_equals, create_empty, make_config, range_tensor, TILE_SIZE, + }; + use crate::kernel::matmul::tiling2d_cube::tile::loader::TileLoader; + use crate::JitRuntime; + + use super::{ + super::base::{CoordinatesExpand, DimensionsExpand}, + *, + }; + + #[cube(launch)] + fn load_tensor_test( + tensor: &Tensor, + sm_out: &mut Array, + unit_row: UInt, + unit_col: UInt, + k: UInt, + config: Comptime, + is_lhs: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_k = Comptime::map(config, |c| c.block_size_k); + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let sm_size = block_size_k * block_size_m / tile_size; + let shared_memory = + SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); + + let batch_offset = UInt::new(0); + + let coordinates = Coordinates { + unit_row, + unit_col, + skip_row: UInt::new(0), + skip_col: UInt::new(0), + }; + + if Comptime::get(is_lhs) { + let dims = Dimensions { + m: tensor.shape(tensor.rank() - UInt::new(2)), + k: tensor.shape(tensor.rank() - UInt::new(1)), + n: UInt::new(0), + }; + let info = LoadInfo { + coordinates, + k, + batch_offset, + shared_memory, + config, + dims, + }; + + load_lhs_transposed::>(tensor, info, config); + } else { + let dims = Dimensions { + m: UInt::new(0), + k: tensor.shape(tensor.rank() - UInt::new(2)), + n: tensor.shape(tensor.rank() - UInt::new(1)), + }; + let info = LoadInfo { + coordinates, + k, + batch_offset, + shared_memory, + config, + dims, + }; + + load_rhs_plain::>(tensor, info, config); + } + + for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + sm_out[i] = shared_memory[i]; + } + } + + #[cube(launch)] + fn load_tensor_permuted_test( + tensor: &Tensor, + sm_out: &mut Array, + unit_row: UInt, + unit_col: UInt, + k: UInt, + config: Comptime, + is_lhs: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_k = Comptime::map(config, |c| c.block_size_k); + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let sm_size = block_size_k * block_size_m / tile_size; + let shared_memory = + SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); + + let batch_offset = UInt::new(0); + + let coordinates = Coordinates { + unit_row, + unit_col, + skip_row: UInt::new(0), + skip_col: UInt::new(0), + }; + + if Comptime::get(is_lhs) { + // Permuted + let dims = Dimensions { + m: tensor.shape(tensor.rank() - UInt::new(1)), + k: tensor.shape(tensor.rank() - UInt::new(2)), + n: UInt::new(0), + }; + let info = LoadInfo { + coordinates, + k, + batch_offset, + shared_memory, + config, + dims, + }; + + load_lhs_plain::>(tensor, info, config); + } else { + // Permuted + let dims = Dimensions { + m: UInt::new(0), + k: tensor.shape(tensor.rank() - UInt::new(1)), + n: tensor.shape(tensor.rank() - UInt::new(2)), + }; + let info = LoadInfo { + coordinates, + k, + batch_offset, + shared_memory, + config, + dims, + }; + + load_rhs_transposed::>(tensor, info, config); + } + + for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + sm_out[i] = shared_memory[i]; + } + } + + #[cube(launch)] + fn load_tensor_multiple_tiles_test( + tensor: &Tensor, + sm_out: &mut Array, + k: UInt, + config: Comptime, + is_lhs: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_k = Comptime::map(config, |c| c.block_size_k); + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let sm_size = block_size_k * block_size_m / tile_size; + let shared_memory = + SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); + + let unit_row = UInt::new(4) * UNIT_POS_X; + let unit_col = UInt::new(4) * UNIT_POS_Y; + let batch_offset = UInt::new(0); + + let coordinates = Coordinates { + unit_row, + unit_col, + skip_row: UInt::new(0), + skip_col: UInt::new(0), + }; + + if Comptime::get(is_lhs) { + let dims = Dimensions { + m: tensor.shape(tensor.rank() - UInt::new(2)), + k: tensor.shape(tensor.rank() - UInt::new(1)), + n: UInt::new(0), + }; + let info = LoadInfo { + coordinates, + k, + batch_offset, + shared_memory, + config, + dims, + }; + + load_lhs_transposed::>(tensor, info, config); + } else { + let dims = Dimensions { + m: UInt::new(0), + k: tensor.shape(tensor.rank() - UInt::new(2)), + n: tensor.shape(tensor.rank() - UInt::new(1)), + }; + let info = LoadInfo { + coordinates, + k, + batch_offset, + shared_memory, + config, + dims, + }; + + load_rhs_plain::>(tensor, info, config); + } + + for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + sm_out[i] = shared_memory[i]; + } + } + + /// Exported test + pub fn load_lhs_transposed_unit_test(device: &R::Device) { + let lhs = range_tensor::(16, 16, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(16, 16, 8); + + load_tensor_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0, + 0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Device) { + let vectorization_factor = 1; + let lhs = range_tensor::(5, 1, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(2, 2, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(5, 1, 1); + + load_tensor_multiple_tiles_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized( + vectorization_factor as u8, + &lhs.handle, + &lhs.strides, + &lhs.shape.dims, + ), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(0), + config, + true, + ); + + let expected = &[ + 0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_lhs_transposed_cube_test(device: &R::Device) { + let lhs = range_tensor::(8, 8, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(2, 2, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, 8, 8); + + load_tensor_multiple_tiles_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(0), + config, + true, + ); + + let expected = &[ + 0.0, 8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0, + 57.0, 2.0, 10.0, 18.0, 26.0, 34.0, 42.0, 50.0, 58.0, 3.0, 11.0, 19.0, 27.0, 35.0, 43.0, + 51.0, 59.0, 4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 5.0, 13.0, 21.0, 29.0, 37.0, + 45.0, 53.0, 61.0, 6.0, 14.0, 22.0, 30.0, 38.0, 46.0, 54.0, 62.0, 7.0, 15.0, 23.0, 31.0, + 39.0, 47.0, 55.0, 63.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { + let lhs = range_tensor::(8, 16, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(2, 2, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, 8, 16); + + load_tensor_multiple_tiles_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(8), + config, + true, + ); + + let expected = &[ + 8.0, 24.0, 40.0, 56.0, 72.0, 88.0, 104.0, 120.0, 9.0, 25.0, 41.0, 57.0, 73.0, 89.0, + 105.0, 121.0, 10.0, 26.0, 42.0, 58.0, 74.0, 90.0, 106.0, 122.0, 11.0, 27.0, 43.0, 59.0, + 75.0, 91.0, 107.0, 123.0, 12.0, 28.0, 44.0, 60.0, 76.0, 92.0, 108.0, 124.0, 13.0, 29.0, + 45.0, 61.0, 77.0, 93.0, 109.0, 125.0, 14.0, 30.0, 46.0, 62.0, 78.0, 94.0, 110.0, 126.0, + 15.0, 31.0, 47.0, 63.0, 79.0, 95.0, 111.0, 127.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_rhs_plain_unit_test(device: &R::Device) { + let rhs = range_tensor::(16, 16, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, 16, 16); + + load_tensor_test_launch::( + rhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0, + 0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0, + 246.0, 247.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_rhs_plain_cube_test(device: &R::Device) { + let rhs = range_tensor::(8, 8, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(2, 2, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, 8, 8); + + load_tensor_multiple_tiles_test_launch::( + rhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(0), + config, + false, + ); + + let expected = &[ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, + 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, + 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, + 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, + 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { + let rhs = range_tensor::(16, 8, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(2, 2, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(16, 16, 8); + + load_tensor_multiple_tiles_test_launch::( + rhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(8), + config, + false, + ); + + let expected = &[ + 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, + 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, + 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, + 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, + 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_lhs_plain_unit_test(device: &R::Device) { + let lhs = range_tensor::(16, 16, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(16, 16, 8); + + load_tensor_permuted_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0, + 0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0, + 246.0, 247.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { + let (m, k) = (6, 14); + let lhs = range_tensor::(k, m, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(m, k, 8); + + load_tensor_permuted_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + true, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 76.0, 77.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_rhs_transposed_unit_test(device: &R::Device) { + let rhs = range_tensor::(16, 16, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(16, 16, 8); + + load_tensor_permuted_test_launch::( + rhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0, + 0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Device) { + let (k, n) = (14, 6); + let rhs = range_tensor::(n, k, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, k, n); + + load_tensor_permuted_test_launch::( + rhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + ScalarArg::new(4), + ScalarArg::new(4), + ScalarArg::new(8), + config, + false, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 68.0, 82.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 69.0, 83.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; + assert_equals::(sm_out, expected, device); + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs new file mode 100644 index 0000000000..d971c92ff1 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs @@ -0,0 +1,19 @@ +mod base; +mod block_loop; +mod compute_loop; +mod launch; +mod load_shared_memory; +mod outer_product; +#[cfg(feature = "export_tests")] +mod test_utils; +mod tile; +mod write_output; + +pub use launch::matmul_tiling_2d_cube; + +#[cfg(feature = "export_tests")] +pub use { + compute_loop::tests as compute_loop_tests, + load_shared_memory::tests as load_shared_memory_tests, + outer_product::tests as outer_product_tests, write_output::tests as write_output_tests, +}; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs new file mode 100644 index 0000000000..fb75125390 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs @@ -0,0 +1,118 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::config::CubeTiling2dConfig; + +#[cube] +pub(crate) fn tile_outer_product( + register_m: F, + register_n: F, + results: &mut Array, + config: Comptime, +) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + + for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) { + let res_pos_base = res_idx_m * Comptime::runtime(tile_size); + for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) { + let mul = register_m[res_idx_m] * register_n[res_idx_n]; + results[res_pos_base + res_idx_n] += mul; + } + } +} + +#[cfg(feature = "export_tests")] +/// Exported tests for outer product +pub mod tests { + use crate::{ + kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::test_utils::{assert_equals, create_empty, make_config}, + }, + JitRuntime, + }; + + use super::*; + + #[cube(launch)] + #[allow(unused_mut)] + fn tile_outer_product_test( + register_m: Array, + register_n: Array, + results: &mut Array, + config: Comptime, + ) { + // We launch with array then convert to vectorized float, + // because direct launch of vectorized float is not supported + let tile_size = Comptime::map(config, |c| c.tile_size); + let register_m = register_m.to_vectorized(tile_size); + let register_n = register_n.to_vectorized(tile_size); + + for i in range( + 0u32, + Comptime::get(tile_size * tile_size), + Comptime::new(false), + ) { + results[i] = F::new(0.); + } + tile_outer_product::(register_m, register_n, results, config) + } + + /// Exported test + pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { + let client = R::client(device); + let register_m = client.create(f32::as_bytes(&[0., 1., 2., 3.])); + let register_n = client.create(f32::as_bytes(&[1., 2., 3., 4.])); + let results = create_empty::(4, 4, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + const SOME_DIM: usize = 12; + let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); + + tile_outer_product_test_launch::( + client.clone(), + cube_count, + cube_dim, + ArrayArg::new(®ister_m, 4), + ArrayArg::new(®ister_n, 4), + ArrayArg::new(&results, 16), + config, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, + ]; + assert_equals::(results, expected, device); + } + + /// Exported test + pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) { + let client = R::client(device); + + let register_m = client.create(f32::as_bytes(&[16., 20., 24., 28.])); + let register_n = client.create(f32::as_bytes(&[4., 5., 6., 7.])); + let results = create_empty::(4, 4, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + const SOME_DIM: usize = 12; + let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); + + tile_outer_product_test_launch::( + client.clone(), + cube_count, + cube_dim, + ArrayArg::new(®ister_m, 4), + ArrayArg::new(®ister_n, 4), + ArrayArg::new(&results, 16), + config, + ); + + let expected = &[ + 64.0, 80.0, 96.0, 112.0, 80.0, 100.0, 120.0, 140.0, 96.0, 120.0, 144.0, 168.0, 112.0, + 140.0, 168.0, 196.0, + ]; + assert_equals::(results, expected, device); + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/test_utils.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/test_utils.rs new file mode 100644 index 0000000000..4dacb8e7da --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/test_utils.rs @@ -0,0 +1,89 @@ +use burn_compute::server::Handle; +use burn_cube::CubeElement; + +use crate::{ + kernel::matmul::config::{CubeTiling2dConfig, Tiling2dConfig}, + tensor::JitTensor, + JitBackend, JitRuntime, +}; + +pub(crate) const TILE_SIZE: usize = 4; + +pub(crate) fn range_tensor( + x: usize, + y: usize, + device: &R::Device, +) -> JitTensor { + type B = JitBackend; + + let n_elements = (x * y) as i64; + burn_tensor::Tensor::, 1, burn_tensor::Int>::arange(0..n_elements, device) + .reshape([x, y]) + .float() + .into_primitive() + .tensor() +} + +pub(crate) fn range_tensor_transposed( + x: usize, + y: usize, + device: &R::Device, +) -> JitTensor { + type B = JitBackend; + + let n_elements = (x * y) as i64; + + burn_tensor::Tensor::, 2>::from_data( + burn_tensor::Tensor::, 1, burn_tensor::Int>::arange(0..n_elements, device) + .reshape([x, y]) + .float() + .transpose() + .into_data(), + device, + ) + .into_primitive() + .tensor() +} + +pub(crate) fn zeros_tensor( + x: usize, + y: usize, + device: &R::Device, +) -> JitTensor { + type B = JitBackend; + burn_tensor::Tensor::, 2>::zeros([x, y], device) + .into_primitive() + .tensor() +} + +pub(crate) fn create_empty( + x: usize, + y: usize, + device: &R::Device, +) -> Handle<::JitServer> { + let client = R::client(device); + client.empty(x * y * core::mem::size_of::()) +} + +pub(crate) fn assert_equals( + output: Handle<::JitServer>, + expected: &[f32], + device: &R::Device, +) { + let client = R::client(device); + + let actual = client.read(output.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual, expected); +} + +pub(crate) fn make_config(m: usize, k: usize, n: usize) -> CubeTiling2dConfig { + let tiling2d_config = Tiling2dConfig { + block_size_m: 8, + block_size_k: 8, + block_size_n: 8, + ..Default::default() + }; + CubeTiling2dConfig::new(&tiling2d_config, m, k, n, false, false) +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/base.rs new file mode 100644 index 0000000000..4d50b86aa2 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/base.rs @@ -0,0 +1,78 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::{ + tile::{ + loader::{CheckBounds, ReadTileInfo}, + memory_access::ContiguousAccess, + }, + write_output::WriteTileInfo, + }, +}; + +#[cube] +pub(crate) trait BlockLoader: Send + Sync + 'static { + fn load_tile_plain>( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + read_tile_info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ); + + fn load_tile_transposed( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + read_tile_info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ); +} + +#[cube] +pub(crate) trait BlockWriter: Send + Sync + 'static { + fn write_output>( + out: &mut Tensor, + results: &Array, + write_tile_info: WriteTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ); +} + +#[cube] +pub(crate) fn all_zeros_runtime( + shared_memory: &mut SharedMemory, + start: UInt, + sm_position_base: UInt, + sm_stride: UInt, + config: Comptime, +) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let zeros = F::vectorized(0., Comptime::get(tile_size)); + + for i in range(start, Comptime::get(tile_size), Comptime::new(false)) { + let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = zeros; + } +} + +#[cube] +pub(crate) fn all_zeros_comptime( + shared_memory: &mut SharedMemory, + sm_position_base: UInt, + sm_stride: UInt, + config: Comptime, +) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + let zeros = F::vectorized(0., Comptime::get(tile_size)); + + for i in range(0u32, Comptime::get(tile_size), unroll) { + let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = zeros; + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs new file mode 100644 index 0000000000..5bd140fb45 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs @@ -0,0 +1,121 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::{ + tile::{ + loader::{CheckBounds, ReadTileInfo}, + memory_access::{ + ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, + WritePositionsExpand, + }, + }, + write_output::WriteTileInfo, + }, +}; + +use super::base::{ + all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand, + BlockLoader, BlockWriter, +}; + +pub(crate) struct HorizontalCheckBlockIO; + +#[cube] +impl BlockLoader for HorizontalCheckBlockIO { + fn load_tile_plain>( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let vectorization = Comptime::vectorization(&tensor); + let unroll = Comptime::map(config, |c| c.unroll_tile); + + let col = check_bounds.skip_col + info.read_col; + if check_bounds.dim_horizontal > col { + for i in range(0u32, Comptime::get(tile_size), unroll) { + let gm_position = + (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = + A::read_contiguous_checked(tensor, gm_position, check_bounds, info, config); + } + } else { + all_zeros_comptime(shared_memory, info.sm_position_base, info.sm_stride, config); + } + } + + fn load_tile_transposed( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + + let mut num_reads = UInt::new(0); + let col = check_bounds.skip_col + info.read_col; + let dim_horizontal = check_bounds.dim_horizontal; + if dim_horizontal > col { + num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size)); + } + + for i in range(0u32, num_reads, Comptime::new(false)) { + let gm_position = info.gm_position_base + i; + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = UnmatchingVectorization::read_strided_unchecked( + tensor, + gm_position, + info.gm_stride, + config, + ); + } + + all_zeros_runtime( + shared_memory, + num_reads, + info.sm_position_base, + info.sm_stride, + config, + ); + } +} + +#[cube] +impl BlockWriter for HorizontalCheckBlockIO { + fn write_output>( + out: &mut Tensor, + results: &Array, + info: WriteTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + let coordinates = info.coordinates; + + let col = coordinates.skip_col + coordinates.unit_col; + + if check_bounds.dim_horizontal > col { + let row = coordinates.skip_row + coordinates.unit_row; + let out_position_base = row * info.out_stride + col + info.offset_output; + + for result_index in range(0u32, Comptime::get(tile_size), unroll) { + let positions = WritePositions { + result: result_index * Comptime::runtime(tile_size), + out: out_position_base + result_index * info.out_stride, + }; + + A::write_contiguous_checked(out, results, positions, check_bounds, col, config); + } + } + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/mod.rs new file mode 100644 index 0000000000..50c913843b --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod horizontal_block_check; +pub mod unchecked_block; +pub mod vertical_block_check; +pub mod whole_block_check; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/unchecked_block.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/unchecked_block.rs new file mode 100644 index 0000000000..d695f3da65 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/unchecked_block.rs @@ -0,0 +1,96 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::{ + tile::{ + loader::{CheckBounds, ReadTileInfo}, + memory_access::{ + ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, + WritePositionsExpand, + }, + }, + write_output::WriteTileInfo, + }, +}; + +use super::base::{BlockLoader, BlockWriter}; + +/// Assumes block sizes divide tensor shape +pub(crate) struct UncheckedBlockIO; + +#[cube] +impl BlockLoader for UncheckedBlockIO { + fn load_tile_plain>( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + _check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + let vectorization = Comptime::vectorization(&tensor); + + for i in range(0u32, Comptime::get(tile_size), unroll) { + let gm_position = + (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = A::read_contiguous_unchecked(tensor, gm_position, config); + } + } + + fn load_tile_transposed( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + _check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + + for i in range(0u32, Comptime::get(tile_size), unroll) { + let gm_position = info.gm_position_base + i; + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = UnmatchingVectorization::read_strided_unchecked( + tensor, + gm_position, + info.gm_stride, + config, + ); + } + } +} + +#[cube] +impl BlockWriter for UncheckedBlockIO { + fn write_output>( + out: &mut Tensor, + results: &Array, + info: WriteTileInfo, + config: Comptime, + _check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + let coordinates = info.coordinates; + + let row = coordinates.skip_row + coordinates.unit_row; + let col = coordinates.skip_col + coordinates.unit_col; + let out_position_base = row * info.out_stride + col + info.offset_output; + + for result_index in range(0u32, Comptime::get(tile_size), unroll) { + let positions = WritePositions { + result: result_index * Comptime::runtime(tile_size), + out: out_position_base + result_index * info.out_stride, + }; + + A::write_contiguous_unchecked(out, results, positions, config); + } + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs new file mode 100644 index 0000000000..677978e7fe --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs @@ -0,0 +1,120 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::{ + tile::{ + loader::{CheckBounds, ReadTileInfo}, + memory_access::{ + ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, + WritePositionsExpand, + }, + }, + write_output::WriteTileInfo, + }, +}; + +use super::base::{all_zeros_runtime, all_zeros_runtime_expand, BlockLoader, BlockWriter}; + +pub(crate) struct VerticalCheckBlockIO; + +#[cube] +impl BlockLoader for VerticalCheckBlockIO { + fn load_tile_plain>( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let vectorization = Comptime::vectorization(&tensor); + + let mut num_reads = UInt::new(0); + let row = check_bounds.skip_row + info.read_row; + if check_bounds.dim_vertical > row { + num_reads = UInt::min( + check_bounds.dim_vertical - row, + Comptime::runtime(tile_size), + ); + } + + for i in range(0u32, num_reads, Comptime::new(false)) { + let gm_position = + (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = A::read_contiguous_unchecked(tensor, gm_position, config); + } + + all_zeros_runtime( + shared_memory, + num_reads, + info.sm_position_base, + info.sm_stride, + config, + ); + } + + fn load_tile_transposed( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + + for i in range(0u32, Comptime::get(tile_size), unroll) { + let gm_position = info.gm_position_base + i; + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = UnmatchingVectorization::read_strided_checked( + tensor, + gm_position, + info.gm_stride, + check_bounds, + info, + config, + ); + } + } +} + +#[cube] +impl BlockWriter for VerticalCheckBlockIO { + fn write_output>( + out: &mut Tensor, + results: &Array, + info: WriteTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let coordinates = info.coordinates; + + let row = coordinates.skip_row + coordinates.unit_row; + let col = coordinates.skip_col + coordinates.unit_col; + let out_position_base = row * info.out_stride + col + info.offset_output; + + let mut num_writes = UInt::new(0); + if check_bounds.dim_vertical > row { + num_writes = UInt::min( + check_bounds.dim_vertical - row, + Comptime::runtime(tile_size), + ); + } + + for result_index in range(0u32, num_writes, Comptime::new(false)) { + let positions = WritePositions { + result: result_index * Comptime::runtime(tile_size), + out: out_position_base + result_index * info.out_stride, + }; + + A::write_contiguous_unchecked(out, results, positions, config); + } + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs new file mode 100644 index 0000000000..274c79181c --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs @@ -0,0 +1,146 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::{ + tile::{ + loader::{CheckBounds, ReadTileInfo}, + memory_access::{ + ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, + WritePositionsExpand, + }, + }, + write_output::WriteTileInfo, + }, +}; + +use super::base::{ + all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand, + BlockLoader, BlockWriter, +}; + +pub(crate) struct WholeCheckBlockIO; + +#[cube] +impl BlockLoader for WholeCheckBlockIO { + fn load_tile_plain>( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let vectorization = Comptime::vectorization(&tensor); + + let col = check_bounds.skip_col + info.read_col; + if check_bounds.dim_horizontal > col { + let mut num_reads_vertical = UInt::new(0); + let row = check_bounds.skip_row + info.read_row; + if check_bounds.dim_vertical > row { + num_reads_vertical = UInt::min( + check_bounds.dim_vertical - row, + Comptime::runtime(tile_size), + ); + } + + for i in range(0u32, num_reads_vertical, Comptime::new(false)) { + let gm_position = + (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = + A::read_contiguous_checked(tensor, gm_position, check_bounds, info, config); + } + + all_zeros_runtime( + shared_memory, + num_reads_vertical, + info.sm_position_base, + info.sm_stride, + config, + ); + } else { + all_zeros_comptime(shared_memory, info.sm_position_base, info.sm_stride, config); + } + } + fn load_tile_transposed( + tensor: &Tensor, + shared_memory: &mut SharedMemory, + info: ReadTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + + let mut num_reads_horizontal = UInt::new(0); + let col = check_bounds.skip_col + info.read_col; + let dim_horizontal = check_bounds.dim_horizontal; + if dim_horizontal > col { + num_reads_horizontal = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size)); + } + + for i in range(0u32, num_reads_horizontal, Comptime::new(false)) { + let gm_position = info.gm_position_base + i; + let sm_position = + (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + + shared_memory[sm_position] = UnmatchingVectorization::read_strided_checked( + tensor, + gm_position, + info.gm_stride, + check_bounds, + info, + config, + ); + } + + all_zeros_runtime( + shared_memory, + num_reads_horizontal, + info.sm_position_base, + info.sm_stride, + config, + ); + } +} + +#[cube] +impl BlockWriter for WholeCheckBlockIO { + fn write_output>( + out: &mut Tensor, + results: &Array, + info: WriteTileInfo, + config: Comptime, + check_bounds: CheckBounds, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let coordinates = info.coordinates; + + let col = coordinates.skip_col + coordinates.unit_col; + + if check_bounds.dim_horizontal > col { + let mut num_writes_vertical = UInt::new(0); + let row = coordinates.skip_row + coordinates.unit_row; + + if check_bounds.dim_vertical > row { + num_writes_vertical = UInt::min( + check_bounds.dim_vertical - row, + Comptime::runtime(tile_size), + ); + } + + let out_position_base = row * info.out_stride + col + info.offset_output; + + for result_index in range(0u32, num_writes_vertical, Comptime::new(false)) { + let positions = WritePositions { + result: result_index * Comptime::runtime(tile_size), + out: out_position_base + result_index * info.out_stride, + }; + + A::write_contiguous_checked(out, results, positions, check_bounds, col, config); + } + } + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/loader.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/loader.rs new file mode 100644 index 0000000000..4df21a1df0 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/loader.rs @@ -0,0 +1,216 @@ +use std::marker::PhantomData; + +use burn_cube::prelude::*; + +use crate::kernel::matmul::tiling2d_cube::load_shared_memory::{LoadInfo, Loader}; + +use super::{ + block_io::base::BlockLoader, + memory_access::{MatchingVectorization, UnmatchingVectorization}, +}; + +// Transposed tensor's vectorization must be 1 +// Plain tensor's vectorization must equal tile size +pub(crate) struct TileLoader { + _f: PhantomData, +} + +#[derive(CubeType)] +pub(crate) struct LoadIndices { + pub offset: UInt, + pub gm_stride: UInt, + pub sm_stride: UInt, +} + +#[derive(CubeType, Copy, Clone)] +pub(crate) struct CheckBounds { + pub dim_vertical: UInt, + pub dim_horizontal: UInt, + pub skip_row: UInt, + pub skip_col: UInt, +} + +#[derive(CubeType, Copy, Clone)] +pub(crate) struct ReadTileInfo { + pub read_row: UInt, + pub read_col: UInt, + pub gm_position_base: UInt, + pub sm_position_base: UInt, + pub gm_stride: UInt, + pub sm_stride: UInt, +} + +#[cube] +impl Loader for TileLoader { + fn load_lhs_plain>(lhs: &Tensor, load_info: LoadInfo) { + let config = load_info.config; + let dims = load_info.dims; + let coordinates = load_info.coordinates; + let gm_stride = dims.m; + + let load_indices = LoadIndices { + offset: coordinates.skip_row + load_info.k * gm_stride + load_info.batch_offset, + gm_stride, + sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), + }; + let check_bounds = CheckBounds { + dim_vertical: dims.k, + dim_horizontal: dims.m, + skip_row: load_info.k, + skip_col: coordinates.skip_row, + }; + + load_plain::(lhs, load_info, load_indices, check_bounds); + } + + fn load_lhs_transposed>(lhs: &Tensor, load_info: LoadInfo) { + let config = load_info.config; + let dims = load_info.dims; + let coordinates = load_info.coordinates; + let gm_stride = dims.k; + + let load_indices = LoadIndices { + offset: coordinates.skip_row * gm_stride + load_info.k + load_info.batch_offset, + gm_stride, + sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_m)), + }; + let check_bounds = CheckBounds { + dim_vertical: dims.m, + dim_horizontal: dims.k, + skip_row: coordinates.skip_row, + skip_col: load_info.k, + }; + + load_transposed::(lhs, load_info, load_indices, check_bounds); + } + + fn load_rhs_plain>(rhs: &Tensor, load_info: LoadInfo) { + let coordinates = load_info.coordinates; + let dims = load_info.dims; + let config = load_info.config; + let gm_stride = dims.n; + + let load_indices = LoadIndices { + offset: coordinates.skip_col + load_info.k * gm_stride + load_info.batch_offset, + gm_stride, + sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), + }; + let check_bounds = CheckBounds { + dim_vertical: dims.k, + dim_horizontal: dims.n, + skip_row: load_info.k, + skip_col: coordinates.skip_col, + }; + + load_plain::(rhs, load_info, load_indices, check_bounds); + } + + fn load_rhs_transposed>(rhs: &Tensor, load_info: LoadInfo) { + let config = load_info.config; + let dims = load_info.dims; + let coordinates = load_info.coordinates; + let gm_stride = dims.k; + + let load_indices = LoadIndices { + offset: coordinates.skip_col * gm_stride + load_info.k + load_info.batch_offset, + gm_stride, + sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), + }; + let check_bounds = CheckBounds { + dim_vertical: dims.n, + dim_horizontal: dims.k, + skip_row: coordinates.skip_col, + skip_col: load_info.k, + }; + + load_transposed::(rhs, load_info, load_indices, check_bounds); + } +} + +#[cube] +pub(crate) fn load_plain>( + tensor: &Tensor, + load_info: LoadInfo, + load_indices: LoadIndices, + check_bounds: CheckBounds, +) { + let coordinates = load_info.coordinates; + let config = load_info.config; + + let vectorization = Comptime::vectorization(tensor); + let tile_size = Comptime::map(config, |c| c.tile_size); + let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); + + let read_row = coordinates.unit_row; + let read_col = coordinates.unit_col; + let write_row = coordinates.unit_row; + let write_col = coordinates.unit_col; + + let gm_position_base = read_row * load_indices.gm_stride + read_col + load_indices.offset; + let sm_position_base = write_row * load_indices.sm_stride + write_col; + + let read_tile_info = ReadTileInfo { + read_row, + read_col, + gm_position_base, + sm_position_base, + gm_stride: load_indices.gm_stride, + sm_stride: load_indices.sm_stride, + }; + let mut sm = load_info.shared_memory; + + if write_row < sm_dim_vertical { + if vectorization == tile_size { + L::load_tile_plain::( + tensor, + &mut sm, + read_tile_info, + config, + check_bounds, + ); + } else { + L::load_tile_plain::( + tensor, + &mut sm, + read_tile_info, + config, + check_bounds, + ); + } + } +} + +#[cube] +pub(crate) fn load_transposed>( + tensor: &Tensor, + load_info: LoadInfo, + load_indices: LoadIndices, + check_bounds: CheckBounds, +) { + let coordinates = load_info.coordinates; + let config = load_info.config; + + let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); + + let read_row = coordinates.unit_row; + let read_col = coordinates.unit_col; + let write_row = coordinates.unit_col; + let write_col = coordinates.unit_row; + + let gm_position_base = read_row * load_indices.gm_stride + read_col + load_indices.offset; + let sm_position_base = write_row * load_indices.sm_stride + write_col; + + let read_tile_info = ReadTileInfo { + read_row, + read_col, + gm_position_base, + sm_position_base, + gm_stride: load_indices.gm_stride, + sm_stride: load_indices.sm_stride, + }; + let mut sm = load_info.shared_memory; + + if write_row < sm_dim_vertical { + L::load_tile_transposed(tensor, &mut sm, read_tile_info, config, check_bounds); + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/memory_access.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/memory_access.rs new file mode 100644 index 0000000000..862472e0dc --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/memory_access.rs @@ -0,0 +1,319 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::config::CubeTiling2dConfig; + +use super::loader::{CheckBounds, ReadTileInfo}; + +#[derive(CubeType)] +pub(crate) struct WritePositions { + pub out: UInt, + pub result: UInt, +} + +#[cube] +pub(crate) trait ContiguousAccess: Send + Sync + 'static { + fn read_contiguous_unchecked( + tensor: &Tensor, + gm_position: UInt, + config: Comptime, + ) -> F; + + fn read_contiguous_checked( + tensor: &Tensor, + gm_position: UInt, + check_bounds: CheckBounds, + read_info: ReadTileInfo, + config: Comptime, + ) -> F; + + fn write_contiguous_unchecked( + out: &mut Tensor, + results: &Array, + positions: WritePositions, + config: Comptime, + ); + + fn write_contiguous_checked( + out: &mut Tensor, + results: &Array, + positions: WritePositions, + check_bounds: CheckBounds, + write_col: UInt, + config: Comptime, + ); +} + +#[cube] +pub(crate) trait StridedAccess: Send + Sync + 'static { + fn read_strided_unchecked( + tensor: &Tensor, + gm_position: UInt, + gm_stride: UInt, + config: Comptime, + ) -> F; + + fn read_strided_checked( + tensor: &Tensor, + gm_position: UInt, + gm_stride: UInt, + check_bounds: CheckBounds, + info: ReadTileInfo, + config: Comptime, + ) -> F; +} + +#[derive(new)] +/// When vectorization == tile_size +pub(crate) struct MatchingVectorization; + +/// When vectorization != tile_size +#[derive(new)] +pub(crate) struct UnmatchingVectorization; + +#[cube] +impl ContiguousAccess for MatchingVectorization { + fn read_contiguous_unchecked( + tensor: &Tensor, + gm_position: UInt, + _config: Comptime, + ) -> F { + tensor[gm_position] + } + + fn read_contiguous_checked( + tensor: &Tensor, + gm_position: UInt, + _check_bounds: CheckBounds, + _read_info: ReadTileInfo, + config: Comptime, + ) -> F { + // If vectorization matches, then it's certain to fit since tile_size divides block_sizes + MatchingVectorization::read_contiguous_unchecked(tensor, gm_position, config) + } + + fn write_contiguous_unchecked( + out: &mut Tensor, + results: &Array, + positions: WritePositions, + config: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + + let mut output_elem = F::vectorized_empty(Comptime::get(tile_size)); + + for i in range(0u32, Comptime::get(tile_size), unroll) { + output_elem[i] = results[positions.result + i]; + } + + out[positions.out / Comptime::runtime(tile_size)] = output_elem; + } + + fn write_contiguous_checked( + out: &mut Tensor, + results: &Array, + positions: WritePositions, + _check_bounds: CheckBounds, + _write_col: UInt, + config: Comptime, + ) { + // If vectorization matches, then it's certain to fit since tile_size divides block_sizes + MatchingVectorization::write_contiguous_unchecked(out, results, positions, config) + } +} + +#[cube] +impl ContiguousAccess for UnmatchingVectorization { + fn read_contiguous_unchecked( + tensor: &Tensor, + gm_position: UInt, + config: Comptime, + ) -> F { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + let vectorization_factor = Comptime::vectorization(tensor); + let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); + + let mut vector = F::vectorized(0., Comptime::get(tile_size)); + + for i in range( + 0u32, + Comptime::get(tile_size / vectorization_factor), + unroll, + ) { + let runtime_vectorization = Comptime::runtime(vectorization_factor); + + if Comptime::get(is_scalar) { + vector[i] = tensor[gm_position + i]; + } else { + let intermediate = tensor[gm_position + i]; + + for j in range(0u32, Comptime::get(vectorization_factor), unroll) { + vector[i * runtime_vectorization + j] = intermediate[j]; + } + } + } + + vector + } + + fn read_contiguous_checked( + tensor: &Tensor, + gm_position: UInt, + check_bounds: CheckBounds, + read_info: ReadTileInfo, + config: Comptime, + ) -> F { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + let vectorization_factor = Comptime::vectorization(tensor); + let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); + let runtime_vectorization = Comptime::runtime(vectorization_factor); + + let mut vector = F::vectorized(0., Comptime::get(tile_size)); + + let mut num_loops = UInt::new(0); + if check_bounds.dim_horizontal > read_info.read_col { + let num_reads = UInt::min( + check_bounds.dim_horizontal - read_info.read_col, + Comptime::runtime(tile_size), + ); + num_loops = num_reads / runtime_vectorization; + } + + for i in range(0u32, num_loops, Comptime::new(false)) { + if Comptime::get(is_scalar) { + vector[i] = tensor[gm_position + i]; + } else { + let intermediate = tensor[gm_position + i]; + + for j in range(0u32, Comptime::get(vectorization_factor), unroll) { + vector[i * runtime_vectorization + j] = intermediate[j]; + } + } + } + + vector + } + + fn write_contiguous_unchecked( + out: &mut Tensor, + results: &Array, + positions: WritePositions, + config: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + let vectorization_factor = Comptime::vectorization(out); + let runtime_vectorization = Comptime::runtime(vectorization_factor); + let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); + + for i in range( + 0u32, + Comptime::get(tile_size / vectorization_factor), + unroll, + ) { + if Comptime::get(is_scalar) { + out[i + positions.out] = results[positions.result + i]; + } else { + let mut output_elem = F::vectorized_empty(Comptime::get(vectorization_factor)); + + for j in range(0u32, Comptime::get(vectorization_factor), unroll) { + let index = i * runtime_vectorization + j; + output_elem[j] = results[positions.result + index]; + } + + out[i + positions.out / runtime_vectorization] = output_elem; + } + } + } + + fn write_contiguous_checked( + out: &mut Tensor, + results: &Array, + positions: WritePositions, + check_bounds: CheckBounds, + write_col: UInt, + config: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let vectorization_factor = Comptime::vectorization(out); + let runtime_vectorization = Comptime::runtime(vectorization_factor); + let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); + + let mut num_loops = UInt::new(0); + if check_bounds.dim_horizontal > write_col { + let num_writes = UInt::min( + check_bounds.dim_horizontal - write_col, + Comptime::runtime(tile_size), + ); + num_loops = num_writes / runtime_vectorization; + } + + for i in range(0u32, num_loops, Comptime::new(false)) { + let unroll = Comptime::map(config, |c| c.unroll_tile); + + if Comptime::get(is_scalar) { + out[i + positions.out] = results[positions.result + i]; + } else { + let mut output_elem = F::vectorized_empty(Comptime::get(vectorization_factor)); + + for j in range(0u32, Comptime::get(vectorization_factor), unroll) { + let index = i * runtime_vectorization + j; + output_elem[j] = results[positions.result + index]; + } + + out[i + positions.out / runtime_vectorization] = output_elem; + } + } + } +} + +#[cube] +impl StridedAccess for UnmatchingVectorization { + fn read_strided_unchecked( + tensor: &Tensor, + gm_position: UInt, + gm_stride: UInt, + config: Comptime, + ) -> F { + let tile_size = Comptime::map(config, |c| c.tile_size); + let unroll = Comptime::map(config, |c| c.unroll_tile); + + let mut vertical = F::vectorized_empty(Comptime::get(tile_size)); + for i in range(0u32, Comptime::get(tile_size), unroll) { + vertical[i] = tensor[gm_position + i * gm_stride]; + } + + vertical + } + + fn read_strided_checked( + tensor: &Tensor, + gm_position: UInt, + gm_stride: UInt, + check_bounds: CheckBounds, + info: ReadTileInfo, + config: Comptime, + ) -> F { + let tile_size = Comptime::map(config, |c| c.tile_size); + + let mut vertical = F::vectorized_empty(Comptime::get(tile_size)); + + let mut num_reads = UInt::new(0); + let row = check_bounds.skip_row + info.read_row; + let dim_vertical = check_bounds.dim_vertical; + if dim_vertical > row { + num_reads = UInt::min(dim_vertical - row, Comptime::runtime(tile_size)); + } + + for i in range(0u32, num_reads, Comptime::new(false)) { + vertical[i] = tensor[gm_position + i * gm_stride]; + } + for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) { + vertical[i] = F::new(0.); + } + + vertical + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/mod.rs new file mode 100644 index 0000000000..015d4a59c7 --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/mod.rs @@ -0,0 +1,4 @@ +pub mod block_io; +pub mod loader; +pub mod memory_access; +pub mod writer; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/writer.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/writer.rs new file mode 100644 index 0000000000..09c1a063ee --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/writer.rs @@ -0,0 +1,60 @@ +use std::marker::PhantomData; + +use burn_cube::prelude::*; + +use crate::kernel::matmul::{ + config::CubeTiling2dConfig, + tiling2d_cube::{ + base::Dimensions, + write_output::{OutputWriter, WriteTileInfo}, + }, +}; + +use super::{ + block_io::base::BlockWriter, + loader::{CheckBounds, CheckBoundsExpand}, + memory_access::{MatchingVectorization, UnmatchingVectorization}, +}; +pub(crate) struct TileWriter { + _f: PhantomData, +} + +#[cube] +impl OutputWriter for TileWriter { + fn write_output>( + out: &mut Tensor, + results: &Array, + write_info: WriteTileInfo, + dims: Dimensions, + config: Comptime, + ) { + let vectorization = Comptime::vectorization(out); + let tile_size = Comptime::map(config, |c| c.tile_size); + let coordinates = write_info.coordinates; + + let check_bounds = CheckBounds { + dim_vertical: dims.m, + dim_horizontal: dims.n, + skip_row: coordinates.skip_row, + skip_col: coordinates.skip_col, + }; + + if vectorization == tile_size { + B::write_output::( + out, + results, + write_info, + config, + check_bounds, + ); + } else { + B::write_output::( + out, + results, + write_info, + config, + check_bounds, + ); + } + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs new file mode 100644 index 0000000000..95a12697ee --- /dev/null +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs @@ -0,0 +1,263 @@ +use burn_cube::prelude::*; + +use crate::kernel::matmul::config::CubeTiling2dConfig; + +use super::{ + base::{Coordinates, Dimensions}, + tile::block_io::{ + base::BlockWriter, horizontal_block_check::HorizontalCheckBlockIO, + unchecked_block::UncheckedBlockIO, vertical_block_check::VerticalCheckBlockIO, + whole_block_check::WholeCheckBlockIO, + }, +}; + +#[derive(CubeType)] +pub(crate) struct WriteTileInfo { + pub coordinates: Coordinates, + pub offset_output: UInt, + pub out_stride: UInt, +} + +#[cube] +pub(crate) trait OutputWriter: Sync + Send + 'static { + fn write_output>( + out: &mut Tensor, + results: &Array, + write_tile_info: WriteTileInfo, + dims: Dimensions, + config: Comptime, + ); +} + +#[cube] +pub(crate) fn write_to_output>( + out: &mut Tensor, + results: &Array, + coordinates: Coordinates, + offset_output: UInt, + dims: Dimensions, + config: Comptime, +) { + let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); + let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); + + let write_info = WriteTileInfo { + coordinates, + offset_output, + out_stride: dims.n, + }; + + if Comptime::get(check_m_bounds) { + if Comptime::get(check_n_bounds) { + W::write_output::(out, results, write_info, dims, config); + } else { + W::write_output::(out, results, write_info, dims, config); + } + } else if Comptime::get(check_n_bounds) { + W::write_output::(out, results, write_info, dims, config); + } else { + W::write_output::(out, results, write_info, dims, config); + } +} + +#[cfg(feature = "export_tests")] +/// Exported tests for write output +pub mod tests { + use crate::{ + kernel::matmul::tiling2d_cube::{ + test_utils::{ + assert_equals, make_config, range_tensor, range_tensor_transposed, zeros_tensor, + TILE_SIZE, + }, + tile::writer::TileWriter, + }, + JitRuntime, + }; + + use super::{ + super::base::{CoordinatesExpand, DimensionsExpand}, + *, + }; + + #[cube(launch)] + fn write_to_output_test( + out: &mut Tensor, + results: &mut Array, + config: Comptime, + ) { + let coordinates = Coordinates { + unit_row: UInt::new(4), + unit_col: UInt::new(4), + skip_row: UInt::new(0), + skip_col: UInt::new(0), + }; + let dims = Dimensions { + m: out.shape(out.rank() - UInt::new(2)), + k: UInt::new(0), + n: out.shape(out.rank() - UInt::new(1)), + }; + + write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); + } + + #[cube(launch)] + fn write_results_to_output_out_of_bounds_test( + out: &mut Tensor, + results: &mut Array, + config: Comptime, + ) { + let coordinates = Coordinates { + unit_row: UNIT_POS_X * UInt::new(4), + unit_col: UNIT_POS_Y * UInt::new(4), + skip_row: UInt::new(0), + skip_col: UInt::new(0), + }; + let dims = Dimensions { + m: out.shape(out.rank() - UInt::new(2)), + k: UInt::new(0), + n: out.shape(out.rank() - UInt::new(1)), + }; + + write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); + } + + /// Exported test + pub fn write_to_output_over_height_unit_test(device: &R::Device) { + let out = zeros_tensor::(6, 8, device); + let tile = range_tensor::(4, 4, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(6, 8, 8); + + write_to_output_test_launch::( + out.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape.dims), + ArrayArg::new(&tile.handle, 16), + config, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, + ]; + assert_equals::(out.handle, expected, device); + } + + /// Exported test + pub fn write_to_output_over_width_unit_test(device: &R::Device) { + let out = zeros_tensor::(8, 4, device); + let tile = range_tensor::(4, 4, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, 8, 4); + + write_to_output_test_launch::( + out.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape.dims), + ArrayArg::new(&tile.handle, 16), + config, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; + assert_equals::(out.handle, expected, device); + } + + /// Exported test + pub fn write_to_output_vectorized_less_than_tile_unit_test(device: &R::Device) { + let vectorization = 2; + let out = zeros_tensor::(8, 8, device); + let tile = range_tensor::(4, 4, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, 8, 8); + + write_to_output_test_launch::( + out.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized( + vectorization as u8, + &out.handle, + &out.strides, + &out.shape.dims, + ), + ArrayArg::new(&tile.handle, 16), + config, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, + 0.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, + ]; + assert_equals::(out.handle, expected, device); + } + + /// Exported test + pub fn write_to_output_scalar_unit_test(device: &R::Device) { + let vectorization = 1; + let out = zeros_tensor::(8, 8, device); + let tile = range_tensor::(4, 4, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(8, 8, 8); + + write_to_output_test_launch::( + out.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized( + vectorization as u8, + &out.handle, + &out.strides, + &out.shape.dims, + ), + ArrayArg::new(&tile.handle, 16), + config, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, + 0.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, + ]; + assert_equals::(out.handle, expected, device); + } + + /// Exported test + pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::Device) { + let vectorization = 1; + let out = zeros_tensor::(5, 1, device); + let results = range_tensor_transposed::(4, 4, device); + let cube_dim = CubeDim::new(2, 1, 1); + let cube_count = CubeCount::Static(1, 1, 1); + + let config = make_config(5, 8, 1); + + write_results_to_output_out_of_bounds_test_launch::( + out.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(vectorization, &out.handle, &out.strides, &out.shape.dims), + ArrayArg::new(&results.handle, 16), + config, + ); + + let expected = &[0.0, 1.0, 2.0, 3.0, 0.0]; + assert_equals::(out.handle, expected, device); + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs index 05a768a6e7..762c470356 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs @@ -1,10 +1,10 @@ use burn_cube::cpa; use burn_cube::ir::{BinaryOperator, Scope, Synchronization, Variable}; +use crate::kernel::matmul::config::Tiling2dConfig; use crate::kernel::matmul::tiling2d_shader::{ computation_loop, gather_shader_information, load_shared_memory, write_to_output, }; -use crate::kernel::matmul::Tiling2dConfig; pub(crate) struct MatmulTiling2dShader { pub variables: BinaryOperator, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs index 41201e7357..6b128684cf 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs @@ -56,11 +56,11 @@ pub fn computation_loop( cpa!( scope, - range(0u32, shader.config.tile_size_m as u32, shader.config.unroll).for_each( + range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( |res_idx_m, scope| { cpa!( scope, - range(0u32, shader.config.tile_size_n as u32, shader.config.unroll) + range(0u32, shader.config.tile_size as u32, shader.config.unroll) .for_each(|res_idx_n, scope| { cpa!(scope, registered_m = register_m[res_idx_m]); cpa!(scope, registered_n = register_n[res_idx_n]); @@ -69,7 +69,7 @@ pub fn computation_loop( cpa!( scope, - results_position = res_idx_m * shader.config.tile_size_n + results_position = res_idx_m * shader.config.tile_size ); cpa!(scope, results_position += res_idx_n); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs index 903321fe59..d9bad3b8d3 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs @@ -18,11 +18,11 @@ pub(crate) fn gather_shader_information( let block_size_m: Variable = shader.config.block_size_m.into(); let block_size_k: Variable = shader.config.block_size_k.into(); let block_size_n: Variable = shader.config.block_size_n.into(); - let tile_size_m: Variable = shader.config.tile_size_m.into(); - let tile_size_n: Variable = shader.config.tile_size_n.into(); + let tile_size_m: Variable = shader.config.tile_size.into(); + let tile_size_n: Variable = shader.config.tile_size.into(); let n_threads_per_row: Variable = - (((shader.config.block_size_n - 1) / shader.config.tile_size_n) + 1).into(); - let results_size = (shader.config.tile_size_m * shader.config.tile_size_n) as u32; + (((shader.config.block_size_n - 1) / shader.config.tile_size) + 1).into(); + let results_size = (shader.config.tile_size * shader.config.tile_size) as u32; // Shader info let local_idx = Variable::UnitPos; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs index b90d0e1a5f..0ce06307a9 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs @@ -26,12 +26,12 @@ pub fn write_to_output( cpa!( scope, - range(0u32, shader.config.tile_size_m as u32, shader.config.unroll).for_each( + range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( |res_idx_m, scope| { cpa!( scope, - range(0u32, shader.config.tile_size_n as u32, shader.config.unroll) - .for_each(|res_idx_n, scope| { + range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( + |res_idx_n, scope| { cpa!(scope, row_index = row + res_idx_m); cpa!(scope, col_index = col + res_idx_n); @@ -50,7 +50,8 @@ pub fn write_to_output( col_index, ); })); - }) + } + ) ); } ) @@ -58,12 +59,12 @@ pub fn write_to_output( } else { cpa!( scope, - range(0u32, shader.config.tile_size_m as u32, shader.config.unroll).for_each( + range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( |res_idx_m, scope| { cpa!( scope, - range(0u32, shader.config.tile_size_n as u32, shader.config.unroll) - .for_each(|res_idx_n, scope| { + range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( + |res_idx_n, scope| { cpa!(scope, row_index = row + res_idx_m); cpa!(scope, col_index = col + res_idx_n); @@ -76,7 +77,8 @@ pub fn write_to_output( row_index, col_index, ) - }) + } + ) ); } ) @@ -107,7 +109,7 @@ fn write_inner( cpa!( scope, - results_position = res_idx_m * shader.config.tile_size_n + results_position = res_idx_m * shader.config.tile_size ); cpa!(scope, results_position += res_idx_n); diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 50d4cc2e75..d35de0bf6f 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -4,7 +4,7 @@ use burn_tensor::{Element, ElementConversion}; use crate::{ element::FloatElement, kernel::{ - matmul::{utils::init_matmul_output, Tiling2dConfig}, + matmul::{config::Tiling2dConfig, utils::init_matmul_output}, prng::random_like_uniform, }, ops::numeric::empty_device, @@ -60,6 +60,11 @@ impl AutotuneOperationSet AutotuneOperationSet AutotuneOperationSet Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)), 1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)), 2 => Box::new(Tiling2dMatmul::new(self.lhs, self.rhs, self.out)), - 3 => Box::new(Tiling2dMatmulPadded::new(self.lhs, self.rhs, self.out)), - 4 => Box::new(Tiling2dMatmulPaddedUnrolled::new( + 3 => Box::new(Tiling2dMatmulUnrolled::new(self.lhs, self.rhs, self.out)), + 4 => Box::new(Tiling2dMatmulPadded::new(self.lhs, self.rhs, self.out)), + 5 => Box::new(Tiling2dMatmulPaddedUnrolled::new( + self.lhs, self.rhs, self.out, + )), + 6 => Box::new(Tiling2dMatmulCube::new(self.lhs, self.rhs, self.out)), + 7 => Box::new(Tiling2dMatmulCubeUnrolled::new( self.lhs, self.rhs, self.out, )), - 5 => Box::new(Tiling2dMatmulUnrolled::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -146,12 +160,30 @@ matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16) }); -// Probably the fastest when fixed size, without loop unrolling +// Maybe the fastest for transposed inputs, without loop unrolling +matmul_tune_ops!(Tiling2dMatmul, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, Tiling2dConfig::default()) +}); + +// Maybe the fastest for transposed inputs, with loop unrolling +matmul_tune_ops!(Tiling2dMatmulUnrolled, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d( + lhs, + rhs, + out, + Tiling2dConfig { + unroll: true, + ..Default::default() + }, + ) +}); + +// Maybe the fastest when fixed size, without loop unrolling matmul_tune_ops!(Tiling2dMatmulPadded, |lhs, rhs, out| { crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, Tiling2dConfig::default()) }); -// Probably the fastest when fixed sizes, with loop unrolling +// Maybe the fastest when fixed sizes, with loop unrolling matmul_tune_ops!(Tiling2dMatmulPaddedUnrolled, |lhs, rhs, out| { crate::kernel::matmul::matmul_tiling_2d_padded( lhs, @@ -165,13 +197,13 @@ matmul_tune_ops!(Tiling2dMatmulPaddedUnrolled, |lhs, rhs, out| { }); // Probably the fastest in the general case, without loop unrolling -matmul_tune_ops!(Tiling2dMatmul, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, Tiling2dConfig::default()) +matmul_tune_ops!(Tiling2dMatmulCube, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d_cube(lhs, rhs, out, Tiling2dConfig::default()) }); // Probably the fastest in the general case, with loop unrolling -matmul_tune_ops!(Tiling2dMatmulUnrolled, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d_padded( +matmul_tune_ops!(Tiling2dMatmulCubeUnrolled, |lhs, rhs, out| { + crate::kernel::matmul::matmul_tiling_2d_cube( lhs, rhs, out, diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 9182882c54..5e7efa2fa0 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -8,6 +8,8 @@ use burn_cube::prelude::*; use burn_tensor::Shape; use std::marker::PhantomData; +use super::layout::{memory_layout, MatrixLayout}; + /// The basic tensor primitive struct. #[derive(new)] pub struct JitTensor @@ -165,27 +167,10 @@ where /// Check if the current tensor is contiguous. pub fn is_contiguous(&self) -> bool { - let mut current_stride = 0; - for d in 0..D { - let stride = self.strides[D - 1 - d]; - - if stride <= current_stride { - return false; - } - - current_stride = stride; - } - - true + self.matrix_layout() == MatrixLayout::Contiguous } - pub(crate) fn batch_swapped_with_row_col(&self) -> bool { - for d in 0..D - 2 { - let stride = self.strides[d]; - if stride < self.strides[D - 2] || stride < self.strides[D - 1] { - return true; - } - } - false + pub(crate) fn matrix_layout(&self) -> MatrixLayout { + memory_layout(&self.strides) } } diff --git a/crates/burn-jit/src/tensor/layout.rs b/crates/burn-jit/src/tensor/layout.rs new file mode 100644 index 0000000000..52b6b65166 --- /dev/null +++ b/crates/burn-jit/src/tensor/layout.rs @@ -0,0 +1,123 @@ +#[derive(PartialEq, Eq, Debug)] +/// Layout for matrix tensors, i.e. tensors whose interpretation +/// is a bunch of batched matrices of 2 dimensions +pub(crate) enum MatrixLayout { + /// Memory is wholly contiguous, with row major layout + Contiguous, + /// Permutations happened, but may not impact some kernels + MildlyPermuted { + /// Last two dims are inverted + transposed: bool, + /// Some permutations exist in batch dimensions + batch_swap: bool, + }, + /// Permutations happened between batch dimensions and last two dims + HighlyPermuted, +} + +pub(crate) fn memory_layout(strides: &[usize; D]) -> MatrixLayout { + if D <= 1 { + return MatrixLayout::Contiguous; + } + + let mut transposed = false; + let mut batch_swap = false; + let row_stride = strides[D - 2]; + let col_stride = strides[D - 1]; + if row_stride < col_stride { + transposed = true; + } + let mut previous_stride = row_stride; + + for d in 0..D - 2 { + let current_stride = strides[D - 3 - d]; + if current_stride < row_stride || current_stride < col_stride { + return MatrixLayout::HighlyPermuted; + } + if current_stride < previous_stride { + batch_swap = true; + } + + previous_stride = current_stride; + } + + if transposed || batch_swap { + MatrixLayout::MildlyPermuted { + transposed, + batch_swap, + } + } else { + MatrixLayout::Contiguous + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn layout_is_contiguous() { + let strides = &[8, 4, 2, 1]; + assert_eq!(memory_layout(strides), MatrixLayout::Contiguous); + } + + #[test] + fn vector_is_contiguous() { + let strides = &[1]; + assert_eq!(memory_layout(strides), MatrixLayout::Contiguous) + } + + #[test] + fn layout_is_transposed_only() { + let strides = &[8, 4, 1, 2]; + if let MatrixLayout::MildlyPermuted { + transposed, + batch_swap, + } = memory_layout(strides) + { + assert!(transposed && !batch_swap); + } else { + unreachable!() + } + } + + #[test] + fn layout_has_swapped_batches_only() { + let strides = &[4, 8, 2, 1]; + if let MatrixLayout::MildlyPermuted { + transposed, + batch_swap, + } = memory_layout(strides) + { + assert!(!transposed && batch_swap); + } else { + unreachable!() + } + } + + #[test] + fn layout_has_swapped_batches_and_is_transposed() { + let strides = &[4, 8, 1, 2]; + if let MatrixLayout::MildlyPermuted { + transposed, + batch_swap, + } = memory_layout(strides) + { + assert!(transposed && batch_swap); + } else { + unreachable!() + } + } + + #[test] + fn layout_has_batch_swapped_with_row() { + let strides = &[8, 2, 4, 1]; + assert_eq!(memory_layout(strides), MatrixLayout::HighlyPermuted); + } + + #[test] + fn layout_has_batch_swapped_with_col() { + let strides = &[1, 4, 2, 8]; + assert_eq!(memory_layout(strides), MatrixLayout::HighlyPermuted); + } +} diff --git a/crates/burn-jit/src/tensor/mod.rs b/crates/burn-jit/src/tensor/mod.rs index 096c94ead7..3a474b8b9a 100644 --- a/crates/burn-jit/src/tensor/mod.rs +++ b/crates/burn-jit/src/tensor/mod.rs @@ -1,2 +1,4 @@ mod base; +mod layout; pub use base::*; +pub(crate) use layout::*; diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index bb8948634c..65ea81e56d 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -406,6 +406,197 @@ mod tests { } } + mod tiling2d_cube { + use super::*; + + #[test] + pub fn straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn shapes_equal_blocks() { + test_with_params(64, 32, 64, 2, 2); + } + + #[test] + pub fn m_exceeds_block() { + test_with_params(75, 32, 64, 2, 2); + } + + #[test] + pub fn k_exceeds_block() { + test_with_params(64, 33, 32, 1, 1); + } + + #[test] + pub fn test_matmul_irregular_shape() { + test_with_params(123, 255, 72, 3, 5); + } + + #[test] + pub fn test64_matmul_unpadded_n_exceeds_block() { + test_with_params(64, 32, 75, 2, 2); + } + + #[test] + pub fn n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn large() { + test_with_params(256, 256, 256, 1, 1); + } + + #[test] + pub fn use_vec2() { + test_with_params(2, 2, 2, 1, 1); + } + + #[test] + fn swapped_batches_no_padding() { + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube(Tiling2dConfig::default()), + swap, + swap, + shape_lhs, + shape_rhs, + ); + } + + #[test] + fn swapped_row_col_no_padding() { + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + + #[test] + fn swapped_lhs_row_col_large_uneven_m() { + let (m, k, n) = (252, 256, 256); + let swap_lhs = [2, 3]; + let swap_rhs = [0, 0]; + let shape_lhs = [3, 2, k, m]; + let shape_rhs = [3, 2, k, n]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + + #[test] + fn swapped_rhs_row_col_large_uneven_n() { + let (m, k, n) = (256, 256, 252); + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, m, k]; + let shape_rhs = [3, 2, n, k]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + + #[test] + fn swapped_both_row_col_large_uneven_k() { + let (m, k, n) = (256, 252, 256); + let swap_lhs = [2, 3]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, k, m]; + let shape_rhs = [3, 2, n, k]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + + #[test] + fn swapped_row_with_batch_no_padding() { + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube(Tiling2dConfig::default()), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference( + MatmulStrategy::Tiling2dCube(Tiling2dConfig::default()), + shape_lhs, + shape_rhs, + ); + } + } + mod padding { use super::*; use burn_jit::kernel::matmul::padding::{crop, pad_round}; diff --git a/crates/burn-jit/src/tests/matmul_cube.rs b/crates/burn-jit/src/tests/matmul_cube.rs new file mode 100644 index 0000000000..8cea6a3990 --- /dev/null +++ b/crates/burn-jit/src/tests/matmul_cube.rs @@ -0,0 +1,125 @@ +#[burn_tensor_testgen::testgen(matmul_cube)] +mod tests { + use super::*; + use burn_jit::kernel::matmul::tiling2d_cube::{ + compute_loop_tests, load_shared_memory_tests, outer_product_tests, write_output_tests, + }; + use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig}; + use burn_tensor::{Shape, Tensor}; + + #[test] + pub fn tiling2d_matmul_outer_product_vectorized_test() { + outer_product_tests::tile_outer_product_vectorized_unit_test::( + &Default::default(), + ) + } + + #[test] + pub fn tiling2d_matmul_outer_product_vectorized_test_2() { + outer_product_tests::tile_outer_product_vectorized_unit_test_2::( + &Default::default(), + ) + } + + #[test] + pub fn tiling2d_matmul_compute_loop_vectorized_test() { + compute_loop_tests::compute_loop_unit_test::(&Default::default()) + } + + #[test] + pub fn compute_loop_unit_offset_test() { + compute_loop_tests::compute_loop_unit_offset_test::(&Default::default()) + } + + #[test] + pub fn load_lhs_transposed_unit_test() { + load_shared_memory_tests::load_lhs_transposed_unit_test::(&Default::default()) + } + + #[test] + pub fn load_lhs_transposed_cube_test() { + load_shared_memory_tests::load_lhs_transposed_cube_test::(&Default::default()) + } + + #[test] + pub fn load_lhs_plain_unit_test() { + load_shared_memory_tests::load_lhs_plain_unit_test::(&Default::default()) + } + + #[test] + pub fn load_lhs_plain_out_of_bounds_unit_test() { + load_shared_memory_tests::load_lhs_plain_out_of_bounds_unit_test::( + &Default::default(), + ) + } + + #[test] + pub fn load_lhs_transposed_out_of_bounds_cube_test() { + load_shared_memory_tests::load_lhs_transposed_out_of_bounds_cube_test::( + &Default::default(), + ) + } + + #[test] + pub fn load_lhs_transposed_offset_cube_test() { + load_shared_memory_tests::load_lhs_transposed_offset_cube_test::( + &Default::default(), + ) + } + + #[test] + pub fn load_rhs_plain_unit_test() { + load_shared_memory_tests::load_rhs_plain_unit_test::(&Default::default()) + } + + #[test] + pub fn load_rhs_plain_cube_test() { + load_shared_memory_tests::load_rhs_plain_cube_test::(&Default::default()) + } + + #[test] + pub fn load_rhs_plain_cube_offset_test() { + load_shared_memory_tests::load_rhs_plain_cube_offset_test::(&Default::default()) + } + + #[test] + pub fn load_rhs_transposed_unit_test() { + load_shared_memory_tests::load_rhs_transposed_unit_test::(&Default::default()) + } + + #[test] + pub fn load_rhs_transposed_out_of_bounds_unit_test() { + load_shared_memory_tests::load_rhs_transposed_out_of_bounds_unit_test::( + &Default::default(), + ) + } + + #[test] + pub fn write_to_output_over_height_unit_test() { + write_output_tests::write_to_output_over_height_unit_test::(&Default::default()) + } + + #[test] + pub fn write_to_output_over_width_unit_test() { + write_output_tests::write_to_output_over_width_unit_test::(&Default::default()) + } + + #[test] + pub fn write_to_output_vectorized_less_than_tile_unit_test() { + write_output_tests::write_to_output_vectorized_less_than_tile_unit_test::( + &Default::default(), + ) + } + + #[test] + pub fn write_to_output_scalar_unit_test() { + write_output_tests::write_to_output_scalar_unit_test::(&Default::default()) + } + + #[test] + pub fn write_to_output_scalar_out_of_bounds_cube_test() { + write_output_tests::write_to_output_scalar_out_of_bounds_cube_test::( + &Default::default(), + ) + } +} diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index 1561271b7f..d6ec2edf17 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -13,6 +13,7 @@ mod gather; mod mask_fill; mod mask_where; mod matmul; +pub mod matmul_cube; mod max_pool2d; mod max_pool2d_backward; mod normal; @@ -74,6 +75,7 @@ macro_rules! testgen_all { burn_jit::testgen_clamp!(); burn_jit::testgen_unary!(); burn_jit::testgen_matmul!(); + burn_jit::testgen_matmul_cube!(); } } mod jit_fusion { diff --git a/crates/burn-tensor/src/tests/ops/matmul.rs b/crates/burn-tensor/src/tests/ops/matmul.rs index fd9a031b0f..b897def7a3 100644 --- a/crates/burn-tensor/src/tests/ops/matmul.rs +++ b/crates/burn-tensor/src/tests/ops/matmul.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(matmul)] mod tests { use super::*; - use burn_tensor::{Tensor, TensorData}; + use burn_tensor::{Int, Tensor, TensorData}; #[test] fn test_matmul_d2() { @@ -81,7 +81,7 @@ mod tests { } #[test] - fn test_matmul_tmp() { + fn test_matmul_4_3() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_floats( [[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]], @@ -98,6 +98,69 @@ mod tests { tensor_3.into_data().assert_eq(&expected, false); } + #[test] + fn test_matmul_trivial() { + let device = Default::default(); + + let tensor_1 = Tensor::::arange(0..16, &device) + .reshape([4, 4]) + .float(); + + let tensor_3 = tensor_1.clone().matmul(tensor_1); + + tensor_3.into_data().assert_eq( + &TensorData::from([ + [56., 62., 68., 74.], + [152., 174., 196., 218.], + [248., 286., 324., 362.], + [344., 398., 452., 506.], + ]), + false, + ); + } + + #[test] + fn test_matmul_trivial_transposed() { + let device = Default::default(); + + let tensor_1 = Tensor::::arange(0..16, &device) + .reshape([4, 4]) + .float(); + + let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); + + tensor_3.into_data().assert_eq( + &TensorData::from([ + [14., 38., 62., 86.], + [38., 126., 214., 302.], + [62., 214., 366., 518.], + [86., 302., 518., 734.], + ]), + false, + ); + } + + #[test] + fn test_matmul_4_8() { + let device = Default::default(); + + let tensor_1 = Tensor::::arange(0..32, &device) + .reshape([4, 8]) + .float(); + + let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); + + tensor_3.into_data().assert_eq( + &TensorData::from([ + [140., 364., 588., 812.], + [364., 1100., 1836., 2572.], + [588., 1836., 3084., 4332.], + [812., 2572., 4332., 6092.], + ]), + false, + ); + } + #[test] fn test_matmul_simple_2() { let device = Default::default();