forked from valgur/SuperGluePretrainedNetwork
-
Notifications
You must be signed in to change notification settings - Fork 1
/
jit.py
31 lines (22 loc) · 733 Bytes
/
jit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from models.superglue_triton import SuperGlue
from models.superpoint_triton import SuperPoint
from models.superpoint_descriptor import SuperPointDescriptor
def export_sp():
superpoint = SuperPoint({
'max_keypoints': 5000,
'keypoint_threshold': 0.01
}).eval()
torch.jit.save(superpoint, 'superpoint_model.pt')
def export_sg():
superglue = SuperGlue({
'weights': 'outdoor',
'sinkhorn_iterations': 10
}).eval()
torch.jit.save(superglue, 'superglue_model.pt')
def export_superpoint_desc():
superpoint_desc = SuperPointDescriptor().eval()
torch.jit.save(superpoint_desc, 'superpoint_desc_model.pt')
export_sp()
# export_sg()
# export_superpoint_desc()