Skip to content

Commit

Permalink
Add Phi 3.5 MoE (#116)
Browse files Browse the repository at this point in the history
* Add Phi 3.5 MoE
* make sure all models are registered.  fix prompt generation
* make SwitchLinear match how Linear work re: quantization
* split switch layers into its own file (to better match the python version)

---------

Co-authored-by: David Koski <[email protected]>
  • Loading branch information
DePasqualeOrg and davidkoski authored Oct 22, 2024
1 parent 5a7a1a4 commit fad4ad0
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 27 deletions.
5 changes: 5 additions & 0 deletions Libraries/LLM/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ private class ModelTypeRegistry: @unchecked Sendable {
Phi3Configuration.self, from: Data(contentsOf: url))
return Phi3Model(configuration)
},
"phimoe": { url in
let configuration = try JSONDecoder().decode(
PhiMoEConfiguration.self, from: Data(contentsOf: url))
return PhiMoEModel(configuration)
},
"gemma": { url in
let configuration = try JSONDecoder().decode(
GemmaConfiguration.self, from: Data(contentsOf: url))
Expand Down
26 changes: 19 additions & 7 deletions Libraries/LLM/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ extension ModelConfiguration {
extraEOSTokens: ["<|end|>"]
)

public static let phi3_5MoE = ModelConfiguration(
id: "mlx-community/Phi-3.5-MoE-instruct-4bit",
defaultPrompt: "What is the gravity on Mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}

public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer",
Expand Down Expand Up @@ -202,19 +211,22 @@ extension ModelConfiguration {
case .idle:
bootstrapState = .bootstrapping
register(configurations: [
codeLlama13b4bit,
gemma2bQuantized,
gemma_2_2b_it_4bit,
gemma_2_9b_it_4bit,
llama3_1_8B_4bit,
llama3_2_1B_4bit,
llama3_2_3B_4bit,
mistralNeMo4bit,
smolLM_135M_4bit,
llama3_8B_4bit,
mistral7B4bit,
codeLlama13b4bit,
phi4bit,
mistralNeMo4bit,
openelm270m4bit,
phi3_5MoE,
phi3_5_4bit,
gemma2bQuantized,
gemma_2_9b_it_4bit,
phi4bit,
qwen205b4bit,
openelm270m4bit,
smolLM_135M_4bit,
])
bootstrapState = .bootstrapped

Expand Down
35 changes: 20 additions & 15 deletions Libraries/LLM/Phi3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,25 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider {
}
}

public struct Phi3Configuration: Codable, Sendable {
struct RopeScaling: Codable {
let longFactor: [Float]?
let shortFactor: [Float]?
let factor: Float?
let type: String?

enum CodingKeys: String, CodingKey {
case type
case factor
case longFactor = "long_factor"
case shortFactor = "short_factor"
}
struct RopeScalingWithFactorArrays: Codable {
let longFactor: [Float]?
let shortFactor: [Float]?
let factor: Float?
let type: String?
let longMScale: Float?
let shortMScale: Float?

enum CodingKeys: String, CodingKey {
case type
case factor
case longFactor = "long_factor"
case shortFactor = "short_factor"
case longMScale = "long_mscale"
case shortMScale = "short_mscale"
}
}

public struct Phi3Configuration: Codable, Sendable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
Expand All @@ -231,7 +235,7 @@ public struct Phi3Configuration: Codable, Sendable {
var kvHeads: Int
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: RopeScaling?
var ropeScaling: RopeScalingWithFactorArrays?
var maxPositionEmbeddings: Int
var originalMaxPositionEmbeddings: Int

Expand Down Expand Up @@ -273,7 +277,8 @@ public struct Phi3Configuration: Codable, Sendable {
ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
ropeScaling = try container.decodeIfPresent(RopeScaling.self, forKey: .ropeScaling)
ropeScaling = try container.decodeIfPresent(
RopeScalingWithFactorArrays.self, forKey: .ropeScaling)
maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
originalMaxPositionEmbeddings = try container.decode(
Int.self, forKey: .originalMaxPositionEmbeddings)
Expand Down
263 changes: 263 additions & 0 deletions Libraries/LLM/PhiMoE.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import Foundation
import MLX
import MLXFast
import MLXNN
import MLXRandom

// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phimoe.py

public struct PhiMoEConfiguration: Codable, Sendable {
var modelType: String = "phimoe"
var vocabularySize: Int = 32064
var hiddenSize: Int = 4096
var intermediateSize: Int = 6400
var hiddenLayers: Int = 32
var attentionHeads: Int = 32
var kvHeads: Int = 8
var maxPositionEmbeddings: Int = 131072
var originalMaxPositionEmbeddings: Int = 4096
var rmsNormEps: Float = 1e-6
var ropeScaling: RopeScalingWithFactorArrays?
var numLocalExperts: Int = 16
var numExpertsPerToken: Int = 2
var ropeTheta: Float = 10000.0

enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case vocabularySize = "vocab_size"
case hiddenSize = "hidden_size"
case intermediateSize = "intermediate_size"
case hiddenLayers = "num_hidden_layers"
case attentionHeads = "num_attention_heads"
case kvHeads = "num_key_value_heads"
case maxPositionEmbeddings = "max_position_embeddings"
case originalMaxPositionEmbeddings = "original_max_position_embeddings"
case rmsNormEps = "rms_norm_eps"
case ropeScaling = "rope_scaling"
case numLocalExperts = "num_local_experts"
case numExpertsPerToken = "num_experts_per_tok"
case ropeTheta = "rope_theta"
}
}

private class Attention: Module {
let args: PhiMoEConfiguration
let scale: Float

@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear

let rope: SuScaledRotaryEmbedding

init(_ args: PhiMoEConfiguration) {
self.args = args

let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads

let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5)

self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: true)

self.rope = SuScaledRotaryEmbedding(
dimensions: headDim,
base: args.ropeTheta,
maxPositionEmbeddings: args.maxPositionEmbeddings,
originalMaxPositionEmbeddings: args.originalMaxPositionEmbeddings,
longFactor: args.ropeScaling?.longFactor as? [Float] ?? [1.0],
longMScale: args.ropeScaling?.longMScale as? Float
)
}

func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2))

let queries = wq(x)
let keys = wk(x)
let values = wv(x)

// Prepare the queries, keys and values for the attention computation
var q = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
var k = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
var v = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if let cache {
q = rope(q, offset: cache.offset)
k = rope(k, offset: cache.offset)
(k, v) = cache.update(keys: k, values: v)
} else {
q = rope(q)
k = rope(k)
}

let output = MLXFast.scaledDotProductAttention(
queries: q, keys: k, values: v, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)

return wo(output)
}
}

