Lecture 26, RMSProp optimizer
This commit is contained in:
parent
2192bf3050
commit
036d06a652
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@ -221,7 +221,7 @@ class Optimizer_SGD():
|
||||
self.iterations += 1
|
||||
|
||||
# %% [markdown]
|
||||
# # Testing the Learning Rate Decay
|
||||
# ## Testing the Learning Rate Decay
|
||||
|
||||
# %%
|
||||
# Create dataset
|
||||
@ -270,7 +270,8 @@ for epoch in range(10001):
|
||||
if not epoch % 100:
|
||||
print(f'epoch: {epoch}, ' +
|
||||
f'acc: {accuracy:.3f}, ' +
|
||||
f'loss: {loss:.3f}')
|
||||
f'loss: {loss:.3f}, ' +
|
||||
f'lr: {optimizer.current_learning_rate}')
|
||||
|
||||
# Backward pass
|
||||
loss_activation.backward(loss_activation.output, y)
|
||||
@ -335,7 +336,7 @@ class Optimizer_SGD():
|
||||
self.iterations += 1
|
||||
|
||||
# %% [markdown]
|
||||
# # Testing the Gradient Optimizer with Momentum
|
||||
# ## Testing the Gradient Optimizer with Momentum
|
||||
|
||||
# %%
|
||||
# Create dataset
|
||||
@ -384,7 +385,8 @@ for epoch in range(10001):
|
||||
if not epoch % 100:
|
||||
print(f'epoch: {epoch}, ' +
|
||||
f'acc: {accuracy:.3f}, ' +
|
||||
f'loss: {loss:.3f}')
|
||||
f'loss: {loss:.3f}, ' +
|
||||
f'lr: {optimizer.current_learning_rate}')
|
||||
|
||||
# Backward pass
|
||||
loss_activation.backward(loss_activation.output, y)
|
||||
|
||||
@ -254,12 +254,14 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AdaGrad Optimizer\n",
|
||||
"Different weights should have different learning rates. If one weight affects the loss much more strongly than the other, then consider using smaller learning rates with it. We can do this by maintaining a \"cache\" of the last gradients and normalizing based on this."
|
||||
"Different weights should have different learning rates. If one weight affects the loss much more strongly than the other, then consider using smaller learning rates with it. We can do this by maintaining a \"cache\" of the last gradients and normalizing based on this.\n",
|
||||
"\n",
|
||||
"A downside is that as the cache keeps accumulating, some neurons will have such a small learning rate that the neuron basically becomes fixed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -269,7 +271,7 @@
|
||||
" self.current_learning_rate = self.initial_learning_rate\n",
|
||||
" self.decay = decay\n",
|
||||
" self.iterations = 0\n",
|
||||
" self.epsilon = 0\n",
|
||||
" self.epsilon = epsilon\n",
|
||||
"\n",
|
||||
" def pre_update_params(self):\n",
|
||||
" if self.decay:\n",
|
||||
@ -299,114 +301,114 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 0, acc: 0.360, loss: 1.099\n",
|
||||
"epoch: 100, acc: 0.477, loss: 0.992\n",
|
||||
"epoch: 200, acc: 0.560, loss: 0.934\n",
|
||||
"epoch: 300, acc: 0.600, loss: 0.881\n",
|
||||
"epoch: 400, acc: 0.637, loss: 0.826\n",
|
||||
"epoch: 500, acc: 0.617, loss: 0.793\n",
|
||||
"epoch: 600, acc: 0.643, loss: 0.748\n",
|
||||
"epoch: 700, acc: 0.657, loss: 0.726\n",
|
||||
"epoch: 800, acc: 0.663, loss: 0.698\n",
|
||||
"epoch: 900, acc: 0.680, loss: 0.681\n",
|
||||
"epoch: 1000, acc: 0.683, loss: 0.664\n",
|
||||
"epoch: 1100, acc: 0.693, loss: 0.653\n",
|
||||
"epoch: 1200, acc: 0.693, loss: 0.642\n",
|
||||
"epoch: 1300, acc: 0.707, loss: 0.632\n",
|
||||
"epoch: 1400, acc: 0.720, loss: 0.625\n",
|
||||
"epoch: 1500, acc: 0.717, loss: 0.615\n",
|
||||
"epoch: 1600, acc: 0.723, loss: 0.608\n",
|
||||
"epoch: 1700, acc: 0.730, loss: 0.600\n",
|
||||
"epoch: 1800, acc: 0.740, loss: 0.588\n",
|
||||
"epoch: 1900, acc: 0.740, loss: 0.581\n",
|
||||
"epoch: 2000, acc: 0.743, loss: 0.576\n",
|
||||
"epoch: 2100, acc: 0.740, loss: 0.570\n",
|
||||
"epoch: 2200, acc: 0.750, loss: 0.563\n",
|
||||
"epoch: 2300, acc: 0.747, loss: 0.562\n",
|
||||
"epoch: 2400, acc: 0.757, loss: 0.557\n",
|
||||
"epoch: 2500, acc: 0.760, loss: 0.554\n",
|
||||
"epoch: 2600, acc: 0.770, loss: 0.550\n",
|
||||
"epoch: 2700, acc: 0.777, loss: 0.546\n",
|
||||
"epoch: 2800, acc: 0.777, loss: 0.543\n",
|
||||
"epoch: 2900, acc: 0.780, loss: 0.540\n",
|
||||
"epoch: 3000, acc: 0.777, loss: 0.537\n",
|
||||
"epoch: 3100, acc: 0.773, loss: 0.533\n",
|
||||
"epoch: 3200, acc: 0.783, loss: 0.531\n",
|
||||
"epoch: 3300, acc: 0.777, loss: 0.528\n",
|
||||
"epoch: 3400, acc: 0.777, loss: 0.526\n",
|
||||
"epoch: 3500, acc: 0.773, loss: 0.523\n",
|
||||
"epoch: 3600, acc: 0.780, loss: 0.522\n",
|
||||
"epoch: 3700, acc: 0.773, loss: 0.520\n",
|
||||
"epoch: 3800, acc: 0.777, loss: 0.518\n",
|
||||
"epoch: 3900, acc: 0.780, loss: 0.516\n",
|
||||
"epoch: 4000, acc: 0.780, loss: 0.515\n",
|
||||
"epoch: 4100, acc: 0.773, loss: 0.513\n",
|
||||
"epoch: 4200, acc: 0.770, loss: 0.511\n",
|
||||
"epoch: 4300, acc: 0.773, loss: 0.510\n",
|
||||
"epoch: 4400, acc: 0.777, loss: 0.509\n",
|
||||
"epoch: 4500, acc: 0.777, loss: 0.508\n",
|
||||
"epoch: 4600, acc: 0.780, loss: 0.507\n",
|
||||
"epoch: 4700, acc: 0.777, loss: 0.506\n",
|
||||
"epoch: 4800, acc: 0.777, loss: 0.504\n",
|
||||
"epoch: 4900, acc: 0.783, loss: 0.503\n",
|
||||
"epoch: 5000, acc: 0.783, loss: 0.503\n",
|
||||
"epoch: 5100, acc: 0.787, loss: 0.502\n",
|
||||
"epoch: 5200, acc: 0.790, loss: 0.501\n",
|
||||
"epoch: 5300, acc: 0.787, loss: 0.500\n",
|
||||
"epoch: 5400, acc: 0.787, loss: 0.498\n",
|
||||
"epoch: 5500, acc: 0.783, loss: 0.497\n",
|
||||
"epoch: 5600, acc: 0.787, loss: 0.496\n",
|
||||
"epoch: 5700, acc: 0.780, loss: 0.495\n",
|
||||
"epoch: 5800, acc: 0.780, loss: 0.495\n",
|
||||
"epoch: 5900, acc: 0.783, loss: 0.494\n",
|
||||
"epoch: 6000, acc: 0.790, loss: 0.494\n",
|
||||
"epoch: 6100, acc: 0.777, loss: 0.493\n",
|
||||
"epoch: 6200, acc: 0.783, loss: 0.492\n",
|
||||
"epoch: 6300, acc: 0.783, loss: 0.491\n",
|
||||
"epoch: 6400, acc: 0.790, loss: 0.490\n",
|
||||
"epoch: 6500, acc: 0.780, loss: 0.488\n",
|
||||
"epoch: 6600, acc: 0.780, loss: 0.487\n",
|
||||
"epoch: 6700, acc: 0.777, loss: 0.485\n",
|
||||
"epoch: 6800, acc: 0.780, loss: 0.483\n",
|
||||
"epoch: 6900, acc: 0.783, loss: 0.482\n",
|
||||
"epoch: 7000, acc: 0.790, loss: 0.480\n",
|
||||
"epoch: 7100, acc: 0.790, loss: 0.479\n",
|
||||
"epoch: 7200, acc: 0.797, loss: 0.477\n",
|
||||
"epoch: 7300, acc: 0.803, loss: 0.476\n",
|
||||
"epoch: 7400, acc: 0.813, loss: 0.475\n",
|
||||
"epoch: 7500, acc: 0.813, loss: 0.474\n",
|
||||
"epoch: 7600, acc: 0.813, loss: 0.472\n",
|
||||
"epoch: 7700, acc: 0.813, loss: 0.471\n",
|
||||
"epoch: 7800, acc: 0.810, loss: 0.470\n",
|
||||
"epoch: 7900, acc: 0.810, loss: 0.469\n",
|
||||
"epoch: 8000, acc: 0.810, loss: 0.468\n",
|
||||
"epoch: 8100, acc: 0.810, loss: 0.465\n",
|
||||
"epoch: 8200, acc: 0.807, loss: 0.463\n",
|
||||
"epoch: 8300, acc: 0.803, loss: 0.462\n",
|
||||
"epoch: 8400, acc: 0.803, loss: 0.461\n",
|
||||
"epoch: 8500, acc: 0.807, loss: 0.459\n",
|
||||
"epoch: 8600, acc: 0.810, loss: 0.458\n",
|
||||
"epoch: 8700, acc: 0.813, loss: 0.458\n",
|
||||
"epoch: 8800, acc: 0.810, loss: 0.456\n",
|
||||
"epoch: 8900, acc: 0.810, loss: 0.455\n",
|
||||
"epoch: 9000, acc: 0.813, loss: 0.452\n",
|
||||
"epoch: 9100, acc: 0.813, loss: 0.450\n",
|
||||
"epoch: 9200, acc: 0.817, loss: 0.448\n",
|
||||
"epoch: 9300, acc: 0.810, loss: 0.447\n",
|
||||
"epoch: 9400, acc: 0.810, loss: 0.446\n",
|
||||
"epoch: 9500, acc: 0.813, loss: 0.444\n",
|
||||
"epoch: 9600, acc: 0.813, loss: 0.441\n",
|
||||
"epoch: 9700, acc: 0.817, loss: 0.440\n",
|
||||
"epoch: 9800, acc: 0.817, loss: 0.438\n",
|
||||
"epoch: 9900, acc: 0.813, loss: 0.436\n",
|
||||
"epoch: 10000, acc: 0.813, loss: 0.435\n"
|
||||
"epoch: 0, acc: 0.353, loss: 1.099, lr: 1.0\n",
|
||||
"epoch: 100, acc: 0.497, loss: 0.986, lr: 0.9901970492127933\n",
|
||||
"epoch: 200, acc: 0.527, loss: 0.936, lr: 0.9804882831650161\n",
|
||||
"epoch: 300, acc: 0.513, loss: 0.918, lr: 0.9709680551509855\n",
|
||||
"epoch: 400, acc: 0.580, loss: 0.904, lr: 0.9616309260505818\n",
|
||||
"epoch: 500, acc: 0.550, loss: 0.910, lr: 0.9524716639679969\n",
|
||||
"epoch: 600, acc: 0.563, loss: 0.860, lr: 0.9434852344560807\n",
|
||||
"epoch: 700, acc: 0.600, loss: 0.838, lr: 0.9346667912889054\n",
|
||||
"epoch: 800, acc: 0.603, loss: 0.815, lr: 0.9260116677470135\n",
|
||||
"epoch: 900, acc: 0.643, loss: 0.787, lr: 0.9175153683824203\n",
|
||||
"epoch: 1000, acc: 0.617, loss: 0.810, lr: 0.9091735612328392\n",
|
||||
"epoch: 1100, acc: 0.637, loss: 0.750, lr: 0.9009820704567978\n",
|
||||
"epoch: 1200, acc: 0.677, loss: 0.744, lr: 0.892936869363336\n",
|
||||
"epoch: 1300, acc: 0.670, loss: 0.722, lr: 0.8850340738118416\n",
|
||||
"epoch: 1400, acc: 0.693, loss: 0.704, lr: 0.8772699359592947\n",
|
||||
"epoch: 1500, acc: 0.680, loss: 0.708, lr: 0.8696408383337683\n",
|
||||
"epoch: 1600, acc: 0.697, loss: 0.664, lr: 0.8621432882145013\n",
|
||||
"epoch: 1700, acc: 0.650, loss: 0.699, lr: 0.8547739123001966\n",
|
||||
"epoch: 1800, acc: 0.707, loss: 0.646, lr: 0.8475294516484448\n",
|
||||
"epoch: 1900, acc: 0.693, loss: 0.633, lr: 0.8404067568703253\n",
|
||||
"epoch: 2000, acc: 0.713, loss: 0.620, lr: 0.8334027835652972\n",
|
||||
"epoch: 2100, acc: 0.710, loss: 0.610, lr: 0.8265145879824779\n",
|
||||
"epoch: 2200, acc: 0.710, loss: 0.599, lr: 0.8197393228953193\n",
|
||||
"epoch: 2300, acc: 0.713, loss: 0.588, lr: 0.8130742336775347\n",
|
||||
"epoch: 2400, acc: 0.733, loss: 0.576, lr: 0.8065166545689169\n",
|
||||
"epoch: 2500, acc: 0.763, loss: 0.579, lr: 0.8000640051204096\n",
|
||||
"epoch: 2600, acc: 0.800, loss: 0.558, lr: 0.7937137868084768\n",
|
||||
"epoch: 2700, acc: 0.803, loss: 0.551, lr: 0.7874635798094338\n",
|
||||
"epoch: 2800, acc: 0.800, loss: 0.545, lr: 0.7813110399249941\n",
|
||||
"epoch: 2900, acc: 0.807, loss: 0.541, lr: 0.7752538956508256\n",
|
||||
"epoch: 3000, acc: 0.757, loss: 0.538, lr: 0.7692899453804138\n",
|
||||
"epoch: 3100, acc: 0.757, loss: 0.532, lr: 0.7634170547370028\n",
|
||||
"epoch: 3200, acc: 0.740, loss: 0.532, lr: 0.7576331540268202\n",
|
||||
"epoch: 3300, acc: 0.763, loss: 0.514, lr: 0.7519362358072035\n",
|
||||
"epoch: 3400, acc: 0.770, loss: 0.512, lr: 0.7463243525636241\n",
|
||||
"epoch: 3500, acc: 0.757, loss: 0.507, lr: 0.7407956144899621\n",
|
||||
"epoch: 3600, acc: 0.780, loss: 0.496, lr: 0.735348187366718\n",
|
||||
"epoch: 3700, acc: 0.787, loss: 0.497, lr: 0.7299802905321557\n",
|
||||
"epoch: 3800, acc: 0.767, loss: 0.491, lr: 0.7246901949416624\n",
|
||||
"epoch: 3900, acc: 0.783, loss: 0.483, lr: 0.7194762213108857\n",
|
||||
"epoch: 4000, acc: 0.790, loss: 0.484, lr: 0.7143367383384527\n",
|
||||
"epoch: 4100, acc: 0.783, loss: 0.477, lr: 0.7092701610043266\n",
|
||||
"epoch: 4200, acc: 0.780, loss: 0.487, lr: 0.7042749489400663\n",
|
||||
"epoch: 4300, acc: 0.787, loss: 0.467, lr: 0.6993496048674733\n",
|
||||
"epoch: 4400, acc: 0.793, loss: 0.465, lr: 0.6944926731022988\n",
|
||||
"epoch: 4500, acc: 0.787, loss: 0.460, lr: 0.6897027381198704\n",
|
||||
"epoch: 4600, acc: 0.777, loss: 0.477, lr: 0.6849784231796698\n",
|
||||
"epoch: 4700, acc: 0.783, loss: 0.454, lr: 0.6803183890060548\n",
|
||||
"epoch: 4800, acc: 0.800, loss: 0.448, lr: 0.6757213325224677\n",
|
||||
"epoch: 4900, acc: 0.800, loss: 0.441, lr: 0.6711859856366199\n",
|
||||
"epoch: 5000, acc: 0.797, loss: 0.437, lr: 0.6667111140742716\n",
|
||||
"epoch: 5100, acc: 0.800, loss: 0.433, lr: 0.6622955162593549\n",
|
||||
"epoch: 5200, acc: 0.813, loss: 0.429, lr: 0.6579380222383051\n",
|
||||
"epoch: 5300, acc: 0.810, loss: 0.426, lr: 0.6536374926465782\n",
|
||||
"epoch: 5400, acc: 0.813, loss: 0.423, lr: 0.649392817715436\n",
|
||||
"epoch: 5500, acc: 0.823, loss: 0.420, lr: 0.6452029163171817\n",
|
||||
"epoch: 5600, acc: 0.820, loss: 0.416, lr: 0.6410667350471184\n",
|
||||
"epoch: 5700, acc: 0.820, loss: 0.413, lr: 0.6369832473405949\n",
|
||||
"epoch: 5800, acc: 0.820, loss: 0.410, lr: 0.6329514526235838\n",
|
||||
"epoch: 5900, acc: 0.820, loss: 0.407, lr: 0.6289703754953141\n",
|
||||
"epoch: 6000, acc: 0.823, loss: 0.404, lr: 0.6250390649415589\n",
|
||||
"epoch: 6100, acc: 0.823, loss: 0.401, lr: 0.6211565935772407\n",
|
||||
"epoch: 6200, acc: 0.827, loss: 0.398, lr: 0.6173220569170937\n",
|
||||
"epoch: 6300, acc: 0.833, loss: 0.395, lr: 0.6135345726731701\n",
|
||||
"epoch: 6400, acc: 0.830, loss: 0.390, lr: 0.6097932800780536\n",
|
||||
"epoch: 6500, acc: 0.827, loss: 0.387, lr: 0.6060973392326807\n",
|
||||
"epoch: 6600, acc: 0.827, loss: 0.384, lr: 0.6024459304777396\n",
|
||||
"epoch: 6700, acc: 0.830, loss: 0.381, lr: 0.5988382537876519\n",
|
||||
"epoch: 6800, acc: 0.830, loss: 0.378, lr: 0.5952735281862016\n",
|
||||
"epoch: 6900, acc: 0.830, loss: 0.375, lr: 0.5917509911829102\n",
|
||||
"epoch: 7000, acc: 0.833, loss: 0.373, lr: 0.5882698982293076\n",
|
||||
"epoch: 7100, acc: 0.837, loss: 0.370, lr: 0.5848295221942803\n",
|
||||
"epoch: 7200, acc: 0.833, loss: 0.368, lr: 0.5814291528577243\n",
|
||||
"epoch: 7300, acc: 0.833, loss: 0.366, lr: 0.5780680964217585\n",
|
||||
"epoch: 7400, acc: 0.837, loss: 0.364, lr: 0.5747456750387954\n",
|
||||
"epoch: 7500, acc: 0.833, loss: 0.362, lr: 0.5714612263557918\n",
|
||||
"epoch: 7600, acc: 0.833, loss: 0.360, lr: 0.5682141030740383\n",
|
||||
"epoch: 7700, acc: 0.837, loss: 0.358, lr: 0.5650036725238714\n",
|
||||
"epoch: 7800, acc: 0.840, loss: 0.357, lr: 0.5618293162537221\n",
|
||||
"epoch: 7900, acc: 0.840, loss: 0.355, lr: 0.5586904296329404\n",
|
||||
"epoch: 8000, acc: 0.840, loss: 0.353, lr: 0.5555864214678593\n",
|
||||
"epoch: 8100, acc: 0.843, loss: 0.351, lr: 0.5525167136305873\n",
|
||||
"epoch: 8200, acc: 0.843, loss: 0.350, lr: 0.5494807407000385\n",
|
||||
"epoch: 8300, acc: 0.843, loss: 0.348, lr: 0.5464779496147331\n",
|
||||
"epoch: 8400, acc: 0.843, loss: 0.346, lr: 0.5435077993369205\n",
|
||||
"epoch: 8500, acc: 0.847, loss: 0.345, lr: 0.5405697605275961\n",
|
||||
"epoch: 8600, acc: 0.847, loss: 0.343, lr: 0.5376633152320017\n",
|
||||
"epoch: 8700, acc: 0.850, loss: 0.342, lr: 0.5347879565752179\n",
|
||||
"epoch: 8800, acc: 0.847, loss: 0.340, lr: 0.5319431884674717\n",
|
||||
"epoch: 8900, acc: 0.847, loss: 0.338, lr: 0.5291285253188\n",
|
||||
"epoch: 9000, acc: 0.847, loss: 0.337, lr: 0.5263434917627243\n",
|
||||
"epoch: 9100, acc: 0.850, loss: 0.336, lr: 0.5235876223886068\n",
|
||||
"epoch: 9200, acc: 0.850, loss: 0.335, lr: 0.5208604614823689\n",
|
||||
"epoch: 9300, acc: 0.847, loss: 0.334, lr: 0.5181615627752734\n",
|
||||
"epoch: 9400, acc: 0.847, loss: 0.333, lr: 0.5154904892004742\n",
|
||||
"epoch: 9500, acc: 0.850, loss: 0.331, lr: 0.5128468126570593\n",
|
||||
"epoch: 9600, acc: 0.850, loss: 0.330, lr: 0.5102301137813153\n",
|
||||
"epoch: 9700, acc: 0.850, loss: 0.329, lr: 0.5076399817249606\n",
|
||||
"epoch: 9800, acc: 0.850, loss: 0.328, lr: 0.5050760139400979\n",
|
||||
"epoch: 9900, acc: 0.850, loss: 0.326, lr: 0.5025378159706518\n",
|
||||
"epoch: 10000, acc: 0.850, loss: 0.325, lr: 0.5000250012500626\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -457,7 +459,239 @@
|
||||
" if not epoch % 100:\n",
|
||||
" print(f'epoch: {epoch}, ' +\n",
|
||||
" f'acc: {accuracy:.3f}, ' +\n",
|
||||
" f'loss: {loss:.3f}')\n",
|
||||
" f'loss: {loss:.3f}, ' +\n",
|
||||
" f'lr: {optimizer.current_learning_rate}')\n",
|
||||
" \n",
|
||||
" # Backward pass\n",
|
||||
" loss_activation.backward(loss_activation.output, y)\n",
|
||||
" dense2.backward(loss_activation.dinputs)\n",
|
||||
" activation1.backward(dense2.dinputs)\n",
|
||||
" dense1.backward(activation1.dinputs)\n",
|
||||
" \n",
|
||||
" # Update weights and biases\n",
|
||||
" optimizer.pre_update_params()\n",
|
||||
" optimizer.update_params(dense1)\n",
|
||||
" optimizer.update_params(dense2)\n",
|
||||
" optimizer.post_update_params()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# RMSProp Optimizer\n",
|
||||
"Root Meas Square Propagation optimizer. It is similar to AdaGrad in that you apply different learning rates to different weights. However, the way you change the learning rate focuses more on the past cache rather than the current gradient."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Optimizer_RMSProp():\n",
|
||||
" def __init__(self, learning_rate=1e-3, decay=0.0, epsilon=1e-7, rho=0.9):\n",
|
||||
" self.initial_learning_rate = learning_rate\n",
|
||||
" self.current_learning_rate = self.initial_learning_rate\n",
|
||||
" self.decay = decay\n",
|
||||
" self.iterations = 0\n",
|
||||
" self.epsilon = epsilon\n",
|
||||
" self.rho = rho\n",
|
||||
"\n",
|
||||
" def pre_update_params(self):\n",
|
||||
" if self.decay:\n",
|
||||
" self.current_learning_rate = self.initial_learning_rate / (1 + self.decay * self.iterations)\n",
|
||||
"\n",
|
||||
" def update_params(self, layer):\n",
|
||||
" if not hasattr(layer, 'weight_cache'):\n",
|
||||
" layer.weight_cache = np.zeros_like(layer.weights)\n",
|
||||
" layer.bias_cache = np.zeros_like(layer.biases)\n",
|
||||
"\n",
|
||||
" layer.weight_cache = self.rho * layer.weight_cache + (1 - self.rho) * layer.dweights**2\n",
|
||||
" layer.bias_cache = self.rho * layer.bias_cache + (1 - self.rho) * layer.dbiases**2\n",
|
||||
"\n",
|
||||
" layer.weights += -self.current_learning_rate * layer.dweights / (np.sqrt(layer.weight_cache) + self.epsilon)\n",
|
||||
" layer.biases += -self.current_learning_rate * layer.dbiases / (np.sqrt(layer.bias_cache) + self.epsilon)\n",
|
||||
"\n",
|
||||
" def post_update_params(self):\n",
|
||||
" self.iterations += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Testing the RMSProp Optimizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 0, acc: 0.413, loss: 1.099, lr: 0.02\n",
|
||||
"epoch: 100, acc: 0.467, loss: 0.980, lr: 0.01998021958261321\n",
|
||||
"epoch: 200, acc: 0.480, loss: 0.904, lr: 0.019960279044701046\n",
|
||||
"epoch: 300, acc: 0.507, loss: 0.864, lr: 0.019940378268975763\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 400, acc: 0.490, loss: 0.861, lr: 0.01992051713662487\n",
|
||||
"epoch: 500, acc: 0.510, loss: 0.843, lr: 0.01990069552930875\n",
|
||||
"epoch: 600, acc: 0.613, loss: 0.779, lr: 0.019880913329158343\n",
|
||||
"epoch: 700, acc: 0.633, loss: 0.741, lr: 0.019861170418772778\n",
|
||||
"epoch: 800, acc: 0.577, loss: 0.722, lr: 0.019841466681217078\n",
|
||||
"epoch: 900, acc: 0.627, loss: 0.698, lr: 0.01982180200001982\n",
|
||||
"epoch: 1000, acc: 0.630, loss: 0.675, lr: 0.019802176259170884\n",
|
||||
"epoch: 1100, acc: 0.657, loss: 0.661, lr: 0.01978258934311912\n",
|
||||
"epoch: 1200, acc: 0.623, loss: 0.654, lr: 0.01976304113677013\n",
|
||||
"epoch: 1300, acc: 0.663, loss: 0.640, lr: 0.019743531525483964\n",
|
||||
"epoch: 1400, acc: 0.627, loss: 0.672, lr: 0.01972406039507293\n",
|
||||
"epoch: 1500, acc: 0.663, loss: 0.618, lr: 0.019704627631799327\n",
|
||||
"epoch: 1600, acc: 0.693, loss: 0.607, lr: 0.019685233122373254\n",
|
||||
"epoch: 1700, acc: 0.657, loss: 0.658, lr: 0.019665876753950384\n",
|
||||
"epoch: 1800, acc: 0.730, loss: 0.587, lr: 0.019646558414129805\n",
|
||||
"epoch: 1900, acc: 0.690, loss: 0.623, lr: 0.019627277990951823\n",
|
||||
"epoch: 2000, acc: 0.730, loss: 0.573, lr: 0.019608035372895814\n",
|
||||
"epoch: 2100, acc: 0.743, loss: 0.576, lr: 0.019588830448878047\n",
|
||||
"epoch: 2200, acc: 0.740, loss: 0.560, lr: 0.019569663108249594\n",
|
||||
"epoch: 2300, acc: 0.710, loss: 0.567, lr: 0.019550533240794143\n",
|
||||
"epoch: 2400, acc: 0.740, loss: 0.548, lr: 0.019531440736725945\n",
|
||||
"epoch: 2500, acc: 0.687, loss: 0.576, lr: 0.019512385486687673\n",
|
||||
"epoch: 2600, acc: 0.710, loss: 0.550, lr: 0.01949336738174836\n",
|
||||
"epoch: 2700, acc: 0.707, loss: 0.573, lr: 0.019474386313401298\n",
|
||||
"epoch: 2800, acc: 0.760, loss: 0.523, lr: 0.019455442173562\n",
|
||||
"epoch: 2900, acc: 0.770, loss: 0.525, lr: 0.019436534854566128\n",
|
||||
"epoch: 3000, acc: 0.787, loss: 0.512, lr: 0.01941766424916747\n",
|
||||
"epoch: 3100, acc: 0.763, loss: 0.517, lr: 0.019398830250535893\n",
|
||||
"epoch: 3200, acc: 0.750, loss: 0.551, lr: 0.019380032752255354\n",
|
||||
"epoch: 3300, acc: 0.803, loss: 0.498, lr: 0.01936127164832186\n",
|
||||
"epoch: 3400, acc: 0.780, loss: 0.493, lr: 0.01934254683314152\n",
|
||||
"epoch: 3500, acc: 0.780, loss: 0.514, lr: 0.019323858201528515\n",
|
||||
"epoch: 3600, acc: 0.793, loss: 0.522, lr: 0.019305205648703173\n",
|
||||
"epoch: 3700, acc: 0.780, loss: 0.516, lr: 0.01928658907028997\n",
|
||||
"epoch: 3800, acc: 0.790, loss: 0.508, lr: 0.01926800836231563\n",
|
||||
"epoch: 3900, acc: 0.750, loss: 0.523, lr: 0.019249463421207133\n",
|
||||
"epoch: 4000, acc: 0.763, loss: 0.516, lr: 0.019230954143789846\n",
|
||||
"epoch: 4100, acc: 0.787, loss: 0.499, lr: 0.019212480427285565\n",
|
||||
"epoch: 4200, acc: 0.770, loss: 0.512, lr: 0.019194042169310647\n",
|
||||
"epoch: 4300, acc: 0.790, loss: 0.496, lr: 0.019175639267874092\n",
|
||||
"epoch: 4400, acc: 0.777, loss: 0.501, lr: 0.019157271621375684\n",
|
||||
"epoch: 4500, acc: 0.810, loss: 0.479, lr: 0.0191389391286041\n",
|
||||
"epoch: 4600, acc: 0.783, loss: 0.482, lr: 0.019120641688735073\n",
|
||||
"epoch: 4700, acc: 0.797, loss: 0.463, lr: 0.019102379201329525\n",
|
||||
"epoch: 4800, acc: 0.787, loss: 0.469, lr: 0.01908415156633174\n",
|
||||
"epoch: 4900, acc: 0.777, loss: 0.497, lr: 0.01906595868406753\n",
|
||||
"epoch: 5000, acc: 0.810, loss: 0.445, lr: 0.01904780045524243\n",
|
||||
"epoch: 5100, acc: 0.793, loss: 0.450, lr: 0.019029676780939874\n",
|
||||
"epoch: 5200, acc: 0.810, loss: 0.438, lr: 0.019011587562619416\n",
|
||||
"epoch: 5300, acc: 0.797, loss: 0.452, lr: 0.01899353270211493\n",
|
||||
"epoch: 5400, acc: 0.800, loss: 0.453, lr: 0.018975512101632844\n",
|
||||
"epoch: 5500, acc: 0.820, loss: 0.419, lr: 0.018957525663750367\n",
|
||||
"epoch: 5600, acc: 0.817, loss: 0.433, lr: 0.018939573291413745\n",
|
||||
"epoch: 5700, acc: 0.763, loss: 0.533, lr: 0.018921654887936498\n",
|
||||
"epoch: 5800, acc: 0.820, loss: 0.411, lr: 0.018903770356997703\n",
|
||||
"epoch: 5900, acc: 0.817, loss: 0.424, lr: 0.01888591960264025\n",
|
||||
"epoch: 6000, acc: 0.810, loss: 0.419, lr: 0.018868102529269144\n",
|
||||
"epoch: 6100, acc: 0.827, loss: 0.403, lr: 0.018850319041649778\n",
|
||||
"epoch: 6200, acc: 0.820, loss: 0.413, lr: 0.018832569044906263\n",
|
||||
"epoch: 6300, acc: 0.820, loss: 0.414, lr: 0.018814852444519702\n",
|
||||
"epoch: 6400, acc: 0.810, loss: 0.410, lr: 0.018797169146326564\n",
|
||||
"epoch: 6500, acc: 0.830, loss: 0.389, lr: 0.018779519056516963\n",
|
||||
"epoch: 6600, acc: 0.823, loss: 0.407, lr: 0.018761902081633038\n",
|
||||
"epoch: 6700, acc: 0.823, loss: 0.403, lr: 0.018744318128567278\n",
|
||||
"epoch: 6800, acc: 0.820, loss: 0.405, lr: 0.018726767104560903\n",
|
||||
"epoch: 6900, acc: 0.823, loss: 0.386, lr: 0.018709248917202218\n",
|
||||
"epoch: 7000, acc: 0.827, loss: 0.393, lr: 0.018691763474424996\n",
|
||||
"epoch: 7100, acc: 0.800, loss: 0.428, lr: 0.018674310684506857\n",
|
||||
"epoch: 7200, acc: 0.833, loss: 0.378, lr: 0.018656890456067686\n",
|
||||
"epoch: 7300, acc: 0.820, loss: 0.382, lr: 0.01863950269806802\n",
|
||||
"epoch: 7400, acc: 0.810, loss: 0.438, lr: 0.018622147319807447\n",
|
||||
"epoch: 7500, acc: 0.833, loss: 0.379, lr: 0.018604824230923078\n",
|
||||
"epoch: 7600, acc: 0.837, loss: 0.351, lr: 0.01858753334138793\n",
|
||||
"epoch: 7700, acc: 0.827, loss: 0.397, lr: 0.018570274561509396\n",
|
||||
"epoch: 7800, acc: 0.837, loss: 0.369, lr: 0.018553047801927663\n",
|
||||
"epoch: 7900, acc: 0.787, loss: 0.458, lr: 0.018535852973614212\n",
|
||||
"epoch: 8000, acc: 0.840, loss: 0.369, lr: 0.01851868998787026\n",
|
||||
"epoch: 8100, acc: 0.863, loss: 0.336, lr: 0.018501558756325222\n",
|
||||
"epoch: 8200, acc: 0.840, loss: 0.366, lr: 0.01848445919093522\n",
|
||||
"epoch: 8300, acc: 0.833, loss: 0.369, lr: 0.018467391203981567\n",
|
||||
"epoch: 8400, acc: 0.837, loss: 0.355, lr: 0.01845035470806926\n",
|
||||
"epoch: 8500, acc: 0.840, loss: 0.357, lr: 0.018433349616125496\n",
|
||||
"epoch: 8600, acc: 0.857, loss: 0.329, lr: 0.018416375841398172\n",
|
||||
"epoch: 8700, acc: 0.843, loss: 0.352, lr: 0.018399433297454436\n",
|
||||
"epoch: 8800, acc: 0.843, loss: 0.356, lr: 0.01838252189817921\n",
|
||||
"epoch: 8900, acc: 0.797, loss: 0.447, lr: 0.018365641557773718\n",
|
||||
"epoch: 9000, acc: 0.847, loss: 0.354, lr: 0.018348792190754044\n",
|
||||
"epoch: 9100, acc: 0.840, loss: 0.349, lr: 0.0183319737119497\n",
|
||||
"epoch: 9200, acc: 0.853, loss: 0.337, lr: 0.018315186036502167\n",
|
||||
"epoch: 9300, acc: 0.847, loss: 0.350, lr: 0.018298429079863496\n",
|
||||
"epoch: 9400, acc: 0.823, loss: 0.383, lr: 0.018281702757794862\n",
|
||||
"epoch: 9500, acc: 0.853, loss: 0.338, lr: 0.018265006986365174\n",
|
||||
"epoch: 9600, acc: 0.780, loss: 0.522, lr: 0.018248341681949654\n",
|
||||
"epoch: 9700, acc: 0.840, loss: 0.335, lr: 0.018231706761228456\n",
|
||||
"epoch: 9800, acc: 0.853, loss: 0.334, lr: 0.01821510214118526\n",
|
||||
"epoch: 9900, acc: 0.723, loss: 0.696, lr: 0.018198527739105907\n",
|
||||
"epoch: 10000, acc: 0.860, loss: 0.314, lr: 0.018181983472577025\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create dataset\n",
|
||||
"X, y = spiral_data(samples=100, classes=3)\n",
|
||||
"\n",
|
||||
"# Create Dense layer with 2 input features and 64 output values\n",
|
||||
"dense1 = Layer_Dense(2, 64)\n",
|
||||
"\n",
|
||||
"# Create ReLU activation (to be used with Dense layer)\n",
|
||||
"activation1 = Activation_ReLU()\n",
|
||||
"\n",
|
||||
"# Create second Dense layer with 64 input features (as we take output\n",
|
||||
"# of previous layer here) and 3 output values (output values)\n",
|
||||
"dense2 = Layer_Dense(64, 3)\n",
|
||||
"\n",
|
||||
"# Create Softmax classifier's combined loss and activation\n",
|
||||
"loss_activation = Activation_Softmax_Loss_CategoricalCrossentropy()\n",
|
||||
"\n",
|
||||
"# Create optimizer\n",
|
||||
"optimizer = Optimizer_RMSProp(learning_rate=0.02, decay=1e-5, rho=0.999)\n",
|
||||
"\n",
|
||||
"# Train in loop\n",
|
||||
"for epoch in range(10001):\n",
|
||||
" # Perform a forward pass of our training data through this layer\n",
|
||||
" dense1.forward(X)\n",
|
||||
" \n",
|
||||
" # Perform a forward pass through activation function\n",
|
||||
" # takes the output of first dense layer here\n",
|
||||
" activation1.forward(dense1.output)\n",
|
||||
" \n",
|
||||
" # Perform a forward pass through second Dense layer\n",
|
||||
" # takes outputs of activation function of first layer as inputs\n",
|
||||
" dense2.forward(activation1.output)\n",
|
||||
" \n",
|
||||
" # Perform a forward pass through the activation/loss function\n",
|
||||
" # takes the output of second dense layer here and returns loss\n",
|
||||
" loss = loss_activation.forward(dense2.output, y)\n",
|
||||
" \n",
|
||||
" # Calculate accuracy from output of activation2 and targets\n",
|
||||
" # calculate values along first axis\n",
|
||||
" predictions = np.argmax(loss_activation.output, axis=1)\n",
|
||||
" if len(y.shape) == 2:\n",
|
||||
" y = np.argmax(y, axis=1)\n",
|
||||
" accuracy = np.mean(predictions == y)\n",
|
||||
" \n",
|
||||
" if not epoch % 100:\n",
|
||||
" print(f'epoch: {epoch}, ' +\n",
|
||||
" f'acc: {accuracy:.3f}, ' +\n",
|
||||
" f'loss: {loss:.3f}, ' +\n",
|
||||
" f'lr: {optimizer.current_learning_rate}')\n",
|
||||
" \n",
|
||||
" # Backward pass\n",
|
||||
" loss_activation.backward(loss_activation.output, y)\n",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user