Skip to content

Commit

Permalink
Modify Tensorflow regularizers (#1864)
Browse files Browse the repository at this point in the history
  • Loading branch information
vl-dud authored Nov 6, 2024
1 parent 40cd7e5 commit 4c8b2d2
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions deepxde/nn/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,32 @@


def get(identifier):
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
Args:
identifier (list/tuple): Specifies the type and factor(s) of the regularizer.
The first element should be one of "l1", "l2", or "l1l2" ("l1+l2").
For "l1" and "l2", a single regularization factor is expected.
For "l1l2", provide both "l1" and "l2" factors.
"""

# TODO: other backends
if identifier is None:
if identifier is None or not identifier:
return None
name, scales = identifier[0].lower(), identifier[1:]
return (
tf.keras.regularizers.L1(l1=scales[0])
if name == "l1"
else tf.keras.regularizers.L2(l2=scales[0])
if name == "l2"
else tf.keras.regularizers.L1L2(l1=scales[0], l2=scales[1])
if name in ("l1+l2", "l1l2")
else None
)
if not isinstance(identifier, (list, tuple)):
raise ValueError("Identifier must be a list or a tuple.")

name = identifier[0].lower()
factor = identifier[1:]
if not factor:
raise ValueError("Regularization factor must be provided.")

if name == "l1":
return tf.keras.regularizers.L1(l1=factor[0])
if name == "l2":
return tf.keras.regularizers.L2(l2=factor[0])
if name in ("l1l2", "l1+l2"):
if len(factor) < 2:
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1])
raise ValueError(f"Unknown regularizer name: {name}")

0 comments on commit 4c8b2d2

Please sign in to comment.