{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we will study the hyperparamter tuning of fully connected scatternet\n",
    "\n",
    "# first find optimal number of layers and neuron numbers\n",
    "# second optimize the batch size and number of epochs for the best learned architecture\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the dataset here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-04T17:46:08.075289Z",
     "start_time": "2018-09-04T17:46:07.653719Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hegder/anaconda3/envs/deep/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/home/hegder/anaconda3/envs/deep/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/home/hegder/anaconda3/envs/deep/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset has been loaded\n",
      "x-train (44999, 8)\n",
      "x-test  (55001, 8)\n",
      "y-train (44999, 256)\n",
      "y-test  (55001, 256)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "import h5py\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "#now load this dataset \n",
    "h5f = h5py.File('./datasets/s8_sio2tio2_v2.h5','r')\n",
    "X = h5f['sizes'][:]\n",
    "Y = h5f['spectrum'][:]\n",
    "\n",
    "#get the ranges of the loaded data\n",
    "num_layers = X.shape[1]\n",
    "num_lpoints = Y.shape[1]\n",
    "size_max = np.amax(X)\n",
    "size_min = np.amin(X)\n",
    "size_av = 0.5*(size_max + size_min)\n",
    "\n",
    "#this information is not given in the dataset\n",
    "lam_min = 300\n",
    "lam_max = 1200\n",
    "lams = np.linspace(lam_min, lam_max, num_lpoints)\n",
    "\n",
    "#create a train - test split of the dataset\n",
    "x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.55, random_state=42)\n",
    "\n",
    "# normalize inputs \n",
    "x_train = (x_train - 50)/20 \n",
    "x_test = (x_test - 50)/20 \n",
    "\n",
    "print(\"Dataset has been loaded\")\n",
    "print(\"x-train\", x_train.shape)\n",
    "print(\"x-test \", x_test.shape)\n",
    "print(\"y-train\", y_train.shape)\n",
    "print(\"y-test \", y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### create models here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-04T08:41:45.424541Z",
     "start_time": "2018-09-04T08:41:44.270792Z"
    }
   },
   "outputs": [],
   "source": [
    "import scnets as scn\n",
    "from IPython.display import SVG\n",
    "from keras.utils.vis_utils import model_to_dot\n",
    "\n",
    "#define and visualize the model here\n",
    "#model = scn.fullycon(num_layers, num_lpoints, 4, 500, 2)\n",
    "\n",
    "model = scn.convprel(in_size=8, \n",
    "        out_size=256,\n",
    "        c1_nf=64,\n",
    "        clayers=3,\n",
    "        ker_size=3)\n",
    "model.summary()\n",
    "#SVG(model_to_dot(model).create(prog='dot', format='svg'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-04T08:53:28.967090Z",
     "start_time": "2018-09-04T08:42:58.990636Z"
    }
   },
   "outputs": [],
   "source": [
    "x_t, x_v, y_t, y_v = train_test_split(x_train, y_train, test_size=0.2, random_state=42)\n",
    "history = model.fit(x_t, y_t,\n",
    "                    batch_size=64,\n",
    "                    epochs=250, \n",
    "                    verbose=1,\n",
    "                    validation_data=(x_v, y_v))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-04T09:13:20.325358Z",
     "start_time": "2018-09-04T09:13:20.044281Z"
    }
   },
   "outputs": [],
   "source": [
    "scn.plot_training_history(history, 64*2.56)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2018-09-04T17:16:38.203Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 12 candidates, totalling 36 fits\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from sklearn.metrics import fbeta_score, make_scorer\n",
    "\n",
    "def my_custom_score_func(ground_truth, predictions):\n",
    "    diff = np.abs(ground_truth - predictions)/np.abs(ground_truth)\n",
    "    return np.mean(diff)\n",
    "\n",
    "\n",
    "\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.wrappers.scikit_learn import KerasRegressor\n",
    "import scnets as scn\n",
    "#model = KerasClassifier(build_fn=scn.fullycon, in_size=8, out_size=250, N_gpus=1, epochs=500, verbose=0)\n",
    "\n",
    "model = KerasRegressor(build_fn=scn.conv1dmodel, \n",
    "                        in_size=8, \n",
    "                        out_size=256, \n",
    "                        c1_nf=64,\n",
    "                        clayers=3,\n",
    "                        ker_size=3,\n",
    "                        epochs=250, \n",
    "                        verbose=0)\n",
    "my_score = make_scorer(my_custom_score_func, greater_is_better=False)\n",
    "\n",
    "\n",
    "\n",
    "param_grid = dict(ker_size=[3, 5],\n",
    "                  clayers=[3,4,5],\n",
    "                  batch_size=[32,64])                                  \n",
    "grid = GridSearchCV(estimator=model, \n",
    "                    param_grid=param_grid, \n",
    "                    n_jobs=1, \n",
    "                    scoring='explained_variance',\n",
    "                    verbose=1)\n",
    "grid_result = grid.fit(x_train, y_train)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-04T14:25:50.558503Z",
     "start_time": "2018-09-04T14:25:50.539227Z"
    }
   },
   "outputs": [],
   "source": [
    "grid_result.cv_results_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-04T14:26:16.796674Z",
     "start_time": "2018-09-04T14:26:16.790511Z"
    }
   },
   "outputs": [],
   "source": [
    "grid_result.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-09-04T15:01:18.552044Z",
     "start_time": "2018-09-04T15:01:18.538960Z"
    }
   },
   "outputs": [],
   "source": [
    "bestidx = np.argsort(grid_result.cv_results_['mean_test_score'])\n",
    "print(idx)\n",
    "print(np.flip(idx))\n",
    "parlist = grid_result.cv_results_['params']\n",
    "bestlist = [parlist[indx] for indx in bestidx]\n",
    "bestlist\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "'mean_fit_time': array([ 440.59602499,  481.73871398,  477.10135643,  504.1547633 ,\n",
    "         488.27373465,  564.76856009,  571.80046884,  569.00620119,\n",
    "         650.78964178,  740.00103672,  597.98015889,  659.03726633,\n",
    "         476.14166268,  508.75317804,  505.10563159,  556.87875859,\n",
    "         556.58007542,  542.48226706,  591.323385  ,  616.57070398,\n",
    "         658.64879243,  891.65342418, 1055.23733377,  993.75831501,\n",
    "         647.80959813,  629.8490928 ,  516.74331371,  787.65950712,\n",
    "         635.2311732 ,  545.72071338,  904.7078863 ,  685.55618747,\n",
    "         530.5135781 ,  815.14902997,  821.7310555 ,  685.34639025,\n",
    "         240.90511258,  246.83409182,  259.09067996,  262.85167106,\n",
    "         228.2099692 ,  392.25437156,  374.78024157,  452.21288816,\n",
    "         459.90044618,  495.74129097,  501.12204981,  368.91637723,\n",
    "         240.11106253,  237.93171946,  261.74841015,  265.67413298,\n",
    "         274.72749496,  378.51081745,  436.21683431,  241.44164371,\n",
    "         274.60629002,  330.47249166,  394.98302094,  318.71901139,\n",
    "         237.12692181,  248.02535923,  254.61501988,  243.65626915,\n",
    "         298.34629599,  273.75161084,  318.28928526,  312.33878072,\n",
    "         309.61013468,  335.33733678,  316.16733519,  284.3205506 ]),\n",
    "\n",
    " 'params': [{'batch_size': 32, 'c1_nf': 32, 'clayers': 1, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 1, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 1, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 2, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 2, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 2, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 3, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 3, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 3, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 4, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 4, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 32, 'clayers': 4, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 1, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 1, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 1, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 2, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 2, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 2, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 3, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 3, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 3, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 4, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 4, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 64, 'clayers': 4, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 1, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 1, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 1, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 2, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 2, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 2, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 3, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 3, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 3, 'ker_size': 7},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 4, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 4, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'c1_nf': 96, 'clayers': 4, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 1, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 1, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 1, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 2, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 2, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 2, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 3, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 3, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 3, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 4, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 4, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 32, 'clayers': 4, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 1, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 1, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 1, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 2, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 2, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 2, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 3, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 3, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 3, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 4, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 4, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 64, 'clayers': 4, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 1, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 1, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 1, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 2, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 2, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 2, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 3, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 3, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 3, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 4, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 4, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'c1_nf': 96, 'clayers': 4, 'ker_size': 7}],\n",
    "\n",
    "\n",
    " 'mean_test_score': array([0.00235561, 0.00175559, 0.00191115, 0.00255561, 0.00228894,\n",
    "        0.00233339, 0.00240005, 0.00260006, 0.00262228, 0.00222227,\n",
    "        0.00222227, 0.00235561, 0.00191115, 0.0021556 , 0.00242228,\n",
    "        0.00217783, 0.00217783, 0.00233339, 0.00231116, 0.00202227,\n",
    "        0.00206671, 0.00251117, 0.00237783, 0.00253339, 0.00197782,\n",
    "        0.00222227, 0.00220005, 0.0019556 , 0.00208894, 0.00240005,\n",
    "        0.00220005, 0.00204449, 0.00222227, 0.00228894, 0.00231116,\n",
    "        0.00246672, 0.00200004, 0.00182226, 0.00164448, 0.00197782,\n",
    "        0.0026445 , 0.00208894, 0.00222227, 0.00248894, 0.00217783,\n",
    "        0.00208894, 0.00253339, 0.0024445 , 0.00204449, 0.00188893,\n",
    "        0.00197782, 0.00233339, 0.00231116, 0.00231116, 0.00233339,\n",
    "        0.0021556 , 0.00231116, 0.0021556 , 0.00220005, 0.00233339,\n",
    "        0.0021556 , 0.00220005, 0.00251117, 0.00191115, 0.00197782,\n",
    "        0.0021556 , 0.00224449, 0.00231116, 0.00228894, 0.00220005,\n",
    "        0.00235561, 0.00211116]),\n",
    " 'std_test_score': array([5.45313915e-04, 1.36982222e-04, 5.71830319e-04, 3.09583745e-04,\n",
    "        4.63034070e-04, 1.88601629e-04, 3.57010463e-04, 2.49388000e-04,\n",
    "        2.06134688e-04, 4.75566024e-04, 1.74937538e-04, 4.01290372e-04,\n",
    "        2.06146739e-04, 5.14564491e-04, 3.28118192e-04, 3.09506255e-04,\n",
    "        3.45776617e-04, 5.65777616e-04, 6.28186970e-05, 2.45451156e-04,\n",
    "        3.03064138e-04, 4.69393085e-04, 3.70577110e-04, 2.37355340e-04,\n",
    "        5.14534288e-04, 2.99865023e-04, 1.08929121e-04, 3.28165699e-04,\n",
    "        1.91227745e-04, 1.44095217e-04, 1.96333991e-04, 1.13377551e-04,\n",
    "        6.28899342e-05, 4.93994775e-04, 5.05855189e-04, 4.25227642e-04,\n",
    "        5.25013167e-04, 1.66308696e-04, 1.91169192e-04, 7.89511257e-04,\n",
    "        3.09609747e-04, 2.20000709e-04, 1.36980546e-04, 2.99872491e-04,\n",
    "        2.06155046e-04, 3.70481852e-04, 1.88604774e-04, 2.74052629e-04,\n",
    "        4.15813294e-04, 2.99823569e-04, 2.06148816e-04, 1.44067451e-04,\n",
    "        3.50039135e-04, 2.79388549e-04, 3.81084580e-04, 3.62491814e-04,\n",
    "        2.26678399e-04, 4.08631950e-04, 4.35536821e-04, 2.49505205e-04,\n",
    "        4.53303035e-04, 1.44064707e-04, 5.77012935e-04, 8.31603039e-05,\n",
    "        5.34338425e-04, 2.79402571e-04, 2.45459743e-04, 1.36996891e-04,\n",
    "        3.86260180e-04, 1.88498951e-04, 3.09579276e-04, 3.70571422e-04]),\n",
    " 'rank_test_score': array([16, 71, 66,  4, 30, 19, 13,  3,  2, 34, 34, 16, 66, 47, 12, 44, 44,\n",
    "        19, 24, 59, 56,  7, 15,  5, 61, 34, 39, 65, 53, 13, 39, 57, 34, 30,\n",
    "        24, 10, 60, 70, 72, 61,  1, 53, 34,  9, 44, 53,  5, 11, 57, 69, 61,\n",
    "        19, 24, 24, 19, 47, 24, 47, 39, 19, 47, 39,  7, 66, 61, 47, 33, 24,\n",
    "        30, 39, 16, 52], dtype=int32),\n",
    "\n",
    "\n",
    " 'mean_train_score': array([0.00231116, 0.00187782, 0.0018556 , 0.00231116, 0.00222227,\n",
    "        0.00217783, 0.00228894, 0.00236672, 0.00232227, 0.00252228,\n",
    "        0.00213338, 0.00241117, 0.00214449, 0.00204449, 0.00235561,\n",
    "        0.00205561, 0.00217783, 0.00238894, 0.00213338, 0.00230005,\n",
    "        0.0021556 , 0.0023445 , 0.00238894, 0.00236672, 0.00203338,\n",
    "        0.00208894, 0.00232228, 0.00203338, 0.00258895, 0.00216671,\n",
    "        0.00221117, 0.00217783, 0.0022556 , 0.00224449, 0.00246672,\n",
    "        0.00240006, 0.00187782, 0.00183338, 0.00153337, 0.00223338,\n",
    "        0.00186671, 0.00204449, 0.00218894, 0.00227783, 0.00203338,\n",
    "        0.00257784, 0.00235561, 0.0023445 , 0.00200004, 0.00210005,\n",
    "        0.00204449, 0.00247784, 0.00233338, 0.00230005, 0.00223338,\n",
    "        0.00222227, 0.00228894, 0.0023445 , 0.00228894, 0.00248895,\n",
    "        0.00188893, 0.00211116, 0.00233339, 0.00236672, 0.00213338,\n",
    "        0.00204449, 0.00248895, 0.00238895, 0.00233338, 0.00231116,\n",
    "        0.0023445 , 0.00241116]),\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "{'mean_fit_time': array([ 893.33654523,  898.53078222, 1119.14394736, 1130.1775128 ,\n",
    "         956.51246222,  964.45365715, 1209.3984166 , 1582.91039157,\n",
    "        1394.26560704, 1616.26630108, 1266.47200227, 1116.83488099,\n",
    "        1205.42738708, 1201.92515103, 1210.92550143]),\n",
    " 'std_fit_time': array([  3.08891285,   6.81113186, 212.20371026, 238.93357922,\n",
    "          6.97112622,  18.87349827, 223.00445851,  57.7875855 ,\n",
    "        100.70476936,  43.95356933, 134.83849082,  11.29690679,\n",
    "          5.42330543,   6.19267952,   4.92641743]),\n",
    " 'params': [{'N_hidden': 1, 'N_neurons': 250},\n",
    "  {'N_hidden': 1, 'N_neurons': 500},\n",
    "  {'N_hidden': 1, 'N_neurons': 1000},\n",
    "  {'N_hidden': 2, 'N_neurons': 250},\n",
    "  {'N_hidden': 2, 'N_neurons': 500},\n",
    "  {'N_hidden': 2, 'N_neurons': 1000},\n",
    "  {'N_hidden': 3, 'N_neurons': 250},\n",
    "  {'N_hidden': 3, 'N_neurons': 500},\n",
    "  {'N_hidden': 3, 'N_neurons': 1000},\n",
    "  {'N_hidden': 4, 'N_neurons': 250},\n",
    "  {'N_hidden': 4, 'N_neurons': 500},\n",
    "  {'N_hidden': 4, 'N_neurons': 1000},\n",
    "  {'N_hidden': 5, 'N_neurons': 250},\n",
    "  {'N_hidden': 5, 'N_neurons': 500},\n",
    "  {'N_hidden': 5, 'N_neurons': 1000}],\n",
    " 'split0_test_score': array([0.00235, 0.00225, 0.0025 , 0.0024 , 0.0026 , 0.0021 , 0.00165,\n",
    "        0.0023 , 0.00255, 0.00255, 0.0024 , 0.0027 , 0.0021 , 0.00225,\n",
    "        0.0022 ]),\n",
    " 'split1_test_score': array([0.0024 , 0.00225, 0.0022 , 0.0022 , 0.00235, 0.0022 , 0.0021 ,\n",
    "        0.00195, 0.00215, 0.0022 , 0.00185, 0.00195, 0.002  , 0.00195,\n",
    "        0.0021 ]),\n",
    " 'split2_test_score': array([0.00255, 0.00315, 0.00275, 0.00295, 0.00355, 0.0032 , 0.00275,\n",
    "        0.0031 , 0.0028 , 0.00305, 0.00305, 0.00355, 0.00275, 0.00285,\n",
    "        0.00305]),\n",
    " 'mean_test_score': array([0.00243333, 0.00255   , 0.00248333, 0.00251667, 0.00283333,\n",
    "        0.0025    , 0.00216667, 0.00245   , 0.0025    , 0.0026    ,\n",
    "        0.00243333, 0.00273333, 0.00228333, 0.00235   , 0.00245   ]),\n",
    " 'std_test_score': array([8.49836586e-05, 4.24264069e-04, 2.24845626e-04, 3.17104960e-04,\n",
    "        5.16935414e-04, 4.96655481e-04, 4.51540573e-04, 4.81317636e-04,\n",
    "        2.67706307e-04, 3.48807492e-04, 4.90464632e-04, 6.53622385e-04,\n",
    "        3.32498956e-04, 3.74165739e-04, 4.26223728e-04]),\n",
    " 'rank_test_score': array([11,  4,  8,  5,  1,  6, 15,  9,  6,  3, 11,  2, 14, 13,  9],\n",
    "       dtype=int32),\n",
    " 'split0_train_score': array([0.00255 , 0.0024  , 0.0023  , 0.00255 , 0.00255 , 0.002375,\n",
    "        0.002125, 0.002625, 0.0025  , 0.002775, 0.002625, 0.0027  ,\n",
    "        0.0023  , 0.00245 , 0.002275]),\n",
    " 'split1_train_score': array([0.002975, 0.0025  , 0.00255 , 0.0028  , 0.0029  , 0.002475,\n",
    "        0.002575, 0.0027  , 0.002575, 0.002375, 0.002475, 0.002475,\n",
    "        0.002325, 0.002425, 0.002375]),\n",
    " 'split2_train_score': array([0.00215 , 0.0022  , 0.00215 , 0.002375, 0.00205 , 0.002175,\n",
    "        0.00215 , 0.0022  , 0.002   , 0.00235 , 0.00235 , 0.002525,\n",
    "        0.00215 , 0.002175, 0.002175]),\n",
    " 'mean_train_score': array([0.00255833, 0.00236667, 0.00233333, 0.002575  , 0.0025    ,\n",
    "        0.00234167, 0.00228333, 0.00250833, 0.00235833, 0.0025    ,\n",
    "        0.00248333, 0.00256667, 0.00225833, 0.00235   , 0.002275  ]),\n",
    " 'std_train_score': array([3.36856382e-04, 1.24721913e-04, 1.64991582e-04, 1.74403746e-04,\n",
    "        3.48807492e-04, 1.24721913e-04, 2.06491862e-04, 2.20164080e-04,\n",
    "        2.55223214e-04, 1.94722024e-04, 1.12422813e-04, 9.64653075e-05,\n",
    "        7.72801541e-05, 1.24163870e-04, 8.16496581e-05])}\n",
    "        \n",
    " \n",
    " {'mean_fit_time': array([  685.01906315,  1809.28454868,   336.60541034,   878.23016135]),\n",
    " 'mean_score_time': array([ 1.38006322,  1.27389534,  0.6934317 ,  0.69225407]),\n",
    " 'mean_test_score': array([ 0.00241667,  0.00251667,  0.00243333,  0.00261667]),\n",
    " 'mean_train_score': array([ 0.00245833,  0.00236667,  0.00248333,  0.00253333]),\n",
    " 'param_batch_size': masked_array(data = [32 32 64 64],\n",
    "              mask = [False False False False],\n",
    "        fill_value = ?),\n",
    " 'param_epochs': masked_array(data = [200 500 200 500],\n",
    "              mask = [False False False False],\n",
    "        fill_value = ?),\n",
    " 'params': ({'batch_size': 32, 'epochs': 200},\n",
    "  {'batch_size': 32, 'epochs': 500},\n",
    "  {'batch_size': 64, 'epochs': 200},\n",
    "  {'batch_size': 64, 'epochs': 500}),\n",
    " 'rank_test_score': array([4, 2, 3, 1], dtype=int32),\n",
    " 'split0_test_score': array([ 0.0021 ,  0.00225,  0.00215,  0.00225]),\n",
    " 'split0_train_score': array([ 0.00235 ,  0.0023  ,  0.002625,  0.002575]),\n",
    " 'split1_test_score': array([ 0.00225,  0.00225,  0.00215,  0.00235]),\n",
    " 'split1_train_score': array([ 0.002675,  0.002725,  0.002675,  0.002825]),\n",
    " 'split2_test_score': array([ 0.0029 ,  0.00305,  0.003  ,  0.00325]),\n",
    " 'split2_train_score': array([ 0.00235 ,  0.002075,  0.00215 ,  0.0022  ]),\n",
    " 'std_fit_time': array([  27.85582158,  121.41697465,    1.58335506,   11.64839192]),\n",
    " 'std_score_time': array([ 0.01602076,  0.06291871,  0.03384719,  0.05541393]),\n",
    " 'std_test_score': array([ 0.00034721,  0.00037712,  0.00040069,  0.00044969]),\n",
    " 'std_train_score': array([ 0.00015321,  0.00026952,  0.00023658,  0.00025685])}\n",
    "        \n",
    "        \n",
    "        'mean_fit_time': array([1236.77363722, 1263.8373781 , 1283.07971772,  617.23694984,\n",
    "         644.64875857,  630.75466394]),\n",
    " 'std_fit_time': array([15.23634435,  1.04774932, 70.7173362 , 19.87266061,  3.13235316,\n",
    "        19.13357172]),\n",
    " 'mean_score_time': array([1.9509182 , 2.09144211, 2.07234033, 1.05850196, 1.09700545,\n",
    "        1.07024908]),\n",
    " 'std_score_time': array([0.09494565, 0.09207867, 0.09335411, 0.04954752, 0.04209056,\n",
    "        0.05320864]),\n",
    " 'param_batch_size': masked_array(data=[32, 32, 32, 64, 64, 64],\n",
    "              mask=[False, False, False, False, False, False],\n",
    "        fill_value='?',\n",
    "             dtype=object),\n",
    " 'param_ker_size': masked_array(data=[3, 5, 7, 3, 5, 7],\n",
    "              mask=[False, False, False, False, False, False],\n",
    "        fill_value='?',\n",
    "             dtype=object),\n",
    " 'params': [{'batch_size': 32, 'ker_size': 3},\n",
    "  {'batch_size': 32, 'ker_size': 5},\n",
    "  {'batch_size': 32, 'ker_size': 7},\n",
    "  {'batch_size': 64, 'ker_size': 3},\n",
    "  {'batch_size': 64, 'ker_size': 5},\n",
    "  {'batch_size': 64, 'ker_size': 7}],\n",
    " 'split0_test_score': array([0.00225, 0.0017 , 0.0027 , 0.00225, 0.00205, 0.0024 ]),\n",
    " 'split1_test_score': array([0.0021 , 0.0023 , 0.0023 , 0.0019 , 0.002  , 0.00165]),\n",
    " 'split2_test_score': array([0.00325, 0.00215, 0.0021 , 0.00245, 0.0024 , 0.0024 ]),\n",
    " 'mean_test_score': array([0.00253333, 0.00205   , 0.00236667, 0.0022    , 0.00215   ,\n",
    "        0.00215   ]),\n",
    " 'std_test_score': array([0.00051045, 0.00025495, 0.00024944, 0.0002273 , 0.00017795,\n",
    "        0.00035355]),\n",
    " 'rank_test_score': array([1, 6, 2, 3, 4, 4], dtype=int32),\n",
    " 'split0_train_score': array([0.001875, 0.001625, 0.002525, 0.002175, 0.0019  , 0.002275]),\n",
    " 'split1_train_score': array([0.0027  , 0.00275 , 0.00245 , 0.0022  , 0.002125, 0.002475]),\n",
    " 'split2_train_score': array([0.0024  , 0.00205 , 0.0019  , 0.0021  , 0.00225 , 0.001775]),\n",
    " 'mean_train_score': array([0.002325  , 0.00214167, 0.00229167, 0.00215833, 0.00209167,\n",
    "        0.002175  ]),\n",
    " 'std_train_score': array([3.40954542e-04, 4.63830668e-04, 2.78637558e-04, 4.24918293e-05,\n",
    "        1.44817893e-04, 2.94392029e-04])}\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}