Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify allocator #388

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Simplify allocator #388

wants to merge 12 commits into from

Conversation

maxtremblay
Copy link
Collaborator

@maxtremblay maxtremblay commented Dec 20, 2024

Highlights

  • The allocator implementation is simpler.
  • The benchmark results are similar.
  • I discovered and fixed a few bugs.

Allocator refactor

  • Remove the allocator trait and its two implementations in favor of a single allocator struct.
  • Rename the local variable creations to new named closer to rust naming scheme (example, use mut for mutable variables).
  • Rename the variant in VariableKind to match the changes.
  • Propagate the changes. This explains the large amount of modified files. A lot of it is renaming and changing to the new allocator.
  • Document to allocator struct.

Changing the backends

  • Rename the local variable variant of the Variable enum to match cubecl-core.
  • Add the depth field to LocalVariable.
  • All local immutable variables are now named l_{depth}_{id}.
  • All local mutable variables (including restricted) are now named l_mut_{depth}_{id}.

Allocation strategy

  • In a few expand functions, I replaced the use of a mutable variable with an immutable one.

Discovered bugs

With all the above changes, I discovered a few bugs in cubecl that I fixed.

  • Because the reusing allocator (used by cuda) was not doing any distinction between mutable and immutable variables, we often used immutable variables when we should have used mutable one. Mostly for accumulators in plane operations.
  • Fix an issue with line indexing where elements of the line where tagged with a vectorization > 1 while there are scalar. Before, we were lucky enough for this bug to have no effect (wgpu: inference let us avoid the type during declaration. cpp: we were using mutables everywhere before which were not causing an error, but were wastefully declared as vectorized).
  • There were some missing fmt_left in the cpp backend.
  • I added some tests to reproduce the bugs more easily.

Benchmarks

There are no statistically meaningful changes.

bench_cuda_main.txt
bench_cuda_new_alloc.txt
bench_wgpu_main.txt
bench_wgpu_new_alloc.txt

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the frontend tests are necessary anymore. They are impacted by the new allocator and I don't think it's worth updating them.

@@ -206,7 +206,7 @@ pub mod plane_any {
elem: ExpandElementTyped<bool>,
) -> ExpandElementTyped<bool> {
let elem: ExpandElement = elem.into();
let output = context.create_local_binding(elem.item);
let output = context.create_local_mut(elem.item);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it only for cuda? Does mutability work differently in wgsl?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, it is an issue with the planeAny implementation of cuda. I will fix it.

crates/cubecl-core/src/ir/allocator.rs Show resolved Hide resolved
Comment on lines +374 to +376
let array_len = scope.create_local_mut(Item::new(Elem::UInt(UIntKind::U32)));
let inside_bound = scope.create_local_mut(Item::new(Elem::Bool));
let item = scope.create_local_mut(out.item);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mut necessary?

Comment on lines +393 to +394
let array_len = scope.create_local_mut(Item::new(Elem::UInt(UIntKind::U32)));
let inside_bound = scope.create_local_mut(Item::new(Elem::Bool));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mut necessary?

@@ -67,7 +67,7 @@ impl Scope {

/// Create a variable initialized at zero.
pub fn zero<I: Into<Item>>(&mut self, item: I) -> Variable {
let local = self.create_local(item);
let local = self.create_local_mut(item);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mut necessary?

@@ -88,7 +88,7 @@ impl Scope {
Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(value.to_u64().unwrap(), kind),
Elem::Bool => ConstantScalarValue::Bool(value.to_u32().unwrap() == 1),
};
let local = self.create_local(item);
let local = self.create_local_mut(item);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mut necessary?

@@ -208,7 +208,7 @@ macro_rules! impl_line_comparison {
let lhs = self.expand.into();
let rhs = rhs.expand.into();

let output = context.create_local_binding(Item::vectorized(bool::as_elem(), size));
let output = context.create_local_mut(Item::vectorized(bool::as_elem(), size));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needs to be mut?

@@ -79,7 +79,7 @@ mod fill {
value: ExpandElementTyped<P>,
) -> Self {
let length = self.expand.item.vectorization;
let output = context.create_local_binding(Item::vectorized(P::as_elem(), length));
let output = context.create_local(Item::vectorized(P::as_elem(), length));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mut required?

@@ -170,7 +170,7 @@ where
F: Fn(UnaryOperator) -> Operator,
{
let input = input.consume();
let output = context.create_local_binding(out_item);
let output = context.create_local_mut(out_item);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mut required?

@@ -694,11 +694,12 @@ impl<D: Dialect> Remainder<D> {

write_op(&lhs, &rhs, &out_tmp, item_out_optimized)?;

let maybe_const = if out.is_const() { " cosnt" } else { "" };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: cosnt


fn cast<D: Dialect>(input: &Variable<D>, target: Item<D>) -> String {
if target != input.item() {
let maybe_const = if input.is_const() { " const" } else { "" };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a method for that, it's repeated at multiple places.

Comment on lines +67 to +68
// let mut file = std::fs::File::create("tests/plane_elect.wgsl").unwrap();
// write!(file, "{compiled}").unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deadcode

Comment on lines +92 to +93
// let mut file = std::fs::File::create("tests/sequence_for_loop.wgsl").unwrap();
// write!(file, "{compiled}").unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deadcode

Comment on lines +120 to +121
// let mut file = std::fs::File::create("tests/unary_bench.wgsl").unwrap();
// write!(file, "{compiled}").unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deadcode

Comment on lines +141 to +142
// let mut file = std::fs::File::create("tests/constant_array.wgsl").unwrap();
// write!(file, "{compiled}").unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deadcode

Comment on lines +161 to +162
// let mut file = std::fs::File::create("tests/naming.wgsl").unwrap();
// write!(file, "{compiled}").unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deadcode

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants