Skip to content

Commit 6f8e4b5

Browse files
committed
Simple example of training a KGE with pytorch setup
1 parent fc9b352 commit 6f8e4b5

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

examples/training.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from dicee.models import Keci
2+
import numpy as np
3+
import torch
4+
# Indexed Triples
5+
X=torch.from_numpy(np.array([[0,0,1],
6+
[0,1,1],
7+
[1,1,1]])).long()
8+
9+
# Labels
10+
y=torch.from_numpy(np.array([1,1,0])).float()
11+
12+
# Model
13+
model=Keci(args={"num_entities":2,"num_relations":2,"embedding_dim":8,"optim":"Adopt"})
14+
# Optim
15+
optim=torch.optim.Adam(model.parameters(),lr=0.1)
16+
#
17+
model.train()
18+
for i in range(10):
19+
optim.zero_grad()
20+
21+
yhat=model(X)
22+
23+
loss=torch.nn.functional.binary_cross_entropy_with_logits(yhat,y)
24+
25+
print(loss.item())
26+
27+
loss.backward()
28+
29+
optim.step()
30+
31+
model.eval()
32+
33+
with torch.no_grad():
34+
print(model(X).mean())

0 commit comments

Comments
 (0)