Neural Net Vs MNIST (Part 1)

Nearly every machine learning blog inevitably does a post on the MNIST data set. I don’t want to break with tradition…so here is mine.

What is MNIST ?

MNIST is a data set containing 60000 images of handwritten digits (0-9) along with a label indicating what digit the image represents. Each image is 28×28 pixels and gray-scale. A sample of what the data looks like can be seen below. This dataset is often used to test machine learning approaches since the data comes in a convenient format (normalized and already split into testing and training data). The data can be downloaded from here.
MNIST_sample

The Model.

The model I am going to use for this example is extremely basic, it consists of 1 layer (the output layer) containing 10 perceptrons using the sigmoid activation function. Each of the perceptrons has 784 weights, one for each pixel of the 28×28 image. Each perceptron will (hopefully) learn to recognize a specific digit, and produce a large output relative to the other perceptrons when its digit is presented to the network.
perceptron unit. image from wikimedia

Training the model


I trained the model with a batch_size of 16, the images used in the training can be seen on the left side of the video. Before training on the images I made the network try to classify them, and colored the images accordingly (RED=wrongly classified, GREEN=correct). This of course does not give an accurate indication of the performance of the network since the network is classifying the same data that it is training on. On the right hand side I have also included a graph where you can see the networks average error over the test data set, this should give us a more accurate idea of its performance.

As you can see the accuracy of the network over the test set increases rapidly, reaching 80% in only a few hundred training steps. When I ran the training longer I saw this percentage go up to 90%. This might seem incredibly high, but keep in mind that sub 0.5% error rates have been achieved on this data with other network architectures.

In part 2 we will run a few predictions with the model and look into what it actually learned, and why this model is not very useful despite its 90% accuracy over the test set. The code for this part and part 2 is on my github. The repository also contains the code for training a convolutional neural net on the same data (will make a post about that later). If you did/didn’t enjoy this and/or have suggestions please leave some feedback below 🙂 .

Leave a Reply

Your email address will not be published. Required fields are marked *