{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 196,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-08-24T15:02:25.827112Z",
     "start_time": "2018-08-24T15:02:25.721467Z"
    }
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import nd, autograd, gluon\n",
    "mx.random.seed(1)\n",
    "import snlay as snlay\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "lam_min = 300\n",
    "lam_max = 1200\n",
    "size_min = 30\n",
    "size_max = 70\n",
    "num_lpoints = 250\n",
    "\n",
    "\n",
    "lams = np.linspace(lam_min, lam_max, num_lpoints)\n",
    "mats=[3,4,3,4,3,4,3,4]\n",
    "sizes=np.array([35, 45, 35, 45, 35, 45, 35, 45])\n",
    "\n",
    "sz = nd.array(sizes).reshape((8,1))\n",
    "\n",
    "spec_calc = snlay.calc_spectrum(sizes, mats, lams)\n",
    "#plt.plot(lams, spec_calc)\n",
    "\n",
    "\n",
    "\n",
    "params = [sz]\n",
    "\n",
    "\n",
    "for param in params:\n",
    "    param.attach_grad()\n",
    "params\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def net2(X):\n",
    "    return nd.abs(nd.sum(sz) - 16)\n",
    "        \n",
    "def square_loss2(yhat, y):\n",
    "    return nd.abs(yhat - y)\n",
    "\n",
    "def SGD(params, lr):\n",
    "    for param in params:\n",
    "        param[:] = param - lr * param.grad\n",
    "\n",
    "epochs = 10\n",
    "learning_rate = .01\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for ind in range(200):\n",
    "    #print(ind)\n",
    "    with autograd.record():\n",
    "        loss = net2(ind)\n",
    "    loss.backward()\n",
    "    SGD(params, learning_rate)\n",
    "    \n",
    "\n",
    "sz2 = np.array(sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}