{ "cells": [ { "cell_type": "code", "execution_count": 196, "metadata": { "ExecuteTime": { "end_time": "2018-08-24T15:02:25.827112Z", "start_time": "2018-08-24T15:02:25.721467Z" } }, "outputs": [], "source": [ "import mxnet as mx\n", "from mxnet import nd, autograd, gluon\n", "mx.random.seed(1)\n", "import snlay as snlay\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "lam_min = 300\n", "lam_max = 1200\n", "size_min = 30\n", "size_max = 70\n", "num_lpoints = 250\n", "\n", "\n", "lams = np.linspace(lam_min, lam_max, num_lpoints)\n", "mats=[3,4,3,4,3,4,3,4]\n", "sizes=np.array([35, 45, 35, 45, 35, 45, 35, 45])\n", "\n", "sz = nd.array(sizes).reshape((8,1))\n", "\n", "spec_calc = snlay.calc_spectrum(sizes, mats, lams)\n", "#plt.plot(lams, spec_calc)\n", "\n", "\n", "\n", "params = [sz]\n", "\n", "\n", "for param in params:\n", " param.attach_grad()\n", "params\n", "\n", "\n", "\n", "\n", "def net2(X):\n", " return nd.abs(nd.sum(sz) - 16)\n", " \n", "def square_loss2(yhat, y):\n", " return nd.abs(yhat - y)\n", "\n", "def SGD(params, lr):\n", " for param in params:\n", " param[:] = param - lr * param.grad\n", "\n", "epochs = 10\n", "learning_rate = .01\n", "\n", "\n", "\n", "\n", "for ind in range(200):\n", " #print(ind)\n", " with autograd.record():\n", " loss = net2(ind)\n", " loss.backward()\n", " SGD(params, learning_rate)\n", " \n", "\n", "sz2 = np.array(sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }