Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

address issues that prevent using composition for layers like LoRA #177

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Source/MLXNN/Linear.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ open class Linear: Module, UnaryLayer, Quantizable {
public let weight: MLXArray
public let bias: MLXArray?

public var shape: (Int, Int) {
open var shape: (Int, Int) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of changes here are public -> open to allow subclasses to override.

weight.shape2
}

Expand Down
85 changes: 66 additions & 19 deletions Source/MLXNN/Module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ open class Module {

/// Flag to indicate whether the module is being trained. Manipulated via
/// ``train(_:)``.
///
/// ### See Also
/// - ``didSetTrain(_:)``
public private(set) var training = true

/// Set of property names that are frozen. Maniupulated via
/// ``freeze(recursive:keys:strict:)`` and
/// ``unfreeze(recursive:keys:strict:)``.
public private(set) var noGrad = Set<String>()
/// See ``noGrad()``
private var _noGrad = Set<String>()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change but unlikely to be used directly (there are methods to manipulate). Subclasses can't override stored properties (they can do so for computed properties). This is replaced with methods that can be overridden.


private var _items: ModuleItems!
private var _setters: [String: TypeErasedSetter]!
Expand Down Expand Up @@ -139,7 +140,7 @@ open class Module {
/// and ``update(parameters:)`` for example.
///
/// Subclasses could potentially override this to provide custom introspection.
public func items() -> ModuleItems {
open func items() -> ModuleItems {
_items
}

Expand Down Expand Up @@ -222,7 +223,7 @@ open class Module {
/// - ``mapParameters(map:isLeaf:)``
/// - ``modules()``
/// - ``items()``
public func filterMap<Result>(
open func filterMap<Result>(
filter: (Module, String, ModuleItem) -> Bool,
map: (ModuleItem) -> Result? = { $0 },
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
Expand Down Expand Up @@ -331,7 +332,7 @@ open class Module {
/// ### See Also
/// - <doc:module-filters>
/// - ``mapParameters(map:)``
public func mapParameters<Result>(
open func mapParameters<Result>(
map: @escaping (MLXArray) -> Result? = { $0 },
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault
) -> NestedDictionary<String, Result> {
Expand All @@ -343,28 +344,28 @@ open class Module {

/// Return a `NestedDictionary<String, MLXArray>` for all parameters in the
/// model (all layers).
public func parameters() -> ModuleParameters {
open func parameters() -> ModuleParameters {
filterMap(filter: Self.filterValidParameters, map: Self.mapParameters())
}

/// Return a `NestedDictionary<String, MLXArray>` for all trainable parameters in the
/// model (all layers).
///
/// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters.
public func trainableParameters() -> ModuleParameters {
open func trainableParameters() -> ModuleParameters {
filterMap(filter: Self.filterTrainableParameters, map: Self.mapParameters())
}

/// Produces a `NestedDictionary<String, Module>` for all direct children of the module.
public func children() -> ModuleChildren {
open func children() -> ModuleChildren {
filterMap(filter: Self.filterValidChild, map: Self.mapModule(), isLeaf: Self.isLeafModule)
}

/// Produces a `NestedDictionary<String, Module>` for all leaf modules module.
///
/// ### See Also
/// - ``isLeafModuleNoChildren``
public func leafModules() -> ModuleChildren {
open func leafModules() -> ModuleChildren {
filterMap(
filter: Self.filterValidChild, map: Self.mapModule(),
isLeaf: Self.isLeafModuleNoChildren)
Expand Down Expand Up @@ -710,7 +711,23 @@ open class Module {
return self
}

private func updateModule(key: String, _ value: Any) throws {
/// Set a module to a new value.
///
/// The module property must be wrapped in a ``ModuleInfo``:
///
/// ```swift
/// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
/// ```
///
/// and the value must be a compatible type.
///
/// This method is called via ``update(modules:)`` and is not typically called directly. This
/// is exposed as an overridable method for subclasses.
///
/// - Parameters:
/// - key: module key, see ``ModuleInfo``
/// - value: the replacement module
open func updateModule(key: String, _ value: Any) throws {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Primarily exposed for subclasses to override (see test case)

if let setter = _setters[key] {
do {
try setter.updateModule(value)
Expand All @@ -727,7 +744,7 @@ open class Module {
}

// `apply_to_modules()`
public func visit(modules visitor: (String, Module) throws -> Void) rethrows {
open func visit(modules visitor: (String, Module) throws -> Void) rethrows {
var stack = [(String, Module)]()
stack.append(("", self))

Expand All @@ -746,7 +763,7 @@ open class Module {
/// - ``namedModules()``
/// - ``children()``
/// - ``leafModules()``
public func modules() -> [Module] {
open func modules() -> [Module] {
var result = [Module]()
visit {
result.append($1)
Expand All @@ -760,7 +777,7 @@ open class Module {
/// - ``modules()``
/// - ``children()``
/// - ``leafModules()``
public func namedModules() -> [(String, Module)] {
open func namedModules() -> [(String, Module)] {
var result = [(String, Module)]()
visit {
result.append(($0, $1))
Expand Down Expand Up @@ -822,7 +839,8 @@ open class Module {
/// - ``unfreeze(recursive:keys:strict:)``
open func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
let visitor = freezeVisitor(keys: keys, strict: strict) {
$0.noGrad.formUnion($1)
$0._noGrad.formUnion($1)
$0.didSetNoGrad($0._noGrad)
}

if recursive {
Expand Down Expand Up @@ -859,7 +877,8 @@ open class Module {
/// - ``Module/unfreeze(recursive:keys:strict:)``
open func unfreeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws {
let visitor = freezeVisitor(keys: keys, strict: strict) {
$0.noGrad.subtract($1)
$0._noGrad.subtract($1)
$0.didSetNoGrad($0._noGrad)
}

if recursive {
Expand All @@ -869,6 +888,24 @@ open class Module {
}
}

/// Set of property names that are frozen. Maniupulated via
/// ``freeze(recursive:keys:strict:)`` and
/// ``unfreeze(recursive:keys:strict:)``.
open func noGrad() -> Set<String> {
_noGrad
}

/// Called when ``noGrad()`` is updated.
///
/// This is provided for subclasses to override.
///
/// - Parameter noGrad: set of properties that are frozen
///
/// ### See Also
/// - ``noGrad()``
open func didSetNoGrad(_ noGrad: Set<String>) {
}

/// Recursively set the model's training mode.
///
/// Training mode only applies to certain layers. For example
Expand All @@ -877,11 +914,21 @@ open class Module {
///
/// ### See Also
/// - ``training``
/// - ``didSetTrain(_:)``
public func train(_ mode: Bool = true) {
visit(modules: {
$1.training = mode
$1.didSetTrain(mode)
})
}

/// Called when ``train(_:)`` is updated.
///
/// This is provided for subclasses to override.
///
/// - Parameter mode: `true` is training
open func didSetTrain(_ mode: Bool) {
}
}

extension Module: IndentedDescription {
Expand Down Expand Up @@ -922,7 +969,7 @@ extension Module: Updatable, Evaluatable {
/// ### See Also
/// - <doc:layers>
/// - ``Sequential``
public protocol UnaryLayer {
public protocol UnaryLayer: Module {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func callAsFunction(_ x: MLXArray) -> MLXArray
}

Expand Down Expand Up @@ -996,7 +1043,7 @@ extension Module {
(module: Module, key: String, item: ModuleItem) in
switch item {
case .array, .dictionary, .value(.parameters), .value(.module):
!key.hasPrefix("_") && !module.noGrad.contains(key)
!key.hasPrefix("_") && !module.noGrad().contains(key)
default: false
}
}
Expand Down
24 changes: 19 additions & 5 deletions Source/MLXNN/Quantized.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@ public protocol Quantizable {
func toQuantized(groupSize: Int, bits: Int) -> Module
}

public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Module? {
if let quantizable = layer as? Quantizable {
quantizable.toQuantized(groupSize: groupSize, bits: bits)
/// Protocol for layers that are quantized.
public protocol Quantized: Module {
var groupSize: Int { get }
var bits: Int { get }
}

public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Quantized? {
if layer is Quantized {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observed in the past and in the same area -- do not quantize already quantized layers.

// already quantized
nil
} else if let quantizable = layer as? Quantizable {
quantizable.toQuantized(groupSize: groupSize, bits: bits) as? Quantized
} else {
nil
}
Expand Down Expand Up @@ -52,7 +61,7 @@ public func quantize(
}

/// The same as ``Embedding`` but with a quantized weight matrix.
open class QuantizedEmbedding: Embedding {
open class QuantizedEmbedding: Embedding, Quantized {

public let groupSize: Int
public let bits: Int
Expand Down Expand Up @@ -121,14 +130,19 @@ open class QuantizedEmbedding: Embedding {
///
/// ### See Also
/// - ``init(weight:bias:groupSize:bits:)``
open class QuantizedLinear: Linear {
open class QuantizedLinear: Linear, Quantized {

public let groupSize: Int
public let bits: Int

public let scales: MLXArray
public let biases: MLXArray

open override var shape: (Int, Int) {
let shape = weight.shape2
return (shape.0, shape.1 * 32 / bits)
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also a breaking change but IMHO it was broken before (returning the size of the quantized arrays which was useless)


/// Applies an affine transformation to the input using a quantized weight matrix.
///
/// This is the quantized version of ``Linear``. Typically this is used via ``quantize(model:groupSize:bits:predicate:)``.
Expand Down
8 changes: 4 additions & 4 deletions Source/MLXNN/Transformer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ open class MultiHeadAttention: Module {

public let numHeads: Int

@ModuleInfo(key: "query_proj") public var queryProjection: Linear
@ModuleInfo(key: "key_proj") public var keyProjection: Linear
@ModuleInfo(key: "value_proj") public var valueProjection: Linear
@ModuleInfo(key: "out_proj") public var outProjection: Linear
@ModuleInfo(key: "query_proj") public var queryProjection: UnaryLayer
@ModuleInfo(key: "key_proj") public var keyProjection: UnaryLayer
@ModuleInfo(key: "value_proj") public var valueProjection: UnaryLayer
@ModuleInfo(key: "out_proj") public var outProjection: UnaryLayer
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how we would change e.g. attention layers if we wanted to use composition for LoRA.


/// Implements the scaled dot product attention with multiple heads.
///
Expand Down
Loading