Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg authored and davidkoski committed Oct 22, 2024
1 parent e6bdadb commit f692c71
Showing 1 changed file with 15 additions and 58 deletions.
73 changes: 15 additions & 58 deletions Libraries/LLM/PhiMoE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class SwitchGLU: Module {
}
}

class SwitchLinear: Module {
class SwitchLinear: Module, Quantizable {
@ModuleInfo(key: "weight") var weight: MLXArray
@ModuleInfo(key: "bias") var bias: MLXArray?

Expand Down Expand Up @@ -345,83 +345,40 @@ class SwitchLinear: Module {
return result
}

func toQuantized(groupSize: Int = 64, bits: Int = 4) -> QuantizedSwitchLinear {
let (numExperts, outputDims, inputDims) = (
self.weight.shape[0], self.weight.shape[1], self.weight.shape[2]
)
let ql = QuantizedSwitchLinear(
inputDims: inputDims,
outputDims: outputDims,
numExperts: numExperts,
bias: self.bias != nil,
groupSize: groupSize,
bits: bits
)

let (quantizedWeight, scales, biases) = MLX.quantized(
self.weight, groupSize: groupSize, bits: bits)
ql.weight = quantizedWeight
ql.scales = scales
ql.biases = biases

if let bias = self.bias {
ql.bias = bias
}

return ql
func toQuantized(groupSize: Int = 64, bits: Int = 4) -> Module {
QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits)
}
}

class QuantizedSwitchLinear: Module {
@ModuleInfo(key: "weight") var weight: MLXArray
class QuantizedSwitchLinear: SwitchLinear {
@ModuleInfo(key: "scales") var scales: MLXArray
@ModuleInfo(key: "biases") var biases: MLXArray
@ModuleInfo(key: "bias") var bias: MLXArray?

let groupSize: Int
let bits: Int

init(
inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true, groupSize: Int = 64,
bits: Int = 4
) {
init(_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4) {
self.groupSize = groupSize
self.bits = bits

let scale = sqrt(1.0 / Float(inputDims))
let uniformWeight = MLXRandom.uniform(
low: -scale,
high: scale,
[numExperts, outputDims, inputDims]
)
super.init(
inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts,
bias: other.bias != nil)

let (quantizedWeight, scales, biases) = MLX.quantized(
uniformWeight, groupSize: groupSize, bits: bits)
self._weight.wrappedValue = quantizedWeight
self._scales.wrappedValue = scales
self._biases.wrappedValue = biases
other.weight, groupSize: groupSize, bits: bits)
self.weight = quantizedWeight
self.scales = scales
self.biases = biases

if bias {
self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims])
if let bias = other.bias {
self.bias = bias
}

super.init()
self.freeze()
}

var inputDims: Int {
return scales.shape[2] * groupSize
}

var outputDims: Int {
return weight.shape[1]
}

var numExperts: Int {
return weight.shape[0]
}

func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
override func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
var result = MLX.gatherQuantizedMatmul(
x,
self.weight,
Expand Down

0 comments on commit f692c71

Please sign in to comment.