Skip to content

Commit

Permalink
Added tests for skutils
Browse files Browse the repository at this point in the history
  • Loading branch information
williamjameshandley committed Apr 30, 2024
1 parent 3db2b4e commit 2d522b3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stemu/emu.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def fit(self, X, t, y):
self.t = t

X = self.X_pipeline.fit_transform(X)
t = self.t_pipeline.fit_transform(t, y)
y = self.y_pipeline.fit_transform(y)
t = self.t_pipeline.fit_transform(t, y)

ty = self.ty_pipeline.fit_transform(np.block([[t], [y]]))
t, y = ty[0], ty[1:]
Expand Down Expand Up @@ -112,6 +112,6 @@ def predict(self, X, t=None):
y = self.model.predict(X)
_, _, y = unstack(X, y)
ty = self.ty_pipeline.inverse_transform(np.block([[t], [y]]))
t, y = ty[0], ty[1:]
_, y = ty[0], ty[1:]
y = self.y_pipeline.inverse_transform(y)
return y
41 changes: 41 additions & 0 deletions tests/test_skutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from numpy.testing import assert_allclose

from stemu.skutils import CDFTransformer, FunctionScaler, IdentityTransformer


def test_CDFTransformer():
t = np.linspace(0, 1, 100)
y = t**3 - t + 1
y = y * np.random.rand(20, len(t))
cdf = CDFTransformer()
assert isinstance(cdf.fit(t, y), CDFTransformer)
t_ = cdf.transform(t)
assert not (t == t_).all()
t_ = cdf.inverse_transform(t_)
assert (t == t_).all()


def test_FunctionScaler():
t = np.linspace(0, 1, 100)
y = t**3 - t + 1
y = y * np.random.rand(20, len(t))
X = np.block([[t], [y]])

fs = FunctionScaler()
assert isinstance(fs.fit(X), FunctionScaler)
X_ = fs.transform(X)
assert_allclose(X_[1:].mean(axis=0), 0, atol=1e-15)
assert_allclose(X_[1:].std(axis=0), 1, atol=1e-15)
assert_allclose(fs.inverse_transform(X_), X)


def test_IdentityTransformer():
y = np.random.rand(10, 20)

identity = IdentityTransformer()
assert isinstance(identity.fit(y), IdentityTransformer)
t_ = identity.transform(t)
assert (t == t_).all()
t_ = identity.inverse_transform(t_)
assert (t == t_).all()

0 comments on commit 2d522b3

Please sign in to comment.