Given a set of images with a relatively high degree of separation between classes, it is perfectly feasible to train a CNN to classify those images on a typical laptop or PC. A great example is the famous MNIST dataset, which contains 60,000 training images of scanned, handwritten digits, each measuring 28 x 28 pixels, plus 10,000 test images. Here are the first 50 scans in the training set:
The following statements load the dataset and use it to train a CNN:
from keras.datasets import mnist from tensorflow.keras.utils import to_categorical from keras.models import Sequential from keras.layers import Conv2D, MaxPooling2D from keras.layers import Dense, Flatten (train_images, train_labels), (test_images, test_labels) = mnist.load_data() x_train = train_images.reshape(60000, 28, 28, 1) / 255 x_test = test_images.reshape(10000, 28, 28, 1) / 255 y_train = to_categorical(train_labels) y_test = to_categorical(test_labels) model = Sequential() model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))) model.add(MaxPooling2D(2, 2)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(2, 2)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dense(10, activation='softmax')) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) hist = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, batch_size=50)
Once trained, this simple CNN can achieve 99% accuracy classifying handwritten digits:
One reason it can attain such accuracy is the number of training samples – roughly 6,000 per class. (As a test, I trained the network with just 100 samples of each class and got 92% accuracy.) Another factor is that a 2 looks very different than an 8. In other words, different digits aren’t terribly difficult to distinguish from one another.
That’s not the case when the problem is more perceptual – for example, when you train a CNN to determine whether a photo contains an Arctic fox or a polar bear as described in my previous post. More training samples are crucial here because you need lots of wildlife pictures with different poses taken from different angles. Even if you can come up with all the photos you need, you’ll probably need a GPU to train the network in a reasonable amount of time. And even then, the accuracy might not be what you want it to be.
The good news is that you might not have to train a CNN from scratch. Microsoft, Google, and other tech companies use a subset of the ImageNet dataset containing almost 1.3 million images to train state-of-the-art CNNs to recognize hundreds of objects, including Arctic foxes and polar bears. Then they publish them for others to use. Called pretrained CNNs, they are more sophisticated than anything you’re likely to train yourself. And if that’s not awesome enough, Keras reduces the process of loading a pretrained CNN to one line of code.
Keras provides classes that wrap more than two dozen popular pretrained CNNs. The full list is shown below and is documented at https://keras.io/api/applications/. Most of these CNNs are documented in scholarly papers such as Deep Residual Learning for Image Recognition and EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. Some have won prestigious competitions such as the ImageNet Large Scale Visual Recognition Challenge and the COCO Detection Challenge. Among the most notable are the ResNet family of networks from Microsoft and the Inception networks from Google. Also noteworthy is MobileNet, which trades size for accuracy and is ideal for mobile devices due to its small memory footprint.
The following statement instantiates Keras’s MobileNetV2 class and initializes it with the weights, biases, and kernel values arrived at when the network was trained on the ImageNet dataset:
from tensorflow.keras.applications import MobileNetV2 model = MobileNetV2(weights='imagenet')
The weights=’imagenet’ parameter tells Keras what parameters to load to recreate the network in its trained state. You can also pass a path to a file containing custom weights, but “imagenet” is the only set of predefined weights that are currently supported.
Before an image is submitted to a pretrained CNN for classification, it must be resized to the dimensions the CNN expects – typically 224 x 224 – and preprocessed. Different CNNs expect images to be preprocessed in different ways, so Keras provides a preprocess_input function for each pretrained CNN. It also includes utility functions for loading and resizing images. The following statements load an image from the file system and preprocess it for input to the MobileNetV2 network:
import numpy as np from tensorflow.keras.applications.mobilenet import preprocess_input from keras.preprocessing import image x = image.load_img('arctic_fox.jpg', target_size=(224, 224)) x = image.img_to_array(x) x = np.expand_dims(x, axis=0) x = preprocess_input(x)
In most cases, preprocess_input does all the work that’s needed, which often involves normalizing the pixel values and converting RGB images to BGR format. In some cases, however, you still need to divide the pixel values by 255. ResNet50V2 is one example:
import numpy as np from tensorflow.keras.applications.resnet50 import preprocess_input from keras.preprocessing import image x = image.load_img('arctic_fox.jpg', target_size=(224, 224)) x = image.img_to_array(x) x = np.expand_dims(x, axis=0) x = preprocess_input(x) / 255
Once an image is preprocessed, making a prediction is as simple as calling the network’s predict method:
y = model.predict(x)
To help you interpret the output, Keras also provides a network-specific decode_predictions method. Here’s what that method returned when the photo below was submitted to ResNet50V2:
ResNet50V2 is 89% sure that the photo contains an Arctic fox – which, it just so happens, it does. MobileNetV2 predicted with 92% certainty that the photo contains an Arctic fox. Both networks were trained on the same dataset, but different pretrained CNNs classify images slightly differently.
Use ResNet50V2 to Classify Images
Let’s use Keras to load a pretrained CNN and classify a pair of images. Begin by downloading the images: arctic_fox_140.jpeg and walrus_143.png. Save them in the directory where your Jupyter notebooks are hosted. Then fire up a notebook and use the following statements to load ResNet50V2:
from tensorflow.keras.applications import ResNet50V2 model = ResNet50V2(weights='imagenet') model.summary()
Next, load the Arctic-fox image and show it in the notebook:
%matplotlib inline import matplotlib.pyplot as plt from keras.preprocessing import image x = image.load_img('arctic_fox_140.jpeg', target_size=(224, 224)) plt.xticks() plt.yticks() plt.imshow(x)
Now preprocess the image (remember that for ResNets, you also have to divide all the pixel values by 255) and pass it to the CNN for classification:
import numpy as np from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions x = image.img_to_array(x) x = np.expand_dims(x, axis=0) x = preprocess_input(x) / 255 y = model.predict(x) decode_predictions(y)
The output should look like this:
ResNet50V2 is virtually certain that the image contains an Arctic fox. But now load the walrus image:
x = image.load_img('walrus_143.png', target_size=(224, 224)) plt.xticks() plt.yticks() plt.imshow(x)
And ask ResNet50V2 to classify it:
x = image.img_to_array(x) x = np.expand_dims(x, axis=0) x = preprocess_input(x) / 255 y = model.predict(x) decode_predictions(y)
Here’s the output:
ResNet50V2 thinks the image is most likely an armadillo, but it’s not even very sure about that. Can you guess why?
ResNet50V2 was trained with almost 1.3 million images. None of them, however, contained a walrus. It’s not one of the classes the CNN was trained to recognize. The ImageNet 1000 Class List shows a complete list of classes it was trained to recognize. A pretrained CNN is great when you need it to classify images using the classes it was trained with, but it is powerless to handle domain-specific tasks that it wasn’t trained for.
But all is not lost. In my next post, I’ll introduce a powerful technique called transfer learning that enables pretrained CNNs to be repurposed to solve domain-specific problems. The repurposing can usually be done on an ordinary CPU; no GPU required. Transfer learning sometimes achieves 95% accuracy with just a few hundred training images. Once you learn about it, you’ll have a completely different perspective on the efficacy of CNNs.
Get the Code
You can download a Jupyter notebook containing examples utilizing pretrained CNNs from the deep-learning repo that I maintain on GitHub. Feel free to check out the other notebooks in the repo while you’re at it. Also be sure to check back from time to time because I am constantly uploading new samples and updating existing ones.