Learning Deep Learning — MNIST with FastAI (Part 2)

David Clark
9 min readJun 6, 2021

In this series of posts my goal is to document and illustrate my journey as I learn the art and science of “deep learning”. I know these posts will be useful to myself as I look back and reflect on how far I’ve come, and I hope they can be great starting points for others as well.

In Part 1 I set the stage and we looked at solving a subset of the full MNIST problem (just dealing with two numbers, 3’s and 7's). Here in Part 2 we’ll take what we learned and apply it to build the full 10-digit MNIST classifier.

Let’s get to it!

Setup and Data

If you’re here without reading Part 1, that’s OK, but if you have questions about setup it might be good to peak back there. I’m going to jump right into it this time around.

# Notebook setup
!pip install -q fastbook
import fastbook
fastbook.setup_book()from fastai.vision.all import *
from fastbook import *# Set the default image color map to grayscale
matplotlib.rc('image', cmap='Greys')
# Download the dataset
path = untar_data(URLs.MNIST)
Path.BASE_PATH = path
path.ls()

OUTPUT:
[Path('testing'), Path('training')]

Here we see a slight difference in the folder structure from the reduced URLs.MNIST_SAMPLE dataset we were looking at in Part 1. We can still create our DataLoaders with the same method, but we need to give the appropriate subfolder names.

'''
Create the DataLoaders which will feed batches of training and validation images
'''
dls = ImageDataLoaders.from_folder(
path,
train='training',
valid='testing',
)
# LOOK AT YOUR DATA
dls.train.show_batch()
Images with their labels from a training batch

Here we peak at some example data from a training batch, and this time we see all manner of digits, not just 3’s and 7's. The labels above each image look correct, too.

At this point it would be wise to investigate the data your dealing with a LOT more. Use dls.valid.show_batch()to look at some images from a validation batch; count the number of total images in train vs validation; look at the distribution of each digit within train and validation. These are fundamental things that, if wrong, will screw up the whole training and model-building process. (Imagine if the entire dataset was missing 9's.)

Building a Baseline: Average Image Creation

Just like in Part 1, we need to establish a baseline to gauge how effective our fancy model is compared to some simpler, cheaper method. Let’s use that “average image” comparison technique again. Last time it gave us a combined 94% accurate baseline, but this time around we have 10 digits instead of 2, so we should expect it to be lower.

This time with 10 classification categories the process will be a bit more involved. Let’s see what we get.

# Gather the validation subdirectories representing each digit
valid_paths = (path / 'testing').ls()
valid_paths.sort()
valid_paths

OUTPUT:
[
Path('testing/0'),
Path('testing/1'),
Path('testing/2'),
Path('testing/3'),
Path('testing/4'),
Path('testing/5'),
Path('testing/6'),
Path('testing/7'),
Path('testing/8'),
Path('testing/9'),
]

At the location we downloaded the full MNIST dataset, sub-folder “testing” has 10 subdirectories, one for each digit. For each of these subfolders we want to create an averaged image.

# Converts paths to image data and stacks them into a single tensor
def stack(paths):
return torch.stack([
tensor(Image.open(path)).float() / 255 for path in paths
])
'''
Helper to get us the "average" image for all images under a given path
'''
def get_mean_img(path):
imgs = stack(path.ls())
return imgs.mean(0)
# Form the "average" images for each validation (testing) sub-dir
mean_imgs = torch.stack([get_mean_img(p) for p in valid_paths])
mean_imgs.shape

OUTPUT:
torch.Size([10, 28, 28])

Nice. We should have formed an average image for each digit and stacked them together into a single tensor, and the tensor shape is indicating that’s exactly what happened. At least we have a stack of ten 28x28 things (our image size). Not taking anything for granted, let’s plot these images to make sure they look correct.

_, axs = plt.subplots(ncols=mean_imgs.shape[0], figsize=(12, 12))
for img, ax in zip(mean_imgs.tolist(), axs):
show_image(img, ax)
Beautiful! Our 10 average, blurry images.

If the image-plotting code above looks unfamiliar I totally get how you feel. I’m pretty new to matplotlib myself and barely know how it works, but it’s been around for a while and is highly used which means there are LOTS of helpful examples online.

In this instance we made a subplot which essentially creates “slots” for us to show images in all at once. With a slot (called an “axis”) for each image, gratefully FastAI’s show_image() method takes a plot axis as a second argument, which makes things easy.

Baseline Continued: Average Image Comparison

With our average images created for each digit and looking believable, now we need a way to compare any single image to this set of platonic ideals. Whichever “average” the image we’re testing is closest to will be the classification we give it.

This comparison is quite a bit trickier than simply asking “Are you closer to 3 or 7?” which we did in Part 1. It would be great if we could reduce each training image to an array of 10 scores, one for each digit representing how well it matches. We’ll make a low score better, more like a measure of how “off” we are from a digit’s average image.

# Stack up the training images
train_img_paths = get_image_files(path / 'training')
train_x = stack(train_img_paths)
train_x.shape

OUTPUT:
torch.Size([60000, 28, 28])

We stack up all of our training images with some help from FastAI’s get_image_files() which recursively finds images under the given path. We can see by the shape of the resulting tensor that we have 60,000 training images.

# Method to compare image pixel similarity
def abs_dist(a, b): return (a-b).abs().mean((-1, -2))
scores = torch.stack([abs_dist(mean_imgs, x) for x in train_x])
scores.shape

OUTPUT:
torch.Size([60000, 10])

Now, if we did this correctly, scores should be exactly what we were trying to build — a tensor where instead of pixel values for each training image we instead have 10 scores, one for each digit’s “average” image. A shape of 60,000x10 looks promising! (If you’ve given the code above a good look and you still don’t understand how it works, it might be good to look up numpy broadcasting.)

Let’s try to verify that we got what we were looking for.

# Grab an example training image
im = train_x[0]
print(scores[0])
'''
Plot the training image along with how well it matches each "average" digit image.
'''
_, axs = plt.subplots(ncols=10+1, figsize=(12, 12))
show_image(1-im, axs[0])
for i in range(10):
show_image((im - mean_imgs[i]).abs(), axs[i+1]);
Visualizing how well a sample training image matches each of the “average” digit images, with scores.

At the top of the image above we see the score values per average digit image (remember that lower score is a better match). The 4th entry of 0.1327 is the lowest in the list, which means our example image matches “3” the best. That’s good, because the inverted image on the far left is our example image and it sure looks like a “3” to me!

Roughly below each score in the image above we can see a visualization of how well our example image matches the average for that digit. (The more black in the image, the further away we were from a perfect match.) This visual sanity-check seems to line up with our scores, so I’m fairly confident we got the results we were looking for.

Finally, to complete our baseline, we need to convert these scores into classifications and compare to the actual labels to measure how accurate we were.

# Get predicted classifications from scores
preds = scores.argmin(-1)
'''
Get actual classifications (labels).
Make sure these are in the order!
'''
train_y = tensor([
int(parent_label(img_path))
for img_path in train_img_paths
])
# Measure prediction accuracy
(preds == train_y).float().mean().item()

OUTPUT:
0.6472166776657104

And there’s our baseline: 64.7%!

Conveniently, the argmin() function (which returns the index of the lowest item) applied for each individual set of scores gives us exactly our predicted label (since in this case the index is the label).

Then to load the actual labels (recall, which are denoted by the name of the parent directory for the image) we use FastAI’s helper method parent_label().

Finally, we do the equality comparison (which gives us Booleans), convert to floats so we have “0” or “1”, and take the mean to get an average accuracy across all our predictions for every training image. A baseline of 64.7% seems very reasonable, since 10% would be guessing and we saw 94% using this same method with just two classification buckets.

Training a Neural Net

With the hard work of forming a baseline done, it’s time to blow it out of the water! Hopefully. For starters, let’s use the same model architecture of resnet18 and do the same fit_one_cycle(1) quick training we did in Part 1 so we can compare results.

# Just a reminder of how we loaded our data before
path = untar_data(URLs.MNIST)
Path.BASE_PATH = path
dls = ImageDataLoaders.from_folder(
path,
train='training',
valid='testing',
)
# The Learner handles the training loop for us
learner = cnn_learner(
dls,
resnet18,
pretrained=False,
metrics=accuracy,
)
learner.fit_one_cycle(1)
With a single “cycle” of training we’re already at 99.0% in about 1 minute!

99.0% in less than a minute of training — that’s better than I expected! In Part 1 with just the two classification categories we hit 99.6% with this method, so I’m glad we got lower here. Something would be off if we performed better on this harder problem.

While looking at accuracy numbers is helpful, a similar principle applies to your model as to your data: LOOK AT YOUR MODEL.

interp = ClassificationInterpretation.from_learner(learner)
interp.plot_confusion_matrix()
interp.plot_top_losses(6)
A classification “confusion matrix” helps you see where your model gets things confused.
The images our model finds the most confusing.

Dang, FastAI makes this easy. Three lines of code and we get beautiful images showing wonderful details about where our model is struggling. The confusion matrix isn’t too useful in this instance since we’re already at 99% accuracy, but we can see the highest non-diagonal score is a 10 for confusion between “4” and “9”. That makes sense to me — those numbers look pretty similar.

Now I absolutely love the plot_top_losses() function that FastAI gives us. The “loss” is what we’re optimizing for (minimizing). It’s a measure of how accurately we categorize these images. plot_top_losses() shows us the images responsible for the largest losses — the ones that “confuse” our model the most.

Looking at the “top loss” examples, for that first image our model predicted “1”, but the actual label was “6”. I can totally see it both ways, so I’m not surprised our model was confused too. Some of the other examples seem more like errors, yet some are just as confusing.

Looking at your model’s top losses along with the confusion matrix is critical for debugging both your model and your data.

Can We Do Even Better?

If we got 99.0% with a single cycle, can we do better with more training? What IS state-of-the-art, anyway? PapersWithCode is reporting state-of-the-art results get about 99.8%. While I don’t expect to reach that without significant fine-tuning and special effort, let’s see how far we can get.

# 15 cycles this time - 15x as much training
learner = cnn_learner(
dls,
resnet18,
pretrained=False,
metrics=accuracy,
)
learner.fit_one_cycle(15)
With 15 cycles of training instead of just 1 we can improve from 99.0% to 99.6% accuracy.

This time we trained for 15 cycles and seemed to be leveling out around 99.6% at the 12th cycle. While we could train for longer, I’m very satisfied with this result (0.2% off state-of-the-art!) and have doubts that further training would improve things without changes to the model or the data.

Let’s take at look at those top losses again to see what changed!

“Top loss” images after 15 cycles of training — 99.6% accuracy.

Yep, I can officially say these images are confusing. Good job model, I don’t blame you.

Wrap Up

If you made it this far, thank you, and I hope you enjoyed what you found here. If you find deep learning really interesting, go try it. FastAI is a great place to start, and I can’t wait to dig into it more myself.

If you want more articles like this, comment to let me know, and follow me for new articles I hope to put out in the future.

--

--