Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: regularizer serialization problem #1250

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
namespace Tensorflow.Keras
using Newtonsoft.Json;
using System.Collections.Generic;
using Tensorflow.Keras.Saving.Common;

namespace Tensorflow.Keras
{
public interface IRegularizer
{
Tensor Apply(RegularizerArgs args);
}
[JsonConverter(typeof(CustomizedRegularizerJsonConverter))]
public interface IRegularizer
{
[JsonProperty("class_name")]
string ClassName { get; }
[JsonProperty("config")]
IDictionary<string, object> Config { get; }
Tensor Apply(RegularizerArgs args);
}

public interface IRegularizerApi
{
IRegularizer GetRegularizerFromName(string name);
IRegularizer L1 { get; }
IRegularizer L2 { get; }
IRegularizer L1L2 { get; }
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations.Regularizers;

namespace Tensorflow.Keras.Saving.Common
{
class RegularizerInfo
{
public string class_name { get; set; }
public JObject config { get; set; }
}

public class CustomizedRegularizerJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(IRegularizer);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var regularizer = value as IRegularizer;
if (regularizer is null)
{
JToken.FromObject(null).WriteTo(writer);
return;
}
JToken.FromObject(new RegularizerInfo()
{
class_name = regularizer.ClassName,
config = JObject.FromObject(regularizer.Config)
}, serializer).WriteTo(writer);
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var info = serializer.Deserialize<RegularizerInfo>(reader);
if (info is null)
{
return null;
}
return info.class_name switch
{
"L1L2" => new L1L2 (info.config["l1"].ToObject<float>(), info.config["l2"].ToObject<float>()),
"L1" => new L1(info.config["l1"].ToObject<float>()),
"L2" => new L2(info.config["l2"].ToObject<float>()),
};
}
}
}
33 changes: 33 additions & 0 deletions src/TensorFlowNET.Core/Operations/Regularizers/L1.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;

using Tensorflow.Keras;

namespace Tensorflow.Operations.Regularizers
{
public class L1 : IRegularizer
{
float _l1;
private readonly Dictionary<string, object> _config;

public string ClassName => "L1";
public virtual IDictionary<string, object> Config => _config;

public L1(float l1 = 0.01f)
{
// l1 = 0.01 if l1 is None else l1
// validate_float_arg(l1, name = "l1")
// self.l1 = ops.convert_to_tensor(l1)
this._l1 = l1;

_config = new();
_config["l1"] = _l1;
}


public Tensor Apply(RegularizerArgs args)
{
//return self.l1 * ops.sum(ops.absolute(x))
return _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
}
}
}
48 changes: 48 additions & 0 deletions src/TensorFlowNET.Core/Operations/Regularizers/L1L2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using System;

using Tensorflow.Keras;

namespace Tensorflow.Operations.Regularizers
{
public class L1L2 : IRegularizer
{
float _l1;
float _l2;
private readonly Dictionary<string, object> _config;

public string ClassName => "L1L2";
public virtual IDictionary<string, object> Config => _config;

public L1L2(float l1 = 0.0f, float l2 = 0.0f)
{
//l1 = 0.0 if l1 is None else l1
//l2 = 0.0 if l2 is None else l2
// validate_float_arg(l1, name = "l1")
// validate_float_arg(l2, name = "l2")

// self.l1 = l1
// self.l2 = l2
this._l1 = l1;
this._l2 = l2;

_config = new();
_config["l1"] = l1;
_config["l2"] = l2;
}

public Tensor Apply(RegularizerArgs args)
{
//regularization = ops.convert_to_tensor(0.0, dtype = x.dtype)
//if self.l1:
// regularization += self.l1 * ops.sum(ops.absolute(x))
//if self.l2:
// regularization += self.l2 * ops.sum(ops.square(x))
//return regularization

Tensor regularization = tf.constant(0.0, args.X.dtype);
regularization += _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
regularization += _l2 * math_ops.reduce_sum(math_ops.square(args.X));
return regularization;
}
}
}
33 changes: 33 additions & 0 deletions src/TensorFlowNET.Core/Operations/Regularizers/L2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;

using Tensorflow.Keras;

namespace Tensorflow.Operations.Regularizers
{
public class L2 : IRegularizer
{
float _l2;
private readonly Dictionary<string, object> _config;

public string ClassName => "L2";
public virtual IDictionary<string, object> Config => _config;

public L2(float l2 = 0.01f)
{
// l2 = 0.01 if l2 is None else l2
// validate_float_arg(l2, name = "l2")
// self.l2 = l2
this._l2 = l2;

_config = new();
_config["l2"] = _l2;
}


public Tensor Apply(RegularizerArgs args)
{
//return self.l2 * ops.sum(ops.square(x))
return _l2 * math_ops.reduce_sum(math_ops.square(args.X));
}
}
}
51 changes: 47 additions & 4 deletions src/TensorFlowNET.Keras/Regularizers.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,51 @@
namespace Tensorflow.Keras
using Tensorflow.Operations.Regularizers;

namespace Tensorflow.Keras
{
public class Regularizers
public class Regularizers: IRegularizerApi
{
private static Dictionary<string, IRegularizer> _nameActivationMap;

public IRegularizer l1(float l1 = 0.01f)
=> new L1(l1);
public IRegularizer l2(float l2 = 0.01f)
=> new L2(l2);

//From TF source
//# The default value for l1 and l2 are different from the value in l1_l2
//# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2
//# and no l1 penalty.
public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f)
=> new L1L2(l1, l2);

static Regularizers()
{
public IRegularizer l2(float l2 = 0.01f)
=> new L2(l2);
_nameActivationMap = new Dictionary<string, IRegularizer>();
_nameActivationMap["L1"] = new L1();
_nameActivationMap["L1"] = new L2();
_nameActivationMap["L1"] = new L1L2();
}

public IRegularizer L1 => l1();

public IRegularizer L2 => l2();

public IRegularizer L1L2 => l1l2();

public IRegularizer GetRegularizerFromName(string name)
{
if (name == null)
{
throw new Exception($"Regularizer name cannot be null");
}
if (!_nameActivationMap.TryGetValue(name, out var res))
{
throw new Exception($"Regularizer {name} not found");
}
else
{
return res;
}
}
}
}
19 changes: 0 additions & 19 deletions src/TensorFlowNET.Keras/Regularizers/L1.cs

This file was deleted.

24 changes: 0 additions & 24 deletions src/TensorFlowNET.Keras/Regularizers/L1L2.cs

This file was deleted.

17 changes: 0 additions & 17 deletions src/TensorFlowNET.Keras/Regularizers/L2.cs

This file was deleted.

48 changes: 48 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using System.Collections.Generic;
using System.Linq;
using System.Xml.Linq;
using Tensorflow.Keras.Engine;
Expand Down Expand Up @@ -129,6 +130,53 @@ public void TestModelBeforeTF2_5()
}


[TestMethod]
public void BiasRegularizerSaveAndLoad()
{
var savemodel = keras.Sequential(new List<ILayer>()
{
tf.keras.layers.InputLayer((227, 227, 3)),
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2),
tf.keras.layers.BatchNormalization(),

tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(1000, activation: "linear"),
tf.keras.layers.Softmax(1)
});

savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var num_epochs = 1;
var batch_size = 8;

var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16);

savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs);

savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf");

var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load");
loadModel.summary();

loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16);

loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs);
}


[TestMethod]
public void CreateConcatenateModelSaveAndLoad()
Expand Down
Loading