Skip to content

Commit

Permalink
Add webgpu backend
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Jan 29, 2025
1 parent ccb61d7 commit dde82fc
Show file tree
Hide file tree
Showing 10 changed files with 579 additions and 0 deletions.
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
endif()
endif()

if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
endif()

else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
Expand Down Expand Up @@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()

if(MLX_BUILD_WEBGPU)
FetchContent_Declare(
betann
GIT_REPOSITORY https://github.com/frost-beta/betann.git
GIT_TAG 77d0837879e6549f04ef37158000697c94fe6702
EXCLUDE_FROM_ALL)
set(BETANN_BUILD_TESTS OFF)
FetchContent_MakeAvailable(betann)
target_link_libraries(mlx PRIVATE betann)
endif()

if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
Expand Down
1 change: 1 addition & 0 deletions examples/cpp/tutorial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace mx = mlx::core;
void array_basics() {
// Make a scalar array:
mx::array x(1.0);
std::cout << x + x << std::endl;

// Get the value out of it:
auto s = x.item<float>();
Expand Down
2 changes: 2 additions & 0 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ endif()

if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
elseif(MLX_BUILD_WEBGPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()
4 changes: 4 additions & 0 deletions mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ bool array::is_tracer() const {
detail::retain_graph();
}

void array::reset_data_ptr() {
array_desc_->data_ptr = buffer().raw_ptr();
}

void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
Expand Down
2 changes: 2 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;

void reset_data_ptr();

void set_data(allocator::Buffer buffer, Deleter d = allocator::free);

void set_data(
Expand Down
6 changes: 6 additions & 0 deletions mlx/backend/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp)
119 changes: 119 additions & 0 deletions mlx/backend/webgpu/allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/webgpu/allocator.h"

namespace mlx::core {

namespace allocator {

Allocator& allocator() {
return webgpu::allocator();
}

void* Buffer::raw_ptr() {
return static_cast<webgpu::DoubleBuffer*>(ptr_)->cpu_data();
}

} // namespace allocator

namespace webgpu {

DoubleBuffer::DoubleBuffer(size_t size)
: cpu_data_(std::malloc(size + sizeof(size_t))) {
*static_cast<size_t*>(cpu_data_) = size;
}

DoubleBuffer::DoubleBuffer(betann::Device& device, size_t size)
: gpu_data_(device.CreateBuffer(
size,
betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {}

DoubleBuffer::~DoubleBuffer() {
std::free(cpu_data_);
}

void DoubleBuffer::copy_to_cpu(const void* data, size_t size) {
assert(!cpu_data_);
cpu_data_ = std::malloc(size + sizeof(size_t));
*static_cast<size_t*>(cpu_data_) = size;
std::memcpy(cpu_data(), data, size);
}

size_t DoubleBuffer::size() const {
if (cpu_data_)
return *static_cast<size_t*>(cpu_data_);
if (gpu_data_)
return gpu_data_.GetSize();
return 0;
}

WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {}

Buffer WgpuAllocator::malloc(size_t size, bool allow_swap) {
return Buffer(new DoubleBuffer(size));
}

void WgpuAllocator::free(Buffer buffer) {
delete static_cast<DoubleBuffer*>(buffer.ptr());
}

size_t WgpuAllocator::size(Buffer buffer) const {
return static_cast<DoubleBuffer*>(buffer.ptr())->size();
}

Buffer WgpuAllocator::gpu_malloc(size_t size) {
return Buffer(new DoubleBuffer(device_, size));
}

void WgpuAllocator::ensure_gpu_data(Buffer& buffer) {
auto* dbuf = static_cast<DoubleBuffer*>(buffer.ptr());
if (dbuf->gpu_data() || dbuf->size() == 0)
return;
dbuf->set_gpu_data(device_.CreateBufferFromData(
dbuf->cpu_data(), dbuf->size(), betann::BufferUsage::Storage));
}

WgpuAllocator& allocator() {
static WgpuAllocator allocator_;
return allocator_;
}

betann::Device& device(mlx::core::Device) {
static betann::Device device;
return device;
}

} // namespace webgpu

namespace metal {

size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}

std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
throw std::runtime_error("[webgpu::device_info] Not implemented");
};

void clear_cache() {}

} // namespace metal

} // namespace mlx::core
63 changes: 63 additions & 0 deletions mlx/backend/webgpu/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include "mlx/allocator.h"
#include "mlx/device.h"

#include <betann/betann.h>

namespace mlx::core::webgpu {

using allocator::Buffer;

// Holds data for both CPU and GPU.
class DoubleBuffer {
public:
// Allocates memory in CPU.
explicit DoubleBuffer(size_t size);
// Allocates memory in GPU.
DoubleBuffer(betann::Device& device, size_t size);

~DoubleBuffer();

void copy_to_cpu(const void* data, size_t size);
void set_gpu_data(betann::Buffer buffer) {
gpu_data_ = std::move(buffer);
}

void* cpu_data() const {
return cpu_data_ ? static_cast<size_t*>(cpu_data_) + 1 : nullptr;
}
const betann::Buffer& gpu_data() const {
return gpu_data_;
}

size_t size() const;

private:
void* cpu_data_ = nullptr;
betann::Buffer gpu_data_;
};

class WgpuAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size, bool allow_swap = false) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;

Buffer gpu_malloc(size_t size);
void ensure_gpu_data(Buffer& buffer);

private:
WgpuAllocator();
friend WgpuAllocator& allocator();

betann::Device& device_;
};

WgpuAllocator& allocator();

betann::Device& device(mlx::core::Device);

} // namespace mlx::core::webgpu
102 changes: 102 additions & 0 deletions mlx/backend/webgpu/metal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright © 2023-2024 Apple Inc.

