From e6bdadbb717b3b9d908bd9a300af18ab060614ae Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Sat, 31 Aug 2024 12:29:05 -0400 Subject: [PATCH 1/6] Add Phi 3.5 MoE --- Libraries/LLM/Configuration.swift | 5 + Libraries/LLM/Models.swift | 9 + Libraries/LLM/Phi3.swift | 35 +- Libraries/LLM/PhiMoE.swift | 442 +++++++++++++++++++ Libraries/LLM/SuScaledRotaryEmbedding.swift | 14 +- mlx-swift-examples.xcodeproj/project.pbxproj | 4 + 6 files changed, 489 insertions(+), 20 deletions(-) create mode 100644 Libraries/LLM/PhiMoE.swift diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index f432af4..0fbdb91 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -53,6 +53,11 @@ private class ModelTypeRegistry: @unchecked Sendable { Phi3Configuration.self, from: Data(contentsOf: url)) return Phi3Model(configuration) }, + "phimoe": { url in + let configuration = try JSONDecoder().decode( + PhiMoEConfiguration.self, from: Data(contentsOf: url)) + return PhiMoEModel(configuration) + }, "gemma": { url in let configuration = try JSONDecoder().decode( GemmaConfiguration.self, from: Data(contentsOf: url)) diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index bb5e8c3..41fb058 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -163,6 +163,15 @@ extension ModelConfiguration { "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" } + public static let phi3_5MoE = ModelConfiguration( + id: "mlx-community/Phi-3.5-MoE-instruct-4bit", + defaultPrompt: "What is the gravity on Mars and the moon?", + extraEOSTokens: ["<|end|>"] + ) { + prompt in + "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" + } + public static let gemma2bQuantized = ModelConfiguration( id: "mlx-community/quantized-gemma-2b-it", overrideTokenizer: "PreTrainedTokenizer", diff --git a/Libraries/LLM/Phi3.swift b/Libraries/LLM/Phi3.swift index d9f8ada..c7709ed 100644 --- a/Libraries/LLM/Phi3.swift +++ b/Libraries/LLM/Phi3.swift @@ -207,21 +207,25 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider { } } -public struct Phi3Configuration: Codable, Sendable { - struct RopeScaling: Codable { - let longFactor: [Float]? - let shortFactor: [Float]? - let factor: Float? - let type: String? - - enum CodingKeys: String, CodingKey { - case type - case factor - case longFactor = "long_factor" - case shortFactor = "short_factor" - } +struct RopeScalingWithFactorArrays: Codable { + let longFactor: [Float]? + let shortFactor: [Float]? + let factor: Float? + let type: String? + let longMScale: Float? + let shortMScale: Float? + + enum CodingKeys: String, CodingKey { + case type + case factor + case longFactor = "long_factor" + case shortFactor = "short_factor" + case longMScale = "long_mscale" + case shortMScale = "short_mscale" } +} +public struct Phi3Configuration: Codable, Sendable { var hiddenSize: Int var hiddenLayers: Int var intermediateSize: Int @@ -231,7 +235,7 @@ public struct Phi3Configuration: Codable, Sendable { var kvHeads: Int var ropeTheta: Float = 10_000 var ropeTraditional: Bool = false - var ropeScaling: RopeScaling? + var ropeScaling: RopeScalingWithFactorArrays? var maxPositionEmbeddings: Int var originalMaxPositionEmbeddings: Int @@ -273,7 +277,8 @@ public struct Phi3Configuration: Codable, Sendable { ropeTraditional = try container.decodeIfPresent( Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false - ropeScaling = try container.decodeIfPresent(RopeScaling.self, forKey: .ropeScaling) + ropeScaling = try container.decodeIfPresent( + RopeScalingWithFactorArrays.self, forKey: .ropeScaling) maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) originalMaxPositionEmbeddings = try container.decode( Int.self, forKey: .originalMaxPositionEmbeddings) diff --git a/Libraries/LLM/PhiMoE.swift b/Libraries/LLM/PhiMoE.swift new file mode 100644 index 0000000..f84c2c9 --- /dev/null +++ b/Libraries/LLM/PhiMoE.swift @@ -0,0 +1,442 @@ +import Foundation +import MLX +import MLXFast +import MLXNN +import MLXRandom + +// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phimoe.py + +public struct PhiMoEConfiguration: Codable, Sendable { + var modelType: String = "phimoe" + var vocabularySize: Int = 32064 + var hiddenSize: Int = 4096 + var intermediateSize: Int = 6400 + var hiddenLayers: Int = 32 + var attentionHeads: Int = 32 + var kvHeads: Int = 8 + var maxPositionEmbeddings: Int = 131072 + var originalMaxPositionEmbeddings: Int = 4096 + var rmsNormEps: Float = 1e-6 + var ropeScaling: RopeScalingWithFactorArrays? + var numLocalExperts: Int = 16 + var numExpertsPerToken: Int = 2 + var ropeTheta: Float = 10000.0 + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case vocabularySize = "vocab_size" + case hiddenSize = "hidden_size" + case intermediateSize = "intermediate_size" + case hiddenLayers = "num_hidden_layers" + case attentionHeads = "num_attention_heads" + case kvHeads = "num_key_value_heads" + case maxPositionEmbeddings = "max_position_embeddings" + case originalMaxPositionEmbeddings = "original_max_position_embeddings" + case rmsNormEps = "rms_norm_eps" + case ropeScaling = "rope_scaling" + case numLocalExperts = "num_local_experts" + case numExpertsPerToken = "num_experts_per_tok" + case ropeTheta = "rope_theta" + } +} + +private class Attention: Module { + let args: PhiMoEConfiguration + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: SuScaledRotaryEmbedding + + init(_ args: PhiMoEConfiguration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + let headDim = args.hiddenSize / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: true) + + self.rope = SuScaledRotaryEmbedding( + dimensions: headDim, + base: args.ropeTheta, + maxPositionEmbeddings: args.maxPositionEmbeddings, + originalMaxPositionEmbeddings: args.originalMaxPositionEmbeddings, + longFactor: args.ropeScaling?.longFactor as? [Float] ?? [1.0], + longMScale: args.ropeScaling?.longMScale as? Float + ) + } + + func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray { + let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2)) + + let queries = wq(x) + let keys = wk(x) + let values = wv(x) + + // Prepare the queries, keys and values for the attention computation + var q = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + var k = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + var v = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if let cache { + q = rope(q, offset: cache.offset) + k = rope(k, offset: cache.offset) + (k, v) = cache.update(keys: k, values: v) + } else { + q = rope(q) + k = rope(k) + } + + let output = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } +} + +private class PhiMoESparseMoeBlock: Module { + let hiddenDim: Int + let ffnDim: Int + let numExperts: Int + let topK: Int + + @ModuleInfo(key: "gate") var gate: Linear + @ModuleInfo(key: "switch_mlp") var switchMLP: SwitchGLU + + init(_ args: PhiMoEConfiguration) { + self.hiddenDim = args.hiddenSize + self.ffnDim = args.intermediateSize + self.numExperts = args.numLocalExperts + self.topK = args.numExpertsPerToken + + self._gate.wrappedValue = Linear(hiddenDim, numExperts, bias: false) + self._switchMLP.wrappedValue = SwitchGLU( + inputDims: hiddenDim, hiddenDims: ffnDim, numExperts: numExperts) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let gates = gate(x) + + let k = self.topK + let inds = MLX.stopGradient( + MLX.argPartition( + -gates, + kth: k - 1 + // !! Here the Python has an extra argument that is not available in Swift: axis=-1 + )[.ellipsis, .. MLXArray { + var residual = x + var hiddenStates = inputLayerNorm(x) + hiddenStates = selfAttn(hiddenStates, mask: mask, cache: cache) + hiddenStates = residual + hiddenStates + + residual = hiddenStates + hiddenStates = postAttentionLayerNorm(hiddenStates) + hiddenStates = blockSparseMoe(hiddenStates) + hiddenStates = residual + hiddenStates + + return hiddenStates + } +} + +private class PhiMoEModelInner: Module { + let args: PhiMoEConfiguration + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + let layers: [PhiMoEDecoderLayer] + @ModuleInfo(key: "norm") var norm: LayerNorm + + init(_ args: PhiMoEConfiguration) { + self.args = args + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + self.layers = (0 ..< args.hiddenLayers).map { _ in PhiMoEDecoderLayer(args) } + self._norm.wrappedValue = LayerNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + var h = embedTokens(inputs) + + let mask = createAttentionMask(h: h, cache: cache) + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } +} + +public class PhiMoEModel: Module, LLMModel, KVCacheDimensionProvider { + public let vocabularySize: Int + public let kvHeads: [Int] + public let headDim: IntOrPair + + fileprivate let model: PhiMoEModelInner + @ModuleInfo(key: "lm_head") var lmHead: Linear + + public init(_ args: PhiMoEConfiguration) { + self.vocabularySize = args.vocabularySize + self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) + self.headDim = .init(args.hiddenSize / args.attentionHeads) + self.model = PhiMoEModelInner(args) + self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + let out = model(inputs, cache: cache) + return lmHead(out) + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitizedWeights = weights + if sanitizedWeights["model.layers.0.block_sparse_moe.experts.0.w1.weight"] == nil { + return sanitizedWeights + } + + for l in 0 ..< model.args.hiddenLayers { + let prefix = "model.layers.\(l)" + for (n, m) in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")] { + for k in ["weight", "scales", "biases"] { + if sanitizedWeights["\(prefix).block_sparse_moe.experts.0.\(n).\(k)"] != nil { + let toJoin = (0 ..< model.args.numLocalExperts).map { e in + sanitizedWeights.removeValue( + forKey: "\(prefix).block_sparse_moe.experts.\(e).\(n).\(k)")! + } + sanitizedWeights["\(prefix).block_sparse_moe.switch_mlp.\(m).\(k)"] = + MLX.stacked(toJoin) + } + } + } + } + + return sanitizedWeights + } +} + +// MARK: - LoRA + +extension PhiMoEModel: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + model.layers.map { ($0.selfAttn, ["q_proj", "v_proj"]) } + } +} + +// MARK: - SwitchGLU + +class SwitchGLU: Module { + @ModuleInfo(key: "gate_proj") var gateProj: SwitchLinear + @ModuleInfo(key: "up_proj") var upProj: SwitchLinear + @ModuleInfo(key: "down_proj") var downProj: SwitchLinear + + let inputDims: Int + let hiddenDims: Int + let numExperts: Int + let activation: (MLXArray) -> MLXArray + + init( + inputDims: Int, + hiddenDims: Int, + numExperts: Int, + activation: @escaping (MLXArray) -> MLXArray = MLXNN.silu, + bias: Bool = false + ) { + self.inputDims = inputDims + self.hiddenDims = hiddenDims + self.numExperts = numExperts + self.activation = activation + + self._gateProj.wrappedValue = SwitchLinear( + inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) + self._upProj.wrappedValue = SwitchLinear( + inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) + self._downProj.wrappedValue = SwitchLinear( + inputDims: hiddenDims, outputDims: inputDims, numExperts: numExperts, bias: bias) + + super.init() + } + + func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { + let x = MLX.expandedDimensions(x, axes: [-2, -3]) + + let xUp = upProj(x, indices) + let xGate = gateProj(x, indices) + let xDown = downProj(activation(xGate) * xUp, indices) + + return MLX.squeezed(xDown, axis: -2) + } +} + +class SwitchLinear: Module { + @ModuleInfo(key: "weight") var weight: MLXArray + @ModuleInfo(key: "bias") var bias: MLXArray? + + let inputDims: Int + let outputDims: Int + let numExperts: Int + + init(inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true) { + self.inputDims = inputDims + self.outputDims = outputDims + self.numExperts = numExperts + + let scale = sqrt(1.0 / Float(inputDims)) + self._weight.wrappedValue = MLXRandom.uniform( + low: -scale, + high: scale, + [numExperts, outputDims, inputDims] + ) + + if bias { + self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims]) + } + + super.init() + } + + func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { + let weightT = self.weight.swappedAxes(-1, -2) + var result = MLX.gatherMatmul(x, weightT, rhsIndices: indices) + + if let bias = self.bias { + result = result + MLX.expandedDimensions(MLX.take(bias, indices), axis: -2) + } + + return result + } + + func toQuantized(groupSize: Int = 64, bits: Int = 4) -> QuantizedSwitchLinear { + let (numExperts, outputDims, inputDims) = ( + self.weight.shape[0], self.weight.shape[1], self.weight.shape[2] + ) + let ql = QuantizedSwitchLinear( + inputDims: inputDims, + outputDims: outputDims, + numExperts: numExperts, + bias: self.bias != nil, + groupSize: groupSize, + bits: bits + ) + + let (quantizedWeight, scales, biases) = MLX.quantized( + self.weight, groupSize: groupSize, bits: bits) + ql.weight = quantizedWeight + ql.scales = scales + ql.biases = biases + + if let bias = self.bias { + ql.bias = bias + } + + return ql + } +} + +class QuantizedSwitchLinear: Module { + @ModuleInfo(key: "weight") var weight: MLXArray + @ModuleInfo(key: "scales") var scales: MLXArray + @ModuleInfo(key: "biases") var biases: MLXArray + @ModuleInfo(key: "bias") var bias: MLXArray? + + let groupSize: Int + let bits: Int + + init( + inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true, groupSize: Int = 64, + bits: Int = 4 + ) { + self.groupSize = groupSize + self.bits = bits + + let scale = sqrt(1.0 / Float(inputDims)) + let uniformWeight = MLXRandom.uniform( + low: -scale, + high: scale, + [numExperts, outputDims, inputDims] + ) + + let (quantizedWeight, scales, biases) = MLX.quantized( + uniformWeight, groupSize: groupSize, bits: bits) + self._weight.wrappedValue = quantizedWeight + self._scales.wrappedValue = scales + self._biases.wrappedValue = biases + + if bias { + self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims]) + } + + super.init() + self.freeze() + } + + var inputDims: Int { + return scales.shape[2] * groupSize + } + + var outputDims: Int { + return weight.shape[1] + } + + var numExperts: Int { + return weight.shape[0] + } + + func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { + var result = MLX.gatherQuantizedMatmul( + x, + self.weight, + scales: self.scales, + biases: self.biases, + rhsIndices: indices, + transpose: true, + groupSize: self.groupSize, + bits: self.bits + ) + + if let bias = self.bias { + result = result + MLX.expandedDimensions(MLX.take(bias, indices), axis: -2) + } + + return result + } +} diff --git a/Libraries/LLM/SuScaledRotaryEmbedding.swift b/Libraries/LLM/SuScaledRotaryEmbedding.swift index 16fb970..3dec287 100644 --- a/Libraries/LLM/SuScaledRotaryEmbedding.swift +++ b/Libraries/LLM/SuScaledRotaryEmbedding.swift @@ -16,7 +16,9 @@ public class SuScaledRotaryEmbedding: Module { base: Float = 10000.0, maxPositionEmbeddings: Int = 131072, originalMaxPositionEmbeddings: Int = 4096, - longFactor: [Float] = [1.0] + longFactor: [Float] = [1.0], + // shortMScale: Float? = nil, + longMScale: Float? = nil ) { precondition(dimensions % 2 == 0, "Dimensions must be even") @@ -30,10 +32,12 @@ public class SuScaledRotaryEmbedding: Module { let freqs = MLX.pow(MLXArray(base), exponent) self._freqs = MLXArray(longFactor).asType(.float32) * freqs - self.scale = sqrt( - 1 + log(Float(maxPositionEmbeddings) / Float(originalMaxPositionEmbeddings)) - / log(Float(originalMaxPositionEmbeddings)) - ) + self.scale = + longMScale + ?? sqrt( + 1 + log(Float(maxPositionEmbeddings) / Float(originalMaxPositionEmbeddings)) + / log(Float(originalMaxPositionEmbeddings)) + ) } public func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray { diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 7e5f974..6152a34 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -15,6 +15,7 @@ 7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */; }; 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; }; 819BEFF82BAF8B4E0002CCEE /* DeviceStat.swift in Sources */ = {isa = PBXBuildFile; fileRef = 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */; }; + 927B80422C83769800500C13 /* PhiMoE.swift in Sources */ = {isa = PBXBuildFile; fileRef = 927B80412C83769400500C13 /* PhiMoE.swift */; }; 927C784E2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift in Sources */ = {isa = PBXBuildFile; fileRef = 927C784D2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift */; }; C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */; }; C3056BB02BCD97B700A31D04 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAF2BCD97B700A31D04 /* ContentView.swift */; }; @@ -296,6 +297,7 @@ 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = ""; }; 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpenELM.swift; sourceTree = ""; }; 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = ""; }; + 927B80412C83769400500C13 /* PhiMoE.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PhiMoE.swift; sourceTree = ""; }; 927C784D2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SuScaledRotaryEmbedding.swift; sourceTree = ""; }; C3056BA12BCD973400A31D04 /* test.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = test.jsonl; sourceTree = ""; }; C3056BA22BCD973400A31D04 /* train.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = train.jsonl; sourceTree = ""; }; @@ -658,6 +660,7 @@ C38935E02B869F420037B833 /* LLMModel.swift */, C38935DE2B869DD00037B833 /* Phi.swift */, 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */, + 927B80412C83769400500C13 /* PhiMoE.swift */, C34E48F62B69832600FCB841 /* README.md */, C34E48ED2B696E6500FCB841 /* Load.swift */, C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, @@ -1336,6 +1339,7 @@ 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */, C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */, C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */, + 927B80422C83769800500C13 /* PhiMoE.swift in Sources */, 7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */, C38935DF2B869DD00037B833 /* Phi.swift in Sources */, C38935CE2B869C870037B833 /* Load.swift in Sources */, From f692c71a1232a08f9c1da4849d1f73f847017827 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Wed, 4 Sep 2024 15:46:41 -0400 Subject: [PATCH 2/6] Address feedback --- Libraries/LLM/PhiMoE.swift | 73 ++++++++------------------------------ 1 file changed, 15 insertions(+), 58 deletions(-) diff --git a/Libraries/LLM/PhiMoE.swift b/Libraries/LLM/PhiMoE.swift index f84c2c9..64f79ee 100644 --- a/Libraries/LLM/PhiMoE.swift +++ b/Libraries/LLM/PhiMoE.swift @@ -307,7 +307,7 @@ class SwitchGLU: Module { } } -class SwitchLinear: Module { +class SwitchLinear: Module, Quantizable { @ModuleInfo(key: "weight") var weight: MLXArray @ModuleInfo(key: "bias") var bias: MLXArray? @@ -345,83 +345,40 @@ class SwitchLinear: Module { return result } - func toQuantized(groupSize: Int = 64, bits: Int = 4) -> QuantizedSwitchLinear { - let (numExperts, outputDims, inputDims) = ( - self.weight.shape[0], self.weight.shape[1], self.weight.shape[2] - ) - let ql = QuantizedSwitchLinear( - inputDims: inputDims, - outputDims: outputDims, - numExperts: numExperts, - bias: self.bias != nil, - groupSize: groupSize, - bits: bits - ) - - let (quantizedWeight, scales, biases) = MLX.quantized( - self.weight, groupSize: groupSize, bits: bits) - ql.weight = quantizedWeight - ql.scales = scales - ql.biases = biases - - if let bias = self.bias { - ql.bias = bias - } - - return ql + func toQuantized(groupSize: Int = 64, bits: Int = 4) -> Module { + QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits) } } -class QuantizedSwitchLinear: Module { - @ModuleInfo(key: "weight") var weight: MLXArray +class QuantizedSwitchLinear: SwitchLinear { @ModuleInfo(key: "scales") var scales: MLXArray @ModuleInfo(key: "biases") var biases: MLXArray - @ModuleInfo(key: "bias") var bias: MLXArray? let groupSize: Int let bits: Int - init( - inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true, groupSize: Int = 64, - bits: Int = 4 - ) { + init(_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4) { self.groupSize = groupSize self.bits = bits - let scale = sqrt(1.0 / Float(inputDims)) - let uniformWeight = MLXRandom.uniform( - low: -scale, - high: scale, - [numExperts, outputDims, inputDims] - ) + super.init( + inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts, + bias: other.bias != nil) let (quantizedWeight, scales, biases) = MLX.quantized( - uniformWeight, groupSize: groupSize, bits: bits) - self._weight.wrappedValue = quantizedWeight - self._scales.wrappedValue = scales - self._biases.wrappedValue = biases + other.weight, groupSize: groupSize, bits: bits) + self.weight = quantizedWeight + self.scales = scales + self.biases = biases - if bias { - self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims]) + if let bias = other.bias { + self.bias = bias } - super.init() self.freeze() } - var inputDims: Int { - return scales.shape[2] * groupSize - } - - var outputDims: Int { - return weight.shape[1] - } - - var numExperts: Int { - return weight.shape[0] - } - - func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { + override func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { var result = MLX.gatherQuantizedMatmul( x, self.weight, From 732a4859fbe64b2bffe735bd84e5d49cae603a76 Mon Sep 17 00:00:00 2001 From: David Koski Date: Tue, 22 Oct 2024 09:21:59 -0700 Subject: [PATCH 3/6] make sure all models are registered. fix prompt generation --- Libraries/LLM/Models.swift | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 41fb058..4fb9dd9 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -169,7 +169,7 @@ extension ModelConfiguration { extraEOSTokens: ["<|end|>"] ) { prompt in - "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" + "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" } public static let gemma2bQuantized = ModelConfiguration( @@ -269,19 +269,22 @@ extension ModelConfiguration { case .idle: bootstrapState = .bootstrapping register(configurations: [ + codeLlama13b4bit, + gemma2bQuantized, + gemma_2_2b_it_4bit, + gemma_2_9b_it_4bit, llama3_1_8B_4bit, llama3_2_1B_4bit, llama3_2_3B_4bit, - mistralNeMo4bit, - smolLM_135M_4bit, + llama3_8B_4bit, mistral7B4bit, - codeLlama13b4bit, - phi4bit, + mistralNeMo4bit, + openelm270m4bit, + phi3_5MoE, phi3_5_4bit, - gemma2bQuantized, - gemma_2_9b_it_4bit, + phi4bit, qwen205b4bit, - openelm270m4bit, + smolLM_135M_4bit, ]) bootstrapState = .bootstrapped From de4ee6ebe69bfb2784a150d2a2c0ccd0e49ae7ea Mon Sep 17 00:00:00 2001 From: David Koski Date: Tue, 22 Oct 2024 09:22:27 -0700 Subject: [PATCH 4/6] make SwitchLinear match how Linear work re: quantization --- Libraries/LLM/PhiMoE.swift | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/Libraries/LLM/PhiMoE.swift b/Libraries/LLM/PhiMoE.swift index 64f79ee..bb92d6a 100644 --- a/Libraries/LLM/PhiMoE.swift +++ b/Libraries/LLM/PhiMoE.swift @@ -334,6 +334,22 @@ class SwitchLinear: Module, Quantizable { super.init() } + /// Initializer meant for subclasses to provide weight and bias arrays directly. + /// + /// This is used e.g. by ``QuantizedSwitchLinear`` to provide quantized weights and biases + /// rather than have ``SwitchLinear`` compute them. + init( + inputDims: Int, outputDims: Int, numExperts: Int, + weight: MLXArray, bias: MLXArray? = nil + ) { + self.inputDims = inputDims + self.outputDims = outputDims + self.numExperts = numExperts + + self._weight.wrappedValue = weight + self._bias.wrappedValue = bias + } + func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { let weightT = self.weight.swappedAxes(-1, -2) var result = MLX.gatherMatmul(x, weightT, rhsIndices: indices) @@ -361,19 +377,15 @@ class QuantizedSwitchLinear: SwitchLinear { self.groupSize = groupSize self.bits = bits - super.init( - inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts, - bias: other.bias != nil) - let (quantizedWeight, scales, biases) = MLX.quantized( other.weight, groupSize: groupSize, bits: bits) - self.weight = quantizedWeight - self.scales = scales - self.biases = biases - if let bias = other.bias { - self.bias = bias - } + self._scales.wrappedValue = scales + self._biases.wrappedValue = biases + + super.init( + inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts, + weight: quantizedWeight, bias: other.bias) self.freeze() } From 9e785c6a487c01c2de44090429e0b2a23f84e237 Mon Sep 17 00:00:00 2001 From: David Koski Date: Tue, 22 Oct 2024 10:22:46 -0700 Subject: [PATCH 5/6] split switch layers into its own file (to better match the python version) --- Libraries/LLM/PhiMoE.swift | 148 ------------------ Libraries/LLM/SwitchLayers.swift | 155 +++++++++++++++++++ mlx-swift-examples.xcodeproj/project.pbxproj | 4 + 3 files changed, 159 insertions(+), 148 deletions(-) create mode 100644 Libraries/LLM/SwitchLayers.swift diff --git a/Libraries/LLM/PhiMoE.swift b/Libraries/LLM/PhiMoE.swift index bb92d6a..fbf40b3 100644 --- a/Libraries/LLM/PhiMoE.swift +++ b/Libraries/LLM/PhiMoE.swift @@ -261,151 +261,3 @@ extension PhiMoEModel: LoRAModel { model.layers.map { ($0.selfAttn, ["q_proj", "v_proj"]) } } } - -// MARK: - SwitchGLU - -class SwitchGLU: Module { - @ModuleInfo(key: "gate_proj") var gateProj: SwitchLinear - @ModuleInfo(key: "up_proj") var upProj: SwitchLinear - @ModuleInfo(key: "down_proj") var downProj: SwitchLinear - - let inputDims: Int - let hiddenDims: Int - let numExperts: Int - let activation: (MLXArray) -> MLXArray - - init( - inputDims: Int, - hiddenDims: Int, - numExperts: Int, - activation: @escaping (MLXArray) -> MLXArray = MLXNN.silu, - bias: Bool = false - ) { - self.inputDims = inputDims - self.hiddenDims = hiddenDims - self.numExperts = numExperts - self.activation = activation - - self._gateProj.wrappedValue = SwitchLinear( - inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) - self._upProj.wrappedValue = SwitchLinear( - inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) - self._downProj.wrappedValue = SwitchLinear( - inputDims: hiddenDims, outputDims: inputDims, numExperts: numExperts, bias: bias) - - super.init() - } - - func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { - let x = MLX.expandedDimensions(x, axes: [-2, -3]) - - let xUp = upProj(x, indices) - let xGate = gateProj(x, indices) - let xDown = downProj(activation(xGate) * xUp, indices) - - return MLX.squeezed(xDown, axis: -2) - } -} - -class SwitchLinear: Module, Quantizable { - @ModuleInfo(key: "weight") var weight: MLXArray - @ModuleInfo(key: "bias") var bias: MLXArray? - - let inputDims: Int - let outputDims: Int - let numExperts: Int - - init(inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true) { - self.inputDims = inputDims - self.outputDims = outputDims - self.numExperts = numExperts - - let scale = sqrt(1.0 / Float(inputDims)) - self._weight.wrappedValue = MLXRandom.uniform( - low: -scale, - high: scale, - [numExperts, outputDims, inputDims] - ) - - if bias { - self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims]) - } - - super.init() - } - - /// Initializer meant for subclasses to provide weight and bias arrays directly. - /// - /// This is used e.g. by ``QuantizedSwitchLinear`` to provide quantized weights and biases - /// rather than have ``SwitchLinear`` compute them. - init( - inputDims: Int, outputDims: Int, numExperts: Int, - weight: MLXArray, bias: MLXArray? = nil - ) { - self.inputDims = inputDims - self.outputDims = outputDims - self.numExperts = numExperts - - self._weight.wrappedValue = weight - self._bias.wrappedValue = bias - } - - func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { - let weightT = self.weight.swappedAxes(-1, -2) - var result = MLX.gatherMatmul(x, weightT, rhsIndices: indices) - - if let bias = self.bias { - result = result + MLX.expandedDimensions(MLX.take(bias, indices), axis: -2) - } - - return result - } - - func toQuantized(groupSize: Int = 64, bits: Int = 4) -> Module { - QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits) - } -} - -class QuantizedSwitchLinear: SwitchLinear { - @ModuleInfo(key: "scales") var scales: MLXArray - @ModuleInfo(key: "biases") var biases: MLXArray - - let groupSize: Int - let bits: Int - - init(_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4) { - self.groupSize = groupSize - self.bits = bits - - let (quantizedWeight, scales, biases) = MLX.quantized( - other.weight, groupSize: groupSize, bits: bits) - - self._scales.wrappedValue = scales - self._biases.wrappedValue = biases - - super.init( - inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts, - weight: quantizedWeight, bias: other.bias) - - self.freeze() - } - - override func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { - var result = MLX.gatherQuantizedMatmul( - x, - self.weight, - scales: self.scales, - biases: self.biases, - rhsIndices: indices, - transpose: true, - groupSize: self.groupSize, - bits: self.bits - ) - - if let bias = self.bias { - result = result + MLX.expandedDimensions(MLX.take(bias, indices), axis: -2) - } - - return result - } -} diff --git a/Libraries/LLM/SwitchLayers.swift b/Libraries/LLM/SwitchLayers.swift new file mode 100644 index 0000000..b7e62d5 --- /dev/null +++ b/Libraries/LLM/SwitchLayers.swift @@ -0,0 +1,155 @@ +import Foundation +import MLX +import MLXFast +import MLXNN +import MLXRandom + +// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/switch_layers.py + +// MARK: - SwitchGLU + +class SwitchGLU: Module { + @ModuleInfo(key: "gate_proj") var gateProj: SwitchLinear + @ModuleInfo(key: "up_proj") var upProj: SwitchLinear + @ModuleInfo(key: "down_proj") var downProj: SwitchLinear + + let inputDims: Int + let hiddenDims: Int + let numExperts: Int + let activation: (MLXArray) -> MLXArray + + init( + inputDims: Int, + hiddenDims: Int, + numExperts: Int, + activation: @escaping (MLXArray) -> MLXArray = MLXNN.silu, + bias: Bool = false + ) { + self.inputDims = inputDims + self.hiddenDims = hiddenDims + self.numExperts = numExperts + self.activation = activation + + self._gateProj.wrappedValue = SwitchLinear( + inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) + self._upProj.wrappedValue = SwitchLinear( + inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) + self._downProj.wrappedValue = SwitchLinear( + inputDims: hiddenDims, outputDims: inputDims, numExperts: numExperts, bias: bias) + + super.init() + } + + func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { + let x = MLX.expandedDimensions(x, axes: [-2, -3]) + + let xUp = upProj(x, indices) + let xGate = gateProj(x, indices) + let xDown = downProj(activation(xGate) * xUp, indices) + + return MLX.squeezed(xDown, axis: -2) + } +} + +class SwitchLinear: Module, Quantizable { + @ModuleInfo(key: "weight") var weight: MLXArray + @ModuleInfo(key: "bias") var bias: MLXArray? + + let inputDims: Int + let outputDims: Int + let numExperts: Int + + init(inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true) { + self.inputDims = inputDims + self.outputDims = outputDims + self.numExperts = numExperts + + let scale = sqrt(1.0 / Float(inputDims)) + self._weight.wrappedValue = MLXRandom.uniform( + low: -scale, + high: scale, + [numExperts, outputDims, inputDims] + ) + + if bias { + self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims]) + } + + super.init() + } + + /// Initializer meant for subclasses to provide weight and bias arrays directly. + /// + /// This is used e.g. by ``QuantizedSwitchLinear`` to provide quantized weights and biases + /// rather than have ``SwitchLinear`` compute them. + init( + inputDims: Int, outputDims: Int, numExperts: Int, + weight: MLXArray, bias: MLXArray? = nil + ) { + self.inputDims = inputDims + self.outputDims = outputDims + self.numExperts = numExperts + + self._weight.wrappedValue = weight + self._bias.wrappedValue = bias + } + + func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { + let weightT = self.weight.swappedAxes(-1, -2) + var result = MLX.gatherMatmul(x, weightT, rhsIndices: indices) + + if let bias = self.bias { + result = result + MLX.expandedDimensions(bias[indices], axis: -2) + } + + return result + } + + func toQuantized(groupSize: Int = 64, bits: Int = 4) -> Module { + QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits) + } +} + +class QuantizedSwitchLinear: SwitchLinear { + @ModuleInfo(key: "scales") var scales: MLXArray + @ModuleInfo(key: "biases") var biases: MLXArray + + let groupSize: Int + let bits: Int + + init(_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4) { + self.groupSize = groupSize + self.bits = bits + + let (quantizedWeight, scales, biases) = MLX.quantized( + other.weight, groupSize: groupSize, bits: bits) + + self._scales.wrappedValue = scales + self._biases.wrappedValue = biases + + super.init( + inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts, + weight: quantizedWeight, bias: other.bias) + + self.freeze() + } + + override func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { + var result = MLX.gatherQuantizedMatmul( + x, + self.weight, + scales: self.scales, + biases: self.biases, + rhsIndices: indices, + transpose: true, + groupSize: self.groupSize, + bits: self.bits + ) + + if let bias = self.bias { + result = result + MLX.expandedDimensions(bias[indices], axis: -2) + } + + return result + } +} diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 6152a34..568b145 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -28,6 +28,7 @@ C3056BBE2BCD984F00A31D04 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; }; C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; }; + C343B2782CC8091B00334888 /* SwitchLayers.swift in Sources */ = {isa = PBXBuildFile; fileRef = C343B2772CC8091B00334888 /* SwitchLayers.swift */; }; C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; }; C34E49102B69A92900FCB841 /* MNIST.h in Headers */ = {isa = PBXBuildFile; fileRef = C34E490F2B69A92900FCB841 /* MNIST.h */; settings = {ATTRIBUTES = (Public, ); }; }; C34E49152B69C1E300FCB841 /* Files.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E49142B69C1E300FCB841 /* Files.swift */; }; @@ -314,6 +315,7 @@ C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; }; C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = ""; }; C3288D842B6D94BD009FF608 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + C343B2772CC8091B00334888 /* SwitchLayers.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SwitchLayers.swift; sourceTree = ""; }; C34E48ED2B696E6500FCB841 /* Load.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Load.swift; sourceTree = ""; }; C34E48EE2B696E6500FCB841 /* Llama.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Llama.swift; sourceTree = ""; }; C34E48EF2B696E6500FCB841 /* Configuration.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Configuration.swift; sourceTree = ""; }; @@ -661,6 +663,7 @@ C38935DE2B869DD00037B833 /* Phi.swift */, 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */, 927B80412C83769400500C13 /* PhiMoE.swift */, + C343B2772CC8091B00334888 /* SwitchLayers.swift */, C34E48F62B69832600FCB841 /* README.md */, C34E48ED2B696E6500FCB841 /* Load.swift */, C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, @@ -1348,6 +1351,7 @@ 1C55317A2C5AAB4E00B07ECD /* Gemma2.swift in Sources */, C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */, C38935CC2B869C870037B833 /* Llama.swift in Sources */, + C343B2782CC8091B00334888 /* SwitchLayers.swift in Sources */, 927C784E2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift in Sources */, 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */, ); From 32cfffb6623a9011ecbc0a45702c2e18ab68cb58 Mon Sep 17 00:00:00 2001 From: David Koski Date: Tue, 22 Oct 2024 10:23:39 -0700 Subject: [PATCH 6/6] test & fix commented lines -- was very close! --- Libraries/LLM/PhiMoE.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Libraries/LLM/PhiMoE.swift b/Libraries/LLM/PhiMoE.swift index fbf40b3..e11542f 100644 --- a/Libraries/LLM/PhiMoE.swift +++ b/Libraries/LLM/PhiMoE.swift @@ -134,13 +134,13 @@ private class PhiMoESparseMoeBlock: Module { let inds = MLX.stopGradient( MLX.argPartition( -gates, - kth: k - 1 - // !! Here the Python has an extra argument that is not available in Swift: axis=-1 + kth: k - 1, + axis: -1 )[.ellipsis, ..