{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Previous Class Definitions\n", "The previously defined Layer_Dense, Activation_ReLU, Activation_Softmax, Loss, and Loss_CategoricalCrossEntropy classes." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import nnfs\n", "from nnfs.datasets import spiral_data, vertical_data\n", "nnfs.init()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class Layer_Dense:\n", " def __init__(self, n_inputs, n_neurons):\n", " # Initialize the weights and biases\n", " self.weights = 0.01 * np.random.randn(n_inputs, n_neurons) # Normal distribution of weights\n", " self.biases = np.zeros((1, n_neurons))\n", "\n", " def forward(self, inputs):\n", " # Calculate the output values from inputs, weights, and biases\n", " self.output = np.dot(inputs, self.weights) + self.biases # Weights are already transposed\n", "\n", "class Activation_ReLU:\n", " def forward(self, inputs):\n", " self.output = np.maximum(0, inputs)\n", " \n", "class Activation_Softmax:\n", " def forward(self, inputs):\n", " # Get the unnormalized probabilities\n", " # Subtract max from the row to prevent larger numbers\n", " exp_values = np.exp(inputs - np.max(inputs, axis=1, keepdims=True))\n", "\n", " # Normalize the probabilities with element wise division\n", " probabilities = exp_values / np.sum(exp_values, axis=1,keepdims=True)\n", " self.output = probabilities\n", "\n", "# Base class for Loss functions\n", "class Loss:\n", " '''Calculates the data and regularization losses given\n", " model output and ground truth values'''\n", " def calculate(self, output, y):\n", " sample_losses = self.forward(output, y)\n", " data_loss = np.average(sample_losses)\n", " return data_loss\n", "\n", "class Loss_CategoricalCrossEntropy(Loss):\n", " def forward(self, y_pred, y_true):\n", " '''y_pred is the neural network output\n", " y_true is the ideal output of the neural network'''\n", " samples = len(y_pred)\n", " # Bound the predicted values \n", " y_pred_clipped = np.clip(y_pred, 1e-7, 1-1e-7)\n", " \n", " if len(y_true.shape) == 1: # Categorically labeled\n", " correct_confidences = y_pred_clipped[range(samples), y_true]\n", " elif len(y_true.shape) == 2: # One hot encoded\n", " correct_confidences = np.sum(y_pred_clipped*y_true, axis=1)\n", "\n", " # Calculate the losses\n", " negative_log_likelihoods = -np.log(correct_confidences)\n", " return negative_log_likelihoods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Backpropagation of a Single Neuron\n", "Backpropagation helps us find the gradient of the neural network with respect to each of the parameters (weights and biases) of each neuron.\n", "\n", "Imagine a layer that has 3 inputs and 1 neuron. There are 3 inputs (x0, x1, x2), three weights (w0, w1, w2), 1 bias (b0), and 1 output (z). There is a ReLU activation layer after the neuron output going into a square loss function (loss = z^2).\n", "\n", "Loss = (ReLU(sum(mul(x0, w0), mul(x1, w1), mul(x2, w2(, b0)))))^2\n", "\n", "$\\frac{\\delta Loss()}{\\delta w0} = \\frac{\\delta Loss()}{\\delta ReLU()} * \\frac{\\delta ReLU()}{\\delta sum()} * \\frac{\\delta sum()}{\\delta mul(x0, w0)} * \\frac{\\delta mul(x0, w0)}{\\delta w0}$\n", "\n", "$\\frac{\\delta Loss()}{\\delta ReLU()} = 2 * ReLU(sum(...))$\n", "\n", "$\\frac{\\delta ReLU()}{\\delta sum()}$ = 0 if sum(...) is less than 0 and 1 if sum(...) is greater than 0\n", "\n", "$\\frac{\\delta sum()}{\\delta mul(x0, w0)} = 1$\n", "\n", "$\\frac{\\delta mul(x0, w0)}{\\delta w0} = x0$\n", "\n", "This is repeated for w0, w1, w2, b0.\n", "\n", "We then use numerical differentiation to approximate the gradient. Then, we update the parameters using small step sizes, such that $w0[i+1] = w0[i] - step*\\frac{\\delta Loss()}{\\delta w0}$\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1, Loss: 36.0\n", "Iteration 2, Loss: 33.872399999999985\n", "Iteration 3, Loss: 31.870541159999995\n", "Iteration 4, Loss: 29.98699217744401\n", "Iteration 5, Loss: 28.21476093975706\n", "Iteration 6, Loss: 26.54726856821742\n", "Iteration 7, Loss: 24.978324995835766\n", "Iteration 8, Loss: 23.502105988581878\n", "Iteration 9, Loss: 22.113131524656684\n", "Iteration 10, Loss: 20.80624545154949\n", "Iteration 11, Loss: 19.576596345362915\n", "Iteration 12, Loss: 18.419619501351963\n", "Iteration 13, Loss: 17.331019988822064\n", "Iteration 14, Loss: 16.306756707482677\n", "Iteration 15, Loss: 15.343027386070442\n", "Iteration 16, Loss: 14.43625446755368\n", "Iteration 17, Loss: 13.583071828521266\n", "Iteration 18, Loss: 12.780312283455652\n", "Iteration 19, Loss: 12.024995827503426\n", "Iteration 20, Loss: 11.314318574097976\n", "Iteration 21, Loss: 10.645642346368787\n", "Iteration 22, Loss: 10.016484883698395\n", "Iteration 23, Loss: 9.424510627071816\n", "Iteration 24, Loss: 8.867522049011871\n", "Iteration 25, Loss: 8.34345149591527\n", "Iteration 26, Loss: 7.850353512506679\n", "Iteration 27, Loss: 7.386397619917536\n", "Iteration 28, Loss: 6.949861520580408\n", "Iteration 29, Loss: 6.539124704714106\n", "Iteration 30, Loss: 6.152662434665503\n", "Iteration 31, Loss: 5.789040084776769\n", "Iteration 32, Loss: 5.446907815766464\n", "Iteration 33, Loss: 5.124995563854669\n", "Iteration 34, Loss: 4.822108326030859\n", "Iteration 35, Loss: 4.537121723962434\n", "Iteration 36, Loss: 4.268977830076255\n", "Iteration 37, Loss: 4.016681240318748\n", "Iteration 38, Loss: 3.7792953790159096\n", "Iteration 39, Loss: 3.55593902211607\n", "Iteration 40, Loss: 3.345783025909011\n", "Iteration 41, Loss: 3.148047249077789\n", "Iteration 42, Loss: 2.9619976566572896\n", "Iteration 43, Loss: 2.786943595148845\n", "Iteration 44, Loss: 2.622235228675548\n", "Iteration 45, Loss: 2.4672611266608238\n", "Iteration 46, Loss: 2.3214459940751673\n", "Iteration 47, Loss: 2.1842485358253243\n", "Iteration 48, Loss: 2.055159447358047\n", "Iteration 49, Loss: 1.9336995240191863\n", "Iteration 50, Loss: 1.8194178821496518\n", "Iteration 51, Loss: 1.7118902853146072\n", "Iteration 52, Loss: 1.6107175694525138\n", "Iteration 53, Loss: 1.5155241610978685\n", "Iteration 54, Loss: 1.4259566831769857\n", "Iteration 55, Loss: 1.3416826432012259\n", "Iteration 56, Loss: 1.2623891989880334\n", "Iteration 57, Loss: 1.18778199732784\n", "Iteration 58, Loss: 1.1175840812857638\n", "Iteration 59, Loss: 1.0515348620817762\n", "Iteration 60, Loss: 0.9893891517327436\n", "Iteration 61, Loss: 0.930916252865338\n", "Iteration 62, Loss: 0.8758991023209965\n", "Iteration 63, Loss: 0.8241334653738256\n", "Iteration 64, Loss: 0.775427177570232\n", "Iteration 65, Loss: 0.7295994313758314\n", "Iteration 66, Loss: 0.6864801049815188\n", "Iteration 67, Loss: 0.6459091307771113\n", "Iteration 68, Loss: 0.6077359011481849\n", "Iteration 69, Loss: 0.5718187093903269\n", "Iteration 70, Loss: 0.538024223665358\n", "Iteration 71, Loss: 0.5062269920467352\n", "Iteration 72, Loss: 0.4763089768167732\n", "Iteration 73, Loss: 0.44815911628690125\n", "Iteration 74, Loss: 0.4216729125143454\n", "Iteration 75, Loss: 0.3967520433847474\n", "Iteration 76, Loss: 0.3733039976207088\n", "Iteration 77, Loss: 0.35124173136132447\n", "Iteration 78, Loss: 0.3304833450378703\n", "Iteration 79, Loss: 0.3109517793461324\n", "Iteration 80, Loss: 0.29257452918677557\n", "Iteration 81, Loss: 0.275283374511837\n", "Iteration 82, Loss: 0.2590141270781873\n", "Iteration 83, Loss: 0.24370639216786646\n", "Iteration 84, Loss: 0.22930334439074573\n", "Iteration 85, Loss: 0.21575151673725296\n", "Iteration 86, Loss: 0.20300060209808138\n", "Iteration 87, Loss: 0.1910032665140845\n", "Iteration 88, Loss: 0.17971497346310233\n", "Iteration 89, Loss: 0.16909381853143318\n", "Iteration 90, Loss: 0.159100373856225\n", "Iteration 91, Loss: 0.14969754176132244\n", "Iteration 92, Loss: 0.1408504170432283\n", "Iteration 93, Loss: 0.13252615739597354\n", "Iteration 94, Loss: 0.1246938614938715\n", "Iteration 95, Loss: 0.11732445427958361\n", "Iteration 96, Loss: 0.11039057903166032\n", "Iteration 97, Loss: 0.10386649581088914\n", "Iteration 98, Loss: 0.09772798590846545\n", "Iteration 99, Loss: 0.09195226194127527\n", "Iteration 100, Loss: 0.08651788326054573\n", "Iteration 101, Loss: 0.08140467635984756\n", "Iteration 102, Loss: 0.07659365998698067\n", "Iteration 103, Loss: 0.07206697468175016\n", "Iteration 104, Loss: 0.06780781647805846\n", "Iteration 105, Loss: 0.06380037452420505\n", "Iteration 106, Loss: 0.060029772389824425\n", "Iteration 107, Loss: 0.05648201284158581\n", "Iteration 108, Loss: 0.05314392588264792\n", "Iteration 109, Loss: 0.05000311986298341\n", "Iteration 110, Loss: 0.04704793547908098\n", "Iteration 111, Loss: 0.044267402492267266\n", "Iteration 112, Loss: 0.04165119900497416\n", "Iteration 113, Loss: 0.03918961314378044\n", "Iteration 114, Loss: 0.03687350700698295\n", "Iteration 115, Loss: 0.03469428274287037\n", "Iteration 116, Loss: 0.032643850632766785\n", "Iteration 117, Loss: 0.030714599060370343\n", "Iteration 118, Loss: 0.028899366255902458\n", "Iteration 119, Loss: 0.027191413710178605\n", "Iteration 120, Loss: 0.025584401159906987\n", "Iteration 121, Loss: 0.02407236305135653\n", "Iteration 122, Loss: 0.02264968639502141\n", "Iteration 123, Loss: 0.02131108992907558\n", "Iteration 124, Loss: 0.020051604514267202\n", "Iteration 125, Loss: 0.018866554687474092\n", "Iteration 126, Loss: 0.01775154130544445\n", "Iteration 127, Loss: 0.01670242521429262\n", "Iteration 128, Loss: 0.015715311884128023\n", "Iteration 129, Loss: 0.014786536951776045\n", "Iteration 130, Loss: 0.01391265261792606\n", "Iteration 131, Loss: 0.013090414848206555\n", "Iteration 132, Loss: 0.01231677133067759\n", "Iteration 133, Loss: 0.011588850145034609\n", "Iteration 134, Loss: 0.01090394910146302\n", "Iteration 135, Loss: 0.010259525709566512\n", "Iteration 136, Loss: 0.00965318774013127\n", "Iteration 137, Loss: 0.009082684344689475\n", "Iteration 138, Loss: 0.008545897699918257\n", "Iteration 139, Loss: 0.008040835145853137\n", "Iteration 140, Loss: 0.00756562178873318\n", "Iteration 141, Loss: 0.0071184935410191314\n", "Iteration 142, Loss: 0.006697790572744897\n", "Iteration 143, Loss: 0.0063019511498957235\n", "Iteration 144, Loss: 0.0059295058369368625\n", "Iteration 145, Loss: 0.005579072041973895\n", "Iteration 146, Loss: 0.005249348884293221\n", "Iteration 147, Loss: 0.004939112365231496\n", "Iteration 148, Loss: 0.0046472108244463226\n", "Iteration 149, Loss: 0.004372560664721515\n", "Iteration 150, Loss: 0.004114142329436494\n", "Iteration 151, Loss: 0.0038709965177668067\n", "Iteration 152, Loss: 0.003642220623566796\n", "Iteration 153, Loss: 0.003426965384714043\n", "Iteration 154, Loss: 0.0032244317304774253\n", "Iteration 155, Loss: 0.003033867815206219\n", "Iteration 156, Loss: 0.0028545662273275238\n", "Iteration 157, Loss: 0.002685861363292454\n", "Iteration 158, Loss: 0.002527126956721865\n", "Iteration 159, Loss: 0.0023777737535795648\n", "Iteration 160, Loss: 0.002237247324743051\n", "Iteration 161, Loss: 0.0021050260078507234\n", "Iteration 162, Loss: 0.001980618970786757\n", "Iteration 163, Loss: 0.001863564389613244\n", "Iteration 164, Loss: 0.0017534277341871227\n", "Iteration 165, Loss: 0.001649800155096659\n", "Iteration 166, Loss: 0.0015522969659304577\n", "Iteration 167, Loss: 0.0014605562152439574\n", "Iteration 168, Loss: 0.001374237342923055\n", "Iteration 169, Loss: 0.0012930199159562866\n", "Iteration 170, Loss: 0.0012166024389232565\n", "Iteration 171, Loss: 0.0011447012347829103\n", "Iteration 172, Loss: 0.0010770493918072343\n", "Iteration 173, Loss: 0.0010133957727514104\n", "Iteration 174, Loss: 0.0009535040825818146\n", "Iteration 175, Loss: 0.0008971519913012098\n", "Iteration 176, Loss: 0.0008441303086153165\n", "Iteration 177, Loss: 0.0007942422073761319\n", "Iteration 178, Loss: 0.0007473024929202092\n", "Iteration 179, Loss: 0.0007031369155886454\n", "Iteration 180, Loss: 0.0006615815238773228\n", "Iteration 181, Loss: 0.0006224820558161947\n", "Iteration 182, Loss: 0.0005856933663174615\n", "Iteration 183, Loss: 0.0005510788883681067\n", "Iteration 184, Loss: 0.0005185101260655349\n", "Iteration 185, Loss: 0.0004878661776150635\n", "Iteration 186, Loss: 0.00045903328651800607\n", "Iteration 187, Loss: 0.0004319044192847727\n", "Iteration 188, Loss: 0.00040637886810505474\n", "Iteration 189, Loss: 0.0003823618770000461\n", "Iteration 190, Loss: 0.00035976429006934636\n", "Iteration 191, Loss: 0.00033850222052625716\n", "Iteration 192, Loss: 0.0003184967392931672\n", "Iteration 193, Loss: 0.0002996735820009388\n", "Iteration 194, Loss: 0.000281962873304691\n", "Iteration 195, Loss: 0.0002652988674923804\n", "Iteration 196, Loss: 0.0002496197044235683\n", "Iteration 197, Loss: 0.00023486717989213552\n", "Iteration 198, Loss: 0.00022098652956051033\n", "Iteration 199, Loss: 0.0002079262256634926\n", "Iteration 200, Loss: 0.00019563778572677975\n", "Final weights: [-3.3990955 -0.20180899 0.80271349]\n", "Final bias: 0.6009044964039992\n" ] } ], "source": [ "import numpy as np\n", "\n", "# Initial parameters\n", "weights = np.array([-3.0, -1.0, 2.0])\n", "bias = 1.0\n", "inputs = np.array([1.0, -2.0, 3.0])\n", "target_output = 0.0\n", "learning_rate = 0.001\n", "\n", "def relu(x):\n", " return np.maximum(0, x)\n", "\n", "def relu_derivative(x):\n", " return np.where(x > 0, 1.0, 0.0)\n", "\n", "for iteration in range(200):\n", " # Forward pass\n", " linear_output = np.dot(weights, inputs) + bias\n", " output = relu(linear_output)\n", " loss = (output - target_output) ** 2\n", "\n", " # Backward pass to calculate gradient\n", " dloss_doutput = 2 * (output - target_output)\n", " doutput_dlinear = relu_derivative(linear_output)\n", " dlinear_dweights = inputs\n", " dlinear_dbias = 1.0\n", "\n", " dloss_dlinear = dloss_doutput * doutput_dlinear\n", " dloss_dweights = dloss_dlinear * dlinear_dweights\n", " dloss_dbias = dloss_dlinear * dlinear_dbias\n", "\n", " # Update weights and bias\n", " weights -= learning_rate * dloss_dweights\n", " bias -= learning_rate * dloss_dbias\n", "\n", " # Print the loss for this iteration\n", " print(f\"Iteration {iteration + 1}, Loss: {loss}\")\n", "\n", "print(\"Final weights:\", weights)\n", "print(\"Final bias:\", bias)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }