Thought leadership from the most innovative tech companies, all in one place.

How to Freeze Model Weights in PyTorch for Transfer Learning: Step-by-Step Tutorial

A step-by-step guide to freezing weights in PyTorch for transfer learning, using a simple example.

Transfer learning is a machine learning technique where a pre-trained model is adapted for a new, but similar problem. One of the key steps in transfer learning is the ability to freeze the layers of the pre-trained model so that only some portions of the network are updated during training. Freezing is crucial when you want to maintain the features that the pre-trained model has already learned.

In this tutorial, we will walk through the process of freezing weights in PyTorch for transfer learning, using a simple example.

Prerequisites

If you don’t have torch and torchvision libraries installed, here is how we can do it in the terminal:

pip install torch torchvision

Import Libraries

Let’s start with the Python code. First, we import the libraries for this tutorial:

import torch
import torch.nn as nn
import torchvision.models as models

Load a Pretrained Model

We’ll use the pre-trained ResNet-18 model for this example:

# Load the pre-trained model
resnet18 = models.resnet18(pretrained=True)

Freezing Layers

To freeze layers, we set the requires_grad attribute to False. This prevents PyTorch from calculating the gradients for these layers during backpropagation.

# Freeze all layers
for param in resnet18.parameters():
    param.requires_grad = False

Unfreezing Some Layers

Typically, for achieving the best results, we fine-tune some fo the later layers in the network. We can do this as follows:

# Unfreeze last layer
for param in resnet18.fc.parameters():
    param.requires_grad = True

Modifying the Network Architecture

We’ll replace the last fully-connected layer to adapt the model to a new problem with a different number of output classes (let’s say 10 classes). Also, this allows us to use this pretrained network for other applications other than classification, for example segmentation. For segmentation, we replace the final layer with a convolutional layer instead. For this example, we continue with a classification task with 10 classes.

# Replace last layer
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 10)

Training the Modified Model

Let’s define a simple training loop. For demonstration purposes, we’ll use random data:

# Create random data
inputs = torch.randn(5, 3, 224, 224)
labels = torch.randint(0, 10, (5,))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(resnet18.fc.parameters(), lr=0.001, momentum=0.9)

# Training loop
for epoch in range(5):
    optimizer.zero_grad()
    outputs = resnet18(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print(f''Epoch {epoch+1}/5, Loss: {loss.item()}'')

In this example, only the last layer’s weights will be updated during training.

Conclusion

Freezing layers in PyTorch is simple and straightforward. By setting the requires_grad attribute to False, you prevent specific layers from being updated during training, allowing you to harness the power of pre-trained models effectively.

Understanding how to freeze and unfreeze layers in PyTorch is crucial for effective transfer learning, as it allows you to leverage pre-trained models for similar but different tasks. With this simple yet powerful technique, you can save both time and computational resources in training deep neural networks.

Thank you for reading!




Continue Learning