"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.33333334 0.33333334 0.33333334]\n",
" [0.33333364 0.3333334 0.3333329 ]\n",
" [0.33333385 0.3333335 0.33333266]\n",
" [0.33333433 0.3333336 0.33333206]\n",
" [0.33333462 0.33333373 0.33333164]]\n"
]
}
],
"source": [
"# Create dataset\n",
"X, y = spiral_data(samples=100, classes=3)\n",
"# Create Dense layer with 2 input features and 3 output values\n",
"dense1 = Layer_Dense(2, 3)\n",
"# Create ReLU activation (to be used with Dense layer):\n",
"activation1 = Activation_ReLU()\n",
"# Create second Dense layer with 3 input features (as we take output\n",
"# of previous layer here) and 3 output values\n",
"dense2 = Layer_Dense(3, 3)\n",
"# Create Softmax activation (to be used with Dense layer):\n",
"activation2 = Activation_Softmax()\n",
"\n",
"# Make a forward pass of our training data through this layer\n",
"dense1.forward(X)\n",
"\n",
"# Make a forward pass through activation function\n",
"# it takes the output of first dense layer here\n",
"activation1.forward(dense1.output)\n",
"# Make a forward pass through second Dense layer\n",
"# it takes outputs of activation function of first layer as inputs\n",
"dense2.forward(activation1.output)\n",
"# Make a forward pass through activation function\n",
"# it takes the output of second dense layer here\n",
"activation2.forward(dense2.output)\n",
"# Let's see output of the first few samples:\n",
"print(activation2.output[:5])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.33333334 0.33333334 0.33333334]\n",
" [0.3333341 0.33333302 0.3333329 ]\n",
" [0.3333341 0.33333302 0.33333296]\n",
" [0.3333341 0.333333 0.33333293]\n",
" [0.3333364 0.33333203 0.33333158]]\n",
"loss: 1.0986193\n",
"acc: 0.28\n"
]
}
],
"source": [
"# Create dataset\n",
"X, y = spiral_data(samples=100, classes=3)\n",
"# Create Dense layer with 2 input features and 3 output values\n",
"dense1 = Layer_Dense(2, 3)\n",
"# Create ReLU activation (to be used with Dense layer):\n",
"activation1 = Activation_ReLU()\n",
"# Create second Dense layer with 3 input features (as we take output\n",
"# of previous layer here) and 3 output values\n",
"dense2 = Layer_Dense(3, 3)\n",
"# Create Softmax activation (to be used with Dense layer):\n",
"activation2 = Activation_Softmax()\n",
"# Create loss function\n",
"loss_function = Loss_CategoricalCrossentropy()\n",
"\n",
"\n",
"# Perform a forward pass of our training data through this layer\n",
"dense1.forward(X)\n",
"# Perform a forward pass through activation function\n",
"# it takes the output of first dense layer here\n",
"activation1.forward(dense1.output)\n",
"\n",
"# Perform a forward pass through second Dense layer\n",
"# it takes outputs of activation function of first layer as inputs\n",
"dense2.forward(activation1.output)\n",
"# Perform a forward pass through activation function\n",
"# it takes the output of second dense layer here\n",
"activation2.forward(dense2.output)\n",
"# Let's see output of the first few samples:\n",
"print(activation2.output[:5])\n",
"# Perform a forward pass through activation function\n",
"# it takes the output of second dense layer here and returns loss\n",
"loss = loss_function.calculate(activation2.output, y)\n",
"# Print loss value\n",
"print('loss:', loss)\n",
"\n",
"# Calculate accuracy from output of activation2 and targets\n",
"# calculate values along first axis\n",
"predictions = np.argmax(activation2.output, axis=1)\n",
"if len(y.shape) == 2:\n",
" y = np.argmax(y, axis=1)\n",
"accuracy = np.mean(predictions == y)\n",
"# Print accuracy\n",
"print('acc:', accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New set of weights found, iteration: 0 loss: 1.099905 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 5 loss: 1.099738 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 11 loss: 1.0992013 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 13 loss: 1.0977142 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 14 loss: 1.0957412 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 16 loss: 1.0941366 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 29 loss: 1.0926114 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 35 loss: 1.0908598 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 42 loss: 1.0890985 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 44 loss: 1.0882245 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 53 loss: 1.080014 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 61 loss: 1.072066 acc: 0.47\n",
"New set of weights found, iteration: 65 loss: 1.0632849 acc: 0.3433333333333333\n",
"New set of weights found, iteration: 68 loss: 1.0544518 acc: 0.37\n",
"New set of weights found, iteration: 69 loss: 1.0531789 acc: 0.33666666666666667\n",
"New set of weights found, iteration: 70 loss: 1.0516953 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 73 loss: 1.0501534 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 75 loss: 1.0490755 acc: 0.65\n",
"New set of weights found, iteration: 78 loss: 1.0362376 acc: 0.84\n",
"New set of weights found, iteration: 79 loss: 1.0319735 acc: 0.6933333333333334\n",
"New set of weights found, iteration: 82 loss: 1.0308099 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 84 loss: 1.0245655 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 86 loss: 1.0163056 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 91 loss: 1.0100644 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 92 loss: 1.0020715 acc: 0.41333333333333333\n",
"New set of weights found, iteration: 94 loss: 0.99989504 acc: 0.4666666666666667\n",
"New set of weights found, iteration: 95 loss: 0.99057025 acc: 0.6066666666666667\n",
"New set of weights found, iteration: 96 loss: 0.9842712 acc: 0.61\n",
"New set of weights found, iteration: 98 loss: 0.98155546 acc: 0.6\n",
"New set of weights found, iteration: 99 loss: 0.9771661 acc: 0.64\n",
"New set of weights found, iteration: 102 loss: 0.9674396 acc: 0.7266666666666667\n",
"New set of weights found, iteration: 103 loss: 0.94826436 acc: 0.7966666666666666\n",
"New set of weights found, iteration: 107 loss: 0.94145477 acc: 0.8666666666666667\n",
"New set of weights found, iteration: 109 loss: 0.9377437 acc: 0.73\n",
"New set of weights found, iteration: 110 loss: 0.91910625 acc: 0.6433333333333333\n",
"New set of weights found, iteration: 112 loss: 0.9161494 acc: 0.6466666666666666\n",
"New set of weights found, iteration: 114 loss: 0.91611814 acc: 0.6333333333333333\n",
"New set of weights found, iteration: 115 loss: 0.9146271 acc: 0.6\n",
"New set of weights found, iteration: 117 loss: 0.9106173 acc: 0.6333333333333333\n",
"New set of weights found, iteration: 118 loss: 0.9050189 acc: 0.65\n",
"New set of weights found, iteration: 119 loss: 0.89243126 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 123 loss: 0.8768594 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 127 loss: 0.8671168 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 128 loss: 0.86372316 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 132 loss: 0.84759533 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 134 loss: 0.8325577 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 138 loss: 0.8243833 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 140 loss: 0.8126023 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 141 loss: 0.81067485 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 144 loss: 0.8094295 acc: 0.6666666666666666\n",
"New set of weights found, iteration: 146 loss: 0.79853076 acc: 0.67\n",
"New set of weights found, iteration: 149 loss: 0.79235256 acc: 0.71\n",
"New set of weights found, iteration: 150 loss: 0.78341687 acc: 0.77\n",
"New set of weights found, iteration: 151 loss: 0.76394886 acc: 0.8033333333333333\n",
"New set of weights found, iteration: 152 loss: 0.7592695 acc: 0.74\n",
"New set of weights found, iteration: 158 loss: 0.75623393 acc: 0.7166666666666667\n",
"New set of weights found, iteration: 161 loss: 0.7518161 acc: 0.81\n",
"New set of weights found, iteration: 163 loss: 0.74768335 acc: 0.88\n",
"New set of weights found, iteration: 165 loss: 0.7471585 acc: 0.8066666666666666\n",
"New set of weights found, iteration: 166 loss: 0.7458357 acc: 0.8066666666666666\n",
"New set of weights found, iteration: 167 loss: 0.7457454 acc: 0.8566666666666667\n",
"New set of weights found, iteration: 175 loss: 0.73652244 acc: 0.8866666666666667\n",
"New set of weights found, iteration: 180 loss: 0.7326938 acc: 0.7433333333333333\n",
"New set of weights found, iteration: 186 loss: 0.7188873 acc: 0.79\n",
"New set of weights found, iteration: 196 loss: 0.7007391 acc: 0.8366666666666667\n",
"New set of weights found, iteration: 197 loss: 0.69283545 acc: 0.8033333333333333\n",
"New set of weights found, iteration: 198 loss: 0.6778049 acc: 0.8433333333333334\n",
"New set of weights found, iteration: 204 loss: 0.6732369 acc: 0.78\n",
"New set of weights found, iteration: 206 loss: 0.6590504 acc: 0.8466666666666667\n",
"New set of weights found, iteration: 208 loss: 0.64687824 acc: 0.8733333333333333\n",
"New set of weights found, iteration: 210 loss: 0.64197236 acc: 0.8266666666666667\n",
"New set of weights found, iteration: 212 loss: 0.6315755 acc: 0.8566666666666667\n",
"New set of weights found, iteration: 216 loss: 0.62233186 acc: 0.8433333333333334\n",
"New set of weights found, iteration: 217 loss: 0.6089423 acc: 0.87\n",
"New set of weights found, iteration: 220 loss: 0.6020286 acc: 0.89\n",
"New set of weights found, iteration: 229 loss: 0.59444267 acc: 0.8733333333333333\n",
"New set of weights found, iteration: 230 loss: 0.57657397 acc: 0.8633333333333333\n",
"New set of weights found, iteration: 234 loss: 0.56438416 acc: 0.8766666666666667\n",
"New set of weights found, iteration: 238 loss: 0.55371267 acc: 0.8733333333333333\n",
"New set of weights found, iteration: 239 loss: 0.55123174 acc: 0.87\n",
"New set of weights found, iteration: 241 loss: 0.5472712 acc: 0.9066666666666666\n",
"New set of weights found, iteration: 244 loss: 0.5312096 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 245 loss: 0.52256846 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 248 loss: 0.51849365 acc: 0.9033333333333333\n",
"New set of weights found, iteration: 253 loss: 0.5171079 acc: 0.9033333333333333\n",
"New set of weights found, iteration: 254 loss: 0.5170107 acc: 0.8766666666666667\n",
"New set of weights found, iteration: 258 loss: 0.51433927 acc: 0.9033333333333333\n",
"New set of weights found, iteration: 263 loss: 0.51368594 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 267 loss: 0.50735676 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 270 loss: 0.49404424 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 276 loss: 0.48898834 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 281 loss: 0.48475555 acc: 0.8733333333333333\n",
"New set of weights found, iteration: 283 loss: 0.482244 acc: 0.9066666666666666\n",
"New set of weights found, iteration: 284 loss: 0.46764347 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 288 loss: 0.45943847 acc: 0.93\n",
"New set of weights found, iteration: 293 loss: 0.45457697 acc: 0.9033333333333333\n",
"New set of weights found, iteration: 298 loss: 0.45358613 acc: 0.9033333333333333\n",
"New set of weights found, iteration: 301 loss: 0.4479193 acc: 0.92\n",
"New set of weights found, iteration: 302 loss: 0.44754446 acc: 0.9033333333333333\n",
"New set of weights found, iteration: 305 loss: 0.4435407 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 308 loss: 0.439522 acc: 0.92\n",
"New set of weights found, iteration: 310 loss: 0.4312136 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 311 loss: 0.42785105 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 312 loss: 0.42762664 acc: 0.91\n",
"New set of weights found, iteration: 314 loss: 0.4265803 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 319 loss: 0.4237134 acc: 0.91\n",
"New set of weights found, iteration: 321 loss: 0.41802156 acc: 0.9133333333333333\n",
"New set of weights found, iteration: 324 loss: 0.4131552 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 325 loss: 0.4108623 acc: 0.91\n",
"New set of weights found, iteration: 328 loss: 0.41035053 acc: 0.9066666666666666\n",
"New set of weights found, iteration: 335 loss: 0.4067101 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 336 loss: 0.3959319 acc: 0.9066666666666666\n",
"New set of weights found, iteration: 337 loss: 0.3919371 acc: 0.9066666666666666\n",
"New set of weights found, iteration: 338 loss: 0.38703263 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 339 loss: 0.38482046 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 346 loss: 0.37982863 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 347 loss: 0.37918502 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 348 loss: 0.37770292 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 362 loss: 0.3769898 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 364 loss: 0.37081122 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 366 loss: 0.36963394 acc: 0.92\n",
"New set of weights found, iteration: 373 loss: 0.36226436 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 376 loss: 0.35897696 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 381 loss: 0.35690197 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 382 loss: 0.35525045 acc: 0.93\n",
"New set of weights found, iteration: 386 loss: 0.35105768 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 392 loss: 0.34419277 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 394 loss: 0.3400308 acc: 0.9133333333333333\n",
"New set of weights found, iteration: 402 loss: 0.33625075 acc: 0.9\n",
"New set of weights found, iteration: 404 loss: 0.32828385 acc: 0.9133333333333333\n",
"New set of weights found, iteration: 407 loss: 0.32826573 acc: 0.9133333333333333\n",
"New set of weights found, iteration: 417 loss: 0.32695287 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 419 loss: 0.32475516 acc: 0.92\n",
"New set of weights found, iteration: 423 loss: 0.3235583 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 434 loss: 0.31984544 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 437 loss: 0.31507537 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 441 loss: 0.31470028 acc: 0.92\n",
"New set of weights found, iteration: 442 loss: 0.30537918 acc: 0.92\n",
"New set of weights found, iteration: 443 loss: 0.30088764 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 444 loss: 0.29980597 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 445 loss: 0.29180583 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 446 loss: 0.28543833 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 448 loss: 0.27976722 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 450 loss: 0.27820233 acc: 0.93\n",
"New set of weights found, iteration: 453 loss: 0.26993147 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 461 loss: 0.2693614 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 465 loss: 0.26654485 acc: 0.93\n",
"New set of weights found, iteration: 467 loss: 0.26303545 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 469 loss: 0.263004 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 474 loss: 0.26084086 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 480 loss: 0.2575986 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 488 loss: 0.25060445 acc: 0.93\n",
"New set of weights found, iteration: 492 loss: 0.25023142 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 493 loss: 0.24514885 acc: 0.93\n",
"New set of weights found, iteration: 496 loss: 0.24100976 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 510 loss: 0.24016263 acc: 0.93\n",
"New set of weights found, iteration: 516 loss: 0.23990427 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 517 loss: 0.23756668 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 527 loss: 0.23614492 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 529 loss: 0.23547292 acc: 0.93\n",
"New set of weights found, iteration: 530 loss: 0.23477767 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 531 loss: 0.2303279 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 533 loss: 0.22919251 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 543 loss: 0.22844787 acc: 0.94\n",
"New set of weights found, iteration: 545 loss: 0.22844426 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 547 loss: 0.22694202 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 557 loss: 0.22352041 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 563 loss: 0.22272877 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 566 loss: 0.22242208 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 567 loss: 0.22208746 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 568 loss: 0.22132492 acc: 0.93\n",
"New set of weights found, iteration: 600 loss: 0.21976604 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 608 loss: 0.21949868 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 615 loss: 0.21724503 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 618 loss: 0.21679601 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 622 loss: 0.21339032 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 626 loss: 0.21238428 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 629 loss: 0.2072347 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 630 loss: 0.20694281 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 632 loss: 0.20632517 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 638 loss: 0.20593578 acc: 0.9233333333333333\n",
"New set of weights found, iteration: 644 loss: 0.20315775 acc: 0.94\n",
"New set of weights found, iteration: 645 loss: 0.2030436 acc: 0.94\n",
"New set of weights found, iteration: 649 loss: 0.20148328 acc: 0.94\n",
"New set of weights found, iteration: 652 loss: 0.20143524 acc: 0.93\n",
"New set of weights found, iteration: 655 loss: 0.1997256 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 663 loss: 0.19749916 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 674 loss: 0.19700658 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 677 loss: 0.19635075 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 683 loss: 0.19607982 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 691 loss: 0.19607717 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 705 loss: 0.1945014 acc: 0.9166666666666666\n",
"New set of weights found, iteration: 711 loss: 0.19085401 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 718 loss: 0.190301 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 723 loss: 0.18917027 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 725 loss: 0.18868148 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 727 loss: 0.18696322 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 735 loss: 0.18605982 acc: 0.93\n",
"New set of weights found, iteration: 753 loss: 0.18571393 acc: 0.93\n",
"New set of weights found, iteration: 754 loss: 0.18480046 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 756 loss: 0.18298945 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 775 loss: 0.18280008 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 780 loss: 0.18264478 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 838 loss: 0.18199474 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 839 loss: 0.1819026 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 870 loss: 0.18085685 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 879 loss: 0.18026078 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 880 loss: 0.17951058 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 881 loss: 0.17820631 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 894 loss: 0.17801896 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 897 loss: 0.17786425 acc: 0.93\n",
"New set of weights found, iteration: 898 loss: 0.17726201 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 901 loss: 0.17575936 acc: 0.9266666666666666\n",
"New set of weights found, iteration: 911 loss: 0.17463076 acc: 0.93\n",
"New set of weights found, iteration: 917 loss: 0.1736356 acc: 0.93\n",
"New set of weights found, iteration: 934 loss: 0.17291996 acc: 0.93\n",
"New set of weights found, iteration: 935 loss: 0.17224732 acc: 0.93\n",
"New set of weights found, iteration: 947 loss: 0.17191714 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 967 loss: 0.17177528 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 984 loss: 0.17169482 acc: 0.93\n",
"New set of weights found, iteration: 992 loss: 0.17093095 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 996 loss: 0.17091219 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1011 loss: 0.17054932 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1012 loss: 0.17051615 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1014 loss: 0.17050186 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1016 loss: 0.17006627 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1032 loss: 0.17000513 acc: 0.93\n",
"New set of weights found, iteration: 1036 loss: 0.16972722 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1042 loss: 0.16963226 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1067 loss: 0.16902782 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1100 loss: 0.16885021 acc: 0.93\n",
"New set of weights found, iteration: 1125 loss: 0.16875145 acc: 0.94\n",
"New set of weights found, iteration: 1131 loss: 0.16835134 acc: 0.93\n",
"New set of weights found, iteration: 1136 loss: 0.16822483 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1155 loss: 0.16820814 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1165 loss: 0.16806214 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1177 loss: 0.16759826 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1196 loss: 0.16752647 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1221 loss: 0.1674471 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1292 loss: 0.16733679 acc: 0.93\n",
"New set of weights found, iteration: 1349 loss: 0.1672686 acc: 0.93\n",
"New set of weights found, iteration: 1374 loss: 0.16720611 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1524 loss: 0.16713503 acc: 0.93\n",
"New set of weights found, iteration: 1554 loss: 0.16697857 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1566 loss: 0.1667167 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1611 loss: 0.16653392 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1614 loss: 0.16648601 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1623 loss: 0.16623822 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1827 loss: 0.16623314 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1840 loss: 0.16610965 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1844 loss: 0.16608463 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1853 loss: 0.1660742 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1909 loss: 0.16592857 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 1951 loss: 0.16568528 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1959 loss: 0.16554318 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 1989 loss: 0.16552317 acc: 0.93\n",
"New set of weights found, iteration: 1996 loss: 0.16528629 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 2028 loss: 0.16519576 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2043 loss: 0.16503423 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 2062 loss: 0.16495857 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 2072 loss: 0.16491874 acc: 0.93\n",
"New set of weights found, iteration: 2085 loss: 0.16468659 acc: 0.94\n",
"New set of weights found, iteration: 2097 loss: 0.16459472 acc: 0.94\n",
"New set of weights found, iteration: 2099 loss: 0.16432194 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2168 loss: 0.16422528 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 2186 loss: 0.16416107 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 2191 loss: 0.16411497 acc: 0.9333333333333333\n",
"New set of weights found, iteration: 2218 loss: 0.16404505 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2276 loss: 0.16398491 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2303 loss: 0.16397893 acc: 0.94\n",
"New set of weights found, iteration: 2308 loss: 0.16390242 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 2320 loss: 0.1638276 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2322 loss: 0.16374306 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2330 loss: 0.1633055 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2389 loss: 0.16307148 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2419 loss: 0.16291772 acc: 0.94\n",
"New set of weights found, iteration: 2500 loss: 0.1628382 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2520 loss: 0.16280368 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2835 loss: 0.16272317 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2893 loss: 0.16264299 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 2929 loss: 0.16250354 acc: 0.94\n",
"New set of weights found, iteration: 3028 loss: 0.16240475 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 3100 loss: 0.16229728 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 3162 loss: 0.16214713 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 3326 loss: 0.16197507 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 3368 loss: 0.16193718 acc: 0.9466666666666667\n",
"New set of weights found, iteration: 3404 loss: 0.16185224 acc: 0.9366666666666666\n",
"New set of weights found, iteration: 3441 loss: 0.16182107 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 3585 loss: 0.16174312 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 3655 loss: 0.16167627 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 5384 loss: 0.16161478 acc: 0.94\n",
"New set of weights found, iteration: 5428 loss: 0.16156833 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 5463 loss: 0.16155337 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 5504 loss: 0.16146044 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 5861 loss: 0.16145106 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 6448 loss: 0.16144219 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 6716 loss: 0.16144097 acc: 0.9466666666666667\n",
"New set of weights found, iteration: 6848 loss: 0.16140525 acc: 0.9466666666666667\n",
"New set of weights found, iteration: 7269 loss: 0.1614049 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 7333 loss: 0.16136383 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 7362 loss: 0.16134067 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 7661 loss: 0.16133144 acc: 0.9466666666666667\n",
"New set of weights found, iteration: 7968 loss: 0.16128716 acc: 0.9433333333333334\n",
"New set of weights found, iteration: 8526 loss: 0.1612735 acc: 0.9433333333333334\n"
]
}
],
"source": [
"# Create dataset\n",
"X, y = vertical_data(samples=100, classes=3)\n",
"# Create model\n",
"dense1 = Layer_Dense(2, 3) # first dense layer, 2 inputs\n",
"activation1 = Activation_ReLU()\n",
"dense2 = Layer_Dense(3, 3) # second dense layer, 3 inputs, 3 outputs\n",
"activation2 = Activation_Softmax()\n",
"# Create loss function\n",
"loss_function = Loss_CategoricalCrossentropy()\n",
"# Helper variables\n",
"lowest_loss = 9999999 # some initial value\n",
"best_dense1_weights = dense1.weights.copy()\n",
"best_dense1_biases = dense1.biases.copy()\n",
"best_dense2_weights = dense2.weights.copy()\n",
"best_dense2_biases = dense2.biases.copy()\n",
"for iteration in range(10000):\n",
" # Update weights with some small random values\n",
" dense1.weights += 0.05 * np.random.randn(2, 3)\n",
" dense1.biases += 0.05 * np.random.randn(1, 3)\n",
" dense2.weights += 0.05 * np.random.randn(3, 3)\n",
" dense2.biases += 0.05 * np.random.randn(1, 3)\n",
" # Perform a forward pass of our training data through this layer\n",
" dense1.forward(X)\n",
" activation1.forward(dense1.output)\n",
" dense2.forward(activation1.output)\n",
" activation2.forward(dense2.output)\n",
" # Perform a forward pass through activation function\n",
" # it takes the output of second dense layer here and returns loss\n",
" loss = loss_function.calculate(activation2.output, y)\n",
" # Calculate accuracy from output of activation2 and targets\n",
" # calculate values along first axis\n",
" predictions = np.argmax(activation2.output, axis=1)\n",
" accuracy = np.mean(predictions == y)\n",
" # If loss is smaller - print and save weights and biases aside\n",
" if loss < lowest_loss:\n",
" print('New set of weights found, iteration:', iteration,'loss:', loss, 'acc:', accuracy)\n",
" best_dense1_weights = dense1.weights.copy()\n",
" best_dense1_biases = dense1.biases.copy()\n",
" best_dense2_weights = dense2.weights.copy()\n",
" best_dense2_biases = dense2.biases.copy()\n",
" lowest_loss = loss\n",
" # Revert weights and biases\n",
" else:\n",
" dense1.weights = best_dense1_weights.copy()\n",
" dense1.biases = best_dense1_biases.copy()\n",
" dense2.weights = best_dense2_weights.copy()\n",
" dense2.biases = best_dense2_biases.copy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"STRATEGY 2: FOR SPIRAL DATASET - DOES NOT WORK!\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New set of weights found, iteration: 0 loss: 1.1005237 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 4 loss: 1.099272 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 7 loss: 1.0988917 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 12 loss: 1.0981287 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 15 loss: 1.0980351 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 26 loss: 1.0977299 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 30 loss: 1.0965452 acc: 0.39\n",
"New set of weights found, iteration: 33 loss: 1.0960287 acc: 0.38333333333333336\n",
"New set of weights found, iteration: 35 loss: 1.0959212 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 40 loss: 1.0954708 acc: 0.4066666666666667\n",
"New set of weights found, iteration: 42 loss: 1.0949016 acc: 0.35\n",
"New set of weights found, iteration: 45 loss: 1.0948775 acc: 0.43\n",
"New set of weights found, iteration: 51 loss: 1.0944252 acc: 0.3333333333333333\n",
"New set of weights found, iteration: 56 loss: 1.0941274 acc: 0.39666666666666667\n",
"New set of weights found, iteration: 61 loss: 1.0938694 acc: 0.38333333333333336\n",
"New set of weights found, iteration: 62 loss: 1.0938368 acc: 0.33\n",
"New set of weights found, iteration: 68 loss: 1.0937221 acc: 0.35333333333333333\n",
"New set of weights found, iteration: 69 loss: 1.0926502 acc: 0.36\n",
"New set of weights found, iteration: 70 loss: 1.0913934 acc: 0.36\n",
"New set of weights found, iteration: 71 loss: 1.0903853 acc: 0.38666666666666666\n",
"New set of weights found, iteration: 72 loss: 1.0899442 acc: 0.38\n",
"New set of weights found, iteration: 75 loss: 1.0879555 acc: 0.36666666666666664\n",
"New set of weights found, iteration: 76 loss: 1.0864727 acc: 0.4166666666666667\n",
"New set of weights found, iteration: 80 loss: 1.084729 acc: 0.4\n",
"New set of weights found, iteration: 98 loss: 1.0840615 acc: 0.3933333333333333\n",
"New set of weights found, iteration: 107 loss: 1.0840529 acc: 0.37333333333333335\n",
"New set of weights found, iteration: 112 loss: 1.0840163 acc: 0.35333333333333333\n",
"New set of weights found, iteration: 113 loss: 1.081911 acc: 0.3933333333333333\n",
"New set of weights found, iteration: 118 loss: 1.0809788 acc: 0.4066666666666667\n",
"New set of weights found, iteration: 129 loss: 1.0806961 acc: 0.42\n",
"New set of weights found, iteration: 133 loss: 1.0801857 acc: 0.39666666666666667\n",
"New set of weights found, iteration: 135 loss: 1.0792507 acc: 0.38666666666666666\n",
"New set of weights found, iteration: 159 loss: 1.0787331 acc: 0.4\n",
"New set of weights found, iteration: 199 loss: 1.0785424 acc: 0.41\n",
"New set of weights found, iteration: 219 loss: 1.0783063 acc: 0.4033333333333333\n",
"New set of weights found, iteration: 240 loss: 1.0780586 acc: 0.4266666666666667\n",
"New set of weights found, iteration: 251 loss: 1.077764 acc: 0.41333333333333333\n",
"New set of weights found, iteration: 254 loss: 1.0777346 acc: 0.4\n",
"New set of weights found, iteration: 255 loss: 1.0771401 acc: 0.42\n",
"New set of weights found, iteration: 258 loss: 1.0770148 acc: 0.42333333333333334\n",
"New set of weights found, iteration: 259 loss: 1.0763383 acc: 0.45666666666666667\n",
"New set of weights found, iteration: 261 loss: 1.0741296 acc: 0.4266666666666667\n",
"New set of weights found, iteration: 270 loss: 1.0740651 acc: 0.45\n",
"New set of weights found, iteration: 290 loss: 1.0738546 acc: 0.42333333333333334\n",
"New set of weights found, iteration: 312 loss: 1.0727702 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 316 loss: 1.072282 acc: 0.39666666666666667\n",
"New set of weights found, iteration: 317 loss: 1.0719106 acc: 0.4033333333333333\n",
"New set of weights found, iteration: 322 loss: 1.0713525 acc: 0.43666666666666665\n",
"New set of weights found, iteration: 350 loss: 1.071252 acc: 0.43666666666666665\n",
"New set of weights found, iteration: 362 loss: 1.0706744 acc: 0.44\n",
"New set of weights found, iteration: 369 loss: 1.0706185 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 408 loss: 1.0704075 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 420 loss: 1.0703163 acc: 0.4666666666666667\n",
"New set of weights found, iteration: 504 loss: 1.0700997 acc: 0.46\n",
"New set of weights found, iteration: 547 loss: 1.069801 acc: 0.45\n",
"New set of weights found, iteration: 849 loss: 1.0697782 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 1100 loss: 1.0697608 acc: 0.45666666666666667\n",
"New set of weights found, iteration: 2395 loss: 1.069701 acc: 0.43666666666666665\n",
"New set of weights found, iteration: 2432 loss: 1.0696893 acc: 0.43\n",
"New set of weights found, iteration: 2839 loss: 1.0696803 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 2848 loss: 1.0695381 acc: 0.45\n",
"New set of weights found, iteration: 2871 loss: 1.0695071 acc: 0.45666666666666667\n",
"New set of weights found, iteration: 3036 loss: 1.0694591 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 3505 loss: 1.0694478 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 3551 loss: 1.0693843 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 3583 loss: 1.0689726 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 4425 loss: 1.0688362 acc: 0.44\n",
"New set of weights found, iteration: 4458 loss: 1.0686082 acc: 0.4266666666666667\n",
"New set of weights found, iteration: 4464 loss: 1.0682111 acc: 0.45\n",
"New set of weights found, iteration: 4466 loss: 1.0681838 acc: 0.45\n",
"New set of weights found, iteration: 4468 loss: 1.0680927 acc: 0.44\n",
"New set of weights found, iteration: 4490 loss: 1.0676068 acc: 0.46\n",
"New set of weights found, iteration: 4567 loss: 1.0675108 acc: 0.44\n",
"New set of weights found, iteration: 4585 loss: 1.0671601 acc: 0.45\n",
"New set of weights found, iteration: 4608 loss: 1.0668906 acc: 0.46\n",
"New set of weights found, iteration: 4645 loss: 1.0667204 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 4657 loss: 1.0661545 acc: 0.44\n",
"New set of weights found, iteration: 4728 loss: 1.0658569 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 4736 loss: 1.0651729 acc: 0.43\n",
"New set of weights found, iteration: 4804 loss: 1.0649427 acc: 0.43\n",
"New set of weights found, iteration: 4805 loss: 1.0647476 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 4807 loss: 1.0638916 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 4832 loss: 1.0638343 acc: 0.45\n",
"New set of weights found, iteration: 4975 loss: 1.0636576 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 4981 loss: 1.0631315 acc: 0.45\n",
"New set of weights found, iteration: 4994 loss: 1.0625263 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 5086 loss: 1.0624654 acc: 0.46\n",
"New set of weights found, iteration: 5108 loss: 1.06243 acc: 0.4266666666666667\n",
"New set of weights found, iteration: 5116 loss: 1.062381 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 5129 loss: 1.062371 acc: 0.45666666666666667\n",
"New set of weights found, iteration: 5135 loss: 1.0610354 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 5141 loss: 1.0599765 acc: 0.43\n",
"New set of weights found, iteration: 5170 loss: 1.0597756 acc: 0.43666666666666665\n",
"New set of weights found, iteration: 5176 loss: 1.0596782 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 5200 loss: 1.0585961 acc: 0.45\n",
"New set of weights found, iteration: 5443 loss: 1.0584674 acc: 0.45666666666666667\n",
"New set of weights found, iteration: 5447 loss: 1.0584322 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 5456 loss: 1.0581532 acc: 0.43\n",
"New set of weights found, iteration: 5507 loss: 1.0578263 acc: 0.44\n",
"New set of weights found, iteration: 5509 loss: 1.0575186 acc: 0.44\n",
"New set of weights found, iteration: 5551 loss: 1.0564784 acc: 0.43666666666666665\n",
"New set of weights found, iteration: 5562 loss: 1.0564457 acc: 0.4666666666666667\n",
"New set of weights found, iteration: 5618 loss: 1.0563129 acc: 0.46\n",
"New set of weights found, iteration: 5650 loss: 1.056275 acc: 0.43666666666666665\n",
"New set of weights found, iteration: 5697 loss: 1.0562539 acc: 0.43\n",
"New set of weights found, iteration: 5704 loss: 1.0562011 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 5715 loss: 1.0553932 acc: 0.4666666666666667\n",
"New set of weights found, iteration: 5727 loss: 1.0553335 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 5743 loss: 1.05487 acc: 0.4266666666666667\n",
"New set of weights found, iteration: 5824 loss: 1.0542839 acc: 0.45\n",
"New set of weights found, iteration: 5828 loss: 1.054109 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 5831 loss: 1.0539638 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 5868 loss: 1.0535976 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 6097 loss: 1.053527 acc: 0.44\n",
"New set of weights found, iteration: 6269 loss: 1.0534726 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 6382 loss: 1.0534171 acc: 0.4166666666666667\n",
"New set of weights found, iteration: 6383 loss: 1.0532491 acc: 0.41\n",
"New set of weights found, iteration: 6437 loss: 1.0530741 acc: 0.4033333333333333\n",
"New set of weights found, iteration: 6448 loss: 1.0529624 acc: 0.41333333333333333\n",
"New set of weights found, iteration: 6489 loss: 1.0526949 acc: 0.42333333333333334\n",
"New set of weights found, iteration: 6509 loss: 1.0526203 acc: 0.4633333333333333\n",
"New set of weights found, iteration: 6510 loss: 1.0524963 acc: 0.44333333333333336\n",
"New set of weights found, iteration: 6532 loss: 1.0524482 acc: 0.43\n",
"New set of weights found, iteration: 6544 loss: 1.0522528 acc: 0.44\n",
"New set of weights found, iteration: 6632 loss: 1.052054 acc: 0.45666666666666667\n",
"New set of weights found, iteration: 6674 loss: 1.0520425 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 6719 loss: 1.0514716 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 6786 loss: 1.0514234 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 6834 loss: 1.0511931 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 6989 loss: 1.0511316 acc: 0.46\n",
"New set of weights found, iteration: 7071 loss: 1.0510049 acc: 0.4666666666666667\n",
"New set of weights found, iteration: 7213 loss: 1.0507108 acc: 0.4533333333333333\n",
"New set of weights found, iteration: 7618 loss: 1.0503324 acc: 0.46\n",
"New set of weights found, iteration: 7782 loss: 1.0500554 acc: 0.4633333333333333\n",
"New set of weights found, iteration: 8259 loss: 1.049952 acc: 0.45666666666666667\n",
"New set of weights found, iteration: 8287 loss: 1.0498099 acc: 0.42333333333333334\n",
"New set of weights found, iteration: 8343 loss: 1.0492009 acc: 0.44666666666666666\n",
"New set of weights found, iteration: 8352 loss: 1.0491385 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 8408 loss: 1.0489689 acc: 0.41333333333333333\n",
"New set of weights found, iteration: 8431 loss: 1.0489289 acc: 0.42333333333333334\n",
"New set of weights found, iteration: 8573 loss: 1.0487736 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 8704 loss: 1.0485046 acc: 0.4266666666666667\n",
"New set of weights found, iteration: 8753 loss: 1.0484601 acc: 0.4166666666666667\n",
"New set of weights found, iteration: 8999 loss: 1.0484053 acc: 0.4\n",
"New set of weights found, iteration: 9121 loss: 1.0479419 acc: 0.39666666666666667\n",
"New set of weights found, iteration: 9329 loss: 1.0479223 acc: 0.43333333333333335\n",
"New set of weights found, iteration: 9330 loss: 1.0475453 acc: 0.42\n",
"New set of weights found, iteration: 9395 loss: 1.0470929 acc: 0.43\n",
"New set of weights found, iteration: 9490 loss: 1.0470548 acc: 0.4166666666666667\n",
"New set of weights found, iteration: 9607 loss: 1.046979 acc: 0.4166666666666667\n",
"New set of weights found, iteration: 9821 loss: 1.0469537 acc: 0.4\n",
"New set of weights found, iteration: 9835 loss: 1.0469226 acc: 0.41333333333333333\n",
"New set of weights found, iteration: 9858 loss: 1.0468988 acc: 0.41\n",
"New set of weights found, iteration: 9867 loss: 1.0468199 acc: 0.41333333333333333\n",
"New set of weights found, iteration: 9877 loss: 1.0466912 acc: 0.39666666666666667\n",
"New set of weights found, iteration: 9892 loss: 1.0462589 acc: 0.3933333333333333\n",
"New set of weights found, iteration: 9908 loss: 1.0462073 acc: 0.4033333333333333\n"
]
}
],
"source": [
"# Create dataset\n",
"X, y = spiral_data(samples=100, classes=3)# Create model\n",
"dense1 = Layer_Dense(2, 3) # first dense layer, 2 inputs\n",
"activation1 = Activation_ReLU()\n",
"dense2 = Layer_Dense(3, 3) # second dense layer, 3 inputs, 3 outputs\n",
"activation2 = Activation_Softmax()\n",
"# Create loss function\n",
"loss_function = Loss_CategoricalCrossentropy()\n",
"# Helper variables\n",
"lowest_loss = 9999999 # some initial value\n",
"best_dense1_weights = dense1.weights.copy()\n",
"best_dense1_biases = dense1.biases.copy()\n",
"best_dense2_weights = dense2.weights.copy()\n",
"best_dense2_biases = dense2.biases.copy()\n",
"for iteration in range(10000):\n",
" # Update weights with some small random values\n",
" dense1.weights += 0.05 * np.random.randn(2, 3)\n",
" dense1.biases += 0.05 * np.random.randn(1, 3)\n",
" dense2.weights += 0.05 * np.random.randn(3, 3)\n",
" dense2.biases += 0.05 * np.random.randn(1, 3)\n",
" # Perform a forward pass of our training data through this layer\n",
" dense1.forward(X)\n",
" activation1.forward(dense1.output)\n",
" dense2.forward(activation1.output)\n",
" activation2.forward(dense2.output)\n",
" # Perform a forward pass through activation function\n",
" # it takes the output of second dense layer here and returns loss\n",
" loss = loss_function.calculate(activation2.output, y)\n",
" # Calculate accuracy from output of activation2 and targets\n",
" # calculate values along first axis\n",
" predictions = np.argmax(activation2.output, axis=1)\n",
" accuracy = np.mean(predictions == y)\n",
" # If loss is smaller - print and save weights and biases aside\n",
" if loss < lowest_loss:\n",
" print('New set of weights found, iteration:', iteration,'loss:', loss, 'acc:', accuracy)\n",
" best_dense1_weights = dense1.weights.copy()\n",
" best_dense1_biases = dense1.biases.copy()\n",
" best_dense2_weights = dense2.weights.copy()\n",
" best_dense2_biases = dense2.biases.copy()\n",
" lowest_loss = loss\n",
" # Revert weights and biases\n",
" else:\n",
" dense1.weights = best_dense1_weights.copy()\n",
" dense1.biases = best_dense1_biases.copy()\n",
" dense2.weights = best_dense2_weights.copy()\n",
" dense2.biases = best_dense2_biases.copy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}