Exploring image classification with Fashion-MNIST

A fun little benchmarking exercise

January 10, 2021

Over the Christmas break I’ve spent a little time creating a GitHub repo for image classification benchmarking. I’m using the Fashion-MNIST dataset for training and testing. It’s a little more challenging than the traditional handwriting MNIST dataset but still manageable enough that I can train the models in a reasonable amount of time on my laptop. I was largely inspired by a developer friend of mine who has a stable of programming problems he tackles in new languages. As of the time of writing, I’ve implemented three Convolutional Neural Networks (CNNs) and one Recurrent Neural Network (RNN) using TensorFlow. Here’s a brief overview of the data and models.

Fashion-MNIST dataset

The Fashion-MNIST dataset is a collection of small (28 x 28 resolution) greyscale images of ten different types of clothing. The collection is divided into 60,000 training images and 10,000 testing images.

Example for each class from the Fashion-MNIST dataset
Example for each class from the Fashion-MNIST dataset

Convolutional Neural Networks

CNNs are deep learning models widely used in image recognition and computer vision tasks. They have excellent performance on spatial grid-like data, which includes images and time-series data. The “convolution” in CNN refers to the mathematical operation of the same name. CNNs consist of multiple convolutional and pooling layers followed by fully-connected layers. A terminal fully-connected later with a softmax activation function is required for classification.

CNN architecture showing multiple convolution and pooling layers followed by fully connected layers.
CNN architecture showing multiple convolution and pooling layers followed by fully connected layers. Géron (2017)

CNNs have a history dating back to the 1950s, although modern, backpropagation-trained CNNs are generally attributed to the work of Yann LeCun et al in the 1980s and 1990s. I implemented the following CNNs using TensorFlow:

  1. LeNet-5: developed by Yann LeCun et al in 1998, this is a fairly small and simple CNN.
  2. AlexNet-11: developed by Alex Krizhevsky et al in 2012, this was the first CNN to introduce consecutive convolution layers.
  3. ResNet-34: developed by Kaiming He et al in 2015, this was the first CNN to use skip connections.

The AlexNet and ResNet models were developed for the ImageNet dataset with images of 227 x 227 resolution. For my implementation, I therefore added an input layer which increases the resolution of the Fashion-MNIST images from 28 x 28 to 227 x 227. It should also be possible to apply these models on small resolution images by adjusting the size of the filters in the convolutional layers instead of increasing the image resolution. This approach would certainly be more computationally efficient by reducing the memory requirements.

Recurrent Neural Networks

RNNs are deep learning models specialised for sequential data. They are often used for solving Natural Language Processing (NLP) problems. Whilst they are less commonly used for computer vision tasks than CNNs, it is possible to use an RNN to classify images. Indeed, they excel with time-series data and can therefore be used for learning tasks involving video. RNNs consist of memory cells, which contain one or more neurons where output is fed back into itself (hence the name “recurrent”). The same weights are used in each time point so the models end up being rather compact. For image data, one dimension is treated as the time dimension, so for example if height is the time dimension, each row represents the data for one time point. Just like the CNN, a terminal fully connected layer with a softmax activation function is required for classification.

Classification is performed using RNNs with architecture consisting of a memory cell followed by a fully connected layer with softmax activation function. The memory cell is shown "unrolled", i.e. showing the individual iterations through the RNN graph.
Classification is performed using RNNs with architecture consisting of a memory cell followed by a fully connected layer with softmax activation function. The memory cell is shown "unrolled", i.e. showing the individual iterations through the RNN graph. Géron (2017)

I implemented a really simple one-layer Long Short Term Memory (LSTM) RNN network using TensorFlow. LSTM networks have greater performance over standard RNNs for long sequences and differ in their memory cell structure (which I won’t get into here). See this excellent blog by Christopher Olah about LSTM networks.

Classification Evaluation

Following training using 10-fold cross-validation on the set of 60,000 training images, I used the test set of 10,000 images for evaluating model performance using accuracy, recall and precision. Overall, all of the models had pretty good performance, ranging between 88.9% and 91.1% accuracy. As expected, the simple LeNet model performed the worst, though not that much worse than the more complex, modern models! Precision and recall were also great, indicating the models could discern between the classes rather effectively.

Model Type Accuracy Precision Recall
ResNet CNN 0.9111 0.9133 0.9096
AlexNet CNN 0.9060 0.9104 0.9010
LSTM RNN 0.8959 0.9047 0.8884
LeNet CNN 0.8882 0.8960 0.8839

A great way to visualise the performance of a classification is with a confusion matrix. At a glance, it’s possible to see how many examples were misclassified, and what they were misclassified as. Here’s the confusion matrix for the AlexNet test set:

Confusion matrix for classification using AlexNet
Confusion matrix for classification using AlexNet

As you can see, most classes were well classified with at least 80% of the instances of each class being correctly classified. The greatest source of misclassification was between shirts and T-shirts… perhaps this is not surprising as these items of clothing look very similar (especially at such low resolutions). The confusion matrices for the other models are almost identical, with LeNet experiencing even more misclassification of shirts as T-shrts, pullovers and coats!

The code to reproduce the image classification and visualisation is available in the repo and takes anywhere from a couple of minutes to an hour to run (depending on the model chosen and available hardware, running on a GPU is highly recommended).

To be continued…

I had fun setting up this benchmarking repo over Christmas. I also read a few chapters of Hands-on machine learning with Scikit-Learn and TensorFlow: concepts, tools, and techniques to build intelligent systems by Aurélien Géron and can heartily recommend it. I’m not sure how much time I’ll have to devote to this little benchmarking project as semester two of my MSc starts tomorrow! I hope I’ll be able to explore further models and technologies in the future though. I’m also intrigued by generative models and hope to explore this aspect further (e.g. RNNs generating poems, Wikipedia articles and images of house numbers).

© 2021, Katie Baker