Here we’ll make this base model fit a new classification task: COVID-19 and Normal chest X-rays. We’ll use the ResNet50 model with a dataset that contains 2,484 images – a small dataset compared to ImageNet.
In this series of articles, we’ll apply a Deep Learning (DL) network, ResNet50, to diagnose Covid-19 in chest X-ray images. We’ll use Python’s TensorFlow library to train the neural network on a Jupyter Notebook.
The tools and libraries you’ll need for this project are:
IDE:
Libraries:
We are assuming that you are familiar with deep learning with Python and Jupyter notebooks. If you're new to Python, start with this tutorial. And if you aren't yet familiar with Jupyter, start here.
In the previous article, we loaded the base model and showed its layers. Now, we’ll make this base model fit a new classification task: COVID-19 and Normal chest X-rays. We’ll use the ResNet50 model with a dataset that contains 2,484 images – a small dataset compared to ImageNet. To make our model fit the above new task, we need to:
- Remove the fully connected layers of the network and add a global averaging layer to condense all the feature maps
- Replace the fully connected layers of the base model with new layers
- Add a new dense output layer with two nodes that represent the two target classes: COVID-19 and Normal
- Freeze the weights of the pretrained layers in the feature extraction part and randomize those of the new fully connected layers
- Train ResNet50 to update only the weights of the fully connected layers
Restructuring the Base Model
As mentioned above, the first step towards reshaping our ResNet50 to apply transfer learning is to remove the fully connected layers and add a global averaging pooling layer, which is used to condense all the feature maps from the base model. In addition, two dense layers are used – one with 512 nodes and the other one with 2 that represent the two target classes.
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
preds = tf.keras.layers.Dense(2, activation ='softmax')(x)
Now, we can create a model with the new structure, containing the base model (feature extraction part), the new input data, and the output structure (preds).
model = tf.keras.models.Model(inputs=base_model.input, outputs=preds)
print(model.summary())
Figure 5 shows that the reshaped model is similar to the base one. The differences are the added global averaging pooling layer and some fully connected dense layers, which change the network output to fit our new target classification task.
Figure 5: A snapshot of the reshaped ResNet50 model
Freezing Weights
Now it’s time to freeze the weights of all layers before the global averaging pooling layer.
We’ll use the same code we’ve used in the previous article to help us enumerate the layers, and then we’ll freeze them by setting them "False."
for i, layer in enumerate(model.layers):
print(i, layer.name)
for layer in model.layers[:175]:
layer.trainable = False
After freezing the weights of the layers, we’ll set the newly added layers as trainable by setting them "True."
for layer in model.layers[175:]:
layer.trainable = True
Next Step
In the next article, we’ll fine-tune our ResNet50 model. Stay tuned!
Dr. Helwan is a machine learning and medical image analysis enthusiast.
His research interests include but not limited to Machine and deep learning in medicine, Medical computational intelligence, Biomedical image processing, and Biomedical engineering and systems.