Skip to content

Commit 259aef7

Browse files
committed
Change extension on CoreML types to internal access level
Extract and consolidate CoreML extensions into separate file
1 parent ea69849 commit 259aef7

File tree

2 files changed

+47
-57
lines changed

2 files changed

+47
-57
lines changed

Sources/Generation/MLMultiArray+Utils.swift renamed to Sources/Generation/CoreML+Extensions.swift

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// MLMultiArray+Utils.swift
2+
// CoreML+Extensions.swift
33
// CoreMLBert
44
//
55
// Created by Julien Chaumond on 27/06/2019.
@@ -10,7 +10,7 @@
1010
import CoreML
1111
import Foundation
1212

13-
public extension MLMultiArray {
13+
extension MLMultiArray {
1414
/// All values will be stored in the last dimension of the MLMultiArray (default is dims=1)
1515
static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray {
1616
var shape = Array(repeating: 1, count: dims)
@@ -88,7 +88,7 @@ public extension MLMultiArray {
8888
}
8989
}
9090

91-
public extension MLMultiArray {
91+
extension MLMultiArray {
9292
/// Provides a way to index n-dimensionals arrays a la numpy.
9393
enum Indexing: Equatable {
9494
case select(Int)
@@ -197,4 +197,48 @@ extension MLMultiArray {
197197
return s + "]"
198198
}
199199
}
200+
201+
extension MLShapedArray<Float> {
202+
var floats: [Float] {
203+
guard strides.first == 1, strides.count == 1 else {
204+
// For some reason this path is slow.
205+
// If strides is not 1, we can write a Metal kernel to copy the values properly.
206+
return scalars
207+
}
208+
209+
// Fast path: memcpy
210+
let mlArray = MLMultiArray(self)
211+
return mlArray.floats ?? scalars
212+
}
213+
}
214+
215+
extension MLShapedArraySlice<Float> {
216+
var floats: [Float] {
217+
guard strides.first == 1, strides.count == 1 else {
218+
// For some reason this path is slow.
219+
// If strides is not 1, we can write a Metal kernel to copy the values properly.
220+
return scalars
221+
}
222+
223+
// Fast path: memcpy
224+
let mlArray = MLMultiArray(self)
225+
return mlArray.floats ?? scalars
226+
}
227+
}
228+
229+
extension MLMultiArray {
230+
var floats: [Float]? {
231+
guard dataType == .float32 else { return nil }
232+
233+
var result: [Float] = Array(repeating: 0, count: count)
234+
return withUnsafeBytes { ptr in
235+
guard let source = ptr.baseAddress else { return nil }
236+
result.withUnsafeMutableBytes { resultPtr in
237+
let dest = resultPtr.baseAddress!
238+
memcpy(dest, source, self.count * MemoryLayout<Float>.stride)
239+
}
240+
return result
241+
}
242+
}
243+
}
200244
#endif // canImport(CoreML)

Sources/Generation/MLShapedArray+Utils.swift

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)