-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
29 lines (25 loc) · 974 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from torch import nn
from torch.nn.utils import clip_grad_value_
from torch.utils.data import DataLoader
from torch.distributions import Categorical
from torch.nn.modules.activation import Sigmoid
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.seq2seq_encoders import (LstmSeq2SeqEncoder,PytorchTransformer)
from tqdm import tqdm
from tokenizer import *
from full_model import *
from discriminator import *
from generator import *
data = []
with open('qm9.csv', "r") as f:
for line in f.readlines()[1:]:
data.append(line.split(",")[1])
cuda = 'cuda:0'
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
gan_mol = MolGen(data, hidden_dim=64, lr=1e-3, device=device)
loader = gan_mol.create_dataloader(data, batch_size=64, shuffle=True, num_workers=10)
gan_mol.train_n_steps(loader, max_step=5000, evaluate_every=1000)