#include <stdexcept>

#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/webgpu/allocator.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"

namespace mlx::core::metal {

bool is_available() {
return true;
}

void new_stream(Stream) {}

std::function<void()> make_task(array arr, bool signal) {
return [arr = std::move(arr), signal]() mutable {
auto s = arr.primitive().stream();
auto& device = webgpu::device(s.device);

for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
input.event().wait();
}
// Ensure all inputs copy their CPU data to GPU.
webgpu::allocator().ensure_gpu_data(input.buffer());
}

auto outputs = arr.outputs();
{
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}

try {
arr.primitive().eval_gpu(arr.inputs(), outputs);
} catch (const std::exception& error) {
abort_with_exception(error);
}
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.push_back(s.data_shared_ptr());
}
if (!arr.is_tracer()) {
arr.detach();
}
for (auto& out : outputs) {
out.set_status(array::Status::evaluated);
}

// Copy data from GPU to CPU.
// FIXME(zcbenz): Should only do it when necessary.
if (arr.data_shared_ptr()) {
auto* dbuf = static_cast<webgpu::DoubleBuffer*>(arr.buffer().ptr());
if (dbuf->gpu_data() && !dbuf->cpu_data()) {
device.Flush();
wgpu::Buffer staging = device.CopyToStagingBuffer(dbuf->gpu_data());
device.Flush();
device.ReadStagingBuffer(
staging,
[arr, dbuf, buffers = std::move(buffers)](
const void* data) mutable {
dbuf->copy_to_cpu(data, dbuf->size());
arr.reset_data_ptr();
});
}
}

if (signal) {
device.Flush();
device.WaitAll();
arr.event().signal();
} else {
device.OnSubmittedWorkDone([buffers = std::move(buffers)]() {});
}
};
}

std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p) {
return [s, p = std::move(p)]() {
auto& device = webgpu::device(s.device);
device.WaitAll();
p->set_value();
};
}

void start_capture(std::string) {}
void stop_capture() {}

} // namespace mlx::core::metal
Loading

0 comments on commit dde82fc

Please sign in to comment.