Skip to content

Commit

Permalink
add graph test
Browse files Browse the repository at this point in the history
  • Loading branch information
transcranial committed Nov 27, 2017
1 parent 695a7bd commit 7097036
Show file tree
Hide file tree
Showing 4 changed files with 430 additions and 1 deletion.
168 changes: 168 additions & 0 deletions notebooks/graph/graph_07.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"import numpy as np\n",
"import json\n",
"from keras.models import Model\n",
"from keras.layers import Input\n",
"from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, BatchNormalization, Concatenate\n",
"from keras import backend as K\n",
"from collections import OrderedDict"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def format_decimal(arr, places=6):\n",
" return [round(x * 10**places) / 10**places for x in arr]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"DATA = OrderedDict()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### graph 7"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"random_seed = 10007\n",
"data_in_shape = (8, 8, 2)\n",
"\n",
"input_layer_0 = Input(shape=data_in_shape)\n",
"branch_0 = Conv2D(4, (3,3), activation='relu', padding='same', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_0)\n",
"branch_0 = Conv2D(4, (3,3), activation='relu', padding='same', strides=(1,1), data_format='channels_last', use_bias=True)(branch_0)\n",
"\n",
"input_layer_1 = Input(shape=data_in_shape)\n",
"branch_1 = Conv2D(4, (3,3), activation='relu', padding='same', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_1)\n",
"branch_1 = Conv2D(4, (3,3), activation='relu', padding='same', strides=(1,1), data_format='channels_last', use_bias=True)(branch_1)\n",
"\n",
"branch_2 = Concatenate()([branch_0, branch_1])\n",
"output_layer = Conv2D(4, (3,3), activation='linear', padding='same', strides=(1,1), data_format='channels_last', use_bias=True)(branch_2)\n",
"model = Model(inputs=[input_layer_0, input_layer_1], outputs=output_layer)\n",
"\n",
"data_in = []\n",
"for i in range(2):\n",
" np.random.seed(random_seed + i)\n",
" data_in.append(np.expand_dims(2 * np.random.random(data_in_shape) - 1, axis=0))\n",
"\n",
"# set weights to random (use seed for reproducibility)\n",
"weights = []\n",
"for i, w in enumerate(model.get_weights()):\n",
" np.random.seed(random_seed + i)\n",
" weights.append(2 * np.random.random(w.shape) - 1)\n",
"model.set_weights(weights)\n",
"\n",
"result = model.predict(data_in)\n",
"data_out_shape = result[0].shape\n",
"data_in_formatted = [format_decimal(data_in[i].ravel().tolist()) for i in range(2)]\n",
"data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
"\n",
"DATA['graph_07'] = {\n",
" 'inputs': [{'data': data_in_formatted[i], 'shape': data_in_shape} for i in range(2)],\n",
" 'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
" 'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### export for Keras.js tests"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"filename = '../../test/data/graph/07.json'\n",
"if not os.path.exists(os.path.dirname(filename)):\n",
" os.makedirs(os.path.dirname(filename))\n",
"with open(filename, 'w') as f:\n",
" json.dump(DATA, f)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"graph_07\": {\"inputs\": [{\"data\": [-0.719503, 0.396686, 0.710021, -0.058892, -0.78922, -0.74094, -0.421022, -0.808307, 0.800549, -0.334559, -0.186061, 0.852542, -0.759177, -0.565928, -0.439353, -0.566006, 0.739029, 0.513581, -0.267873, 0.743929, -0.767563, -0.052566, -0.04449, 0.68184, 0.572289, 0.515837, -0.674829, -0.345664, 0.702239, -0.638317, -0.083214, -0.666578, 0.712641, -0.391794, 0.056017, -0.858674, -0.008642, -0.159135, -0.00677, 0.331901, -0.257608, -0.824415, 0.571256, 0.741805, -0.75251, 0.468774, -0.784013, 0.681435, -0.23194, 0.101493, -0.128162, -0.603826, -0.49174, -0.065357, -0.32479, 0.631058, 0.083835, -0.769532, -0.682364, 0.882268, -0.833573, -0.460043, 0.514984, 0.945882, -0.716884, -0.795482, 0.426371, -0.988366, 0.643317, 0.599972, -0.313279, -0.44164, 0.582645, -0.149293, 0.678646, -0.010416, -0.170983, 0.657191, 0.654839, -0.66861, -0.486733, 0.566186, 0.50032, -0.631832, -0.587607, 0.569767, 0.635765, 0.819135, 0.441633, -0.561978, -0.62083, 0.874285, -0.688123, 0.574826, -0.362262, 0.073101, -0.519688, -0.102416, -0.402484, 0.690901, -0.55265, -0.364623, -0.226078, 0.313158, 0.300124, 0.973938, -0.711208, -0.058027, 0.753508, -0.049087, 0.451992, -0.073969, 0.672509, 0.3412, -0.688499, 0.775659, 0.181513, 0.12763, 0.433325, 0.425375, -0.140093, 0.739271, -0.666047, -0.676331, -0.081829, 0.245534, 0.139074, -0.414734], \"shape\": [8, 8, 2]}, {\"data\": [0.839584, -0.988988, -0.527734, -0.797693, -0.884905, -0.146869, -0.267342, 0.449944, -0.282828, -0.728855, -0.587547, -0.355267, -0.669017, 0.196118, 0.567189, 0.859179, 0.975979, -0.814016, 0.903094, -0.21355, -0.109601, -0.768741, 0.27166, 0.030371, -0.275986, -0.819553, 0.554826, -0.864389, 0.096677, 0.957206, 0.754441, 0.330801, 0.5741, 0.36419, 0.987371, -0.530545, 0.591249, -0.273263, 0.897734, -0.304084, 0.940353, 0.135187, -0.058766, -0.038047, -0.415426, -0.621638, -0.92997, -0.631599, 0.03693, 0.434788, 0.436622, -0.591559, -0.857843, 0.426905, -0.663415, 0.457006, 0.764796, -0.413896, -0.71939, -0.479789, -0.317247, -0.016311, 0.148157, 0.362594, -0.908516, -0.776948, 0.518685, 0.231957, -0.205797, 0.925386, -0.2551, -0.131588, -0.287732, 0.194838, -0.478042, 0.006402, 0.49051, -0.218892, -0.082877, -0.598026, 0.878144, 0.797895, 0.578246, 0.222223, -0.143421, 0.189322, 0.283295, -0.602854, 0.74498, 0.850133, -0.421519, 0.816516, 0.098714, 0.595811, 0.926334, -0.1268, -0.84326, 0.723379, -0.824281, -0.992293, -0.526419, -0.331198, -0.447634, -0.964083, -0.638283, 0.186446, 0.388669, 0.704854, -0.214886, -0.50761, 0.961234, 0.373414, 0.216379, 0.391879, -0.856657, -0.65227, 0.838337, 0.61763, -0.7879, -0.437836, 0.517913, -0.217903, -0.300952, 0.023804, -0.861553, -0.649055, 0.42497, 0.183235], \"shape\": [8, 8, 2]}], \"weights\": [{\"data\": [-0.719503, 0.396686, 0.710021, -0.058892, -0.78922, -0.74094, -0.421022, -0.808307, 0.800549, -0.334559, -0.186061, 0.852542, -0.759177, -0.565928, -0.439353, -0.566006, 0.739029, 0.513581, -0.267873, 0.743929, -0.767563, -0.052566, -0.04449, 0.68184, 0.572289, 0.515837, -0.674829, -0.345664, 0.702239, -0.638317, -0.083214, -0.666578, 0.712641, -0.391794, 0.056017, -0.858674, -0.008642, -0.159135, -0.00677, 0.331901, -0.257608, -0.824415, 0.571256, 0.741805, -0.75251, 0.468774, -0.784013, 0.681435, -0.23194, 0.101493, -0.128162, -0.603826, -0.49174, -0.065357, -0.32479, 0.631058, 0.083835, -0.769532, -0.682364, 0.882268, -0.833573, -0.460043, 0.514984, 0.945882, -0.716884, -0.795482, 0.426371, -0.988366, 0.643317, 0.599972, -0.313279, -0.44164], \"shape\": [3, 3, 2, 4]}, {\"data\": [0.839584, -0.988988, -0.527734, -0.797693], \"shape\": [4]}, {\"data\": [-0.677523, 0.027961, 0.394704, -0.409564, -0.263961, -0.578124, 0.478145, -0.038085, 0.966269, 0.73844, 0.750385, -0.864013, 0.739996, -0.969876, 0.94102, -0.969374, -0.266124, 0.642256, -0.690874, -0.122534, -0.757021, -0.548516, 0.110985, 0.664401, 0.881525, 0.895536, -0.49967, -0.29786, -0.068312, 0.425857, 0.294558, 0.888644, 0.969055, 0.974488, -0.415146, 0.515722, -0.548753, 0.701101, 0.104793, -0.525767, -0.899254, 0.019513, 0.663871, 0.396586, -0.78326, -0.255662, -0.300106, -0.963967, -0.924659, 0.732113, -0.957076, -0.672519, -0.089013, -0.954748, -0.834295, -0.312389, -0.226936, 0.283681, 0.878406, -0.842698, 0.944196, 0.371051, 0.552261, -0.42905, 0.215084, 0.192538, 0.255143, 0.566775, 0.718326, 0.252937, -0.655206, 0.567654], \"shape\": [3, 3, 2, 4]}, {\"data\": [-0.890769, -0.844694, -0.920267, -0.682613], \"shape\": [4]}, {\"data\": [0.112538, -0.288508, 0.401783, 0.811569, -0.995511, 0.808396, 0.602689, 0.089005, -0.897683, -0.393118, -0.667562, 0.816652, -0.824418, 0.654265, -0.422005, 0.849247, 0.601831, -0.568844, 0.897222, 0.486418, -0.087611, 0.480959, 0.004321, 0.978637, -0.348266, -0.914552, 0.792941, -0.263543, -0.741203, -0.951525, 0.367234, 0.186412, -0.40553, 0.518136, 0.932792, -0.625391, -0.040345, 0.736396, -0.857775, -0.744501, 0.239287, 0.731086, -0.007853, -0.471562, 0.828783, 0.900072, 0.873478, 0.072277, -0.915735, -0.257961, -0.487677, 0.845285, 0.762145, 0.476619, -0.219475, 0.919266, 0.66316, -0.832711, 0.234218, 0.469246, -0.310549, -0.997257, -0.195151, 0.586701, 0.67762, 0.562394, 0.245161, 0.553843, 0.283796, -0.716804, -0.094667, -0.17674, -0.574616, -0.816479, -0.197545, -0.595246, -0.878515, 0.089214, 0.475428, -0.189262, -0.690449, 0.415482, -0.514852, -0.025337, 0.402232, -0.618405, -0.979007, -0.503433, -0.903723, 0.937322, -0.852217, 0.607334, -0.806095, -0.647108, 0.190353, 0.106521, 0.854423, 0.946531, 0.527663, -0.616984, 0.652559, 0.392006, 0.854114, 0.277204, -0.459716, -0.716285, 0.20248, -0.624515, -0.30171, -0.459956, 0.31041, -0.41268, -0.918444, -0.269906, -0.518655, -0.568083, -0.846487, 0.550739, -0.817968, 0.052704, -0.65807, -0.487246, -0.995944, -0.486223, -0.467124, -0.86441, 0.772284, -0.669564, -0.836047, -0.982721, -0.162248, -0.737051, 0.704922, -0.661582, -0.596443, 0.52055, -0.970706, 0.773144, 0.752803, 0.867639, 0.588682, 0.319624, -0.69027, -0.846233], \"shape\": [3, 3, 4, 4]}, {\"data\": [-0.050965, -0.087285, -0.473172, 0.867282], \"shape\": [4]}, {\"data\": [0.314304, 0.087114, -0.192081, 0.687827, -0.05382, -0.449211, 0.152559, 0.729694, 0.99547, -0.487042, -0.486432, 0.714711, -0.292248, -0.199289, -0.242229, -0.439609, -0.575396, -0.053792, 0.860916, -0.065459, 0.353475, 0.680997, 0.830208, 0.469092, -0.265621, -0.687023, -0.45321, 0.530606, -0.319466, 0.898567, -0.021989, -0.726166, 0.350933, 0.471236, 0.214698, -0.218634, -0.656787, 0.487585, -0.45693, 0.095657, 0.951011, -0.969214, -0.540415, 0.42117, 0.268511, -0.931772, -0.026435, -0.416304, -0.229962, 0.715691, -0.609217, 0.957363, -0.197645, -0.423248, 0.29652, 0.2076, 0.413223, 0.727793, -0.155075, 0.589662, 0.438914, 0.892885, -0.387566, 0.081995, 0.90693, 0.976951, -0.092549, -0.104668, 0.701797, 0.631171, 0.652832, 0.209322, 0.222041, -0.436956, -0.919785, -0.375099, 0.659733, 0.459222, -0.057638, -0.001735, 0.579448, 0.507595, 0.708452, -0.907805, 0.948767, -0.013304, 0.863001, 0.632943, -0.996805, -0.914495, -0.739203, -0.681927, 0.257542, 0.009895, -0.878075, 0.83345, -0.012611, -0.950112, 0.427788, 0.838164, -0.107221, -0.143387, 0.744575, 0.857727, -0.637755, -0.18538, -0.122234, 0.134867, -0.394695, 0.220049, -0.723508, -0.399619, 0.171876, 0.93248, 0.729662, 0.786324, -0.977909, -0.955888, 0.848926, -0.977581, 0.690744, 0.157652, -0.032257, -0.704781, 0.712356, -0.732367, -0.663543, 0.821131, 0.62497, 0.88944, -0.154804, -0.501889, 0.520767, 0.62754, -0.584354, -0.544752, 0.320135, 0.068166, -0.592856, 0.749543, 0.773931, -0.086116, 0.84131, -0.656192], \"shape\": [3, 3, 4, 4]}, {\"data\": [0.488004, 0.911709, 0.385817, 0.735948], \"shape\": [4]}, {\"data\": [0.263717, -0.720003, 0.8268, -0.92253, 0.462425, 0.244446, -0.124898, -0.811324, 0.354873, 0.801179, 0.564291, -0.355745, 0.644611, 0.859794, -0.560491, -0.788843, -0.963575, -0.638994, 0.405345, 0.176622, -0.153501, 0.998311, 0.944831, 0.536878, 0.033269, -0.435141, -0.309305, 0.602218, 0.537614, 0.254239, 0.37362, -0.505583, -0.948316, 0.939032, 0.149339, -0.197973, -0.224213, 0.028595, -0.795155, -0.951132, 0.729428, 0.59133, 0.864157, -0.620287, -0.843726, -0.920702, -0.496038, 0.822909, 0.890635, -0.096026, 0.915961, -0.470611, -0.719671, 0.234755, -0.244699, 0.606162, 0.895724, -0.800812, 0.539927, 0.684273, 0.542181, -0.02343, 0.875754, -0.860446, 0.844953, 0.374322, -0.421283, 0.809843, 0.107933, 0.538515, 0.538688, -0.249508, 0.1449, 0.410438, -0.693297, -0.07563, -0.666376, 0.690156, 0.08943, -0.572714, -0.502461, 0.950915, -0.108646, -0.915989, -0.144237, -0.900061, 0.475656, 0.173209, 0.152618, 0.967328, 0.87992, -0.298116, 0.36125, -0.320061, 0.309461, 0.27828, -0.045387, 0.400924, 0.451493, 0.207171, -0.870517, 0.011689, 0.552887, 0.941986, -0.380805, -0.095083, -0.650022, -0.598894, -0.694641, -0.080797, -0.390523, 0.876907, 0.830436, -0.439703, 0.020501, 0.585298, -0.1045, 0.842147, -0.930716, 0.834074, 0.550016, 0.684493, 0.174206, -0.022342, 0.104366, 0.789688, 0.558795, -0.526273, -0.731229, 0.406012, 0.951943, -0.2683, -0.34048, 0.838036, 0.72188, -0.910357, 0.786073, 0.308813, -0.560274, -0.009713, 0.528142, 0.066026, -0.908124, -0.127743, -0.54385, 0.463717, -0.585671, 0.106338, -0.711419, -0.418877, 0.326981, -0.311289, -0.411615, -0.812747, -0.813544, 0.019126, -0.013371, -0.018314, 0.849754, 0.749216, -0.383548, -0.848601, -0.937479, 0.690852, -0.974004, -0.59376, -0.731103, -0.505768, 0.538382, 0.996418, -0.653786, 0.003821, 0.346238, -0.570115, -0.292759, 0.596403, 0.356962, -0.40833, 0.257265, 0.392791, 0.214136, 0.339779, 0.912384, -0.770968, -0.563594, 0.742829, 0.028137, 0.99606, -0.154497, 0.206626, -0.214892, -0.321264, 0.455809, -0.96512, -0.656706, -0.751642, 0.996265, 0.176439, -0.710484, 0.266252, -0.123037, -0.291063, -0.669706, -0.843051, 0.258219, -0.415613, 0.765296, -0.632937, -0.361384, 0.766739, -0.741761, -0.424624, 0.468028, -0.100906, 0.120367, -0.024052, -0.611933, 0.320729, 0.594457, -0.187492, 0.077855, 0.650452, -0.564818, -0.331068, -0.805392, -0.621492, 0.351791, 0.435483, 0.867174, -0.831411, 0.878362, -0.337213, -0.144539, -0.717679, 0.086155, 0.673595, -0.18332, -0.807448, -0.66528, -0.291659, -0.802201, 0.590244, -0.260274, 0.247972, 0.701667, -0.067998, -0.522099, 0.651477, 0.186356, 0.54196, -0.99628, -0.975049, -0.558434, -0.511012, -0.759343, 0.732053, 0.323626, 0.574041, 0.536674, 0.808216, 0.105766, 0.741759, -0.608342, -0.566958, 0.401101, 0.142419, 0.25029, -0.37009, -0.592485, -0.647812, -0.568179, -0.574508, 0.007552, 0.980874, 0.970099, 0.806563, 0.582666, -0.518095, 0.759172, -0.648391, 0.217929, 0.342155, 0.715688, -0.913938, 0.724745, 0.959203, -0.414178, 0.445853], \"shape\": [3, 3, 8, 4]}, {\"data\": [-0.296519, 0.572188, -0.272211, -0.907259], \"shape\": [4]}], \"expected\": {\"data\": [-0.199215, 9.321177, -6.123923, -7.006053, -5.413159, -5.326313, -7.939443, -6.753269, 6.903975, 9.45579, -2.139473, -18.041553, -9.351606, -3.980932, -0.858323, -2.170862, -5.97595, -1.8138, -10.65713, -10.314507, 5.640358, 7.388157, 7.886442, -10.678903, -1.885194, 7.332788, 2.791726, 4.864554, -5.189827, 2.661085, -5.790349, 6.434088, 1.693445, 14.199347, -3.535717, 3.732931, 5.119977, 22.164043, -3.710325, 0.713292, -1.457271, 13.16334, -10.182754, -4.327676, 5.693687, 4.036681, -13.889244, 0.260298, -1.707905, 7.586016, -17.696356, 1.212139, -6.814427, 9.845968, -12.768095, -11.99397, -3.874164, -1.349166, -0.110665, -3.78852, 0.847219, 2.443405, -2.484042, -1.902033, -3.752414, 12.614375, 1.046178, -0.805366, -8.236174, 23.791599, -15.543843, 10.271393, 15.285831, 20.535896, -1.870446, -4.298431, 4.95017, 22.507881, -1.644534, -18.672527, 3.829441, 14.193356, -1.070387, 5.254819, 10.809513, 9.400687, -14.664796, 3.996032, 7.729726, 9.005091, -15.530272, 2.715595, 0.127956, 2.220701, -3.442966, -2.593572, 7.058976, 11.400736, 14.563081, 2.471402, 10.376163, 17.066015, 15.220338, 18.87874, 5.081575, 12.331302, 11.813938, -2.951488, 11.918758, 18.773928, 17.136768, -15.290675, -2.84932, 21.415611, 13.484735, -1.283965, -1.864361, 19.469845, 5.258374, 2.175235, -1.207329, 13.686732, -11.487003, 1.75823, -0.013982, 3.737968, -8.860118, -5.006305, -3.323633, 6.472236, 1.06452, 0.994937, -8.600193, 0.114783, 7.774179, 8.202867, 7.304964, 13.060142, 8.21752, -7.835262, 0.965753, 3.097943, 9.769804, -6.146391, -0.074937, 6.423484, -2.954583, -2.763388, 7.570226, 11.46158, 6.737085, -2.930335, 6.425412, 12.766352, 14.017634, -15.594198, 7.068308, 7.780724, -3.201651, -7.796819, 1.441781, -7.340527, -5.051141, 8.232263, 4.240792, -1.735634, -7.512732, -7.639565, 12.498146, 10.351672, 14.107317, -19.074913, -5.308741, -8.25484, 7.524553, 3.015877, 2.394606, 3.677567, -15.117904, 3.95361, 4.22118, 8.562589, -0.359103, -1.222611, 1.663245, 11.310987, 13.365661, -2.912982, -8.013397, -0.714444, -8.413801, 9.198146, 4.231226, 12.671323, -2.79388, 0.221687, -3.44127, 8.797631, -14.695992, 1.338225, 0.312279, 15.296196, -19.268986, 1.775031, -7.169279, 10.000463, -9.891088, -0.977783, -1.925249, 19.257942, -1.798841, -8.909183, -10.720798, 6.389853, -2.00158, 4.845818, 4.44215, 12.865691, -3.875045, -9.847071, -5.769631, 1.606278, -0.778178, -1.679536, 2.156123, 11.227082, 3.919703, -9.794403, -1.644038, 11.155446, 6.826945, -8.210423, 7.292017, 9.678843, -0.702047, -2.175535, 4.991436, 14.484716, 8.852351, -9.653472, 2.560756, 1.159242, 9.405992, -0.932719, -1.861191, 9.154112, 0.984699, -0.554561, 4.108758, 8.507056, 6.874351, 3.753832, -0.590122, 6.254458, -0.969853, 1.688881], \"shape\": [8, 8, 4]}}}\n"
]
}
],
"source": [
"print(json.dumps(DATA))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 7097036

Please sign in to comment.