-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from 1eedaegon/master
gogo
- Loading branch information
Showing
6 changed files
with
484 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"C:\\Users\\STU\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", | ||
" from ._conv import register_converters as _register_converters\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n", | ||
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n", | ||
"Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n", | ||
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", | ||
"Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n", | ||
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", | ||
"Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n", | ||
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n", | ||
"WARNING:tensorflow:From <ipython-input-1-a8044be962d6>:61: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", | ||
"Instructions for updating:\n", | ||
"\n", | ||
"Future major versions of TensorFlow will allow gradients to flow\n", | ||
"into the labels input on backprop by default.\n", | ||
"\n", | ||
"See tf.nn.softmax_cross_entropy_with_logits_v2.\n", | ||
"\n", | ||
"Learning started. It takes sometime.\n", | ||
"Epoch: 0001 cost = 0.345577185\n", | ||
"Epoch: 0002 cost = 0.091736604\n", | ||
"Epoch: 0003 cost = 0.068284046\n", | ||
"Epoch: 0004 cost = 0.056339833\n", | ||
"Epoch: 0005 cost = 0.047010720\n", | ||
"Epoch: 0006 cost = 0.041194586\n", | ||
"Epoch: 0007 cost = 0.036663712\n", | ||
"Epoch: 0008 cost = 0.032757639\n", | ||
"Epoch: 0009 cost = 0.027963868\n", | ||
"Epoch: 0010 cost = 0.025047483\n", | ||
"Epoch: 0011 cost = 0.022065875\n", | ||
"Epoch: 0012 cost = 0.020263703\n", | ||
"Epoch: 0013 cost = 0.016754853\n", | ||
"Epoch: 0014 cost = 0.015507657\n", | ||
"Epoch: 0015 cost = 0.013157484\n", | ||
"Learning Finished!\n", | ||
"Accuracy: 0.9883\n", | ||
"Label: [3]\n", | ||
"Prediction: [3]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"###################################################################################################################\n", | ||
"# B A S I C C O N V O L U S I O N N U R A L _ N E T #\n", | ||
"###################################################################################################################\n", | ||
"import tensorflow as tf\n", | ||
"import random\n", | ||
"import csv\n", | ||
"# import matplotlib.pyplot as plt\n", | ||
"tf.set_random_seed(777) # 랜덤이지만 항상 일정하게\n", | ||
"\n", | ||
"f = open('c:/tensorflow/train.csv')\n", | ||
"\n", | ||
"# 에폭, 배치크기\n", | ||
"learning_rate = 0.001\n", | ||
"training_epochs = 15\n", | ||
"batch_size = 100\n", | ||
"\n", | ||
"# 한번에 넣을 placeholder\n", | ||
"X = tf.placeholder(tf.float32, [None, 784])\n", | ||
"X_img = tf.reshape(X, [-1, 28, 28, 1]) # img 28x28x1 (black/white)\n", | ||
"Y = tf.placeholder(tf.float32, [None, 10])\n", | ||
"\n", | ||
"\n", | ||
"W1 = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))\n", | ||
"L1 = tf.nn.conv2d(X_img, W1, strides=[1, 1, 1, 1], padding='SAME')\n", | ||
"L1 = tf.nn.relu(L1)\n", | ||
"L1 = tf.nn.max_pool(L1, ksize=[1, 2, 2, 1],\n", | ||
" strides=[1, 2, 2, 1], padding='SAME')\n", | ||
"'''\n", | ||
"Tensor(\"Conv2D:0\", shape=(?, 28, 28, 32), dtype=float32)\n", | ||
"Tensor(\"Relu:0\", shape=(?, 28, 28, 32), dtype=float32)\n", | ||
"Tensor(\"MaxPool:0\", shape=(?, 14, 14, 32), dtype=float32)\n", | ||
"'''\n", | ||
"\n", | ||
"# L2 ImgIn shape=(?, 14, 14, 32)\n", | ||
"W2 = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))\n", | ||
"# Conv ->(?, 14, 14, 64)\n", | ||
"# Pool ->(?, 7, 7, 64)\n", | ||
"L2 = tf.nn.conv2d(L1, W2, strides=[1, 1, 1, 1], padding='SAME')\n", | ||
"# 14 to 64\n", | ||
"L2 = tf.nn.relu(L2)\n", | ||
"# 14 to 643\n", | ||
"L2 = tf.nn.max_pool(L2, ksize=[1, 2, 2, 1],\n", | ||
" strides=[1, 2, 2, 1], padding='SAME')\n", | ||
"# 7 to 64\n", | ||
"L2_flat = tf.reshape(L2, [-1, 7 * 7 * 64])\n", | ||
"\n", | ||
"\n", | ||
"# 숫자 10개로 아웃\n", | ||
"W3 = tf.get_variable(\"W3\", shape=[7 * 7 * 64, 10],\n", | ||
" initializer=tf.contrib.layers.xavier_initializer())\n", | ||
"b = tf.Variable(tf.random_normal([10]))\n", | ||
"logits = tf.matmul(L2_flat, W3) + b\n", | ||
"\n", | ||
"# 옵티마이저\n", | ||
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n", | ||
" logits=logits, labels=Y))\n", | ||
"optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)\n", | ||
"\n", | ||
"# 세션영역\n", | ||
"sess = tf.Session()\n", | ||
"sess.run(tf.global_variables_initializer())\n", | ||
"##################################################### 배치 ############################################\n", | ||
"print('Learning started. It takes sometime.')\n", | ||
"for epoch in range(training_epochs):\n", | ||
" avg_cost = 0\n", | ||
" total_batch = # 로우 갯수 / 배치크기\n", | ||
"\n", | ||
" for i in range(total_batch):\n", | ||
" batch_xs, batch_ys = #다음 배치(배치크기만큼)\n", | ||
" feed_dict = {X: batch_xs, Y: batch_ys}\n", | ||
" c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)\n", | ||
" avg_cost += c / total_batch\n", | ||
"##################################################### batch ############################################\n", | ||
" print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))\n", | ||
"\n", | ||
"print('Learning Finished!')\n", | ||
"\n", | ||
"# Test model and check accuracy\n", | ||
"correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))\n", | ||
"accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", | ||
"print('Accuracy:', sess.run(accuracy, feed_dict={\n", | ||
" X: #테스트 이미지 x축, Y: #테스트 이미지 y축}))\n", | ||
"\n", | ||
"# Get one and predict\n", | ||
"r = random.randint(0, mnist.test.num_examples - 1)\n", | ||
"print(\"Label: \", sess.run(tf.argmax(mnist.test.labels[r:r + 1], 1)))\n", | ||
"print(\"Prediction: \", sess.run(\n", | ||
" tf.argmax(logits, 1), feed_dict={X: mnist.test.images[r:r + 1]}))\n", | ||
"\n", | ||
"# plt.imshow(mnist.test.images[r:r + 1].\n", | ||
"# reshape(28, 28), cmap='Greys', interpolation='nearest')\n", | ||
"# plt.show()\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.