Skip to content

Commit

Permalink
be more tolerant
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Jul 17, 2024
1 parent 17fe33e commit 8c0ab2a
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions tests/test_proc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_scale_linear(tid: MemberId):
op(sample)

expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c"))
xr.testing.assert_allclose(expected, sample.members[tid].data)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


def test_scale_linear_no_channel(tid: MemberId):
Expand All @@ -42,7 +42,7 @@ def test_scale_linear_no_channel(tid: MemberId):
op(sample)

expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y"))
xr.testing.assert_allclose(expected, sample.members[tid].data)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


T = TypeVar("T")
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_zero_mean_unit_variance(tid: MemberId):
),
dims=("x", "y"),
)
xr.testing.assert_allclose(expected, sample.members[tid].data)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


def test_zero_mean_unit_variance_fixed(tid: MemberId):
Expand All @@ -102,7 +102,7 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId):
)
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
op(sample)
xr.testing.assert_allclose(expected, sample.members[tid].data)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


def test_zero_mean_unit_across_axes(tid: MemberId):
Expand All @@ -123,7 +123,7 @@ def test_zero_mean_unit_across_axes(tid: MemberId):
[(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], dim="c"
)
op(sample)
xr.testing.assert_allclose(expected, sample.members[tid].data)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


def test_zero_mean_unit_variance_fixed2(tid: MemberId):
Expand All @@ -139,7 +139,7 @@ def test_zero_mean_unit_variance_fixed2(tid: MemberId):
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y"))
op(sample)
xr.testing.assert_allclose(expected, sample.members[tid].data)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


def test_binarize(tid: MemberId):
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: MemberId):
)

op(sample)
xr.testing.assert_allclose(expected, sample.members[tid].data)
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]):
)
sample.stat = compute_measures(op.required_measures, [sample])
op(sample)
xr.testing.assert_allclose(ref_data, sample.members[tid].data)
xr.testing.assert_allclose(ref_data, sample.members[tid].data, rtol=1e-6, atol=1e-7)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -290,11 +290,15 @@ def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str])

if axes is not None and AxisId("c") not in axes:
# mean,std per channel should match exactly
xr.testing.assert_allclose(ref_data, sample.members[tid].data)
xr.testing.assert_allclose(
ref_data, sample.members[tid].data, rtol=1e-6, atol=1e-7
)
else:
# mean,std across channels should not match
with pytest.raises(AssertionError):
xr.testing.assert_allclose(ref_data, sample.members[tid].data)
xr.testing.assert_allclose(
ref_data, sample.members[tid].data, rtol=1e-6, atol=1e-7
)


def test_scale_range(tid: MemberId):
Expand All @@ -313,7 +317,7 @@ def test_scale_range(tid: MemberId):

op(sample)
# NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct
np.testing.assert_allclose(expected, sample.members[tid].data)
np.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)


def test_scale_range_axes(tid: MemberId):
Expand Down Expand Up @@ -363,4 +367,4 @@ def test_sigmoid(tid: MemberId):
sigmoid(sample)

exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes)
xr.testing.assert_allclose(exp, sample.members[tid].data)
xr.testing.assert_allclose(exp, sample.members[tid].data, rtol=1e-6, atol=1e-7)

0 comments on commit 8c0ab2a

Please sign in to comment.