Lecture 26, RMSProp optimizer

This commit is contained in:
judsonupchurch 2025-01-21 02:13:32 +00:00
parent 2192bf3050
commit 036d06a652
4 changed files with 558 additions and 1220 deletions

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -221,7 +221,7 @@ class Optimizer_SGD():
self.iterations += 1
# %% [markdown]
# # Testing the Learning Rate Decay
# ## Testing the Learning Rate Decay
# %%
# Create dataset
@ -270,7 +270,8 @@ for epoch in range(10001):
if not epoch % 100:
print(f'epoch: {epoch}, ' +
f'acc: {accuracy:.3f}, ' +
f'loss: {loss:.3f}')
f'loss: {loss:.3f}, ' +
f'lr: {optimizer.current_learning_rate}')
# Backward pass
loss_activation.backward(loss_activation.output, y)
@ -335,7 +336,7 @@ class Optimizer_SGD():
self.iterations += 1
# %% [markdown]
# # Testing the Gradient Optimizer with Momentum
# ## Testing the Gradient Optimizer with Momentum
# %%
# Create dataset
@ -384,7 +385,8 @@ for epoch in range(10001):
if not epoch % 100:
print(f'epoch: {epoch}, ' +
f'acc: {accuracy:.3f}, ' +
f'loss: {loss:.3f}')
f'loss: {loss:.3f}, ' +
f'lr: {optimizer.current_learning_rate}')
# Backward pass
loss_activation.backward(loss_activation.output, y)

View File

