-
Notifications
You must be signed in to change notification settings - Fork 1
/
embeddings.py
73 lines (56 loc) · 2.11 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from data_loader import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
SOS_token = 0
EOS_token = 1
class Input:
def __init__(self, name):
self.name = name
self.word2index = {"SOS": SOS_token, "EOS": EOS_token}
self.word2count = {}
self.index2word = {SOS_token: "SOS", EOS_token: "EOS"}
self.n_words = 2 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
class Output:
def __init__(self, name):
self.name = name
self.word2index = {"SOS": SOS_token, "EOS": EOS_token}
self.word2count = {}
self.index2word = {SOS_token: "SOS", EOS_token: "EOS"}
self.n_words = 2 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def get_embedding(word, lookup_dict, embeds):
tensor = torch.tensor([lookup_dict[word]], dtype=torch.long)
return embeds(tensor)
def indexesFromSentence(lang, sentence):
return [lang.word2index[word] for word in sentence.split(' ')]
def tensorFromSentence(lang, sentence):
indexes = indexesFromSentence(lang, sentence)
indexes.append(EOS_token)
return torch.tensor(indexes, dtype=torch.long).view(-1, 1)
def tensorsFromPair(pair, input_lang, output_lang):
input_tensor = tensorFromSentence(input_lang, pair[0])
output_tensor = tensorFromSentence(output_lang, pair[1])
return (input_tensor, output_tensor)