-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun.py
72 lines (69 loc) · 2.24 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import mmns
from mmns.config import Trainer, Tester
from mmns.module.model import MMTransE
from mmns.module.loss import MarginLoss
from mmns.module.strategy import NegativeSampling
from mmns.data import TrainDataLoader, TestDataLoader
from args import get_args
if __name__ == "__main__":
args = get_args()
print(args)
# dataloader for training
train_dataloader = TrainDataLoader(
in_path="./benchmarks/" + args.dataset + '/',
nbatches=args.num_batch,
threads=8,
# 当dismult的时候是cross
sampling_mode="normal",
bern_flag=1,
filter_flag=1,
neg_ent=args.neg_num,
neg_rel=0
)
# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/" + args.dataset + '/', "link")
img_emb = torch.load('./visual/' + args.dataset + '.pth')
if args.kernel == 'transe':
# define the model
transe = MMTransE(
ent_tot=train_dataloader.get_ent_tot(),
rel_tot=train_dataloader.get_rel_tot(),
dim=128,
p_norm=1,
norm_flag=True,
img_dim=args.img_dim,
img_emb=img_emb,
test_mode=args.test_mode,
beta=args.beta
)
print(transe)
# define the loss function
model = NegativeSampling(
model=transe,
loss=MarginLoss(margin=args.margin),
batch_size=train_dataloader.get_batch_size(),
neg_mode=args.neg_mode
)
# train the model
trainer = Trainer(
model=model,
data_loader=train_dataloader,
train_times=args.epoch,
alpha=1.0,
use_gpu=True,
opt_method='Adam',
train_mode=args.train_mode
)
trainer.run()
transe.save_checkpoint(args.save)
# test the model
transe.load_checkpoint(args.save)
tester = Tester(model=transe, data_loader=test_dataloader, use_gpu=True)
# link prediction task
tester.run_link_prediction(type_constrain=False)
# triple classification task
acc, p, r, f, _ = tester.run_triple_classification_four_metrics()
print(acc, p, r, f)
else:
raise NotImplementedError