Skip to content

Commit ba0ab0d

Browse files
committed
fix hidden state bug
1 parent 89eb1e3 commit ba0ab0d

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

inltk/inltk.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def get_embedding_vectors(input: str, language_code: str):
8989
path = Path(__file__).parent
9090
learn = load_learner(path / 'models' / f'{language_code}')
9191
encoder = get_model(learn.model)[0]
92+
encoder.reset()
9293
embeddings = encoder.state_dict()['encoder.weight']
9394
embeddings = np.array(embeddings)
9495
embedding_vectors = []
@@ -105,8 +106,9 @@ def get_sentence_encoding(input: str, language_code: str):
105106
defaults.device = torch.device('cpu')
106107
path = Path(__file__).parent
107108
learn = load_learner(path / 'models' / f'{language_code}')
108-
m = learn.model
109-
kk0 = m[0](Tensor([token_ids]).to(torch.int64))
109+
encoder = learn.model[0]
110+
encoder.reset()
111+
kk0 = encoder(Tensor([token_ids]).to(torch.int64))
110112
return np.array(kk0[0][-1][0][-1])
111113

112114

@@ -128,6 +130,7 @@ def get_similar_sentences(sen: str, no_of_variations: int, language_code: str):
128130
path = Path(__file__).parent
129131
learn = load_learner(path / 'models' / f'{language_code}')
130132
encoder = get_model(learn.model)[0]
133+
encoder.reset()
131134
embeddings = encoder.state_dict()['encoder.weight']
132135
embeddings = np.array(embeddings)
133136
# cos similarity of vectors

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="inltk",
8-
version="0.7.1",
8+
version="0.7.2",
99
author="Gaurav",
1010
author_email="[email protected]",
1111
description="Natural Language Toolkit for Indian Languages (iNLTK)",

0 commit comments

Comments
 (0)