What does a "difficult" MNIST digit look like?

A question I've asked myself repeatedly. It's always interesting when a new deep learning architecture is able to beat the state of the art. The MNIST dataset has 10,000 images in the test set. At the time of writing Hinton's capsule networks has achieved the state of the art with 0.25% test error. This translates to 25 misclassified digits. Not bad at all. But what do these digits look like? How does this compare to human performance?

In this blog post I'm going to try to gain some intuition on how good the state of the art is compared to human performance by looking at misclassified MNIST digits using a simple convnet written in keras.

Model

For the model I'm going to build a simple "off the shelf" convnet (also using batchnorm and dropout).

In [4]:
input_shape = (img_rows,img_cols,1)
model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5),
                 activation='relu',
                 input_shape=input_shape, padding='valid'))
model.add(BatchNormalization())
model.add(Conv2D(64, (5, 5), activation='relu', padding='valid'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu', padding='valid'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dropout(0.1))
model.add(Dense(128, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

Results

In [6]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Test loss: 0.024712890361214206
Test accuracy: 0.9938

Test error rate of 0.62%. Not bad for something this simple. But the question I'm dieing to ask: What do these digits look like?

In [13]:
def show_misclassified_digits(X, Y):
    Yp = model.predict(X)
    bad_imgs = [[] for i in range(10)]
    for X0,Y0,Yp0 in zip(X,Y,Yp):
        Y0max = np.argmax(Y0)
        Yp0max = np.argmax(Yp0)
        if Y0max == Yp0max: continue
        bad_imgs[Y0max].append(X0)
    for i in range(10):
        print("Misclassified images that were actually %d" %(i))
        plot_img_grid(bad_imgs[i], 3,4)

show_misclassified_digits(x_test, y_test)
Misclassified images that were actually 0
Misclassified images that were actually 1
Misclassified images that were actually 2
Misclassified images that were actually 3
Misclassified images that were actually 4
Misclassified images that were actually 5
Misclassified images that were actually 6
Misclassified images that were actually 7
Misclassified images that were actually 8
Misclassified images that were actually 9

Looking at the digits that were misclassified, I estimate I could probably accurately classify 2/3 of these misclassified digits, but who knows how many digits the model accurately translated that I would fail on.

What about training digits that were misclassified?

Our model didn't reach 100% accuracy within 20 epochs, presumably due to the regularizing effects of dropout and batch norm. What do the digits look like that are not correctly classified from the training set?

In [14]:
show_misclassified_digits(x_train, y_train)
Misclassified images that were actually 0
Misclassified images that were actually 1
Misclassified images that were actually 2
Misclassified images that were actually 3
Misclassified images that were actually 4
Misclassified images that were actually 5
Misclassified images that were actually 6
Misclassified images that were actually 7
Misclassified images that were actually 8
Misclassified images that were actually 9

Interesting! I suspect that some of these training digits were mislabeled. I haven't researched the process used to label the dataset, but I have a hard time believing that a human intended to write "4" and it came out looking like that "7".

Conclusion

Looking at the test digits that were misclassified, I estimate I could probably accurately classify 2/3 of the misclassified digits. That would imply a best case human error rate of about 0.23%. But in reality it seems highly unlikely that I'd be able to correctly classify 100% of the test digits the convnet got right, so my real world accuracy is probably a fair bit lower. After this experiment I'm very impressed with the state of the art (0.25%). It seems likely that ML is achieving super human performance on MNIST.

References

  1. http://yann.lecun.com/exdb/mnist/
  2. Hinton, Geoffrey E., Sara Sabour, and Nicholas Frosst. "Matrix capsules with EM routing." (2018).
  3. https://github.com/fchollet/keras

Comments

Comments powered by Disqus