From 2d522b3937b7f73d2dbb29c4657f556b8bf5e707 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Tue, 30 Apr 2024 11:19:36 +0100 Subject: [PATCH] Added tests for skutils --- stemu/emu.py | 4 ++-- tests/test_skutils.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 tests/test_skutils.py diff --git a/stemu/emu.py b/stemu/emu.py index df798e9..da90877 100644 --- a/stemu/emu.py +++ b/stemu/emu.py @@ -68,8 +68,8 @@ def fit(self, X, t, y): self.t = t X = self.X_pipeline.fit_transform(X) - t = self.t_pipeline.fit_transform(t, y) y = self.y_pipeline.fit_transform(y) + t = self.t_pipeline.fit_transform(t, y) ty = self.ty_pipeline.fit_transform(np.block([[t], [y]])) t, y = ty[0], ty[1:] @@ -112,6 +112,6 @@ def predict(self, X, t=None): y = self.model.predict(X) _, _, y = unstack(X, y) ty = self.ty_pipeline.inverse_transform(np.block([[t], [y]])) - t, y = ty[0], ty[1:] + _, y = ty[0], ty[1:] y = self.y_pipeline.inverse_transform(y) return y diff --git a/tests/test_skutils.py b/tests/test_skutils.py new file mode 100644 index 0000000..3574524 --- /dev/null +++ b/tests/test_skutils.py @@ -0,0 +1,41 @@ +import numpy as np +from numpy.testing import assert_allclose + +from stemu.skutils import CDFTransformer, FunctionScaler, IdentityTransformer + + +def test_CDFTransformer(): + t = np.linspace(0, 1, 100) + y = t**3 - t + 1 + y = y * np.random.rand(20, len(t)) + cdf = CDFTransformer() + assert isinstance(cdf.fit(t, y), CDFTransformer) + t_ = cdf.transform(t) + assert not (t == t_).all() + t_ = cdf.inverse_transform(t_) + assert (t == t_).all() + + +def test_FunctionScaler(): + t = np.linspace(0, 1, 100) + y = t**3 - t + 1 + y = y * np.random.rand(20, len(t)) + X = np.block([[t], [y]]) + + fs = FunctionScaler() + assert isinstance(fs.fit(X), FunctionScaler) + X_ = fs.transform(X) + assert_allclose(X_[1:].mean(axis=0), 0, atol=1e-15) + assert_allclose(X_[1:].std(axis=0), 1, atol=1e-15) + assert_allclose(fs.inverse_transform(X_), X) + + +def test_IdentityTransformer(): + y = np.random.rand(10, 20) + + identity = IdentityTransformer() + assert isinstance(identity.fit(y), IdentityTransformer) + t_ = identity.transform(t) + assert (t == t_).all() + t_ = identity.inverse_transform(t_) + assert (t == t_).all()