Skip to content

Commit

Permalink
handle non-contiguous backing when reading out MLXArray (#96)
Browse files Browse the repository at this point in the history
* handle non-contiguous backing when reading out MLXArray

- fixes #83
- mlx::core::array can have non-contiguous backing
- handle those cases and simplify the readout
  • Loading branch information
davidkoski authored Jun 5, 2024
1 parent d6d9472 commit 36d63a1
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 91 deletions.
232 changes: 141 additions & 91 deletions Source/MLX/MLXArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,35 +160,30 @@ public final class MLXArray {
/// let value = array[1].item(Float.self)
/// ```
public func item<T: HasDType>(_ type: T.Type) -> T {
self.eval()
precondition(self.size == 1)

var array_ctx = self.ctx
var free = false
if type.dtype != self.dtype {
array_ctx = mlx_astype(self.ctx, type.dtype.cmlxDtype, StreamOrDevice.default.ctx)
mlx_array_eval(array_ctx)
free = true
return self.asType(type).item(type)
}

// can't do it inside the else as it will free at the end of the block
defer { if free { mlx_free(array_ctx) } }
self.eval()

switch type {
case is Bool.Type: return mlx_array_item_bool(array_ctx) as! T
case is UInt8.Type: return mlx_array_item_uint8(array_ctx) as! T
case is UInt16.Type: return mlx_array_item_uint16(array_ctx) as! T
case is UInt32.Type: return mlx_array_item_uint32(array_ctx) as! T
case is UInt64.Type: return mlx_array_item_uint64(array_ctx) as! T
case is Int8.Type: return mlx_array_item_int8(array_ctx) as! T
case is Int16.Type: return mlx_array_item_int16(array_ctx) as! T
case is Int32.Type: return mlx_array_item_int32(array_ctx) as! T
case is Int64.Type: return mlx_array_item_int64(array_ctx) as! T
case is Int.Type: return Int(mlx_array_item_int64(array_ctx)) as! T
case is Bool.Type: return mlx_array_item_bool(self.ctx) as! T
case is UInt8.Type: return mlx_array_item_uint8(self.ctx) as! T
case is UInt16.Type: return mlx_array_item_uint16(self.ctx) as! T
case is UInt32.Type: return mlx_array_item_uint32(self.ctx) as! T
case is UInt64.Type: return mlx_array_item_uint64(self.ctx) as! T
case is Int8.Type: return mlx_array_item_int8(self.ctx) as! T
case is Int16.Type: return mlx_array_item_int16(self.ctx) as! T
case is Int32.Type: return mlx_array_item_int32(self.ctx) as! T
case is Int64.Type: return mlx_array_item_int64(self.ctx) as! T
case is Int.Type: return Int(mlx_array_item_int64(self.ctx)) as! T
#if !arch(x86_64)
case is Float16.Type: return mlx_array_item_float16(array_ctx) as! T
case is Float16.Type: return mlx_array_item_float16(self.ctx) as! T
#endif
case is Float32.Type: return mlx_array_item_float32(array_ctx) as! T
case is Float.Type: return mlx_array_item_float32(array_ctx) as! T
case is Float32.Type: return mlx_array_item_float32(self.ctx) as! T
case is Float.Type: return mlx_array_item_float32(self.ctx) as! T
case is Complex<Float32>.Type:
// mlx_array_item_complex64() isn't visible in swift so read the array
// contents
Expand Down Expand Up @@ -246,6 +241,109 @@ public final class MLXArray {
asType(T.dtype, stream: stream)
}

/// Return the dimension where the storage is contiguous.
///
/// If this returns 0 then the whole storage is contiguous. If it returns ndmin + 1 then none of it is contiguous.
func contiguousToDimension() -> Int {
let shape = self.shape
let strides = self.strides

var expectedStride = 1

for (dimension, (shape, stride)) in zip(shape, strides).enumerated().reversed() {
// as long as the actual strides match the expected (contiguous) strides
// the backing is contiguous in these dimensions
if stride != expectedStride {
return dimension + 1
}
expectedStride *= shape
}

return 0
}

/// Return the physical size of the backing (assuming it is evaluated) in elements
var physicalSize: Int {
// nbytes is the logical size of the input, not the physical size
return zip(self.shape, self.strides)
.map { Swift.abs($0.0 * $0.1) }
.max()
?? self.size
}

func copy(from: UnsafeRawBufferPointer, to output: UnsafeMutableRawBufferPointer) {
let contiguousDimension = self.contiguousToDimension()

if contiguousDimension == 0 {
// entire backing is contiguous
from.copyBytes(to: output)

} else {
// only part of the backing is contiguous (possibly a single element)
// iterate the non-contiguous parts and copy the contiguous chunks into
// the output.

// these are the parts to iterate
let shape = self.shape.prefix(upTo: contiguousDimension)
let strides = self.strides.prefix(upTo: contiguousDimension)
let ndim = contiguousDimension
let itemSize = self.itemSize

// the size of each chunk that we copy. this computes the stride of
// (contiguousDimension - 1) if it were contiguous
let destItemSize: Int
if contiguousDimension == self.ndim {
// nothing contiguous
destItemSize = itemSize
} else {
destItemSize =
self.strides[contiguousDimension] * self.shape[contiguousDimension] * itemSize
}

// the index of the current source item
var index = Array.init(repeating: 0, count: ndim)

// output pointer
var dest = output.baseAddress!

while true {
// compute the source index by multiplying the index by the
// stride for each dimension

// note: in the case where the array has negative strides / offset
// the base pointer we have will have the offset already applied,
// e.g. asStrided(a, [3, 3], strides: [-3, -1], offset: 8)

let sourceIndex = zip(index, strides).reduce(0) { $0 + ($1.0 * $1.1) }

// convert to byte pointer
let src = from.baseAddress! + sourceIndex * itemSize
dest.copyMemory(from: src, byteCount: destItemSize)

// next output address
dest += destItemSize

// increment the index
for dimension in Swift.stride(from: ndim - 1, through: 0, by: -1) {
// do we need to "carry" into the next dimension?
if index[dimension] == (shape[dimension] - 1) {
if dimension == 0 {
// all done
return
}

index[dimension] = 0
} else {
// just increment the dimension and we are done
index[dimension] += 1
break
}
}
}

}
}

/// Return the contents as a single contiguous 1d `Swift.Array`.
///
/// Note: because the number of dimensions is dynamic, this cannot produce a multi-dimensional
Expand All @@ -255,53 +353,17 @@ public final class MLXArray {
/// - <doc:conversion>
/// - ``asData(noCopy:)``
public func asArray<T: HasDType>(_ type: T.Type) -> [T] {
self.eval()

var array_ctx = self.ctx
var free = false
if type.dtype != self.dtype {
array_ctx = mlx_astype(self.ctx, type.dtype.cmlxDtype, StreamOrDevice.default.ctx)
mlx_array_eval(array_ctx)
free = true
return self.asType(type).asArray(type)
}

// can't do it inside the else as it will free at the end of the block
defer { if free { mlx_free(array_ctx) } }
self.eval()

func convert(_ ptr: UnsafePointer<T>) -> [T] {
Array(UnsafeBufferPointer(start: ptr, count: self.size))
}

switch type {
case is Bool.Type: return convert(mlx_array_data_bool(array_ctx) as! UnsafePointer<T>)
case is UInt8.Type: return convert(mlx_array_data_uint8(array_ctx) as! UnsafePointer<T>)
case is UInt16.Type: return convert(mlx_array_data_uint16(array_ctx) as! UnsafePointer<T>)
case is UInt32.Type: return convert(mlx_array_data_uint32(array_ctx) as! UnsafePointer<T>)
case is UInt64.Type: return convert(mlx_array_data_uint64(array_ctx) as! UnsafePointer<T>)
case is Int8.Type: return convert(mlx_array_data_int8(array_ctx) as! UnsafePointer<T>)
case is Int16.Type: return convert(mlx_array_data_int16(array_ctx) as! UnsafePointer<T>)
case is Int32.Type: return convert(mlx_array_data_int32(array_ctx) as! UnsafePointer<T>)
case is Int64.Type: return convert(mlx_array_data_int64(array_ctx) as! UnsafePointer<T>)
case is Int.Type:
// Int and Int64 are the same bits but distinct types. coerce pointers as needed
let pointer = mlx_array_data_int64(array_ctx)
let bufferPointer = UnsafeBufferPointer(start: pointer, count: self.size)
return bufferPointer.withMemoryRebound(to: Int.self) { buffer in
Array(buffer) as! [T]
}
#if !arch(x86_64)
case is Float16.Type:
return convert(mlx_array_data_float16(array_ctx) as! UnsafePointer<T>)
#endif
case is Float32.Type: return convert(mlx_array_data_float32(array_ctx) as! UnsafePointer<T>)
case is Float.Type: return convert(mlx_array_data_float32(array_ctx) as! UnsafePointer<T>)
case is Complex<Float32>.Type:
let ptr = UnsafeBufferPointer(
start: UnsafePointer<Complex<Float32>>(mlx_array_data_complex64(ctx)),
count: self.size)
return Array(ptr) as! [T]
default:
fatalError("Unable to get item() as \(type)")
return [T](unsafeUninitializedCapacity: self.size) { destination, initializedCount in
let source = UnsafeRawBufferPointer(
start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize)
copy(from: source, to: UnsafeMutableRawBufferPointer(destination))
initializedCount = self.size
}
}

Expand All @@ -317,34 +379,22 @@ public final class MLXArray {
public func asData(noCopy: Bool = false) -> Data {
self.eval()

func convert<T>(_ ptr: UnsafePointer<T>) -> Data {
if noCopy {
Data(
bytesNoCopy: UnsafeMutableRawPointer(mutating: ptr), count: self.nbytes,
deallocator: .none)
} else {
Data(buffer: UnsafeBufferPointer(start: ptr, count: self.size))
if noCopy && self.contiguousToDimension() == 0 {
// the backing is contiguous, we can provide a wrapper
// for the contents without a copy (if requested)
let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))!
return Data(
bytesNoCopy: source, count: self.nbytes,
deallocator: .none)
} else {
let source = UnsafeRawBufferPointer(
start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize)

var data = Data(count: self.nbytes)
data.withUnsafeMutableBytes { destination in
copy(from: source, to: destination)
}
}

switch self.dtype {
case .bool: return convert(mlx_array_data_bool(ctx))
case .uint8: return convert(mlx_array_data_uint8(ctx))
case .uint16: return convert(mlx_array_data_uint16(ctx))
case .uint32: return convert(mlx_array_data_uint32(ctx))
case .uint64: return convert(mlx_array_data_uint64(ctx))
case .int8: return convert(mlx_array_data_int8(ctx))
case .int16: return convert(mlx_array_data_int16(ctx))
case .int32: return convert(mlx_array_data_int32(ctx))
case .int64: return convert(mlx_array_data_int64(ctx))
#if !arch(x86_64)
case .float16: return convert(mlx_array_data_float16(ctx))
#endif
case .float32: return convert(mlx_array_data_float32(ctx))
case .complex64:
return convert(UnsafePointer<Complex<Float32>>(mlx_array_data_complex64(ctx)))
default:
fatalError("Unable to get asData() for \(self.dtype)")
return data
}
}

Expand Down
78 changes: 78 additions & 0 deletions Tests/MLXTests/MLXArrayTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,82 @@ class MLXArrayTests: XCTestCase {
XCTAssertEqual(a[1][2].item(Int.self), 5)
}

func testAsArrayContiguous() {
// read array from contiguous memory
let a = MLXArray(0 ..< 12, [4, 3])
let b = a.asArray(Int.self)
XCTAssertEqual(b, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
}

func testAsArrayNonContiguous1() {
// skipping elements via slicing
let a = MLXArray(0 ..< 9, [3, 3])

let s = a[0 ..< 2, 1 ..< 3]
assertEqual(s, MLXArray([1, 2, 4, 5], [2, 2]))

XCTAssertEqual(s.shape, [2, 2])

// size and nbytes are the logical size
XCTAssertEqual(s.size, 2 * 2)
XCTAssertEqual(s.nbytes, 2 * 2 * s.itemSize)

// internal property for counting the physical size of the backing.
// note that the physical size doesn't include the row that is
// sliced out
XCTAssertEqual(s.physicalSize, 3 * 2)

// evaluating s (the comparison above) will realize the strides.
// if we eamine these before they might be [2, 1] which are the
// "logical" strides
XCTAssertEqual(s.strides, [3, 1])

let s_arr = s.asArray(Int32.self)
XCTAssertEqual(s_arr, [1, 2, 4, 5])
}

func testAsArrayNonContiguous2() {
// a transpose via strides
let a = MLXArray(0 ..< 12, [4, 3])

let s = asStrided(a, [3, 4], strides: [1, 3])

let expected: [Int32] = [0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11]
assertEqual(s, MLXArray(expected, [3, 4]))

// Note: be careful to use the matching type -- if we transcode
// to a different type it will be converted to contiguous
let s_arr = s.asArray(Int32.self)
XCTAssertEqual(s_arr, expected)
}

func testAsArrayNonContiguous3() {
// reversed via strides -- note that the base pointer for the
// storage has an offset applied to it
let a = MLXArray(0 ..< 9, [3, 3])

let s = asStrided(a, [3, 3], strides: [-3, -1], offset: 8)

let expected: [Int32] = [8, 7, 6, 5, 4, 3, 2, 1, 0]
assertEqual(s, MLXArray(expected, [3, 3]))

let s_arr = s.asArray(Int32.self)
XCTAssertEqual(s_arr, expected)
}

func testAsArrayNonContiguous4() {
// buffer with holes (last dimension has stride of 2 and
// thus larger storage than it physically needs)
let a = MLXArray(0 ..< 16, [4, 4])
let s = a[0..., .stride(by: 2)]

let expected: [Int32] = [0, 2, 4, 6, 8, 10, 12, 14]
assertEqual(s, MLXArray(expected, [4, 2]))

XCTAssertEqual(s.strides, [4, 2])

let s_arr = s.asArray(Int32.self)
XCTAssertEqual(s_arr, expected)
}

}

0 comments on commit 36d63a1

Please sign in to comment.