|
1 | 1 | // |
2 | | -// MLMultiArray+Utils.swift |
| 2 | +// CoreML+Extensions.swift |
3 | 3 | // CoreMLBert |
4 | 4 | // |
5 | 5 | // Created by Julien Chaumond on 27/06/2019. |
|
10 | 10 | import CoreML |
11 | 11 | import Foundation |
12 | 12 |
|
13 | | -public extension MLMultiArray { |
| 13 | +extension MLMultiArray { |
14 | 14 | /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) |
15 | 15 | static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { |
16 | 16 | var shape = Array(repeating: 1, count: dims) |
@@ -88,7 +88,7 @@ public extension MLMultiArray { |
88 | 88 | } |
89 | 89 | } |
90 | 90 |
|
91 | | -public extension MLMultiArray { |
| 91 | +extension MLMultiArray { |
92 | 92 | /// Provides a way to index n-dimensionals arrays a la numpy. |
93 | 93 | enum Indexing: Equatable { |
94 | 94 | case select(Int) |
@@ -197,4 +197,48 @@ extension MLMultiArray { |
197 | 197 | return s + "]" |
198 | 198 | } |
199 | 199 | } |
| 200 | + |
| 201 | +extension MLShapedArray<Float> { |
| 202 | + var floats: [Float] { |
| 203 | + guard strides.first == 1, strides.count == 1 else { |
| 204 | + // For some reason this path is slow. |
| 205 | + // If strides is not 1, we can write a Metal kernel to copy the values properly. |
| 206 | + return scalars |
| 207 | + } |
| 208 | + |
| 209 | + // Fast path: memcpy |
| 210 | + let mlArray = MLMultiArray(self) |
| 211 | + return mlArray.floats ?? scalars |
| 212 | + } |
| 213 | +} |
| 214 | + |
| 215 | +extension MLShapedArraySlice<Float> { |
| 216 | + var floats: [Float] { |
| 217 | + guard strides.first == 1, strides.count == 1 else { |
| 218 | + // For some reason this path is slow. |
| 219 | + // If strides is not 1, we can write a Metal kernel to copy the values properly. |
| 220 | + return scalars |
| 221 | + } |
| 222 | + |
| 223 | + // Fast path: memcpy |
| 224 | + let mlArray = MLMultiArray(self) |
| 225 | + return mlArray.floats ?? scalars |
| 226 | + } |
| 227 | +} |
| 228 | + |
| 229 | +extension MLMultiArray { |
| 230 | + var floats: [Float]? { |
| 231 | + guard dataType == .float32 else { return nil } |
| 232 | + |
| 233 | + var result: [Float] = Array(repeating: 0, count: count) |
| 234 | + return withUnsafeBytes { ptr in |
| 235 | + guard let source = ptr.baseAddress else { return nil } |
| 236 | + result.withUnsafeMutableBytes { resultPtr in |
| 237 | + let dest = resultPtr.baseAddress! |
| 238 | + memcpy(dest, source, self.count * MemoryLayout<Float>.stride) |
| 239 | + } |
| 240 | + return result |
| 241 | + } |
| 242 | + } |
| 243 | +} |
200 | 244 | #endif // canImport(CoreML) |
0 commit comments