Skip to content

Commit

Permalink
Merge pull request #17 from fidelity/test_fix_ci
Browse files Browse the repository at this point in the history
Fix SVD Instability
  • Loading branch information
kuppulur authored Dec 12, 2024
2 parents 7593961 + 51a2f90 commit 35ca015
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
TextWiser CHANGELOG
=====================

-------------------------------------------------------------------------------
Dec 9, 2024 2.0.2
-------------------------------------------------------------------------------
minor:
- Updated SVD tests to fix the directional instability of eigen vectors

-------------------------------------------------------------------------------
Sep 9, 2024 2.0.1
-------------------------------------------------------------------------------
Expand Down
18 changes: 12 additions & 6 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,24 @@ def _reset_seed(self, seed=1234):
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def _test_fit_transform(self, tw_model, expected, atol=1e-6):
def _test_fit_transform(self, tw_model, expected, svd=False, atol=1e-6):
predicted = tw_model.fit_transform(docs)
# torch.set_printoptions(precision=10)
# print(predicted)
self.assertTrue(torch.allclose(predicted, expected.to(device), atol=atol))
if svd:
self.assertTrue(torch.allclose(np.abs(predicted), np.abs(expected.to(device)), atol=atol))
else:
self.assertTrue(torch.allclose(predicted, expected.to(device), atol=atol))

def _test_fit_before_transform(self, tw_model, expected, atol=1e-6):
def _test_fit_before_transform(self, tw_model, expected, svd=False, atol=1e-6):
tw_model.fit(docs)
# torch.set_printoptions(precision=10)
# print(tw_model.transform(docs))
self.assertTrue(torch.allclose(tw_model.transform(docs), expected.to(device), atol=atol))
self.assertTrue(torch.allclose(tw_model(docs), expected.to(device), atol=atol))
if svd:
self.assertTrue(torch.allclose(np.abs(tw_model.transform(docs)), np.abs(expected.to(device)), atol=atol))
self.assertTrue(torch.allclose(np.abs(tw_model(docs)), np.abs(expected.to(device)), atol=atol))
else:
self.assertTrue(torch.allclose(tw_model.transform(docs), expected.to(device), atol=atol))
self.assertTrue(torch.allclose(tw_model(docs), expected.to(device), atol=atol))

def _get_test_path(self, *names):
cwd = os.getcwd()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def test_fit_transform(self):
expected = torch.tensor([[0.8526761532, -0.5070778131],
[0.9837458134, -0.0636523664],
[0.7350711226, 0.6733918786]], dtype=torch.float32)
self._test_fit_transform(tw, expected)
self._test_fit_transform(tw, expected, svd=True)
self._reset_seed()
self._test_fit_before_transform(tw, expected)
self._test_fit_before_transform(tw, expected, svd=True)

def test_min_components(self):
with self.assertRaises(ValueError):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
import warnings

import numpy as np

from textwiser import TextWiser, Embedding, Transformation, WordOptions, device
from tests.test_base import BaseTest, docs

Expand Down Expand Up @@ -29,4 +31,5 @@ def test_list_handling(self):
[0.0000000000, 0.0000000000]], dtype=torch.float32)
]
for p, e in zip(predicted, expected):
self.assertTrue(torch.allclose(p, e.to(device), atol=1e-6))
# np.abs due to SVD instability
self.assertTrue(torch.allclose(np.abs(p), np.abs(e.to(device)), atol=1e-5))
2 changes: 1 addition & 1 deletion textwiser/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 FMR LLC <[email protected]>
# SPDX-License-Identifer: Apache-2.0

__version__ = "2.0.1"
__version__ = "2.0.2"

0 comments on commit 35ca015

Please sign in to comment.