{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# we will study the hyperparamter tuning of fully connected scatternet\n", "\n", "# first find optimal number of layers and neuron numbers\n", "# second optimize the batch size and number of epochs for the best learned architecture\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading the dataset here" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2018-09-04T17:46:08.075289Z", "start_time": "2018-09-04T17:46:07.653719Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/hegder/anaconda3/envs/deep/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", " return f(*args, **kwds)\n", "/home/hegder/anaconda3/envs/deep/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", " return f(*args, **kwds)\n", "/home/hegder/anaconda3/envs/deep/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", " return f(*args, **kwds)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Dataset has been loaded\n", "x-train (44999, 8)\n", "x-test (55001, 8)\n", "y-train (44999, 256)\n", "y-test (55001, 256)\n" ] } ], "source": [ "import numpy as np\n", "\n", "import h5py\n", "from sklearn.model_selection import train_test_split\n", "\n", "#now load this dataset \n", "h5f = h5py.File('./datasets/s8_sio2tio2_v2.h5','r')\n", "X = h5f['sizes'][:]\n", "Y = h5f['spectrum'][:]\n", "\n", "#get the ranges of the loaded data\n", "num_layers = X.shape[1]\n", "num_lpoints = Y.shape[1]\n", "size_max = np.amax(X)\n", "size_min = np.amin(X)\n", "size_av = 0.5*(size_max + size_min)\n", "\n", "#this information is not given in the dataset\n", "lam_min = 300\n", "lam_max = 1200\n", "lams = np.linspace(lam_min, lam_max, num_lpoints)\n", "\n", "#create a train - test split of the dataset\n", "x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.55, random_state=42)\n", "\n", "# normalize inputs \n", "x_train = (x_train - 50)/20 \n", "x_test = (x_test - 50)/20 \n", "\n", "print(\"Dataset has been loaded\")\n", "print(\"x-train\", x_train.shape)\n", "print(\"x-test \", x_test.shape)\n", "print(\"y-train\", y_train.shape)\n", "print(\"y-test \", y_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### create models here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-09-04T08:41:45.424541Z", "start_time": "2018-09-04T08:41:44.270792Z" } }, "outputs": [], "source": [ "import scnets as scn\n", "from IPython.display import SVG\n", "from keras.utils.vis_utils import model_to_dot\n", "\n", "#define and visualize the model here\n", "#model = scn.fullycon(num_layers, num_lpoints, 4, 500, 2)\n", "\n", "model = scn.convprel(in_size=8, \n", " out_size=256,\n", " c1_nf=64,\n", " clayers=3,\n", " ker_size=3)\n", "model.summary()\n", "#SVG(model_to_dot(model).create(prog='dot', format='svg'))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-09-04T08:53:28.967090Z", "start_time": "2018-09-04T08:42:58.990636Z" } }, "outputs": [], "source": [ "x_t, x_v, y_t, y_v = train_test_split(x_train, y_train, test_size=0.2, random_state=42)\n", "history = model.fit(x_t, y_t,\n", " batch_size=64,\n", " epochs=250, \n", " verbose=1,\n", " validation_data=(x_v, y_v))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-09-04T09:13:20.325358Z", "start_time": "2018-09-04T09:13:20.044281Z" } }, "outputs": [], "source": [ "scn.plot_training_history(history, 64*2.56)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "start_time": "2018-09-04T17:16:38.203Z" }, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 3 folds for each of 12 candidates, totalling 36 fits\n" ] } ], "source": [ "\n", "\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "from sklearn.metrics import fbeta_score, make_scorer\n", "\n", "def my_custom_score_func(ground_truth, predictions):\n", " diff = np.abs(ground_truth - predictions)/np.abs(ground_truth)\n", " return np.mean(diff)\n", "\n", "\n", "\n", "from sklearn.model_selection import GridSearchCV\n", "from keras.models import Sequential\n", "from keras.layers import Dense\n", "from keras.wrappers.scikit_learn import KerasRegressor\n", "import scnets as scn\n", "#model = KerasClassifier(build_fn=scn.fullycon, in_size=8, out_size=250, N_gpus=1, epochs=500, verbose=0)\n", "\n", "model = KerasRegressor(build_fn=scn.conv1dmodel, \n", " in_size=8, \n", " out_size=256, \n", " c1_nf=64,\n", " clayers=3,\n", " ker_size=3,\n", " epochs=250, \n", " verbose=0)\n", "my_score = make_scorer(my_custom_score_func, greater_is_better=False)\n", "\n", "\n", "\n", "param_grid = dict(ker_size=[3, 5],\n", " clayers=[3,4,5],\n", " batch_size=[32,64]) \n", "grid = GridSearchCV(estimator=model, \n", " param_grid=param_grid, \n", " n_jobs=1, \n", " scoring='explained_variance',\n", " verbose=1)\n", "grid_result = grid.fit(x_train, y_train)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-09-04T14:25:50.558503Z", "start_time": "2018-09-04T14:25:50.539227Z" } }, "outputs": [], "source": [ "grid_result.cv_results_" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-09-04T14:26:16.796674Z", "start_time": "2018-09-04T14:26:16.790511Z" } }, "outputs": [], "source": [ "grid_result.best_params_" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-09-04T15:01:18.552044Z", "start_time": "2018-09-04T15:01:18.538960Z" } }, "outputs": [], "source": [ "bestidx = np.argsort(grid_result.cv_results_['mean_test_score'])\n", "print(idx)\n", "print(np.flip(idx))\n", "parlist = grid_result.cv_results_['params']\n", "bestlist = [parlist[indx] for indx in bestidx]\n", "bestlist\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "'mean_fit_time': array([ 440.59602499, 481.73871398, 477.10135643, 504.1547633 ,\n", " 488.27373465, 564.76856009, 571.80046884, 569.00620119,\n", " 650.78964178, 740.00103672, 597.98015889, 659.03726633,\n", " 476.14166268, 508.75317804, 505.10563159, 556.87875859,\n", " 556.58007542, 542.48226706, 591.323385 , 616.57070398,\n", " 658.64879243, 891.65342418, 1055.23733377, 993.75831501,\n", " 647.80959813, 629.8490928 , 516.74331371, 787.65950712,\n", " 635.2311732 , 545.72071338, 904.7078863 , 685.55618747,\n", " 530.5135781 , 815.14902997, 821.7310555 , 685.34639025,\n", " 240.90511258, 246.83409182, 259.09067996, 262.85167106,\n", " 228.2099692 , 392.25437156, 374.78024157, 452.21288816,\n", " 459.90044618, 495.74129097, 501.12204981, 368.91637723,\n", " 240.11106253, 237.93171946, 261.74841015, 265.67413298,\n", " 274.72749496, 378.51081745, 436.21683431, 241.44164371,\n", " 274.60629002, 330.47249166, 394.98302094, 318.71901139,\n", " 237.12692181, 248.02535923, 254.61501988, 243.65626915,\n", " 298.34629599, 273.75161084, 318.28928526, 312.33878072,\n", " 309.61013468, 335.33733678, 316.16733519, 284.3205506 ]),\n", "\n", " 'params': [{'batch_size': 32, 'c1_nf': 32, 'clayers': 1, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 1, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 1, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 2, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 2, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 2, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 3, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 3, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 3, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 4, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 4, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 32, 'clayers': 4, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 1, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 1, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 1, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 2, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 2, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 2, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 3, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 3, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 3, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 4, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 4, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 64, 'clayers': 4, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 1, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 1, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 1, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 2, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 2, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 2, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 3, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 3, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 3, 'ker_size': 7},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 4, 'ker_size': 3},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 4, 'ker_size': 5},\n", " {'batch_size': 32, 'c1_nf': 96, 'clayers': 4, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 1, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 1, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 1, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 2, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 2, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 2, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 3, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 3, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 3, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 4, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 4, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 32, 'clayers': 4, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 1, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 1, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 1, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 2, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 2, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 2, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 3, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 3, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 3, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 4, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 4, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 64, 'clayers': 4, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 1, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 1, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 1, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 2, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 2, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 2, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 3, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 3, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 3, 'ker_size': 7},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 4, 'ker_size': 3},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 4, 'ker_size': 5},\n", " {'batch_size': 64, 'c1_nf': 96, 'clayers': 4, 'ker_size': 7}],\n", "\n", "\n", " 'mean_test_score': array([0.00235561, 0.00175559, 0.00191115, 0.00255561, 0.00228894,\n", " 0.00233339, 0.00240005, 0.00260006, 0.00262228, 0.00222227,\n", " 0.00222227, 0.00235561, 0.00191115, 0.0021556 , 0.00242228,\n", " 0.00217783, 0.00217783, 0.00233339, 0.00231116, 0.00202227,\n", " 0.00206671, 0.00251117, 0.00237783, 0.00253339, 0.00197782,\n", " 0.00222227, 0.00220005, 0.0019556 , 0.00208894, 0.00240005,\n", " 0.00220005, 0.00204449, 0.00222227, 0.00228894, 0.00231116,\n", " 0.00246672, 0.00200004, 0.00182226, 0.00164448, 0.00197782,\n", " 0.0026445 , 0.00208894, 0.00222227, 0.00248894, 0.00217783,\n", " 0.00208894, 0.00253339, 0.0024445 , 0.00204449, 0.00188893,\n", " 0.00197782, 0.00233339, 0.00231116, 0.00231116, 0.00233339,\n", " 0.0021556 , 0.00231116, 0.0021556 , 0.00220005, 0.00233339,\n", " 0.0021556 , 0.00220005, 0.00251117, 0.00191115, 0.00197782,\n", " 0.0021556 , 0.00224449, 0.00231116, 0.00228894, 0.00220005,\n", " 0.00235561, 0.00211116]),\n", " 'std_test_score': array([5.45313915e-04, 1.36982222e-04, 5.71830319e-04, 3.09583745e-04,\n", " 4.63034070e-04, 1.88601629e-04, 3.57010463e-04, 2.49388000e-04,\n", " 2.06134688e-04, 4.75566024e-04, 1.74937538e-04, 4.01290372e-04,\n", " 2.06146739e-04, 5.14564491e-04, 3.28118192e-04, 3.09506255e-04,\n", " 3.45776617e-04, 5.65777616e-04, 6.28186970e-05, 2.45451156e-04,\n", " 3.03064138e-04, 4.69393085e-04, 3.70577110e-04, 2.37355340e-04,\n", " 5.14534288e-04, 2.99865023e-04, 1.08929121e-04, 3.28165699e-04,\n", " 1.91227745e-04, 1.44095217e-04, 1.96333991e-04, 1.13377551e-04,\n", " 6.28899342e-05, 4.93994775e-04, 5.05855189e-04, 4.25227642e-04,\n", " 5.25013167e-04, 1.66308696e-04, 1.91169192e-04, 7.89511257e-04,\n", " 3.09609747e-04, 2.20000709e-04, 1.36980546e-04, 2.99872491e-04,\n", " 2.06155046e-04, 3.70481852e-04, 1.88604774e-04, 2.74052629e-04,\n", " 4.15813294e-04, 2.99823569e-04, 2.06148816e-04, 1.44067451e-04,\n", " 3.50039135e-04, 2.79388549e-04, 3.81084580e-04, 3.62491814e-04,\n", " 2.26678399e-04, 4.08631950e-04, 4.35536821e-04, 2.49505205e-04,\n", " 4.53303035e-04, 1.44064707e-04, 5.77012935e-04, 8.31603039e-05,\n", " 5.34338425e-04, 2.79402571e-04, 2.45459743e-04, 1.36996891e-04,\n", " 3.86260180e-04, 1.88498951e-04, 3.09579276e-04, 3.70571422e-04]),\n", " 'rank_test_score': array([16, 71, 66, 4, 30, 19, 13, 3, 2, 34, 34, 16, 66, 47, 12, 44, 44,\n", " 19, 24, 59, 56, 7, 15, 5, 61, 34, 39, 65, 53, 13, 39, 57, 34, 30,\n", " 24, 10, 60, 70, 72, 61, 1, 53, 34, 9, 44, 53, 5, 11, 57, 69, 61,\n", " 19, 24, 24, 19, 47, 24, 47, 39, 19, 47, 39, 7, 66, 61, 47, 33, 24,\n", " 30, 39, 16, 52], dtype=int32),\n", "\n", "\n", " 'mean_train_score': array([0.00231116, 0.00187782, 0.0018556 , 0.00231116, 0.00222227,\n", " 0.00217783, 0.00228894, 0.00236672, 0.00232227, 0.00252228,\n", " 0.00213338, 0.00241117, 0.00214449, 0.00204449, 0.00235561,\n", " 0.00205561, 0.00217783, 0.00238894, 0.00213338, 0.00230005,\n", " 0.0021556 , 0.0023445 , 0.00238894, 0.00236672, 0.00203338,\n", " 0.00208894, 0.00232228, 0.00203338, 0.00258895, 0.00216671,\n", " 0.00221117, 0.00217783, 0.0022556 , 0.00224449, 0.00246672,\n", " 0.00240006, 0.00187782, 0.00183338, 0.00153337, 0.00223338,\n", " 0.00186671, 0.00204449, 0.00218894, 0.00227783, 0.00203338,\n", " 0.00257784, 0.00235561, 0.0023445 , 0.00200004, 0.00210005,\n", " 0.00204449, 0.00247784, 0.00233338, 0.00230005, 0.00223338,\n", " 0.00222227, 0.00228894, 0.0023445 , 0.00228894, 0.00248895,\n", " 0.00188893, 0.00211116, 0.00233339, 0.00236672, 0.00213338,\n", " 0.00204449, 0.00248895, 0.00238895, 0.00233338, 0.00231116,\n", " 0.0023445 , 0.00241116]),\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "{'mean_fit_time': array([ 893.33654523, 898.53078222, 1119.14394736, 1130.1775128 ,\n", " 956.51246222, 964.45365715, 1209.3984166 , 1582.91039157,\n", " 1394.26560704, 1616.26630108, 1266.47200227, 1116.83488099,\n", " 1205.42738708, 1201.92515103, 1210.92550143]),\n", " 'std_fit_time': array([ 3.08891285, 6.81113186, 212.20371026, 238.93357922,\n", " 6.97112622, 18.87349827, 223.00445851, 57.7875855 ,\n", " 100.70476936, 43.95356933, 134.83849082, 11.29690679,\n", " 5.42330543, 6.19267952, 4.92641743]),\n", " 'params': [{'N_hidden': 1, 'N_neurons': 250},\n", " {'N_hidden': 1, 'N_neurons': 500},\n", " {'N_hidden': 1, 'N_neurons': 1000},\n", " {'N_hidden': 2, 'N_neurons': 250},\n", " {'N_hidden': 2, 'N_neurons': 500},\n", " {'N_hidden': 2, 'N_neurons': 1000},\n", " {'N_hidden': 3, 'N_neurons': 250},\n", " {'N_hidden': 3, 'N_neurons': 500},\n", " {'N_hidden': 3, 'N_neurons': 1000},\n", " {'N_hidden': 4, 'N_neurons': 250},\n", " {'N_hidden': 4, 'N_neurons': 500},\n", " {'N_hidden': 4, 'N_neurons': 1000},\n", " {'N_hidden': 5, 'N_neurons': 250},\n", " {'N_hidden': 5, 'N_neurons': 500},\n", " {'N_hidden': 5, 'N_neurons': 1000}],\n", " 'split0_test_score': array([0.00235, 0.00225, 0.0025 , 0.0024 , 0.0026 , 0.0021 , 0.00165,\n", " 0.0023 , 0.00255, 0.00255, 0.0024 , 0.0027 , 0.0021 , 0.00225,\n", " 0.0022 ]),\n", " 'split1_test_score': array([0.0024 , 0.00225, 0.0022 , 0.0022 , 0.00235, 0.0022 , 0.0021 ,\n", " 0.00195, 0.00215, 0.0022 , 0.00185, 0.00195, 0.002 , 0.00195,\n", " 0.0021 ]),\n", " 'split2_test_score': array([0.00255, 0.00315, 0.00275, 0.00295, 0.00355, 0.0032 , 0.00275,\n", " 0.0031 , 0.0028 , 0.00305, 0.00305, 0.00355, 0.00275, 0.00285,\n", " 0.00305]),\n", " 'mean_test_score': array([0.00243333, 0.00255 , 0.00248333, 0.00251667, 0.00283333,\n", " 0.0025 , 0.00216667, 0.00245 , 0.0025 , 0.0026 ,\n", " 0.00243333, 0.00273333, 0.00228333, 0.00235 , 0.00245 ]),\n", " 'std_test_score': array([8.49836586e-05, 4.24264069e-04, 2.24845626e-04, 3.17104960e-04,\n", " 5.16935414e-04, 4.96655481e-04, 4.51540573e-04, 4.81317636e-04,\n", " 2.67706307e-04, 3.48807492e-04, 4.90464632e-04, 6.53622385e-04,\n", " 3.32498956e-04, 3.74165739e-04, 4.26223728e-04]),\n", " 'rank_test_score': array([11, 4, 8, 5, 1, 6, 15, 9, 6, 3, 11, 2, 14, 13, 9],\n", " dtype=int32),\n", " 'split0_train_score': array([0.00255 , 0.0024 , 0.0023 , 0.00255 , 0.00255 , 0.002375,\n", " 0.002125, 0.002625, 0.0025 , 0.002775, 0.002625, 0.0027 ,\n", " 0.0023 , 0.00245 , 0.002275]),\n", " 'split1_train_score': array([0.002975, 0.0025 , 0.00255 , 0.0028 , 0.0029 , 0.002475,\n", " 0.002575, 0.0027 , 0.002575, 0.002375, 0.002475, 0.002475,\n", " 0.002325, 0.002425, 0.002375]),\n", " 'split2_train_score': array([0.00215 , 0.0022 , 0.00215 , 0.002375, 0.00205 , 0.002175,\n", " 0.00215 , 0.0022 , 0.002 , 0.00235 , 0.00235 , 0.002525,\n", " 0.00215 , 0.002175, 0.002175]),\n", " 'mean_train_score': array([0.00255833, 0.00236667, 0.00233333, 0.002575 , 0.0025 ,\n", " 0.00234167, 0.00228333, 0.00250833, 0.00235833, 0.0025 ,\n", " 0.00248333, 0.00256667, 0.00225833, 0.00235 , 0.002275 ]),\n", " 'std_train_score': array([3.36856382e-04, 1.24721913e-04, 1.64991582e-04, 1.74403746e-04,\n", " 3.48807492e-04, 1.24721913e-04, 2.06491862e-04, 2.20164080e-04,\n", " 2.55223214e-04, 1.94722024e-04, 1.12422813e-04, 9.64653075e-05,\n", " 7.72801541e-05, 1.24163870e-04, 8.16496581e-05])}\n", " \n", " \n", " {'mean_fit_time': array([ 685.01906315, 1809.28454868, 336.60541034, 878.23016135]),\n", " 'mean_score_time': array([ 1.38006322, 1.27389534, 0.6934317 , 0.69225407]),\n", " 'mean_test_score': array([ 0.00241667, 0.00251667, 0.00243333, 0.00261667]),\n", " 'mean_train_score': array([ 0.00245833, 0.00236667, 0.00248333, 0.00253333]),\n", " 'param_batch_size': masked_array(data = [32 32 64 64],\n", " mask = [False False False False],\n", " fill_value = ?),\n", " 'param_epochs': masked_array(data = [200 500 200 500],\n", " mask = [False False False False],\n", " fill_value = ?),\n", " 'params': ({'batch_size': 32, 'epochs': 200},\n", " {'batch_size': 32, 'epochs': 500},\n", " {'batch_size': 64, 'epochs': 200},\n", " {'batch_size': 64, 'epochs': 500}),\n", " 'rank_test_score': array([4, 2, 3, 1], dtype=int32),\n", " 'split0_test_score': array([ 0.0021 , 0.00225, 0.00215, 0.00225]),\n", " 'split0_train_score': array([ 0.00235 , 0.0023 , 0.002625, 0.002575]),\n", " 'split1_test_score': array([ 0.00225, 0.00225, 0.00215, 0.00235]),\n", " 'split1_train_score': array([ 0.002675, 0.002725, 0.002675, 0.002825]),\n", " 'split2_test_score': array([ 0.0029 , 0.00305, 0.003 , 0.00325]),\n", " 'split2_train_score': array([ 0.00235 , 0.002075, 0.00215 , 0.0022 ]),\n", " 'std_fit_time': array([ 27.85582158, 121.41697465, 1.58335506, 11.64839192]),\n", " 'std_score_time': array([ 0.01602076, 0.06291871, 0.03384719, 0.05541393]),\n", " 'std_test_score': array([ 0.00034721, 0.00037712, 0.00040069, 0.00044969]),\n", " 'std_train_score': array([ 0.00015321, 0.00026952, 0.00023658, 0.00025685])}\n", " \n", " \n", " 'mean_fit_time': array([1236.77363722, 1263.8373781 , 1283.07971772, 617.23694984,\n", " 644.64875857, 630.75466394]),\n", " 'std_fit_time': array([15.23634435, 1.04774932, 70.7173362 , 19.87266061, 3.13235316,\n", " 19.13357172]),\n", " 'mean_score_time': array([1.9509182 , 2.09144211, 2.07234033, 1.05850196, 1.09700545,\n", " 1.07024908]),\n", " 'std_score_time': array([0.09494565, 0.09207867, 0.09335411, 0.04954752, 0.04209056,\n", " 0.05320864]),\n", " 'param_batch_size': masked_array(data=[32, 32, 32, 64, 64, 64],\n", " mask=[False, False, False, False, False, False],\n", " fill_value='?',\n", " dtype=object),\n", " 'param_ker_size': masked_array(data=[3, 5, 7, 3, 5, 7],\n", " mask=[False, False, False, False, False, False],\n", " fill_value='?',\n", " dtype=object),\n", " 'params': [{'batch_size': 32, 'ker_size': 3},\n", " {'batch_size': 32, 'ker_size': 5},\n", " {'batch_size': 32, 'ker_size': 7},\n", " {'batch_size': 64, 'ker_size': 3},\n", " {'batch_size': 64, 'ker_size': 5},\n", " {'batch_size': 64, 'ker_size': 7}],\n", " 'split0_test_score': array([0.00225, 0.0017 , 0.0027 , 0.00225, 0.00205, 0.0024 ]),\n", " 'split1_test_score': array([0.0021 , 0.0023 , 0.0023 , 0.0019 , 0.002 , 0.00165]),\n", " 'split2_test_score': array([0.00325, 0.00215, 0.0021 , 0.00245, 0.0024 , 0.0024 ]),\n", " 'mean_test_score': array([0.00253333, 0.00205 , 0.00236667, 0.0022 , 0.00215 ,\n", " 0.00215 ]),\n", " 'std_test_score': array([0.00051045, 0.00025495, 0.00024944, 0.0002273 , 0.00017795,\n", " 0.00035355]),\n", " 'rank_test_score': array([1, 6, 2, 3, 4, 4], dtype=int32),\n", " 'split0_train_score': array([0.001875, 0.001625, 0.002525, 0.002175, 0.0019 , 0.002275]),\n", " 'split1_train_score': array([0.0027 , 0.00275 , 0.00245 , 0.0022 , 0.002125, 0.002475]),\n", " 'split2_train_score': array([0.0024 , 0.00205 , 0.0019 , 0.0021 , 0.00225 , 0.001775]),\n", " 'mean_train_score': array([0.002325 , 0.00214167, 0.00229167, 0.00215833, 0.00209167,\n", " 0.002175 ]),\n", " 'std_train_score': array([3.40954542e-04, 4.63830668e-04, 2.78637558e-04, 4.24918293e-05,\n", " 1.44817893e-04, 2.94392029e-04])}\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }