Skip to content

Commit 5de5405

Browse files
authored
address issues that prevent using composition for layers like LoRA (#177)
* address issues that prevent using composition for layers like LoRA - see ml-explore/mlx-swift-examples#167 - also fixes issue where quantize() could quantize a quantized layer!
1 parent 118e448 commit 5de5405

File tree

5 files changed

+204
-29
lines changed

5 files changed

+204
-29
lines changed

Source/MLXNN/Linear.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ open class Linear: Module, UnaryLayer, Quantizable {
7373
public let weight: MLXArray
7474
public let bias: MLXArray?
7575

76-
public var shape: (Int, Int) {
76+
open var shape: (Int, Int) {
7777
weight.shape2
7878
}
7979

Source/MLXNN/Module.swift

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ open class Module {
9898

9999
/// Flag to indicate whether the module is being trained. Manipulated via
100100
/// ``train(_:)``.
101+
///
102+
/// ### See Also
103+
/// - ``didSetTrain(_:)``
101104
public private(set) var training = true
102105

103-
/// Set of property names that are frozen. Maniupulated via
104-
/// ``freeze(recursive:keys:strict:)`` and
105-
/// ``unfreeze(recursive:keys:strict:)``.
106-
public private(set) var noGrad = Set<String>()
106+
/// See ``noGrad()``
107+
private var _noGrad = Set<String>()
107108

108109
private var _items: ModuleItems!
109110
private var _setters: [String: TypeErasedSetter]!
@@ -139,7 +140,7 @@ open class Module {
139140
/// and ``update(parameters:)`` for example.
140141
///
141142
/// Subclasses could potentially override this to provide custom introspection.
142-
public func items() -> ModuleItems {
143+
open func items() -> ModuleItems {
143144
_items
144145
}
145146

@@ -222,7 +223,7 @@ open class Module {
222223
/// - ``mapParameters(map:isLeaf:)``
223224
/// - ``modules()``
224225
/// - ``items()``
225-
public func filterMap<Result>(
226+
open func filterMap<Result>(
226227
filter: (Module, String, ModuleItem) -> Bool,
227228
map: (ModuleItem) -> Result? = { $0 },
228229
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
@@ -331,7 +332,7 @@ open class Module {
331332
/// ### See Also
332333
/// - <doc:module-filters>
333334
/// - ``mapParameters(map:)``
334-
public func mapParameters<Result>(
335+
open func mapParameters<Result>(
335336
map: @escaping (MLXArray) -> Result? = { $0 },
336337
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
337338
) -> NestedDictionary<String, Result> {
@@ -343,28 +344,28 @@ open class Module {
343344

344345
/// Return a `NestedDictionary<String, MLXArray>` for all parameters in the
345346
/// model (all layers).
346-
public func parameters() -> ModuleParameters {
347+
open func parameters() -> ModuleParameters {
347348
filterMap(filter: Self.filterValidParameters, map: Self.mapParameters())
348349
}
349350

350351
/// Return a `NestedDictionary<String, MLXArray>` for all trainable parameters in the
351352
/// model (all layers).
352353
///
353354
/// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters.
354-
public func trainableParameters() -> ModuleParameters {
355+
open func trainableParameters() -> ModuleParameters {
355356
filterMap(filter: Self.filterTrainableParameters, map: Self.mapParameters())
356357
}
357358

358359
/// Produces a `NestedDictionary<String, Module>` for all direct children of the module.
359-
public func children() -> ModuleChildren {
360+
open func children() -> ModuleChildren {
360361
filterMap(filter: Self.filterValidChild, map: Self.mapModule(), isLeaf: Self.isLeafModule)
361362
}
362363

363364
/// Produces a `NestedDictionary<String, Module>` for all leaf modules module.
364365
///
365366
/// ### See Also
366367
/// - ``isLeafModuleNoChildren``
367-
public func leafModules() -> ModuleChildren {
368+
open func leafModules() -> ModuleChildren {
368369
filterMap(
369370
filter: Self.filterValidChild, map: Self.mapModule(),
370371
isLeaf: Self.isLeafModuleNoChildren)
@@ -714,7 +715,23 @@ open class Module {
714715
return self
715716
}
716717

717-
private func updateModule(key: String, _ value: Any) throws {
718+
/// Set a module to a new value.
719+
///
720+
/// The module property must be wrapped in a ``ModuleInfo``:
721+
///
722+
/// ```swift
723+
/// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
724+
/// ```
725+
///
726+
/// and the value must be a compatible type.
727+
///
728+
/// This method is called via ``update(modules:)`` and is not typically called directly. This
729+
/// is exposed as an overridable method for subclasses.
730+
///
731+
/// - Parameters:
732+
/// - key: module key, see ``ModuleInfo``
733+
/// - value: the replacement module
734+
open func updateModule(key: String, _ value: Any) throws {
718735
if let setter = _setters[key] {
719736
do {
720737
try setter.updateModule(value)
@@ -731,7 +748,7 @@ open class Module {
731748
}
732749

733750
// `apply_to_modules()`
734-
public func visit(modules visitor: (String, Module) throws -> Void) rethrows {
751+
open func visit(modules visitor: (String, Module) throws -> Void) rethrows {
735752
var stack = [(String, Module)]()
736753
stack.append(("", self))
737754

@@ -750,7 +767,7 @@ open class Module {
750767
/// - ``namedModules()``
751768
/// - ``children()``
752769
/// - ``leafModules()``
753-
public func modules() -> [Module] {
770+
open func modules() -> [Module] {
754771
var result = [Module]()
755772
visit {
756773
result.append($1)
@@ -764,7 +781,7 @@ open class Module {
764781
/// - ``modules()``
765782
/// - ``children()``
766783
/// - ``leafModules()``
767-
public func namedModules() -> [(String, Module)] {
784+
open func namedModules() -> [(String, Module)] {
768785
var result = [(String, Module)]()
769786
visit {
770787
result.append(($0, $1))
@@ -826,7 +843,8 @@ open class Module {
826843
/// - ``unfreeze(recursive:keys:strict:)``
827844
open func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
828845
let visitor = freezeVisitor(keys: keys, strict: strict) {
829-
$0.noGrad.formUnion($1)
846+
$0._noGrad.formUnion($1)
847+
$0.didSetNoGrad($0._noGrad)
830848
}
831849

832850
if recursive {
@@ -863,7 +881,8 @@ open class Module {
863881
/// - ``Module/unfreeze(recursive:keys:strict:)``
864882
open func unfreeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
865883
let visitor = freezeVisitor(keys: keys, strict: strict) {
866-
$0.noGrad.subtract($1)
884+
$0._noGrad.subtract($1)
885+
$0.didSetNoGrad($0._noGrad)
867886
}
868887

869888
if recursive {
@@ -873,6 +892,24 @@ open class Module {
873892
}
874893
}
875894

895+
/// Set of property names that are frozen. Maniupulated via
896+
/// ``freeze(recursive:keys:strict:)`` and
897+
/// ``unfreeze(recursive:keys:strict:)``.
898+
open func noGrad() -> Set<String> {
899+
_noGrad
900+
}
901+
902+
/// Called when ``noGrad()`` is updated.
903+
///
904+
/// This is provided for subclasses to override.
905+
///
906+
/// - Parameter noGrad: set of properties that are frozen
907+
///
908+
/// ### See Also
909+
/// - ``noGrad()``
910+
open func didSetNoGrad(_ noGrad: Set<String>) {
911+
}
912+
876913
/// Recursively set the model's training mode.
877914
///
878915
/// Training mode only applies to certain layers. For example
@@ -881,11 +918,21 @@ open class Module {
881918
///
882919
/// ### See Also
883920
/// - ``training``
921+
/// - ``didSetTrain(_:)``
884922
public func train(_ mode: Bool = true) {
885923
visit(modules: {
886924
$1.training = mode
925+
$1.didSetTrain(mode)
887926
})
888927
}
928+
929+
/// Called when ``train(_:)`` is updated.
930+
///
931+
/// This is provided for subclasses to override.
932+
///
933+
/// - Parameter mode: `true` is training
934+
open func didSetTrain(_ mode: Bool) {
935+
}
889936
}
890937

891938
extension Module: IndentedDescription {
@@ -926,7 +973,7 @@ extension Module: Updatable, Evaluatable {
926973
/// ### See Also
927974
/// - <doc:layers>
928975
/// - ``Sequential``
929-
public protocol UnaryLayer {
976+
public protocol UnaryLayer: Module {
930977
func callAsFunction(_ x: MLXArray) -> MLXArray
931978
}
932979

@@ -1008,7 +1055,7 @@ extension Module {
10081055
(module: Module, key: String, item: ModuleItem) in
10091056
switch item {
10101057
case .array, .dictionary, .value(.parameters), .value(.module):
1011-
parameterIsValid(key) && !module.noGrad.contains(key)
1058+
parameterIsValid(key) && !module.noGrad().contains(key)
10121059
default: false
10131060
}
10141061
}

Source/MLXNN/Quantized.swift

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,18 @@ public protocol Quantizable {
1111
func toQuantized(groupSize: Int, bits: Int) -> Module
1212
}
1313

14-
public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Module? {
15-
if let quantizable = layer as? Quantizable {
16-
quantizable.toQuantized(groupSize: groupSize, bits: bits)
14+
/// Protocol for layers that are quantized.
15+
public protocol Quantized: Module {
16+
var groupSize: Int { get }
17+
var bits: Int { get }
18+
}
19+
20+
public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Quantized? {
21+
if layer is Quantized {
22+
// already quantized
23+
nil
24+
} else if let quantizable = layer as? Quantizable {
25+
quantizable.toQuantized(groupSize: groupSize, bits: bits) as? Quantized
1726
} else {
1827
nil
1928
}
@@ -52,7 +61,7 @@ public func quantize(
5261
}
5362

5463
/// The same as ``Embedding`` but with a quantized weight matrix.
55-
open class QuantizedEmbedding: Embedding {
64+
open class QuantizedEmbedding: Embedding, Quantized {
5665

5766
public let groupSize: Int
5867
public let bits: Int
@@ -121,14 +130,19 @@ open class QuantizedEmbedding: Embedding {
121130
///
122131
/// ### See Also
123132
/// - ``init(weight:bias:groupSize:bits:)``
124-
open class QuantizedLinear: Linear {
133+
open class QuantizedLinear: Linear, Quantized {
125134

126135
public let groupSize: Int
127136
public let bits: Int
128137

129138
public let scales: MLXArray
130139
public let biases: MLXArray
131140

141+
open override var shape: (Int, Int) {
142+
let shape = weight.shape2
143+
return (shape.0, shape.1 * 32 / bits)
144+
}
145+
132146
/// Applies an affine transformation to the input using a quantized weight matrix.
133147
///
134148
/// This is the quantized version of ``Linear``. Typically this is used via ``quantize(model:groupSize:bits:predicate:)``.

Source/MLXNN/Transformer.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ open class MultiHeadAttention: Module {
1212

1313
public let numHeads: Int
1414

15-
@ModuleInfo(key: "query_proj") public var queryProjection: Linear
16-
@ModuleInfo(key: "key_proj") public var keyProjection: Linear
17-
@ModuleInfo(key: "value_proj") public var valueProjection: Linear
18-
@ModuleInfo(key: "out_proj") public var outProjection: Linear
15+
@ModuleInfo(key: "query_proj") public var queryProjection: UnaryLayer
16+
@ModuleInfo(key: "key_proj") public var keyProjection: UnaryLayer
17+
@ModuleInfo(key: "value_proj") public var valueProjection: UnaryLayer
18+
@ModuleInfo(key: "out_proj") public var outProjection: UnaryLayer
1919

2020
/// Implements the scaled dot product attention with multiple heads.
2121
///

0 commit comments

Comments
 (0)