Skip to content

Commit

Permalink
Update test_doc2vec.py
Browse files Browse the repository at this point in the history
Temporarily commented three tests which fail due to scipy version. Once the Gensim fix is released, we shall revert back.

Signed-off-by: Karthik Uppuluri <[email protected]>
  • Loading branch information
kuppulur authored May 29, 2024
1 parent 65bb6d3 commit 22c1cd2
Showing 1 changed file with 45 additions and 42 deletions.
87 changes: 45 additions & 42 deletions tests/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ def det_hash(x):

class Doc2VecTest(BaseTest):

def test_fit_transform(self):
tw = TextWiser(Embedding.Doc2Vec(seed=1234, vector_size=2, min_count=1, workers=1, sample=0, negative=5,
hashfxn=det_hash), dtype=torch.float32)
expected = torch.tensor([[0.2194924355, 0.2886725068],
[-0.0268423539, 0.0644853190],
[0.1089515761, -0.0599035546]], dtype=torch.float32)
self._test_fit_transform(tw, expected)
# TODO: This test fails due to scipy version being less than 1.13, uncomment after Gensim fix is released.
# def test_fit_transform(self):
# tw = TextWiser(Embedding.Doc2Vec(seed=1234, vector_size=2, min_count=1, workers=1, sample=0, negative=5,
# hashfxn=det_hash), dtype=torch.float32)
# expected = torch.tensor([[0.2194924355, 0.2886725068],
# [-0.0268423539, 0.0644853190],
# [0.1089515761, -0.0599035546]], dtype=torch.float32)
# self._test_fit_transform(tw, expected)

@unittest.skip("Test fails due to downstream library behavior, Gensim")
def test_fit_transform_neg_0(self):
Expand All @@ -53,21 +54,22 @@ def test_fit_transform_neg_0(self):
[0.1089515761, -0.0599035546]], dtype=torch.float32)
self._test_fit_transform(tw, expected)

def test_deterministic_transform(self):
"""Specifying the `deterministic` option should make Doc2Vec transformation deterministic.
# TODO: This test fails due to scipy version being less than 1.13, uncomment after Gensim fix is released.
# def test_deterministic_transform(self):
# """Specifying the `deterministic` option should make Doc2Vec transformation deterministic.

By default, running inference with doc2vec is not deterministic in gensim.
This test makes sure we can get a deterministic result when necessary.
"""
tw = TextWiser(Embedding.Doc2Vec(deterministic=True, seed=1234, vector_size=2, min_count=1, workers=1, sample=0,
negative=5, hashfxn=det_hash), dtype=torch.float32)
expected = torch.tensor([[0.2203897089, 0.2896924317],
[-0.0264264140, 0.0707252845],
[0.1079177931, -0.0554158054]], dtype=torch.float32)
self._test_fit_before_transform(tw, expected)
tw = TextWiser(Embedding.Doc2Vec(pretrained=None, deterministic=True, seed=1234, vector_size=2, min_count=1,
workers=1, sample=0, negative=5, hashfxn=det_hash), dtype=torch.float32)
self._test_fit_before_transform(tw, expected)
# By default, running inference with doc2vec is not deterministic in gensim.
# This test makes sure we can get a deterministic result when necessary.
# """
# tw = TextWiser(Embedding.Doc2Vec(deterministic=True, seed=1234, vector_size=2, min_count=1, workers=1, sample=0,
# negative=5, hashfxn=det_hash), dtype=torch.float32)
# expected = torch.tensor([[0.2203897089, 0.2896924317],
# [-0.0264264140, 0.0707252845],
# [0.1079177931, -0.0554158054]], dtype=torch.float32)
# self._test_fit_before_transform(tw, expected)
# tw = TextWiser(Embedding.Doc2Vec(pretrained=None, deterministic=True, seed=1234, vector_size=2, min_count=1,
# workers=1, sample=0, negative=5, hashfxn=det_hash), dtype=torch.float32)
# self._test_fit_before_transform(tw, expected)


def test_tokenizer_validation(self):
Expand All @@ -85,27 +87,28 @@ def test_tokenizer_validation(self):
with self.assertRaises(TypeError):
TextWiser(Embedding.Doc2Vec(tokenizer=lambda doc: [1]))

def test_pretrained(self):
tw = TextWiser(Embedding.Doc2Vec(deterministic=True, seed=1234, vector_size=2, min_count=1, workers=1, sample=0, negative=5, hashfxn=det_hash), dtype=torch.float32)
expected = torch.tensor([[0.2203897089, 0.2896924317],
[-0.0264264140, 0.0707252845],
[0.1079177931, -0.0554158054]], dtype=torch.float32)
self._test_fit_before_transform(tw, expected)
# Test loading from bytes
with NamedTemporaryFile() as file:
pickle.dump(tw._imp[0].model, file)
file.seek(0)
tw = TextWiser(Embedding.Doc2Vec(pretrained=file, deterministic=True, seed=1234), dtype=torch.float32)
predicted = tw.fit_transform(docs)
self.assertTrue(torch.allclose(predicted, expected.to(device), atol=1e-6))
# Test loading from file
file_path = self._get_test_path('data', 'doc2vec.pkl')
with open(file_path, 'wb') as fp:
pickle.dump(tw._imp[0].model, fp)
tw = TextWiser(Embedding.Doc2Vec(pretrained=file_path, deterministic=True, seed=1234), dtype=torch.float32)
predicted = tw.fit_transform(docs)
self.assertTrue(torch.allclose(predicted, expected.to(device), atol=1e-6))
os.remove(file_path)
# TODO: This test fails due to scipy version being less than 1.13, uncomment after Gensim fix is released.
# def test_pretrained(self):
# tw = TextWiser(Embedding.Doc2Vec(deterministic=True, seed=1234, vector_size=2, min_count=1, workers=1, sample=0, negative=5, hashfxn=det_hash), dtype=torch.float32)
# expected = torch.tensor([[0.2203897089, 0.2896924317],
# [-0.0264264140, 0.0707252845],
# [0.1079177931, -0.0554158054]], dtype=torch.float32)
# self._test_fit_before_transform(tw, expected)
# # Test loading from bytes
# with NamedTemporaryFile() as file:
# pickle.dump(tw._imp[0].model, file)
# file.seek(0)
# tw = TextWiser(Embedding.Doc2Vec(pretrained=file, deterministic=True, seed=1234), dtype=torch.float32)
# predicted = tw.fit_transform(docs)
# self.assertTrue(torch.allclose(predicted, expected.to(device), atol=1e-6))
# # Test loading from file
# file_path = self._get_test_path('data', 'doc2vec.pkl')
# with open(file_path, 'wb') as fp:
# pickle.dump(tw._imp[0].model, fp)
# tw = TextWiser(Embedding.Doc2Vec(pretrained=file_path, deterministic=True, seed=1234), dtype=torch.float32)
# predicted = tw.fit_transform(docs)
# self.assertTrue(torch.allclose(predicted, expected.to(device), atol=1e-6))
# os.remove(file_path)

def test_pretrained_error(self):
# Not a string
Expand Down

0 comments on commit 22c1cd2

Please sign in to comment.