diff --git a/Source/MLXNN/Linear.swift b/Source/MLXNN/Linear.swift index 10e09c1..6cdf37c 100644 --- a/Source/MLXNN/Linear.swift +++ b/Source/MLXNN/Linear.swift @@ -73,7 +73,7 @@ open class Linear: Module, UnaryLayer, Quantizable { public let weight: MLXArray public let bias: MLXArray? - public var shape: (Int, Int) { + open var shape: (Int, Int) { weight.shape2 } diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index 8f4b2e5..bd64546 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -98,12 +98,13 @@ open class Module { /// Flag to indicate whether the module is being trained. Manipulated via /// ``train(_:)``. + /// + /// ### See Also + /// - ``didSetTrain(_:)`` public private(set) var training = true - /// Set of property names that are frozen. Maniupulated via - /// ``freeze(recursive:keys:strict:)`` and - /// ``unfreeze(recursive:keys:strict:)``. - public private(set) var noGrad = Set() + /// See ``noGrad()`` + private var _noGrad = Set() private var _items: ModuleItems! private var _setters: [String: TypeErasedSetter]! @@ -139,7 +140,7 @@ open class Module { /// and ``update(parameters:)`` for example. /// /// Subclasses could potentially override this to provide custom introspection. - public func items() -> ModuleItems { + open func items() -> ModuleItems { _items } @@ -222,7 +223,7 @@ open class Module { /// - ``mapParameters(map:isLeaf:)`` /// - ``modules()`` /// - ``items()`` - public func filterMap( + open func filterMap( filter: (Module, String, ModuleItem) -> Bool, map: (ModuleItem) -> Result? = { $0 }, isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault @@ -331,7 +332,7 @@ open class Module { /// ### See Also /// - /// - ``mapParameters(map:)`` - public func mapParameters( + open func mapParameters( map: @escaping (MLXArray) -> Result? = { $0 }, isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault ) -> NestedDictionary { @@ -343,7 +344,7 @@ open class Module { /// Return a `NestedDictionary` for all parameters in the /// model (all layers). - public func parameters() -> ModuleParameters { + open func parameters() -> ModuleParameters { filterMap(filter: Self.filterValidParameters, map: Self.mapParameters()) } @@ -351,12 +352,12 @@ open class Module { /// model (all layers). /// /// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters. - public func trainableParameters() -> ModuleParameters { + open func trainableParameters() -> ModuleParameters { filterMap(filter: Self.filterTrainableParameters, map: Self.mapParameters()) } /// Produces a `NestedDictionary` for all direct children of the module. - public func children() -> ModuleChildren { + open func children() -> ModuleChildren { filterMap(filter: Self.filterValidChild, map: Self.mapModule(), isLeaf: Self.isLeafModule) } @@ -364,7 +365,7 @@ open class Module { /// /// ### See Also /// - ``isLeafModuleNoChildren`` - public func leafModules() -> ModuleChildren { + open func leafModules() -> ModuleChildren { filterMap( filter: Self.filterValidChild, map: Self.mapModule(), isLeaf: Self.isLeafModuleNoChildren) @@ -710,7 +711,23 @@ open class Module { return self } - private func updateModule(key: String, _ value: Any) throws { + /// Set a module to a new value. + /// + /// The module property must be wrapped in a ``ModuleInfo``: + /// + /// ```swift + /// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + /// ``` + /// + /// and the value must be a compatible type. + /// + /// This method is called via ``update(modules:)`` and is not typically called directly. This + /// is exposed as an overridable method for subclasses. + /// + /// - Parameters: + /// - key: module key, see ``ModuleInfo`` + /// - value: the replacement module + open func updateModule(key: String, _ value: Any) throws { if let setter = _setters[key] { do { try setter.updateModule(value) @@ -727,7 +744,7 @@ open class Module { } // `apply_to_modules()` - public func visit(modules visitor: (String, Module) throws -> Void) rethrows { + open func visit(modules visitor: (String, Module) throws -> Void) rethrows { var stack = [(String, Module)]() stack.append(("", self)) @@ -746,7 +763,7 @@ open class Module { /// - ``namedModules()`` /// - ``children()`` /// - ``leafModules()`` - public func modules() -> [Module] { + open func modules() -> [Module] { var result = [Module]() visit { result.append($1) @@ -760,7 +777,7 @@ open class Module { /// - ``modules()`` /// - ``children()`` /// - ``leafModules()`` - public func namedModules() -> [(String, Module)] { + open func namedModules() -> [(String, Module)] { var result = [(String, Module)]() visit { result.append(($0, $1)) @@ -822,7 +839,8 @@ open class Module { /// - ``unfreeze(recursive:keys:strict:)`` open func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws { let visitor = freezeVisitor(keys: keys, strict: strict) { - $0.noGrad.formUnion($1) + $0._noGrad.formUnion($1) + $0.didSetNoGrad($0._noGrad) } if recursive { @@ -859,7 +877,8 @@ open class Module { /// - ``Module/unfreeze(recursive:keys:strict:)`` open func unfreeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws { let visitor = freezeVisitor(keys: keys, strict: strict) { - $0.noGrad.subtract($1) + $0._noGrad.subtract($1) + $0.didSetNoGrad($0._noGrad) } if recursive { @@ -869,6 +888,24 @@ open class Module { } } + /// Set of property names that are frozen. Maniupulated via + /// ``freeze(recursive:keys:strict:)`` and + /// ``unfreeze(recursive:keys:strict:)``. + open func noGrad() -> Set { + _noGrad + } + + /// Called when ``noGrad()`` is updated. + /// + /// This is provided for subclasses to override. + /// + /// - Parameter noGrad: set of properties that are frozen + /// + /// ### See Also + /// - ``noGrad()`` + open func didSetNoGrad(_ noGrad: Set) { + } + /// Recursively set the model's training mode. /// /// Training mode only applies to certain layers. For example @@ -877,11 +914,21 @@ open class Module { /// /// ### See Also /// - ``training`` + /// - ``didSetTrain(_:)`` public func train(_ mode: Bool = true) { visit(modules: { $1.training = mode + $1.didSetTrain(mode) }) } + + /// Called when ``train(_:)`` is updated. + /// + /// This is provided for subclasses to override. + /// + /// - Parameter mode: `true` is training + open func didSetTrain(_ mode: Bool) { + } } extension Module: IndentedDescription { @@ -922,7 +969,7 @@ extension Module: Updatable, Evaluatable { /// ### See Also /// - /// - ``Sequential`` -public protocol UnaryLayer { +public protocol UnaryLayer: Module { func callAsFunction(_ x: MLXArray) -> MLXArray } @@ -996,7 +1043,7 @@ extension Module { (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary, .value(.parameters), .value(.module): - !key.hasPrefix("_") && !module.noGrad.contains(key) + !key.hasPrefix("_") && !module.noGrad().contains(key) default: false } } diff --git a/Source/MLXNN/Quantized.swift b/Source/MLXNN/Quantized.swift index 9fd35ed..1190d8b 100644 --- a/Source/MLXNN/Quantized.swift +++ b/Source/MLXNN/Quantized.swift @@ -11,9 +11,18 @@ public protocol Quantizable { func toQuantized(groupSize: Int, bits: Int) -> Module } -public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Module? { - if let quantizable = layer as? Quantizable { - quantizable.toQuantized(groupSize: groupSize, bits: bits) +/// Protocol for layers that are quantized. +public protocol Quantized: Module { + var groupSize: Int { get } + var bits: Int { get } +} + +public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Quantized? { + if layer is Quantized { + // already quantized + nil + } else if let quantizable = layer as? Quantizable { + quantizable.toQuantized(groupSize: groupSize, bits: bits) as? Quantized } else { nil } @@ -52,7 +61,7 @@ public func quantize( } /// The same as ``Embedding`` but with a quantized weight matrix. -open class QuantizedEmbedding: Embedding { +open class QuantizedEmbedding: Embedding, Quantized { public let groupSize: Int public let bits: Int @@ -121,7 +130,7 @@ open class QuantizedEmbedding: Embedding { /// /// ### See Also /// - ``init(weight:bias:groupSize:bits:)`` -open class QuantizedLinear: Linear { +open class QuantizedLinear: Linear, Quantized { public let groupSize: Int public let bits: Int @@ -129,6 +138,11 @@ open class QuantizedLinear: Linear { public let scales: MLXArray public let biases: MLXArray + open override var shape: (Int, Int) { + let shape = weight.shape2 + return (shape.0, shape.1 * 32 / bits) + } + /// Applies an affine transformation to the input using a quantized weight matrix. /// /// This is the quantized version of ``Linear``. Typically this is used via ``quantize(model:groupSize:bits:predicate:)``. diff --git a/Source/MLXNN/Transformer.swift b/Source/MLXNN/Transformer.swift index d8b3680..8831b68 100644 --- a/Source/MLXNN/Transformer.swift +++ b/Source/MLXNN/Transformer.swift @@ -12,10 +12,10 @@ open class MultiHeadAttention: Module { public let numHeads: Int - @ModuleInfo(key: "query_proj") public var queryProjection: Linear - @ModuleInfo(key: "key_proj") public var keyProjection: Linear - @ModuleInfo(key: "value_proj") public var valueProjection: Linear - @ModuleInfo(key: "out_proj") public var outProjection: Linear + @ModuleInfo(key: "query_proj") public var queryProjection: UnaryLayer + @ModuleInfo(key: "key_proj") public var keyProjection: UnaryLayer + @ModuleInfo(key: "value_proj") public var valueProjection: UnaryLayer + @ModuleInfo(key: "out_proj") public var outProjection: UnaryLayer /// Implements the scaled dot product attention with multiple heads. /// diff --git a/Tests/MLXTests/ModuleTests.swift b/Tests/MLXTests/ModuleTests.swift index dec31bc..782323f 100644 --- a/Tests/MLXTests/ModuleTests.swift +++ b/Tests/MLXTests/ModuleTests.swift @@ -2,6 +2,7 @@ import Foundation import MLX +import MLXRandom import XCTest @testable import MLXNN @@ -630,6 +631,26 @@ class ModuleTests: XCTestCase { quantize(model: m) XCTAssertTrue(m.module.child is QuantizedLinear) + XCTAssertTrue(m.module.child.shape == (256, 256)) + } + + func testAlreadyQuantized() throws { + // should not quantize an already quantize layer -- that would be silly + class C: Module { + @ModuleInfo + var child: Linear = QuantizedLinear(256, 256) + + var other = Sigmoid() + } + class M: Module { + let module = C() + } + + let m = M() + quantize(model: m) + + XCTAssertTrue(m.module.child is QuantizedLinear) + XCTAssertTrue(m.module.child.shape == (256, 256)) } func testQuantizePredicate() throws { @@ -770,4 +791,97 @@ class ModuleTests: XCTestCase { XCTAssertTrue(pm.mlp.0 is QuantizedLinear) XCTAssertTrue(pm.mlp.2 is QuantizedLinear) } + + func testCompositeLayer() throws { + // a test of making a LoRA layer as a composite -- it adapts + // another layer but mixes their properties. this is not + // necessarily the way to implement LoRA (you can also use subclasses + // like Linear/QuantizedLinear) but this verifies that it can + // be written this way. + + class LoRA: Module, UnaryLayer { + + let adapts: UnaryLayer + let scale: Float + + @ParameterInfo(key: "lora_a") var loraA: MLXArray + @ParameterInfo(key: "lora_b") var loraB: MLXArray + + public init( + adapts: UnaryLayer, inputDimensions: Int, outputDimensions: Int, rank: Int = 8, + scale: Float = 20.0 + ) { + self.adapts = adapts + + self.scale = scale + + let loraScale = 1 / sqrt(Float(inputDimensions)) + self._loraA.wrappedValue = MLXRandom.uniform( + low: -loraScale, high: loraScale, [inputDimensions, rank]) + self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) + + super.init() + + freeze() + } + + public convenience init(linear: Linear, rank: Int = 8, scale: Float = 20.0) { + let (outputDimensions, inputDimensions) = linear.shape + self.init( + adapts: linear, + inputDimensions: inputDimensions, outputDimensions: outputDimensions, + rank: rank, scale: scale) + } + + // produce a merged view of properties (flatten LoRA into adapts) + override func items() -> ModuleItems { + var result = adapts.items() + for (key, value) in super.items() { + if key == "adapts" { continue } + result[key] = value + } + return result + } + + // forward module updates -> adapt + override func updateModule(key: String, _ value: Any) throws { + try adapts.updateModule(key: key, value) + } + + override func modules() -> [Module] { + adapts.modules() + } + + override func freeze( + recursive: Bool = true, keys: [String]? = nil, strict: Bool = false + ) throws { + try adapts.freeze(recursive: recursive, keys: keys, strict: strict) + } + + override func noGrad() -> Set { + adapts.noGrad() + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + let y = adapts(x) + let z = matmul(matmul(x, self.loraA), self.loraB) + return y + scale * z + } + } + + let linear = Linear(10, 20) + let lora = LoRA(linear: linear) + + let linearProperties = linear.parameters() + let loraProperties = lora.parameters() + + XCTAssertEqual(linearProperties.count, 2) + XCTAssertEqual(loraProperties.count, 4) + + try lora.update(parameters: loraProperties.mapValues { $0 + 1 }, verify: .all) + + let trainable = lora.trainableParameters() + XCTAssertEqual(trainable.count, 2) + XCTAssertTrue(trainable["lora_a"] != nil) + } }