@ -254,12 +254,14 @@
"metadata": {},
"source": [
"# AdaGrad Optimizer\n",
"Different weights should have different learning rates. If one weight affects the loss much more strongly than the other, then consider using smaller learning rates with it. We can do this by maintaining a \"cache\" of the last gradients and normalizing based on this."
"Different weights should have different learning rates. If one weight affects the loss much more strongly than the other, then consider using smaller learning rates with it. We can do this by maintaining a \"cache\" of the last gradients and normalizing based on this.\n",
"\n",
"A downside is that as the cache keeps accumulating, some neurons will have such a small learning rate that the neuron basically becomes fixed."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -269,7 +271,7 @@
" self.current_learning_rate = self.initial_learning_rate\n",
" self.decay = decay\n",
" self.iterations = 0\n",
" self.epsilon = 0\n",
" self.epsilon = epsilon\n",
"\n",
" def pre_update_params(self):\n",
" if self.decay:\n",
@ -299,114 +301,114 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, acc: 0.360, loss: 1.099\n",
"epoch: 100, acc: 0.477, loss: 0.992\n",
"epoch: 200, acc: 0.560, loss: 0.934\n",
"epoch: 300, acc: 0.600, loss: 0.881\n",
"epoch: 400, acc: 0.637, loss: 0.826\n",
"epoch: 500, acc: 0.617, loss: 0.793\n",
"epoch: 600, acc: 0.643, loss: 0.748\n",
"epoch: 700, acc: 0.657, loss: 0.726\n",
"epoch: 800, acc: 0.663, loss: 0.698\n",
"epoch: 900, acc: 0.680, loss: 0.681\n",
"epoch: 1000, acc: 0.683, loss: 0.664\n",
"epoch: 1100, acc: 0.693, loss: 0.653\n",
"epoch: 1200, acc: 0.693, loss: 0.642\n",
"epoch: 1300, acc: 0.707, loss: 0.632\n",
"epoch: 1400, acc: 0.720, loss: 0.625\n",
"epoch: 1500, acc: 0.717, loss: 0.615\n",
"epoch: 1600, acc: 0.723, loss: 0.608\n",
"epoch: 1700, acc: 0.730, loss: 0.600\n",
"epoch: 1800, acc: 0.740, loss: 0.588\n",
"epoch: 1900, acc: 0.740, loss: 0.581\n",
"epoch: 2000, acc: 0.743, loss: 0.576\n",
"epoch: 2100, acc: 0.740, loss: 0.570\n",
"epoch: 2200, acc: 0.750, loss: 0.563\n",
"epoch: 2300, acc: 0.747, loss: 0.562\n",
"epoch: 2400, acc: 0.757, loss: 0.557\n",
"epoch: 2500, acc: 0.760, loss: 0.554\n",
"epoch: 2600, acc: 0.770, loss: 0.550\n",
"epoch: 2700, acc: 0.777, loss: 0.546\n",
"epoch: 2800, acc: 0.777, loss: 0.543\n",
"epoch: 2900, acc: 0.780, loss: 0.540\n",
"epoch: 3000, acc: 0.777, loss: 0.537\n",
"epoch: 3100, acc: 0.773, loss: 0.533\n",
"epoch: 3200, acc: 0.783, loss: 0.531\n",
"epoch: 3300, acc: 0.777, loss: 0.528\n",
"epoch: 3400, acc: 0.777, loss: 0.526\n",
"epoch: 3500, acc: 0.773, loss: 0.523\n",
"epoch: 3600, acc: 0.780, loss: 0.522\n",
"epoch: 3700, acc: 0.773, loss: 0.520\n",
"epoch: 3800, acc: 0.777, loss: 0.518\n",
"epoch: 3900, acc: 0.780, loss: 0.516\n",
"epoch: 4000, acc: 0.780, loss: 0.515\n",
"epoch: 4100, acc: 0.773, loss: 0.513\n",
"epoch: 4200, acc: 0.770, loss: 0.511\n",
"epoch: 4300, acc: 0.773, loss: 0.510\n",
"epoch: 4400, acc: 0.777, loss: 0.509\n",
"epoch: 4500, acc: 0.777, loss: 0.508\n",
"epoch: 4600, acc: 0.780, loss: 0.507\n",
"epoch: 4700, acc: 0.777, loss: 0.506\n",
"epoch: 4800, acc: 0.777, loss: 0.504\n",
"epoch: 4900, acc: 0.783, loss: 0.503\n",
"epoch: 5000, acc: 0.783, loss: 0.503\n",
"epoch: 5100, acc: 0.787, loss: 0.502\n",
"epoch: 5200, acc: 0.790, loss: 0.501\n",
"epoch: 5300, acc: 0.787, loss: 0.500\n",
"epoch: 5400, acc: 0.787, loss: 0.498\n",
"epoch: 5500, acc: 0.783, loss: 0.497\n",
"epoch: 5600, acc: 0.787, loss: 0.496\n",
"epoch: 5700, acc: 0.780, loss: 0.495\n",
"epoch: 5800, acc: 0.780, loss: 0.495\n",
"epoch: 5900, acc: 0.783, loss: 0.494\n",
"epoch: 6000, acc: 0.790, loss: 0.494\n",
"epoch: 6100, acc: 0.777, loss: 0.493\n",
"epoch: 6200, acc: 0.783, loss: 0.492\n",
"epoch: 6300, acc: 0.783, loss: 0.491\n",
"epoch: 6400, acc: 0.790, loss: 0.490\n",
"epoch: 6500, acc: 0.780, loss: 0.488\n",
"epoch: 6600, acc: 0.780, loss: 0.487\n",
"epoch: 6700, acc: 0.777, loss: 0.485\n",
"epoch: 6800, acc: 0.780, loss: 0.483\n",
"epoch: 6900, acc: 0.783, loss: 0.482\n",
"epoch: 7000, acc: 0.790, loss: 0.480\n",
"epoch: 7100, acc: 0.790, loss: 0.479\n",
"epoch: 7200, acc: 0.797, loss: 0.477\n",
"epoch: 7300, acc: 0.803, loss: 0.476\n",
"epoch: 7400, acc: 0.813, loss: 0.475\n",
"epoch: 7500, acc: 0.813, loss: 0.474\n",
"epoch: 7600, acc: 0.813, loss: 0.472\n",
"epoch: 7700, acc: 0.813, loss: 0.471\n",
"epoch: 7800, acc: 0.810, loss: 0.470\n",
"epoch: 7900, acc: 0.810, loss: 0.469\n",
"epoch: 8000, acc: 0.810, loss: 0.468\n",
"epoch: 8100, acc: 0.810, loss: 0.465\n",
"epoch: 8200, acc: 0.807, loss: 0.463\n",
"epoch: 8300, acc: 0.803, loss: 0.462\n",
"epoch: 8400, acc: 0.803, loss: 0.461\n",
"epoch: 8500, acc: 0.807, loss: 0.459\n",
"epoch: 8600, acc: 0.810, loss: 0.458\n",
"epoch: 8700, acc: 0.813, loss: 0.458\n",
"epoch: 8800, acc: 0.810, loss: 0.456\n",
"epoch: 8900, acc: 0.810, loss: 0.455\n",
"epoch: 9000, acc: 0.813, loss: 0.452\n",
"epoch: 9100, acc: 0.813, loss: 0.450\n",
"epoch: 9200, acc: 0.817, loss: 0.448\n",
"epoch: 9300, acc: 0.810, loss: 0.447\n",
"epoch: 9400, acc: 0.810, loss: 0.446\n",
"epoch: 9500, acc: 0.813, loss: 0.444\n",
"epoch: 9600, acc: 0.813, loss: 0.441\n",
"epoch: 9700, acc: 0.817, loss: 0.440\n",
"epoch: 9800, acc: 0.817, loss: 0.438\n",
"epoch: 9900, acc: 0.813, loss: 0.436\n",
"epoch: 10000, acc: 0.813, loss: 0.435\n"
"epoch: 0, acc: 0.353, loss: 1.099, lr: 1.0\n",
"epoch: 100, acc: 0.497, loss: 0.986, lr: 0.9901970492127933\n",
"epoch: 200, acc: 0.527, loss: 0.936, lr: 0.9804882831650161\n",
"epoch: 300, acc: 0.513, loss: 0.918, lr: 0.9709680551509855\n",
"epoch: 400, acc: 0.580, loss: 0.904, lr: 0.9616309260505818\n",
"epoch: 500, acc: 0.550, loss: 0.910, lr: 0.9524716639679969\n",
"epoch: 600, acc: 0.563, loss: 0.860, lr: 0.9434852344560807\n",
"epoch: 700, acc: 0.600, loss: 0.838, lr: 0.9346667912889054\n",
"epoch: 800, acc: 0.603, loss: 0.815, lr: 0.9260116677470135\n",
"epoch: 900, acc: 0.643, loss: 0.787, lr: 0.9175153683824203\n",
"epoch: 1000, acc: 0.617, loss: 0.810, lr: 0.9091735612328392\n",
"epoch: 1100, acc: 0.637, loss: 0.750, lr: 0.9009820704567978\n",
"epoch: 1200, acc: 0.677, loss: 0.744, lr: 0.892936869363336\n",
"epoch: 1300, acc: 0.670, loss: 0.722, lr: 0.8850340738118416\n",
"epoch: 1400, acc: 0.693, loss: 0.704, lr: 0.8772699359592947\n",
"epoch: 1500, acc: 0.680, loss: 0.708, lr: 0.8696408383337683\n",
"epoch: 1600, acc: 0.697, loss: 0.664, lr: 0.8621432882145013\n",
"epoch: 1700, acc: 0.650, loss: 0.699, lr: 0.8547739123001966\n",
"epoch: 1800, acc: 0.707, loss: 0.646, lr: 0.8475294516484448\n",
"epoch: 1900, acc: 0.693, loss: 0.633, lr: 0.8404067568703253\n",
"epoch: 2000, acc: 0.713, loss: 0.620, lr: 0.8334027835652972\n",
"epoch: 2100, acc: 0.710, loss: 0.610, lr: 0.8265145879824779\n",
"epoch: 2200, acc: 0.710, loss: 0.599, lr: 0.8197393228953193\n",
"epoch: 2300, acc: 0.713, loss: 0.588, lr: 0.8130742336775347\n",
"epoch: 2400, acc: 0.733, loss: 0.576, lr: 0.8065166545689169\n",
"epoch: 2500, acc: 0.763, loss: 0.579, lr: 0.8000640051204096\n",
"epoch: 2600, acc: 0.800, loss: 0.558, lr: 0.7937137868084768\n",
"epoch: 2700, acc: 0.803, loss: 0.551, lr: 0.7874635798094338\n",
"epoch: 2800, acc: 0.800, loss: 0.545, lr: 0.7813110399249941\n",
"epoch: 2900, acc: 0.807, loss: 0.541, lr: 0.7752538956508256\n",
"epoch: 3000, acc: 0.757, loss: 0.538, lr: 0.7692899453804138\n",
"epoch: 3100, acc: 0.757, loss: 0.532, lr: 0.7634170547370028\n",
"epoch: 3200, acc: 0.740, loss: 0.532, lr: 0.7576331540268202\n",
"epoch: 3300, acc: 0.763, loss: 0.514, lr: 0.7519362358072035\n",
"epoch: 3400, acc: 0.770, loss: 0.512, lr: 0.7463243525636241\n",
"epoch: 3500, acc: 0.757, loss: 0.507, lr: 0.7407956144899621\n",
"epoch: 3600, acc: 0.780, loss: 0.496, lr: 0.735348187366718\n",
"epoch: 3700, acc: 0.787, loss: 0.497, lr: 0.7299802905321557\n",
"epoch: 3800, acc: 0.767, loss: 0.491, lr: 0.7246901949416624\n",
"epoch: 3900, acc: 0.783, loss: 0.483, lr: 0.7194762213108857\n",
"epoch: 4000, acc: 0.790, loss: 0.484, lr: 0.7143367383384527\n",
"epoch: 4100, acc: 0.783, loss: 0.477, lr: 0.7092701610043266\n",
"epoch: 4200, acc: 0.780, loss: 0.487, lr: 0.7042749489400663\n",
"epoch: 4300, acc: 0.787, loss: 0.467, lr: 0.6993496048674733\n",
"epoch: 4400, acc: 0.793, loss: 0.465, lr: 0.6944926731022988\n",
"epoch: 4500, acc: 0.787, loss: 0.460, lr: 0.6897027381198704\n",
"epoch: 4600, acc: 0.777, loss: 0.477, lr: 0.6849784231796698\n",
"epoch: 4700, acc: 0.783, loss: 0.454, lr: 0.6803183890060548\n",
"epoch: 4800, acc: 0.800, loss: 0.448, lr: 0.6757213325224677\n",
"epoch: 4900, acc: 0.800, loss: 0.441, lr: 0.6711859856366199\n",
"epoch: 5000, acc: 0.797, loss: 0.437, lr: 0.6667111140742716\n",
"epoch: 5100, acc: 0.800, loss: 0.433, lr: 0.6622955162593549\n",
"epoch: 5200, acc: 0.813, loss: 0.429, lr: 0.6579380222383051\n",
"epoch: 5300, acc: 0.810, loss: 0.426, lr: 0.6536374926465782\n",
"epoch: 5400, acc: 0.813, loss: 0.423, lr: 0.649392817715436\n",
"epoch: 5500, acc: 0.823, loss: 0.420, lr: 0.6452029163171817\n",
"epoch: 5600, acc: 0.820, loss: 0.416, lr: 0.6410667350471184\n",
"epoch: 5700, acc: 0.820, loss: 0.413, lr: 0.6369832473405949\n",
"epoch: 5800, acc: 0.820, loss: 0.410, lr: 0.6329514526235838\n",
"epoch: 5900, acc: 0.820, loss: 0.407, lr: 0.6289703754953141\n",
"epoch: 6000, acc: 0.823, loss: 0.404, lr: 0.6250390649415589\n",
"epoch: 6100, acc: 0.823, loss: 0.401, lr: 0.6211565935772407\n",
"epoch: 6200, acc: 0.827, loss: 0.398, lr: 0.6173220569170937\n",
"epoch: 6300, acc: 0.833, loss: 0.395, lr: 0.6135345726731701\n",
"epoch: 6400, acc: 0.830, loss: 0.390, lr: 0.6097932800780536\n",
"epoch: 6500, acc: 0.827, loss: 0.387, lr: 0.6060973392326807\n",
"epoch: 6600, acc: 0.827, loss: 0.384, lr: 0.6024459304777396\n",
"epoch: 6700, acc: 0.830, loss: 0.381, lr: 0.5988382537876519\n",
"epoch: 6800, acc: 0.830, loss: 0.378, lr: 0.5952735281862016\n",
"epoch: 6900, acc: 0.830, loss: 0.375, lr: 0.5917509911829102\n",
"epoch: 7000, acc: 0.833, loss: 0.373, lr: 0.5882698982293076\n",
"epoch: 7100, acc: 0.837, loss: 0.370, lr: 0.5848295221942803\n",
"epoch: 7200, acc: 0.833, loss: 0.368, lr: 0.5814291528577243\n",
"epoch: 7300, acc: 0.833, loss: 0.366, lr: 0.5780680964217585\n",
"epoch: 7400, acc: 0.837, loss: 0.364, lr: 0.5747456750387954\n",
"epoch: 7500, acc: 0.833, loss: 0.362, lr: 0.5714612263557918\n",
"epoch: 7600, acc: 0.833, loss: 0.360, lr: 0.5682141030740383\n",
"epoch: 7700, acc: 0.837, loss: 0.358, lr: 0.5650036725238714\n",
"epoch: 7800, acc: 0.840, loss: 0.357, lr: 0.5618293162537221\n",
"epoch: 7900, acc: 0.840, loss: 0.355, lr: 0.5586904296329404\n",
"epoch: 8000, acc: 0.840, loss: 0.353, lr: 0.5555864214678593\n",
"epoch: 8100, acc: 0.843, loss: 0.351, lr: 0.5525167136305873\n",
"epoch: 8200, acc: 0.843, loss: 0.350, lr: 0.5494807407000385\n",
"epoch: 8300, acc: 0.843, loss: 0.348, lr: 0.5464779496147331\n",
"epoch: 8400, acc: 0.843, loss: 0.346, lr: 0.5435077993369205\n",
"epoch: 8500, acc: 0.847, loss: 0.345, lr: 0.5405697605275961\n",
"epoch: 8600, acc: 0.847, loss: 0.343, lr: 0.5376633152320017\n",
"epoch: 8700, acc: 0.850, loss: 0.342, lr: 0.5347879565752179\n",
"epoch: 8800, acc: 0.847, loss: 0.340, lr: 0.5319431884674717\n",
"epoch: 8900, acc: 0.847, loss: 0.338, lr: 0.5291285253188\n",
"epoch: 9000, acc: 0.847, loss: 0.337, lr: 0.5263434917627243\n",
"epoch: 9100, acc: 0.850, loss: 0.336, lr: 0.5235876223886068\n",
"epoch: 9200, acc: 0.850, loss: 0.335, lr: 0.5208604614823689\n",
"epoch: 9300, acc: 0.847, loss: 0.334, lr: 0.5181615627752734\n",
"epoch: 9400, acc: 0.847, loss: 0.333, lr: 0.5154904892004742\n",
"epoch: 9500, acc: 0.850, loss: 0.331, lr: 0.5128468126570593\n",
"epoch: 9600, acc: 0.850, loss: 0.330, lr: 0.5102301137813153\n",
"epoch: 9700, acc: 0.850, loss: 0.329, lr: 0.5076399817249606\n",
"epoch: 9800, acc: 0.850, loss: 0.328, lr: 0.5050760139400979\n",
"epoch: 9900, acc: 0.850, loss: 0.326, lr: 0.5025378159706518\n",
"epoch: 10000, acc: 0.850, loss: 0.325, lr: 0.5000250012500626\n"
]
}
],
@ -457,7 +459,239 @@
" if not epoch % 100:\n",
" print(f'epoch: {epoch}, ' +\n",
" f'acc: {accuracy:.3f}, ' +\n",
" f'loss: {loss:.3f}')\n",
" f'loss: {loss:.3f}, ' +\n",
" f'lr: {optimizer.current_learning_rate}')\n",
" \n",
" # Backward pass\n",
" loss_activation.backward(loss_activation.output, y)\n",
" dense2.backward(loss_activation.dinputs)\n",
" activation1.backward(dense2.dinputs)\n",
" dense1.backward(activation1.dinputs)\n",
" \n",
" # Update weights and biases\n",
" optimizer.pre_update_params()\n",
" optimizer.update_params(dense1)\n",
" optimizer.update_params(dense2)\n",
" optimizer.post_update_params()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RMSProp Optimizer\n",
"Root Meas Square Propagation optimizer. It is similar to AdaGrad in that you apply different learning rates to different weights. However, the way you change the learning rate focuses more on the past cache rather than the current gradient."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class Optimizer_RMSProp():\n",
" def __init__(self, learning_rate=1e-3, decay=0.0, epsilon=1e-7, rho=0.9):\n",
" self.initial_learning_rate = learning_rate\n",
" self.current_learning_rate = self.initial_learning_rate\n",
" self.decay = decay\n",
" self.iterations = 0\n",
" self.epsilon = epsilon\n",
" self.rho = rho\n",
"\n",
" def pre_update_params(self):\n",
" if self.decay:\n",
" self.current_learning_rate = self.initial_learning_rate / (1 + self.decay * self.iterations)\n",
"\n",
" def update_params(self, layer):\n",
" if not hasattr(layer, 'weight_cache'):\n",
" layer.weight_cache = np.zeros_like(layer.weights)\n",
" layer.bias_cache = np.zeros_like(layer.biases)\n",
"\n",
" layer.weight_cache = self.rho * layer.weight_cache + (1 - self.rho) * layer.dweights**2\n",
" layer.bias_cache = self.rho * layer.bias_cache + (1 - self.rho) * layer.dbiases**2\n",
"\n",
" layer.weights += -self.current_learning_rate * layer.dweights / (np.sqrt(layer.weight_cache) + self.epsilon)\n",
" layer.biases += -self.current_learning_rate * layer.dbiases / (np.sqrt(layer.bias_cache) + self.epsilon)\n",
"\n",
" def post_update_params(self):\n",
" self.iterations += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Testing the RMSProp Optimizer"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, acc: 0.413, loss: 1.099, lr: 0.02\n",
"epoch: 100, acc: 0.467, loss: 0.980, lr: 0.01998021958261321\n",
"epoch: 200, acc: 0.480, loss: 0.904, lr: 0.019960279044701046\n",
"epoch: 300, acc: 0.507, loss: 0.864, lr: 0.019940378268975763\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 400, acc: 0.490, loss: 0.861, lr: 0.01992051713662487\n",
"epoch: 500, acc: 0.510, loss: 0.843, lr: 0.01990069552930875\n",
"epoch: 600, acc: 0.613, loss: 0.779, lr: 0.019880913329158343\n",
"epoch: 700, acc: 0.633, loss: 0.741, lr: 0.019861170418772778\n",
"epoch: 800, acc: 0.577, loss: 0.722, lr: 0.019841466681217078\n",
"epoch: 900, acc: 0.627, loss: 0.698, lr: 0.01982180200001982\n",
"epoch: 1000, acc: 0.630, loss: 0.675, lr: 0.019802176259170884\n",
"epoch: 1100, acc: 0.657, loss: 0.661, lr: 0.01978258934311912\n",
"epoch: 1200, acc: 0.623, loss: 0.654, lr: 0.01976304113677013\n",
"epoch: 1300, acc: 0.663, loss: 0.640, lr: 0.019743531525483964\n",
"epoch: 1400, acc: 0.627, loss: 0.672, lr: 0.01972406039507293\n",
"epoch: 1500, acc: 0.663, loss: 0.618, lr: 0.019704627631799327\n",
"epoch: 1600, acc: 0.693, loss: 0.607, lr: 0.019685233122373254\n",
"epoch: 1700, acc: 0.657, loss: 0.658, lr: 0.019665876753950384\n",
"epoch: 1800, acc: 0.730, loss: 0.587, lr: 0.019646558414129805\n",
"epoch: 1900, acc: 0.690, loss: 0.623, lr: 0.019627277990951823\n",
"epoch: 2000, acc: 0.730, loss: 0.573, lr: 0.019608035372895814\n",
"epoch: 2100, acc: 0.743, loss: 0.576, lr: 0.019588830448878047\n",
"epoch: 2200, acc: 0.740, loss: 0.560, lr: 0.019569663108249594\n",
"epoch: 2300, acc: 0.710, loss: 0.567, lr: 0.019550533240794143\n",
"epoch: 2400, acc: 0.740, loss: 0.548, lr: 0.019531440736725945\n",
"epoch: 2500, acc: 0.687, loss: 0.576, lr: 0.019512385486687673\n",
"epoch: 2600, acc: 0.710, loss: 0.550, lr: 0.01949336738174836\n",
"epoch: 2700, acc: 0.707, loss: 0.573, lr: 0.019474386313401298\n",
"epoch: 2800, acc: 0.760, loss: 0.523, lr: 0.019455442173562\n",
"epoch: 2900, acc: 0.770, loss: 0.525, lr: 0.019436534854566128\n",
"epoch: 3000, acc: 0.787, loss: 0.512, lr: 0.01941766424916747\n",
"epoch: 3100, acc: 0.763, loss: 0.517, lr: 0.019398830250535893\n",
"epoch: 3200, acc: 0.750, loss: 0.551, lr: 0.019380032752255354\n",
"epoch: 3300, acc: 0.803, loss: 0.498, lr: 0.01936127164832186\n",
"epoch: 3400, acc: 0.780, loss: 0.493, lr: 0.01934254683314152\n",
"epoch: 3500, acc: 0.780, loss: 0.514, lr: 0.019323858201528515\n",
"epoch: 3600, acc: 0.793, loss: 0.522, lr: 0.019305205648703173\n",
"epoch: 3700, acc: 0.780, loss: 0.516, lr: 0.01928658907028997\n",
"epoch: 3800, acc: 0.790, loss: 0.508, lr: 0.01926800836231563\n",
"epoch: 3900, acc: 0.750, loss: 0.523, lr: 0.019249463421207133\n",
"epoch: 4000, acc: 0.763, loss: 0.516, lr: 0.019230954143789846\n",
"epoch: 4100, acc: 0.787, loss: 0.499, lr: 0.019212480427285565\n",
"epoch: 4200, acc: 0.770, loss: 0.512, lr: 0.019194042169310647\n",
"epoch: 4300, acc: 0.790, loss: 0.496, lr: 0.019175639267874092\n",
"epoch: 4400, acc: 0.777, loss: 0.501, lr: 0.019157271621375684\n",
"epoch: 4500, acc: 0.810, loss: 0.479, lr: 0.0191389391286041\n",
"epoch: 4600, acc: 0.783, loss: 0.482, lr: 0.019120641688735073\n",
"epoch: 4700, acc: 0.797, loss: 0.463, lr: 0.019102379201329525\n",
"epoch: 4800, acc: 0.787, loss: 0.469, lr: 0.01908415156633174\n",
"epoch: 4900, acc: 0.777, loss: 0.497, lr: 0.01906595868406753\n",
"epoch: 5000, acc: 0.810, loss: 0.445, lr: 0.01904780045524243\n",
"epoch: 5100, acc: 0.793, loss: 0.450, lr: 0.019029676780939874\n",
"epoch: 5200, acc: 0.810, loss: 0.438, lr: 0.019011587562619416\n",
"epoch: 5300, acc: 0.797, loss: 0.452, lr: 0.01899353270211493\n",
"epoch: 5400, acc: 0.800, loss: 0.453, lr: 0.018975512101632844\n",
"epoch: 5500, acc: 0.820, loss: 0.419, lr: 0.018957525663750367\n",
"epoch: 5600, acc: 0.817, loss: 0.433, lr: 0.018939573291413745\n",
"epoch: 5700, acc: 0.763, loss: 0.533, lr: 0.018921654887936498\n",
"epoch: 5800, acc: 0.820, loss: 0.411, lr: 0.018903770356997703\n",
"epoch: 5900, acc: 0.817, loss: 0.424, lr: 0.01888591960264025\n",
"epoch: 6000, acc: 0.810, loss: 0.419, lr: 0.018868102529269144\n",
"epoch: 6100, acc: 0.827, loss: 0.403, lr: 0.018850319041649778\n",
"epoch: 6200, acc: 0.820, loss: 0.413, lr: 0.018832569044906263\n",
"epoch: 6300, acc: 0.820, loss: 0.414, lr: 0.018814852444519702\n",
"epoch: 6400, acc: 0.810, loss: 0.410, lr: 0.018797169146326564\n",
"epoch: 6500, acc: 0.830, loss: 0.389, lr: 0.018779519056516963\n",
"epoch: 6600, acc: 0.823, loss: 0.407, lr: 0.018761902081633038\n",
"epoch: 6700, acc: 0.823, loss: 0.403, lr: 0.018744318128567278\n",
"epoch: 6800, acc: 0.820, loss: 0.405, lr: 0.018726767104560903\n",
"epoch: 6900, acc: 0.823, loss: 0.386, lr: 0.018709248917202218\n",
"epoch: 7000, acc: 0.827, loss: 0.393, lr: 0.018691763474424996\n",
"epoch: 7100, acc: 0.800, loss: 0.428, lr: 0.018674310684506857\n",
"epoch: 7200, acc: 0.833, loss: 0.378, lr: 0.018656890456067686\n",
"epoch: 7300, acc: 0.820, loss: 0.382, lr: 0.01863950269806802\n",
"epoch: 7400, acc: 0.810, loss: 0.438, lr: 0.018622147319807447\n",
"epoch: 7500, acc: 0.833, loss: 0.379, lr: 0.018604824230923078\n",
"epoch: 7600, acc: 0.837, loss: 0.351, lr: 0.01858753334138793\n",
"epoch: 7700, acc: 0.827, loss: 0.397, lr: 0.018570274561509396\n",
"epoch: 7800, acc: 0.837, loss: 0.369, lr: 0.018553047801927663\n",
"epoch: 7900, acc: 0.787, loss: 0.458, lr: 0.018535852973614212\n",
"epoch: 8000, acc: 0.840, loss: 0.369, lr: 0.01851868998787026\n",
"epoch: 8100, acc: 0.863, loss: 0.336, lr: 0.018501558756325222\n",
"epoch: 8200, acc: 0.840, loss: 0.366, lr: 0.01848445919093522\n",
"epoch: 8300, acc: 0.833, loss: 0.369, lr: 0.018467391203981567\n",
"epoch: 8400, acc: 0.837, loss: 0.355, lr: 0.01845035470806926\n",
"epoch: 8500, acc: 0.840, loss: 0.357, lr: 0.018433349616125496\n",
"epoch: 8600, acc: 0.857, loss: 0.329, lr: 0.018416375841398172\n",
"epoch: 8700, acc: 0.843, loss: 0.352, lr: 0.018399433297454436\n",
"epoch: 8800, acc: 0.843, loss: 0.356, lr: 0.01838252189817921\n",
"epoch: 8900, acc: 0.797, loss: 0.447, lr: 0.018365641557773718\n",
"epoch: 9000, acc: 0.847, loss: 0.354, lr: 0.018348792190754044\n",
"epoch: 9100, acc: 0.840, loss: 0.349, lr: 0.0183319737119497\n",
"epoch: 9200, acc: 0.853, loss: 0.337, lr: 0.018315186036502167\n",
"epoch: 9300, acc: 0.847, loss: 0.350, lr: 0.018298429079863496\n",
"epoch: 9400, acc: 0.823, loss: 0.383, lr: 0.018281702757794862\n",
"epoch: 9500, acc: 0.853, loss: 0.338, lr: 0.018265006986365174\n",
"epoch: 9600, acc: 0.780, loss: 0.522, lr: 0.018248341681949654\n",
"epoch: 9700, acc: 0.840, loss: 0.335, lr: 0.018231706761228456\n",
"epoch: 9800, acc: 0.853, loss: 0.334, lr: 0.01821510214118526\n",
"epoch: 9900, acc: 0.723, loss: 0.696, lr: 0.018198527739105907\n",
"epoch: 10000, acc: 0.860, loss: 0.314, lr: 0.018181983472577025\n"
]
}
],
"source": [
"# Create dataset\n",
"X, y = spiral_data(samples=100, classes=3)\n",
"\n",
"# Create Dense layer with 2 input features and 64 output values\n",
"dense1 = Layer_Dense(2, 64)\n",
"\n",
"# Create ReLU activation (to be used with Dense layer)\n",
"activation1 = Activation_ReLU()\n",
"\n",
"# Create second Dense layer with 64 input features (as we take output\n",
"# of previous layer here) and 3 output values (output values)\n",
"dense2 = Layer_Dense(64, 3)\n",
"\n",
"# Create Softmax classifier's combined loss and activation\n",
"loss_activation = Activation_Softmax_Loss_CategoricalCrossentropy()\n",
"\n",
"# Create optimizer\n",
"optimizer = Optimizer_RMSProp(learning_rate=0.02, decay=1e-5, rho=0.999)\n",
"\n",
"# Train in loop\n",
"for epoch in range(10001):\n",
" # Perform a forward pass of our training data through this layer\n",
" dense1.forward(X)\n",
" \n",
" # Perform a forward pass through activation function\n",
" # takes the output of first dense layer here\n",
" activation1.forward(dense1.output)\n",
" \n",
" # Perform a forward pass through second Dense layer\n",
" # takes outputs of activation function of first layer as inputs\n",
" dense2.forward(activation1.output)\n",
" \n",
" # Perform a forward pass through the activation/loss function\n",
" # takes the output of second dense layer here and returns loss\n",
" loss = loss_activation.forward(dense2.output, y)\n",
" \n",
" # Calculate accuracy from output of activation2 and targets\n",
" # calculate values along first axis\n",
" predictions = np.argmax(loss_activation.output, axis=1)\n",
" if len(y.shape) == 2:\n",
" y = np.argmax(y, axis=1)\n",
" accuracy = np.mean(predictions == y)\n",
" \n",
" if not epoch % 100:\n",
" print(f'epoch: {epoch}, ' +\n",
" f'acc: {accuracy:.3f}, ' +\n",
" f'loss: {loss:.3f}, ' +\n",
" f'lr: {optimizer.current_learning_rate}')\n",
" \n",
" # Backward pass\n",
" loss_activation.backward(loss_activation.output, y)\n",