-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
84 lines (70 loc) · 2.18 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import datetime
import random
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from absl import app, flags, logging
from loguru import logger
from scipy import stats
from sklearn import metrics, model_selection
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch.utils.tensorboard import SummaryWriter
import config
import dataset
import engine
from model import BERTBaseUncased
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
writer = SummaryWriter()
logger.add("experiment.log")
flags.DEFINE_boolean('features', True, "")
flags.DEFINE_string('input', None, "")
flags.DEFINE_string('output', None, "")
flags.DEFINE_string('model_path', None, "")
FLAGS = flags.FLAGS
def main(_):
input = config.EVAL_PROC
output = 'predictions.csv'
model_path = config.MODEL_PATH
if FLAGS.input:
input = FLAGS.input
if FLAGS.output:
output = FLAGS.input
if FLAGS.model_path:
model_path = FLAGS.model_path
df_test = pd.read_fwf(input)
logger.info(f"Bert Model: {config.BERT_PATH}")
logger.info(
f"Current date and time :{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ")
logger.info(f"Test file: {input}")
logger.info(f"Test size : {len(df_test):.4f}")
trg = []
for i in range(len(df_test.values)):
trg.append(0)
test_dataset = dataset.BERTDataset(
text=df_test.values,
target=trg
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=3
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BERTBaseUncased(config.DROPOUT)
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
model.to(device)
outputs, extracted_features = engine.predict_fn(
test_data_loader, model, device, extract_features=FLAGS.features)
df_test["predicted"] = outputs
# save file
df_test.to_csv(output, header=None, index=False)
if __name__ == "__main__":
app.run(main)