-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_csv.py
executable file
·48 lines (39 loc) · 1.75 KB
/
train_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# coding: utf-8
from __future__ import print_function
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import tensorflow as tf
def main(_):
csv_file_name = './data/period_trend.csv'
reader = tf.contrib.timeseries.CSVReader(csv_file_name)
train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(reader, batch_size=16, window_size=16)
with tf.Session() as sess:
data = reader.read_full()
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess=sess, coord=coord)
data = sess.run(data)
coord.request_stop()
ar = tf.contrib.timeseries.ARRegressor(
periodicities=100, input_window_size=10, output_window_size=6,
num_features=1,
loss=tf.contrib.timeseries.ARModel.NORMAL_LIKELIHOOD_LOSS)
ar.train(input_fn=train_input_fn, steps=1000)
evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
# keys of evaluation: ['covariance', 'loss', 'mean', 'observed', 'start_tuple', 'times', 'global_step']
evaluation = ar.evaluate(input_fn=evaluation_input_fn, steps=1)
(predictions,) = tuple(ar.predict(
input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
evaluation, steps=250)))
plt.figure(figsize=(15, 5))
plt.plot(data['times'].reshape(-1), data['values'].reshape(-1), label='origin')
plt.plot(evaluation['times'].reshape(-1), evaluation['mean'].reshape(-1), label='evaluation')
plt.plot(predictions['times'].reshape(-1), predictions['mean'].reshape(-1), label='prediction')
plt.xlabel('time_step')
plt.ylabel('values')
plt.legend(loc=4)
plt.savefig('predict_result.jpg')
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()