private class PhiMoESparseMoeBlock: Module {
let hiddenDim: Int
let ffnDim: Int
let numExperts: Int
let topK: Int

@ModuleInfo(key: "gate") var gate: Linear
@ModuleInfo(key: "switch_mlp") var switchMLP: SwitchGLU

init(_ args: PhiMoEConfiguration) {
self.hiddenDim = args.hiddenSize
self.ffnDim = args.intermediateSize
self.numExperts = args.numLocalExperts
self.topK = args.numExpertsPerToken

self._gate.wrappedValue = Linear(hiddenDim, numExperts, bias: false)
self._switchMLP.wrappedValue = SwitchGLU(
inputDims: hiddenDim, hiddenDims: ffnDim, numExperts: numExperts)
}

func callAsFunction(_ x: MLXArray) -> MLXArray {
let gates = gate(x)

let k = self.topK
let inds = MLX.stopGradient(
MLX.argPartition(
-gates,
kth: k - 1,
axis: -1
)[.ellipsis, ..<k])
let scores = MLX.softmax(MLX.takeAlong(gates, inds, axis: -1), axis: -1, precise: true)

let y = switchMLP(x, inds)
return (y * scores[.ellipsis, .newAxis]).sum(axis: -2)
}
}

private class PhiMoEDecoderLayer: Module {
let hiddenSize: Int

@ModuleInfo(key: "self_attn") var selfAttn: Attention
@ModuleInfo(key: "block_sparse_moe") var blockSparseMoe: PhiMoESparseMoeBlock
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: LayerNorm

init(_ args: PhiMoEConfiguration) {
self.hiddenSize = args.hiddenSize

self._selfAttn.wrappedValue = Attention(args)
self._blockSparseMoe.wrappedValue = PhiMoESparseMoeBlock(args)
self._inputLayerNorm.wrappedValue = LayerNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = LayerNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
var residual = x
var hiddenStates = inputLayerNorm(x)
hiddenStates = selfAttn(hiddenStates, mask: mask, cache: cache)
hiddenStates = residual + hiddenStates

residual = hiddenStates
hiddenStates = postAttentionLayerNorm(hiddenStates)
hiddenStates = blockSparseMoe(hiddenStates)
hiddenStates = residual + hiddenStates

return hiddenStates
}
}

private class PhiMoEModelInner: Module {
let args: PhiMoEConfiguration

@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
let layers: [PhiMoEDecoderLayer]
@ModuleInfo(key: "norm") var norm: LayerNorm

init(_ args: PhiMoEConfiguration) {
self.args = args

self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers).map { _ in PhiMoEDecoderLayer(args) }
self._norm.wrappedValue = LayerNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var h = embedTokens(inputs)

let mask = createAttentionMask(h: h, cache: cache)

for (i, layer) in layers.enumerated() {
h = layer(h, mask: mask, cache: cache?[i])
}

return norm(h)
}
}

public class PhiMoEModel: Module, LLMModel, KVCacheDimensionProvider {
public let vocabularySize: Int
public let kvHeads: [Int]
public let headDim: IntOrPair

fileprivate let model: PhiMoEModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear

public init(_ args: PhiMoEConfiguration) {
self.vocabularySize = args.vocabularySize
self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
self.headDim = .init(args.hiddenSize / args.attentionHeads)
self.model = PhiMoEModelInner(args)
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true)
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
let out = model(inputs, cache: cache)
return lmHead(out)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var sanitizedWeights = weights
if sanitizedWeights["model.layers.0.block_sparse_moe.experts.0.w1.weight"] == nil {
return sanitizedWeights
}

for l in 0 ..< model.args.hiddenLayers {
let prefix = "model.layers.\(l)"
for (n, m) in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")] {
for k in ["weight", "scales", "biases"] {
if sanitizedWeights["\(prefix).block_sparse_moe.experts.0.\(n).\(k)"] != nil {
let toJoin = (0 ..< model.args.numLocalExperts).map { e in
sanitizedWeights.removeValue(
forKey: "\(prefix).block_sparse_moe.experts.\(e).\(n).\(k)")!
}
sanitizedWeights["\(prefix).block_sparse_moe.switch_mlp.\(m).\(k)"] =
MLX.stacked(toJoin)
}
}
}
}

return sanitizedWeights
}
}

// MARK: - LoRA

extension PhiMoEModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.selfAttn, ["q_proj", "v_proj"]) }
}
}
Loading

0 comments on commit fad4ad0

Please sign in to comment.