-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
address issues that prevent using composition for layers like LoRA #177
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,12 +98,13 @@ open class Module { | |
|
||
/// Flag to indicate whether the module is being trained. Manipulated via | ||
/// ``train(_:)``. | ||
/// | ||
/// ### See Also | ||
/// - ``didSetTrain(_:)`` | ||
public private(set) var training = true | ||
|
||
/// Set of property names that are frozen. Maniupulated via | ||
/// ``freeze(recursive:keys:strict:)`` and | ||
/// ``unfreeze(recursive:keys:strict:)``. | ||
public private(set) var noGrad = Set<String>() | ||
/// See ``noGrad()`` | ||
private var _noGrad = Set<String>() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change but unlikely to be used directly (there are methods to manipulate). Subclasses can't override stored properties (they can do so for computed properties). This is replaced with methods that can be overridden. |
||
|
||
private var _items: ModuleItems! | ||
private var _setters: [String: TypeErasedSetter]! | ||
|
@@ -139,7 +140,7 @@ open class Module { | |
/// and ``update(parameters:)`` for example. | ||
/// | ||
/// Subclasses could potentially override this to provide custom introspection. | ||
public func items() -> ModuleItems { | ||
open func items() -> ModuleItems { | ||
_items | ||
} | ||
|
||
|
@@ -222,7 +223,7 @@ open class Module { | |
/// - ``mapParameters(map:isLeaf:)`` | ||
/// - ``modules()`` | ||
/// - ``items()`` | ||
public func filterMap<Result>( | ||
open func filterMap<Result>( | ||
filter: (Module, String, ModuleItem) -> Bool, | ||
map: (ModuleItem) -> Result? = { $0 }, | ||
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault | ||
|
@@ -331,7 +332,7 @@ open class Module { | |
/// ### See Also | ||
/// - <doc:module-filters> | ||
/// - ``mapParameters(map:)`` | ||
public func mapParameters<Result>( | ||
open func mapParameters<Result>( | ||
map: @escaping (MLXArray) -> Result? = { $0 }, | ||
isLeaf: (Module, String, ModuleItem) -> Bool = Module.isLeafDefault | ||
) -> NestedDictionary<String, Result> { | ||
|
@@ -343,28 +344,28 @@ open class Module { | |
|
||
/// Return a `NestedDictionary<String, MLXArray>` for all parameters in the | ||
/// model (all layers). | ||
public func parameters() -> ModuleParameters { | ||
open func parameters() -> ModuleParameters { | ||
filterMap(filter: Self.filterValidParameters, map: Self.mapParameters()) | ||
} | ||
|
||
/// Return a `NestedDictionary<String, MLXArray>` for all trainable parameters in the | ||
/// model (all layers). | ||
/// | ||
/// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters. | ||
public func trainableParameters() -> ModuleParameters { | ||
open func trainableParameters() -> ModuleParameters { | ||
filterMap(filter: Self.filterTrainableParameters, map: Self.mapParameters()) | ||
} | ||
|
||
/// Produces a `NestedDictionary<String, Module>` for all direct children of the module. | ||
public func children() -> ModuleChildren { | ||
open func children() -> ModuleChildren { | ||
filterMap(filter: Self.filterValidChild, map: Self.mapModule(), isLeaf: Self.isLeafModule) | ||
} | ||
|
||
/// Produces a `NestedDictionary<String, Module>` for all leaf modules module. | ||
/// | ||
/// ### See Also | ||
/// - ``isLeafModuleNoChildren`` | ||
public func leafModules() -> ModuleChildren { | ||
open func leafModules() -> ModuleChildren { | ||
filterMap( | ||
filter: Self.filterValidChild, map: Self.mapModule(), | ||
isLeaf: Self.isLeafModuleNoChildren) | ||
|
@@ -710,7 +711,23 @@ open class Module { | |
return self | ||
} | ||
|
||
private func updateModule(key: String, _ value: Any) throws { | ||
/// Set a module to a new value. | ||
/// | ||
/// The module property must be wrapped in a ``ModuleInfo``: | ||
/// | ||
/// ```swift | ||
/// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm | ||
/// ``` | ||
/// | ||
/// and the value must be a compatible type. | ||
/// | ||
/// This method is called via ``update(modules:)`` and is not typically called directly. This | ||
/// is exposed as an overridable method for subclasses. | ||
/// | ||
/// - Parameters: | ||
/// - key: module key, see ``ModuleInfo`` | ||
/// - value: the replacement module | ||
open func updateModule(key: String, _ value: Any) throws { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Primarily exposed for subclasses to override (see test case) |
||
if let setter = _setters[key] { | ||
do { | ||
try setter.updateModule(value) | ||
|
@@ -727,7 +744,7 @@ open class Module { | |
} | ||
|
||
// `apply_to_modules()` | ||
public func visit(modules visitor: (String, Module) throws -> Void) rethrows { | ||
open func visit(modules visitor: (String, Module) throws -> Void) rethrows { | ||
var stack = [(String, Module)]() | ||
stack.append(("", self)) | ||
|
||
|
@@ -746,7 +763,7 @@ open class Module { | |
/// - ``namedModules()`` | ||
/// - ``children()`` | ||
/// - ``leafModules()`` | ||
public func modules() -> [Module] { | ||
open func modules() -> [Module] { | ||
var result = [Module]() | ||
visit { | ||
result.append($1) | ||
|
@@ -760,7 +777,7 @@ open class Module { | |
/// - ``modules()`` | ||
/// - ``children()`` | ||
/// - ``leafModules()`` | ||
public func namedModules() -> [(String, Module)] { | ||
open func namedModules() -> [(String, Module)] { | ||
var result = [(String, Module)]() | ||
visit { | ||
result.append(($0, $1)) | ||
|
@@ -822,7 +839,8 @@ open class Module { | |
/// - ``unfreeze(recursive:keys:strict:)`` | ||
open func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws { | ||
let visitor = freezeVisitor(keys: keys, strict: strict) { | ||
$0.noGrad.formUnion($1) | ||
$0._noGrad.formUnion($1) | ||
$0.didSetNoGrad($0._noGrad) | ||
} | ||
|
||
if recursive { | ||
|
@@ -859,7 +877,8 @@ open class Module { | |
/// - ``Module/unfreeze(recursive:keys:strict:)`` | ||
open func unfreeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) throws { | ||
let visitor = freezeVisitor(keys: keys, strict: strict) { | ||
$0.noGrad.subtract($1) | ||
$0._noGrad.subtract($1) | ||
$0.didSetNoGrad($0._noGrad) | ||
} | ||
|
||
if recursive { | ||
|
@@ -869,6 +888,24 @@ open class Module { | |
} | ||
} | ||
|
||
/// Set of property names that are frozen. Maniupulated via | ||
/// ``freeze(recursive:keys:strict:)`` and | ||
/// ``unfreeze(recursive:keys:strict:)``. | ||
open func noGrad() -> Set<String> { | ||
_noGrad | ||
} | ||
|
||
/// Called when ``noGrad()`` is updated. | ||
/// | ||
/// This is provided for subclasses to override. | ||
/// | ||
/// - Parameter noGrad: set of properties that are frozen | ||
/// | ||
/// ### See Also | ||
/// - ``noGrad()`` | ||
open func didSetNoGrad(_ noGrad: Set<String>) { | ||
} | ||
|
||
/// Recursively set the model's training mode. | ||
/// | ||
/// Training mode only applies to certain layers. For example | ||
|
@@ -877,11 +914,21 @@ open class Module { | |
/// | ||
/// ### See Also | ||
/// - ``training`` | ||
/// - ``didSetTrain(_:)`` | ||
public func train(_ mode: Bool = true) { | ||
visit(modules: { | ||
$1.training = mode | ||
$1.didSetTrain(mode) | ||
}) | ||
} | ||
|
||
/// Called when ``train(_:)`` is updated. | ||
/// | ||
/// This is provided for subclasses to override. | ||
/// | ||
/// - Parameter mode: `true` is training | ||
open func didSetTrain(_ mode: Bool) { | ||
} | ||
} | ||
|
||
extension Module: IndentedDescription { | ||
|
@@ -922,7 +969,7 @@ extension Module: Updatable, Evaluatable { | |
/// ### See Also | ||
/// - <doc:layers> | ||
/// - ``Sequential`` | ||
public protocol UnaryLayer { | ||
public protocol UnaryLayer: Module { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per observation in ml-explore/mlx-swift-examples#167 |
||
func callAsFunction(_ x: MLXArray) -> MLXArray | ||
} | ||
|
||
|
@@ -996,7 +1043,7 @@ extension Module { | |
(module: Module, key: String, item: ModuleItem) in | ||
switch item { | ||
case .array, .dictionary, .value(.parameters), .value(.module): | ||
!key.hasPrefix("_") && !module.noGrad.contains(key) | ||
!key.hasPrefix("_") && !module.noGrad().contains(key) | ||
default: false | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,9 +11,18 @@ public protocol Quantizable { | |
func toQuantized(groupSize: Int, bits: Int) -> Module | ||
} | ||
|
||
public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Module? { | ||
if let quantizable = layer as? Quantizable { | ||
quantizable.toQuantized(groupSize: groupSize, bits: bits) | ||
/// Protocol for layers that are quantized. | ||
public protocol Quantized: Module { | ||
var groupSize: Int { get } | ||
var bits: Int { get } | ||
} | ||
|
||
public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Quantized? { | ||
if layer is Quantized { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Observed in the past and in the same area -- do not quantize already quantized layers. |
||
// already quantized | ||
nil | ||
} else if let quantizable = layer as? Quantizable { | ||
quantizable.toQuantized(groupSize: groupSize, bits: bits) as? Quantized | ||
} else { | ||
nil | ||
} | ||
|
@@ -52,7 +61,7 @@ public func quantize( | |
} | ||
|
||
/// The same as ``Embedding`` but with a quantized weight matrix. | ||
open class QuantizedEmbedding: Embedding { | ||
open class QuantizedEmbedding: Embedding, Quantized { | ||
|
||
public let groupSize: Int | ||
public let bits: Int | ||
|
@@ -121,14 +130,19 @@ open class QuantizedEmbedding: Embedding { | |
/// | ||
/// ### See Also | ||
/// - ``init(weight:bias:groupSize:bits:)`` | ||
open class QuantizedLinear: Linear { | ||
open class QuantizedLinear: Linear, Quantized { | ||
|
||
public let groupSize: Int | ||
public let bits: Int | ||
|
||
public let scales: MLXArray | ||
public let biases: MLXArray | ||
|
||
open override var shape: (Int, Int) { | ||
let shape = weight.shape2 | ||
return (shape.0, shape.1 * 32 / bits) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also a breaking change but IMHO it was broken before (returning the size of the quantized arrays which was useless) |
||
|
||
/// Applies an affine transformation to the input using a quantized weight matrix. | ||
/// | ||
/// This is the quantized version of ``Linear``. Typically this is used via ``quantize(model:groupSize:bits:predicate:)``. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,10 +12,10 @@ open class MultiHeadAttention: Module { | |
|
||
public let numHeads: Int | ||
|
||
@ModuleInfo(key: "query_proj") public var queryProjection: Linear | ||
@ModuleInfo(key: "key_proj") public var keyProjection: Linear | ||
@ModuleInfo(key: "value_proj") public var valueProjection: Linear | ||
@ModuleInfo(key: "out_proj") public var outProjection: Linear | ||
@ModuleInfo(key: "query_proj") public var queryProjection: UnaryLayer | ||
@ModuleInfo(key: "key_proj") public var keyProjection: UnaryLayer | ||
@ModuleInfo(key: "value_proj") public var valueProjection: UnaryLayer | ||
@ModuleInfo(key: "out_proj") public var outProjection: UnaryLayer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is how we would change e.g. attention layers if we wanted to use composition for LoRA. |
||
|
||
/// Implements the scaled dot product attention with multiple heads. | ||
/// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of changes here are
public
->open
to allow subclasses to override.