Skip to content

Commit

Permalink
move to v0.18.0 of mlx (#137)
Browse files Browse the repository at this point in the history
* move to v0.18.0 of mlx

- https://github.com/ml-explore/mlx-c v0.0.10
- https://github.com/ml-explore/mlx/compare/v0.16.0... v0.18.0

* turn on additional swift 6 concurrency checks and fix issues
* adopt new mlx_optional_*

Co-authored-by: Awni Hannun <[email protected]>
  • Loading branch information
davidkoski and awni authored Oct 3, 2024
1 parent 0bd59e8 commit 78a7cfe
Show file tree
Hide file tree
Showing 57 changed files with 2,962 additions and 841 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG "v0.0.9")
GIT_TAG "v0.0.10")
FetchContent_MakeAvailable(mlx-c)

# swift-numerics
Expand Down
33 changes: 27 additions & 6 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,31 +148,52 @@ let package = Package(
dependencies: [
"Cmlx",
.product(name: "Numerics", package: "swift-numerics"),
],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
name: "MLXRandom",
dependencies: ["MLX"]
dependencies: ["MLX"],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
name: "MLXFast",
dependencies: ["MLX", "Cmlx"]
dependencies: ["MLX", "Cmlx"],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
name: "MLXNN",
dependencies: ["MLX", "MLXRandom", "MLXFast"]
dependencies: ["MLX", "MLXRandom", "MLXFast"],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
name: "MLXOptimizers",
dependencies: ["MLX", "MLXNN"]
dependencies: ["MLX", "MLXNN"],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
name: "MLXFFT",
dependencies: ["MLX"]
dependencies: ["MLX"],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.target(
name: "MLXLinalg",
dependencies: ["MLX"]
dependencies: ["MLX"],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),

.testTarget(
Expand Down
1 change: 0 additions & 1 deletion Plugins/PrepareMetalShaders/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ struct PrepareMetalShaders: BuildToolPlugin {
"arg_reduce.metal",
"conv.metal",
"gemv.metal",
"gemv_masked.metal",
"random.metal",
"rms_norm.metal",
"layer_norm.metal",
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
Submodule mlx updated 252 files
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx-c
Submodule mlx-c updated 69 files
+6 −4 CMakeLists.txt
+1 −1 docs/src/conf.py
+1 −1 docs/src/distributed_ops.rst
+0 −5 docs/src/future.rst
+4 −3 docs/src/index.rst
+3 −3 docs/src/map.rst
+5 −0 docs/src/optional.rst
+6 −13 docs/src/overview.rst
+5 −0 docs/src/tuple.rst
+5 −0 docs/src/variant.rst
+0 −5 docs/src/vecarray.rst
+5 −0 docs/src/vector.rst
+0 −5 docs/src/vecvecarray.rst
+35 −15 examples/example-grad.c
+21 −0 examples/example.c
+0 −84 mlx/c/array.cpp
+0 −88 mlx/c/array.h
+437 −41 mlx/c/closure.cpp
+119 −34 mlx/c/closure.h
+24 −2 mlx/c/compile.cpp
+7 −1 mlx/c/compile.h
+45 −6 mlx/c/distributed.cpp
+26 −6 mlx/c/distributed.h
+53 −3 mlx/c/fast.cpp
+25 −2 mlx/c/fast.h
+2 −1 mlx/c/fft.cpp
+0 −1 mlx/c/fft.h
+0 −15 mlx/c/future.cpp
+0 −29 mlx/c/future.h
+2 −1 mlx/c/io.cpp
+0 −1 mlx/c/io.h
+17 −2 mlx/c/linalg.cpp
+5 −2 mlx/c/linalg.h
+89 −0 mlx/c/map.cpp
+67 −2 mlx/c/map.h
+6 −1 mlx/c/metal.cpp
+1 −1 mlx/c/metal.h
+13 −0 mlx/c/mlx.h
+100 −2 mlx/c/ops.cpp
+54 −2 mlx/c/ops.h
+41 −0 mlx/c/optional.h
+0 −32 mlx/c/private/array.h
+72 −2 mlx/c/private/closure.h
+0 −19 mlx/c/private/future.h
+27 −0 mlx/c/private/map.h
+84 −0 mlx/c/private/tuple.h
+31 −20 mlx/c/private/utils.h
+37 −0 mlx/c/private/variant.h
+149 −0 mlx/c/private/vector.h
+19 −2 mlx/c/random.cpp
+9 −2 mlx/c/random.h
+24 −9 mlx/c/transforms.cpp
+10 −3 mlx/c/transforms.h
+5 −4 mlx/c/transforms_impl.cpp
+1 −2 mlx/c/transforms_impl.h
+153 −0 mlx/c/tuple.cpp
+111 −0 mlx/c/tuple.h
+93 −0 mlx/c/variant.cpp
+66 −0 mlx/c/variant.h
+374 −0 mlx/c/vector.cpp
+176 −0 mlx/c/vector.h
+159 −17 python/c.py
+318 −0 python/closure_generator.py
+43 −35 python/generator.py
+295 −0 python/map_generator.py
+90 −0 python/mlxtypes.py
+246 −0 python/tuple_generator.py
+241 −0 python/variant_generator.py
+272 −0 python/vector_generator.py
66 changes: 45 additions & 21 deletions Source/Cmlx/mlx-generated/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,36 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[0], b[offset]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
Expand All @@ -57,7 +87,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
Expand All @@ -72,25 +102,10 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
Expand All @@ -101,9 +116,18 @@ template <typename T, typename U, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;
idx.y += b_xstride;
}
}
)preamble";
}
Expand Down
82 changes: 56 additions & 26 deletions Source/Cmlx/mlx-generated/binary_two.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,45 @@ template <typename T, typename U, typename Op>
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
Expand All @@ -73,7 +112,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
Expand All @@ -91,30 +130,12 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
Expand All @@ -126,11 +147,20 @@ template <typename T, typename U, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx++] = out[1];
idx.x += a_xstride;
idx.y += b_xstride;
}
}
)preamble";
}
Expand Down
28 changes: 16 additions & 12 deletions Source/Cmlx/mlx-generated/compiled_preamble.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ return R"preamble(
# 1 "Source/Cmlx/mlx/mlx/backend/common/compiled_preamble.h"
# 1 "<built-in>" 1
# 1 "<built-in>" 3
# 418 "<built-in>" 3
# 424 "<built-in>" 3
# 1 "<command line>" 1
# 1 "<built-in>" 2
# 1 "Source/Cmlx/mlx/mlx/backend/common/compiled_preamble.h" 2
Expand All @@ -22,35 +22,35 @@ return R"preamble(
# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/arm_fp16.h" 1 3
# 27 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/arm_fp16.h" 3
# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 1 3
# 96 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3
# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/arm_fp16.h" 1 3
# 27 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/arm_fp16.h" 3
# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 1 3
# 96 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3
typedef long long int int64_t;
typedef long long unsigned int uint64_t;
# 118 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3
# 118 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3
typedef int64_t int_least64_t;
typedef uint64_t uint_least64_t;
typedef int64_t int_fast64_t;
typedef uint64_t uint_fast64_t;
# 193 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3
# 193 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3
typedef int int32_t;
typedef unsigned int uint32_t;
# 216 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3
# 216 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3
typedef int32_t int_least32_t;
typedef uint32_t uint_least32_t;
typedef int32_t int_fast32_t;
typedef uint32_t uint_fast32_t;
# 241 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3
# 241 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3
typedef short int16_t;
typedef unsigned short uint16_t;
# 255 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3
# 255 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3
typedef int16_t int_least16_t;
typedef uint16_t uint_least16_t;
typedef int16_t int_fast16_t;
Expand All @@ -74,7 +74,7 @@ typedef int8_t int_least8_t;
typedef uint8_t uint_least8_t;
typedef int8_t int_fast8_t;
typedef uint8_t uint_fast8_t;
# 291 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3
# 291 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3
typedef long int intptr_t;
Expand All @@ -90,7 +90,7 @@ typedef long unsigned int uintptr_t;
typedef long int intmax_t;
typedef long unsigned int uintmax_t;
# 28 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/arm_fp16.h" 2 3
# 28 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/arm_fp16.h" 2 3
typedef __fp16 float16_t;
# 7 "Source/Cmlx/mlx/mlx/types/half_types.h" 2
Expand Down Expand Up @@ -671,6 +671,10 @@ struct Sign {
uint64_t operator()(uint64_t x) {
return x != 0;
}
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
};
struct Sin {
Expand Down
4 changes: 2 additions & 2 deletions Source/Cmlx/mlx-generated/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ struct Conv2DWeightBlockLoader {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_ -> wt_strides[0]),
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
Expand Down Expand Up @@ -581,7 +581,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_ -> wt_strides[0]),
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
Expand Down
Loading

0 comments on commit 78a7cfe

Please sign in to comment.