Skip to content

Commit

Permalink
Add PhiMoE
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Aug 31, 2024
1 parent 9266a62 commit c52b1e1
Show file tree
Hide file tree
Showing 6 changed files with 489 additions and 20 deletions.
5 changes: 5 additions & 0 deletions Libraries/LLM/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public enum ModelType: String, Codable, Sendable {
case llama
case phi
case phi3
case phimoe
case gemma
case gemma2
case qwen2
Expand All @@ -52,6 +53,10 @@ public enum ModelType: String, Codable, Sendable {
let configuration = try JSONDecoder().decode(
Phi3Configuration.self, from: Data(contentsOf: configuration))
return Phi3Model(configuration)
case .phimoe:
let configuration = try JSONDecoder().decode(
PhiMoEConfiguration.self, from: Data(contentsOf: configuration))
return PhiMoEModel(configuration)
case .gemma:
let configuration = try JSONDecoder().decode(
GemmaConfiguration.self, from: Data(contentsOf: configuration))
Expand Down
9 changes: 9 additions & 0 deletions Libraries/LLM/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ extension ModelConfiguration {
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}

public static let phi3_5MoE = ModelConfiguration(
id: "mlx-community/Phi-3.5-MoE-instruct-4bit",
defaultPrompt: "What is the gravity on Mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}

public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer",
Expand Down
35 changes: 20 additions & 15 deletions Libraries/LLM/Phi3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,25 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider {
}
}

public struct Phi3Configuration: Codable, Sendable {
struct RopeScaling: Codable {
let longFactor: [Float]?
let shortFactor: [Float]?
let factor: Float?
let type: String?

enum CodingKeys: String, CodingKey {
case type
case factor
case longFactor = "long_factor"
case shortFactor = "short_factor"
}
struct RopeScalingWithFactorArrays: Codable {
let longFactor: [Float]?
let shortFactor: [Float]?
let factor: Float?
let type: String?
let longMScale: Float?
let shortMScale: Float?

enum CodingKeys: String, CodingKey {
case type
case factor
case longFactor = "long_factor"
case shortFactor = "short_factor"
case longMScale = "long_mscale"
case shortMScale = "short_mscale"
}
}

public struct Phi3Configuration: Codable, Sendable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
Expand All @@ -231,7 +235,7 @@ public struct Phi3Configuration: Codable, Sendable {
var kvHeads: Int
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: RopeScaling?
var ropeScaling: RopeScalingWithFactorArrays?
var maxPositionEmbeddings: Int
var originalMaxPositionEmbeddings: Int

Expand Down Expand Up @@ -273,7 +277,8 @@ public struct Phi3Configuration: Codable, Sendable {
ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
ropeScaling = try container.decodeIfPresent(RopeScaling.self, forKey: .ropeScaling)
ropeScaling = try container.decodeIfPresent(
RopeScalingWithFactorArrays.self, forKey: .ropeScaling)
maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
originalMaxPositionEmbeddings = try container.decode(
Int.self, forKey: .originalMaxPositionEmbeddings)
Expand Down
Loading

0 comments on commit c52b1e1

Please sign in to comment.