From dde82fc0beefecce8d013f9c8f019637e36df398 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 24 Jan 2025 09:44:54 +0900 Subject: [PATCH] Add webgpu backend --- CMakeLists.txt | 16 ++ examples/cpp/tutorial.cpp | 1 + mlx/CMakeLists.txt | 2 + mlx/array.cpp | 4 + mlx/array.h | 2 + mlx/backend/webgpu/CMakeLists.txt | 6 + mlx/backend/webgpu/allocator.cpp | 119 ++++++++++++++ mlx/backend/webgpu/allocator.h | 63 +++++++ mlx/backend/webgpu/metal.cpp | 102 ++++++++++++ mlx/backend/webgpu/primitives.cpp | 264 ++++++++++++++++++++++++++++++ 10 files changed, 579 insertions(+) create mode 100644 mlx/backend/webgpu/CMakeLists.txt create mode 100644 mlx/backend/webgpu/allocator.cpp create mode 100644 mlx/backend/webgpu/allocator.h create mode 100644 mlx/backend/webgpu/metal.cpp create mode 100644 mlx/backend/webgpu/primitives.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 58ccd0a60e..8e9bc00d76 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) +option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") endif() endif() + if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL) + message(FATAL_ERROR "Can not build both webgpu and metal backends.") + endif() + else() set(MLX_BUILD_METAL OFF) message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") @@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) endif() +if(MLX_BUILD_WEBGPU) + FetchContent_Declare( + betann + GIT_REPOSITORY https://github.com/frost-beta/betann.git + GIT_TAG 77d0837879e6549f04ef37158000697c94fe6702 + EXCLUDE_FROM_ALL) + set(BETANN_BUILD_TESTS OFF) + FetchContent_MakeAvailable(betann) + target_link_libraries(mlx PRIVATE betann) +endif() + if(WIN32) if(MSVC) # GGUF does not build with MSVC. diff --git a/examples/cpp/tutorial.cpp b/examples/cpp/tutorial.cpp index ae2cd3cfbe..554d497316 100644 --- a/examples/cpp/tutorial.cpp +++ b/examples/cpp/tutorial.cpp @@ -10,6 +10,7 @@ namespace mx = mlx::core; void array_basics() { // Make a scalar array: mx::array x(1.0); + std::cout << x + x << std::endl; // Get the value out of it: auto s = x.item(); diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index c7ef4670f0..1ecdf9e793 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,6 +47,8 @@ endif() if(MLX_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) +elseif(MLX_BUILD_WEBGPU) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) endif() diff --git a/mlx/array.cpp b/mlx/array.cpp index c2edb4940d..452d365156 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -109,6 +109,10 @@ bool array::is_tracer() const { detail::retain_graph(); } +void array::reset_data_ptr() { + array_desc_->data_ptr = buffer().raw_ptr(); +} + void array::set_data(allocator::Buffer buffer, Deleter d) { array_desc_->data = std::make_shared(buffer, d); array_desc_->data_ptr = buffer.raw_ptr(); diff --git a/mlx/array.h b/mlx/array.h index 6ad0e578ac..bcde905c58 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -401,6 +401,8 @@ class array { // Check if the array is a tracer array bool is_tracer() const; + void reset_data_ptr(); + void set_data(allocator::Buffer buffer, Deleter d = allocator::free); void set_data( diff --git a/mlx/backend/webgpu/CMakeLists.txt b/mlx/backend/webgpu/CMakeLists.txt new file mode 100644 index 0000000000..6ef70d56f1 --- /dev/null +++ b/mlx/backend/webgpu/CMakeLists.txt @@ -0,0 +1,6 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp) diff --git a/mlx/backend/webgpu/allocator.cpp b/mlx/backend/webgpu/allocator.cpp new file mode 100644 index 0000000000..56830a1cc4 --- /dev/null +++ b/mlx/backend/webgpu/allocator.cpp @@ -0,0 +1,119 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/webgpu/allocator.h" + +namespace mlx::core { + +namespace allocator { + +Allocator& allocator() { + return webgpu::allocator(); +} + +void* Buffer::raw_ptr() { + return static_cast(ptr_)->cpu_data(); +} + +} // namespace allocator + +namespace webgpu { + +DoubleBuffer::DoubleBuffer(size_t size) + : cpu_data_(std::malloc(size + sizeof(size_t))) { + *static_cast(cpu_data_) = size; +} + +DoubleBuffer::DoubleBuffer(betann::Device& device, size_t size) + : gpu_data_(device.CreateBuffer( + size, + betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {} + +DoubleBuffer::~DoubleBuffer() { + std::free(cpu_data_); +} + +void DoubleBuffer::copy_to_cpu(const void* data, size_t size) { + assert(!cpu_data_); + cpu_data_ = std::malloc(size + sizeof(size_t)); + *static_cast(cpu_data_) = size; + std::memcpy(cpu_data(), data, size); +} + +size_t DoubleBuffer::size() const { + if (cpu_data_) + return *static_cast(cpu_data_); + if (gpu_data_) + return gpu_data_.GetSize(); + return 0; +} + +WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {} + +Buffer WgpuAllocator::malloc(size_t size, bool allow_swap) { + return Buffer(new DoubleBuffer(size)); +} + +void WgpuAllocator::free(Buffer buffer) { + delete static_cast(buffer.ptr()); +} + +size_t WgpuAllocator::size(Buffer buffer) const { + return static_cast(buffer.ptr())->size(); +} + +Buffer WgpuAllocator::gpu_malloc(size_t size) { + return Buffer(new DoubleBuffer(device_, size)); +} + +void WgpuAllocator::ensure_gpu_data(Buffer& buffer) { + auto* dbuf = static_cast(buffer.ptr()); + if (dbuf->gpu_data() || dbuf->size() == 0) + return; + dbuf->set_gpu_data(device_.CreateBufferFromData( + dbuf->cpu_data(), dbuf->size(), betann::BufferUsage::Storage)); +} + +WgpuAllocator& allocator() { + static WgpuAllocator allocator_; + return allocator_; +} + +betann::Device& device(mlx::core::Device) { + static betann::Device device; + return device; +} + +} // namespace webgpu + +namespace metal { + +size_t get_active_memory() { + return 0; +} +size_t get_peak_memory() { + return 0; +} +void reset_peak_memory() {} +size_t get_cache_memory() { + return 0; +} +size_t set_memory_limit(size_t, bool) { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; +} +size_t set_wired_limit(size_t) { + return 0; +} + +std::unordered_map> +device_info() { + throw std::runtime_error("[webgpu::device_info] Not implemented"); +}; + +void clear_cache() {} + +} // namespace metal + +} // namespace mlx::core diff --git a/mlx/backend/webgpu/allocator.h b/mlx/backend/webgpu/allocator.h new file mode 100644 index 0000000000..6033ea6101 --- /dev/null +++ b/mlx/backend/webgpu/allocator.h @@ -0,0 +1,63 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/device.h" + +#include + +namespace mlx::core::webgpu { + +using allocator::Buffer; + +// Holds data for both CPU and GPU. +class DoubleBuffer { + public: + // Allocates memory in CPU. + explicit DoubleBuffer(size_t size); + // Allocates memory in GPU. + DoubleBuffer(betann::Device& device, size_t size); + + ~DoubleBuffer(); + + void copy_to_cpu(const void* data, size_t size); + void set_gpu_data(betann::Buffer buffer) { + gpu_data_ = std::move(buffer); + } + + void* cpu_data() const { + return cpu_data_ ? static_cast(cpu_data_) + 1 : nullptr; + } + const betann::Buffer& gpu_data() const { + return gpu_data_; + } + + size_t size() const; + + private: + void* cpu_data_ = nullptr; + betann::Buffer gpu_data_; +}; + +class WgpuAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size, bool allow_swap = false) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + Buffer gpu_malloc(size_t size); + void ensure_gpu_data(Buffer& buffer); + + private: + WgpuAllocator(); + friend WgpuAllocator& allocator(); + + betann::Device& device_; +}; + +WgpuAllocator& allocator(); + +betann::Device& device(mlx::core::Device); + +} // namespace mlx::core::webgpu diff --git a/mlx/backend/webgpu/metal.cpp b/mlx/backend/webgpu/metal.cpp new file mode 100644 index 0000000000..a508dd954a --- /dev/null +++ b/mlx/backend/webgpu/metal.cpp @@ -0,0 +1,102 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/metal.h" +#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/webgpu/allocator.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" +#include "mlx/utils.h" + +namespace mlx::core::metal { + +bool is_available() { + return true; +} + +void new_stream(Stream) {} + +std::function make_task(array arr, bool signal) { + return [arr = std::move(arr), signal]() mutable { + auto s = arr.primitive().stream(); + auto& device = webgpu::device(s.device); + + for (auto& input : arr.inputs()) { + if (input.event().valid() && + input.event().stream() != arr.primitive().stream()) { + input.event().wait(); + } + // Ensure all inputs copy their CPU data to GPU. + webgpu::allocator().ensure_gpu_data(input.buffer()); + } + + auto outputs = arr.outputs(); + { + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + + try { + arr.primitive().eval_gpu(arr.inputs(), outputs); + } catch (const std::exception& error) { + abort_with_exception(error); + } + } + std::vector> buffers; + for (auto& in : arr.inputs()) { + buffers.push_back(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.push_back(s.data_shared_ptr()); + } + if (!arr.is_tracer()) { + arr.detach(); + } + for (auto& out : outputs) { + out.set_status(array::Status::evaluated); + } + + // Copy data from GPU to CPU. + // FIXME(zcbenz): Should only do it when necessary. + if (arr.data_shared_ptr()) { + auto* dbuf = static_cast(arr.buffer().ptr()); + if (dbuf->gpu_data() && !dbuf->cpu_data()) { + device.Flush(); + wgpu::Buffer staging = device.CopyToStagingBuffer(dbuf->gpu_data()); + device.Flush(); + device.ReadStagingBuffer( + staging, + [arr, dbuf, buffers = std::move(buffers)]( + const void* data) mutable { + dbuf->copy_to_cpu(data, dbuf->size()); + arr.reset_data_ptr(); + }); + } + } + + if (signal) { + device.Flush(); + device.WaitAll(); + arr.event().signal(); + } else { + device.OnSubmittedWorkDone([buffers = std::move(buffers)]() {}); + } + }; +} + +std::function make_synchronize_task( + Stream s, + std::shared_ptr> p) { + return [s, p = std::move(p)]() { + auto& device = webgpu::device(s.device); + device.WaitAll(); + p->set_value(); + }; +} + +void start_capture(std::string) {} +void stop_capture() {} + +} // namespace mlx::core::metal diff --git a/mlx/backend/webgpu/primitives.cpp b/mlx/backend/webgpu/primitives.cpp new file mode 100644 index 0000000000..27f5ce09a4 --- /dev/null +++ b/mlx/backend/webgpu/primitives.cpp @@ -0,0 +1,264 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/webgpu/allocator.h" +#include "mlx/distributed/primitives.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" + +#define BINARY_GPU(func, op) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + binary_op_gpu(inputs, out, op); \ + } + +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no GPU implementation."); \ + } + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no GPU implementation."); \ + } + +namespace mlx::core { + +namespace { + +void set_binary_op_output_gpu_data( + betann::Device& device, + const array& a, + const array& b, + array& out, + BinaryOpType bopt) { + switch (bopt) { + case BinaryOpType::ScalarScalar: + out.set_data( + webgpu::allocator().gpu_malloc(out.itemsize()), + 1, + a.strides(), + a.flags()); + break; + case BinaryOpType::ScalarVector: + out.set_data( + webgpu::allocator().gpu_malloc(b.data_size() * out.itemsize()), + b.data_size(), + b.strides(), + b.flags()); + break; + case BinaryOpType::VectorScalar: + out.set_data( + webgpu::allocator().gpu_malloc(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + break; + case BinaryOpType::VectorVector: + out.set_data( + webgpu::allocator().gpu_malloc(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + break; + case BinaryOpType::General: + out.set_data(webgpu::allocator().gpu_malloc(out.nbytes())); + break; + } +} + +const char* dtype_to_wgsl(Dtype dtype) { + switch (dtype) { + case bool_: + return "bool"; + case int32: + return "i32"; + case uint32: + return "u32"; + case float32: + return "f32"; + case float16: + return "f16"; + default: + throw std::runtime_error("Unsupported dtype in WGSL."); + } +} + +template +std::vector to_u32_vector(const std::vector& vec) { + return std::vector(vec.begin(), vec.end()); +} + +const betann::Buffer& get_gpu_buffer(const array& arr) { + return static_cast(arr.buffer().ptr()) + ->gpu_data(); +} + +void binary_op_gpu( + const std::vector& inputs, + array& out, + const char* op) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + auto& device = webgpu::device(out.primitive().stream().device); + set_binary_op_output_gpu_data(device, a, b, out, bopt); + if (bopt == BinaryOpType::General) { + betann::BinaryOpGeneral( + device, + op, + dtype_to_wgsl(out.dtype()), + get_gpu_buffer(out), + to_u32_vector(a.shape()), + dtype_to_wgsl(a.dtype()), + get_gpu_buffer(a), + a.data_size(), + to_u32_vector(a.strides()), + get_gpu_buffer(b), + b.data_size(), + to_u32_vector(b.strides())); + } else { + betann::BinaryOpContiguous( + device, + op, + static_cast(bopt), + dtype_to_wgsl(out.dtype()), + get_gpu_buffer(out), + out.data_size(), + dtype_to_wgsl(a.dtype()), + get_gpu_buffer(a), + get_gpu_buffer(b)); + } +} + +} // namespace + +BINARY_GPU(Add, "add") +BINARY_GPU(ArcTan2, "arc_tan2") +BINARY_GPU(Divide, "divide") +BINARY_GPU(Remainder, "remainder") +BINARY_GPU(Equal, "equal") +BINARY_GPU(Greater, "greater") +BINARY_GPU(GreaterEqual, "greater_equal") +BINARY_GPU(Less, "less") +BINARY_GPU(LessEqual, "less_equal") +BINARY_GPU(LogicalAnd, "logical_and") +BINARY_GPU(LogicalOr, "logical_or") +BINARY_GPU(LogAddExp, "log_add_exp") +BINARY_GPU(Maximum, "maximum") +BINARY_GPU(Minimum, "minimum") +BINARY_GPU(Multiply, "multiply") +BINARY_GPU(NotEqual, "not_equal") +BINARY_GPU(Power, "power") +BINARY_GPU(Subtract, "subtract") + +NO_GPU(Abs) +NO_GPU(AddMM) +NO_GPU(Arange) +NO_GPU(ArcCos) +NO_GPU(ArcCosh) +NO_GPU(ArcSin) +NO_GPU(ArcSinh) +NO_GPU(ArcTan) +NO_GPU(ArcTanh) +NO_GPU(ArgPartition) +NO_GPU(ArgReduce) +NO_GPU(ArgSort) +NO_GPU(AsType) +NO_GPU(AsStrided) +NO_GPU(BitwiseBinary) +NO_GPU(BlockMaskedMM) +NO_GPU(Broadcast) +NO_GPU(BroadcastAxes) +NO_GPU(Ceil) +NO_GPU_MULTI(Compiled) +NO_GPU(Concatenate) +NO_GPU(Conjugate) +NO_GPU(Contiguous) +NO_GPU(Convolution) +NO_GPU(Copy) +NO_GPU(Cos) +NO_GPU(Cosh) +NO_GPU_MULTI(CustomTransforms) +NO_GPU_MULTI(Depends) +NO_GPU_MULTI(DivMod) +NO_GPU(DynamicSlice) +NO_GPU(DynamicSliceUpdate) +NO_GPU(NumberOfElements) +NO_GPU(Erf) +NO_GPU(ErfInv) +NO_GPU(Exp) +NO_GPU(ExpandDims) +NO_GPU(Expm1) +NO_GPU(FFT) +NO_GPU(Flatten) +NO_GPU(Floor) +NO_GPU(Full) +NO_GPU(Gather) +NO_GPU(GatherMM) +NO_GPU(GatherQMM) +NO_GPU(Hadamard) +NO_GPU(Imag) +NO_GPU(Load) +NO_GPU(Log) +NO_GPU(Log1p) +NO_GPU(LogicalNot) +NO_GPU(Matmul) +NO_GPU(Negative) +NO_GPU(Pad) +NO_GPU(Partition) +NO_GPU_MULTI(QRF) +NO_GPU(QuantizedMatmul) +NO_GPU(RandomBits) +NO_GPU(Real) +NO_GPU(Reduce) +NO_GPU(Reshape) +NO_GPU(Round) +NO_GPU(Scan) +NO_GPU(Scatter) +NO_GPU(Select) +NO_GPU(Sigmoid) +NO_GPU(Sign) +NO_GPU(Sin) +NO_GPU(Sinh) +NO_GPU(Slice) +NO_GPU(SliceUpdate) +NO_GPU(Softmax) +NO_GPU(Sort) +NO_GPU_MULTI(Split) +NO_GPU(Square) +NO_GPU(Squeeze) +NO_GPU(Sqrt) +NO_GPU(StopGradient) +NO_GPU_MULTI(SVD) +NO_GPU(Tan) +NO_GPU(Tanh) +NO_GPU(Transpose) +NO_GPU(Unflatten) +NO_GPU(Inverse) +NO_GPU(Cholesky) +NO_GPU_MULTI(Eigh) +NO_GPU(View) + +namespace fast { +NO_GPU_MULTI(LayerNorm) +NO_GPU_MULTI(LayerNormVJP) +NO_GPU_MULTI(RMSNorm) +NO_GPU_MULTI(RMSNormVJP) +NO_GPU_MULTI(RoPE) +NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(CustomKernel) +} // namespace fast + +namespace distributed { +NO_GPU_MULTI(AllReduce) +NO_GPU_MULTI(AllGather) +NO_GPU_MULTI(Send) +NO_GPU_MULTI(Recv) +} // namespace distributed + +} // namespace mlx::core