Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 43 additions & 39 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -308,45 +308,49 @@ public extension TextDecoding {
return kvCache
}

static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray,
valueTensor: MLMultiArray, valueSlice: MLMultiArray,
insertAtIndex index: Int)
{
let tensorShape = keyTensor.shape.map { $0.intValue }
let sliceShape = keySlice.shape.map { $0.intValue }
let sliceStrides = keySlice.strides.map { $0.intValue } // same for val
let bytesPerSample = MemoryLayout<FloatType>.size

keyTensor.withUnsafeMutableBytes { keyTensorPointer, keyTargetStrides in
keySlice.withUnsafeBytes { keySlicePointer in
valueTensor.withUnsafeMutableBytes { valueTensorPointer, valueTargetStrides in
valueSlice.withUnsafeBytes { valueSlicePointer in
// Assuming batch size is always 1
DispatchQueue.concurrentPerform(iterations: tensorShape[1]) { j in
// Slice size is 3 for prefill and 1 for decode loops
for k in 0..<sliceShape[3] {
// Equivalent to:
// `tensor[0, j, 0, k + index] = slice[0, j, 0, k + index]`
let keyDestIndex = j * keyTargetStrides[1] + (index + k) * keyTargetStrides[3]
let keyDest = keyTensorPointer.baseAddress! + keyDestIndex * bytesPerSample

let keySliceIndex = j * sliceStrides[1] + k * sliceStrides[3]
let keySlice = keySlicePointer.baseAddress! + keySliceIndex * bytesPerSample
memcpy(keyDest, keySlice, bytesPerSample)

let valDestIndex = j * valueTargetStrides[1] + (index + k) * valueTargetStrides[3]
let valDest = valueTensorPointer.baseAddress! + valDestIndex * bytesPerSample

let valSliceIndex = j * sliceStrides[1] + k * sliceStrides[3]
let valSlice = valueSlicePointer.baseAddress! + valSliceIndex * bytesPerSample
memcpy(valDest, valSlice, bytesPerSample)
}
}
}
}
}
}
}
static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray,
valueTensor: MLMultiArray, valueSlice: MLMultiArray,
insertAtIndex index: Int)
{
let tensorShape = keyTensor.shape.map { $0.intValue }
let sliceShape = keySlice.shape.map { $0.intValue }

// Create flat arrays for safe concurrent access
var keyData = [FloatType](repeating: 0, count: keyTensor.count)
var valueData = [FloatType](repeating: 0, count: valueTensor.count)

// Get current tensor data
memcpy(&keyData, keyTensor.dataPointer, keyTensor.count * MemoryLayout<FloatType>.size)
memcpy(&valueData, valueTensor.dataPointer, valueTensor.count * MemoryLayout<FloatType>.size)

// Calculate dimensions for index mapping
let seqLength = tensorShape[3]
let hiddenDim = tensorShape[1]

// Concurrent processing across hidden dimension
DispatchQueue.concurrentPerform(iterations: hiddenDim) { j in
for k in 0..<sliceShape[3] {
// Calculate linear indices
let targetSeqPos = index + k
guard targetSeqPos < seqLength else { continue }

// Map 4D indices [0, j, 0, index+k] to linear index
let flatKeyIndex = j * seqLength + targetSeqPos
let flatSliceIndex = j * sliceShape[3] + k

// Copy from slice to tensor
let sliceKeyPtr = keySlice.dataPointer.assumingMemoryBound(to: FloatType.self)
let sliceValuePtr = valueSlice.dataPointer.assumingMemoryBound(to: FloatType.self)

keyData[flatKeyIndex] = sliceKeyPtr[flatSliceIndex]
valueData[flatKeyIndex] = sliceValuePtr[flatSliceIndex]
}
}

// Copy data back to tensors
memcpy(keyTensor.dataPointer, &keyData, keyTensor.count * MemoryLayout<FloatType>.size)
memcpy(valueTensor.dataPointer, &valueData, valueTensor.count * MemoryLayout<FloatType>.size)
}

static func updateAlignmentWeights(
alignmentTensor: MLMultiArray,
Expand Down
Loading