@@ -98,12 +98,13 @@ open class Module {
9898
9999 /// Flag to indicate whether the module is being trained. Manipulated via
100100 /// ``train(_:)``.
101+ ///
102+ /// ### See Also
103+ /// - ``didSetTrain(_:)``
101104 public private( set) var training = true
102105
103- /// Set of property names that are frozen. Maniupulated via
104- /// ``freeze(recursive:keys:strict:)`` and
105- /// ``unfreeze(recursive:keys:strict:)``.
106- public private( set) var noGrad = Set < String > ( )
106+ /// See ``noGrad()``
107+ private var _noGrad = Set < String > ( )
107108
108109 private var _items : ModuleItems !
109110 private var _setters : [ String : TypeErasedSetter ] !
@@ -139,7 +140,7 @@ open class Module {
139140 /// and ``update(parameters:)`` for example.
140141 ///
141142 /// Subclasses could potentially override this to provide custom introspection.
142- public func items( ) -> ModuleItems {
143+ open func items( ) -> ModuleItems {
143144 _items
144145 }
145146
@@ -222,7 +223,7 @@ open class Module {
222223 /// - ``mapParameters(map:isLeaf:)``
223224 /// - ``modules()``
224225 /// - ``items()``
225- public func filterMap< Result> (
226+ open func filterMap< Result> (
226227 filter: ( Module , String , ModuleItem ) -> Bool ,
227228 map: ( ModuleItem ) -> Result ? = { $0 } ,
228229 isLeaf: ( Module , String , ModuleItem ) -> Bool = Module . isLeafDefault
@@ -331,7 +332,7 @@ open class Module {
331332 /// ### See Also
332333 /// - <doc:module-filters>
333334 /// - ``mapParameters(map:)``
334- public func mapParameters< Result> (
335+ open func mapParameters< Result> (
335336 map: @escaping ( MLXArray ) -> Result ? = { $0 } ,
336337 isLeaf: ( Module , String , ModuleItem ) -> Bool = Module . isLeafDefault
337338 ) -> NestedDictionary < String , Result > {
@@ -343,28 +344,28 @@ open class Module {
343344
344345 /// Return a `NestedDictionary<String, MLXArray>` for all parameters in the
345346 /// model (all layers).
346- public func parameters( ) -> ModuleParameters {
347+ open func parameters( ) -> ModuleParameters {
347348 filterMap ( filter: Self . filterValidParameters, map: Self . mapParameters ( ) )
348349 }
349350
350351 /// Return a `NestedDictionary<String, MLXArray>` for all trainable parameters in the
351352 /// model (all layers).
352353 ///
353354 /// This omits ``freeze(recursive:keys:strict:)`` (frozen) parameters.
354- public func trainableParameters( ) -> ModuleParameters {
355+ open func trainableParameters( ) -> ModuleParameters {
355356 filterMap ( filter: Self . filterTrainableParameters, map: Self . mapParameters ( ) )
356357 }
357358
358359 /// Produces a `NestedDictionary<String, Module>` for all direct children of the module.
359- public func children( ) -> ModuleChildren {
360+ open func children( ) -> ModuleChildren {
360361 filterMap ( filter: Self . filterValidChild, map: Self . mapModule ( ) , isLeaf: Self . isLeafModule)
361362 }
362363
363364 /// Produces a `NestedDictionary<String, Module>` for all leaf modules module.
364365 ///
365366 /// ### See Also
366367 /// - ``isLeafModuleNoChildren``
367- public func leafModules( ) -> ModuleChildren {
368+ open func leafModules( ) -> ModuleChildren {
368369 filterMap (
369370 filter: Self . filterValidChild, map: Self . mapModule ( ) ,
370371 isLeaf: Self . isLeafModuleNoChildren)
@@ -714,7 +715,23 @@ open class Module {
714715 return self
715716 }
716717
717- private func updateModule( key: String , _ value: Any ) throws {
718+ /// Set a module to a new value.
719+ ///
720+ /// The module property must be wrapped in a ``ModuleInfo``:
721+ ///
722+ /// ```swift
723+ /// @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
724+ /// ```
725+ ///
726+ /// and the value must be a compatible type.
727+ ///
728+ /// This method is called via ``update(modules:)`` and is not typically called directly. This
729+ /// is exposed as an overridable method for subclasses.
730+ ///
731+ /// - Parameters:
732+ /// - key: module key, see ``ModuleInfo``
733+ /// - value: the replacement module
734+ open func updateModule( key: String , _ value: Any ) throws {
718735 if let setter = _setters [ key] {
719736 do {
720737 try setter. updateModule ( value)
@@ -731,7 +748,7 @@ open class Module {
731748 }
732749
733750 // `apply_to_modules()`
734- public func visit( modules visitor: ( String , Module ) throws -> Void ) rethrows {
751+ open func visit( modules visitor: ( String , Module ) throws -> Void ) rethrows {
735752 var stack = [ ( String, Module) ] ( )
736753 stack. append ( ( " " , self ) )
737754
@@ -750,7 +767,7 @@ open class Module {
750767 /// - ``namedModules()``
751768 /// - ``children()``
752769 /// - ``leafModules()``
753- public func modules( ) -> [ Module ] {
770+ open func modules( ) -> [ Module ] {
754771 var result = [ Module] ( )
755772 visit {
756773 result. append ( $1)
@@ -764,7 +781,7 @@ open class Module {
764781 /// - ``modules()``
765782 /// - ``children()``
766783 /// - ``leafModules()``
767- public func namedModules( ) -> [ ( String , Module ) ] {
784+ open func namedModules( ) -> [ ( String , Module ) ] {
768785 var result = [ ( String, Module) ] ( )
769786 visit {
770787 result. append ( ( $0, $1) )
@@ -826,7 +843,8 @@ open class Module {
826843 /// - ``unfreeze(recursive:keys:strict:)``
827844 open func freeze( recursive: Bool = true , keys: [ String ] ? = nil , strict: Bool = false ) throws {
828845 let visitor = freezeVisitor ( keys: keys, strict: strict) {
829- $0. noGrad. formUnion ( $1)
846+ $0. _noGrad. formUnion ( $1)
847+ $0. didSetNoGrad ( $0. _noGrad)
830848 }
831849
832850 if recursive {
@@ -863,7 +881,8 @@ open class Module {
863881 /// - ``Module/unfreeze(recursive:keys:strict:)``
864882 open func unfreeze( recursive: Bool = true , keys: [ String ] ? = nil , strict: Bool = false ) throws {
865883 let visitor = freezeVisitor ( keys: keys, strict: strict) {
866- $0. noGrad. subtract ( $1)
884+ $0. _noGrad. subtract ( $1)
885+ $0. didSetNoGrad ( $0. _noGrad)
867886 }
868887
869888 if recursive {
@@ -873,6 +892,24 @@ open class Module {
873892 }
874893 }
875894
895+ /// Set of property names that are frozen. Maniupulated via
896+ /// ``freeze(recursive:keys:strict:)`` and
897+ /// ``unfreeze(recursive:keys:strict:)``.
898+ open func noGrad( ) -> Set < String > {
899+ _noGrad
900+ }
901+
902+ /// Called when ``noGrad()`` is updated.
903+ ///
904+ /// This is provided for subclasses to override.
905+ ///
906+ /// - Parameter noGrad: set of properties that are frozen
907+ ///
908+ /// ### See Also
909+ /// - ``noGrad()``
910+ open func didSetNoGrad( _ noGrad: Set < String > ) {
911+ }
912+
876913 /// Recursively set the model's training mode.
877914 ///
878915 /// Training mode only applies to certain layers. For example
@@ -881,11 +918,21 @@ open class Module {
881918 ///
882919 /// ### See Also
883920 /// - ``training``
921+ /// - ``didSetTrain(_:)``
884922 public func train( _ mode: Bool = true ) {
885923 visit ( modules: {
886924 $1. training = mode
925+ $1. didSetTrain ( mode)
887926 } )
888927 }
928+
929+ /// Called when ``train(_:)`` is updated.
930+ ///
931+ /// This is provided for subclasses to override.
932+ ///
933+ /// - Parameter mode: `true` is training
934+ open func didSetTrain( _ mode: Bool ) {
935+ }
889936}
890937
891938extension Module : IndentedDescription {
@@ -926,7 +973,7 @@ extension Module: Updatable, Evaluatable {
926973/// ### See Also
927974/// - <doc:layers>
928975/// - ``Sequential``
929- public protocol UnaryLayer {
976+ public protocol UnaryLayer : Module {
930977 func callAsFunction( _ x: MLXArray ) -> MLXArray
931978}
932979
@@ -1008,7 +1055,7 @@ extension Module {
10081055 ( module: Module , key: String , item: ModuleItem ) in
10091056 switch item {
10101057 case . array, . dictionary, . value( . parameters) , . value( . module) :
1011- parameterIsValid ( key) && !module. noGrad. contains ( key)
1058+ parameterIsValid ( key) && !module. noGrad ( ) . contains ( key)
10121059 default : false
10131060 }
10141061 }
0 commit comments