Contrastive Learning with SimCLR in PyTorch
Last Updated :
25 Jun, 2025
SimCLR (Simple Framework for Contrastive Learning of Visual Representations) is a self-supervised learning approach that learns powerful image representations without labeled data. It does so by maximizing agreement between differently augmented views of the same image via a contrastive loss in the latent space.By maximizing the similarity between different augmented views of the same image and minimizing similarity with other images, SimCLR enables models to learn powerful visual representations. Implementing SimCLR in PyTorch allows for flexible experimentation and strong performance on image tasks using only unlabeled data.
Core Ideas of SimCLR
- Data Augmentation: Each input image is randomly augmented twice to create two correlated views (positive pair). Common augmentations include random cropping, flipping, color jittering, and Gaussian blur.
- Encoder Network: A deep neural network (often ResNet-18/50) encodes each augmented image into a feature vector. The final classification layer is removed and replaced with a projection head.
- Projection Head: An MLP (multi-layer perceptron) maps the encoder’s output to a lower-dimensional embedding space where the contrastive loss is applied.
SimCLR in PyTorch: Main Components
1. Data Augmentation: Define a set of strong augmentations to generate two different views of each image.
2. Encoder and Projection Head
- Use a backbone (e.g., ResNet-18/50) without the final classification layer.
- Add a projection head (typically a 2-layer MLP) to map features to the embedding space.
- Contrastive Loss (NT-Xent): The normalized temperature-scaled cross-entropy loss encourages embeddings of positive pairs to be similar and those of different images (negatives) to be dissimilar.
3. Contrastive Loss Implementation: A custom loss function (NT-Xent) computes the contrastive loss for each positive pair in the batch.
4. Training Loop
- For each batch, generate two augmented views per image.
- Pass both views through the encoder and projection head.
- Compute the contrastive loss and update the model.
PyTorch Implementation
1. Install Libraries
Installs the required PyTorch packages to run the model.
Python
!pip install torch torchvision pytorch-lightning --quiet
Output
Install LibrariesSetup: Install Libraries
The standard imports models, datasets, transformations, and helper utilities (like tqdm for progress bars).
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.datasets import CIFAR10
import torchvision.models as models
import numpy as np
import random
from tqdm import tqdm
2. Data Augmentation
The augmentations create two different views of the same image to learn invariant features.
Python
simclr_transform = T.Compose([
T.RandomResizedCrop(32),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=3),
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
3. Dataset for Two Views
For each image, return two augmented versions: xi and xj, used as positive pairs in contrastive learning.
Python
class SimCLRDataset(Dataset):
def __init__(self, base_dataset, transform):
self.dataset = base_dataset
self.transform = transform
def __getitem__(self, index):
image, _ = self.dataset[index]
xi = self.transform(image)
xj = self.transform(image)
return xi, xj
def __len__(self):
return len(self.dataset)
4. SimCLR Model = Encoder + Projection Head
- encoder: a ResNet18 without its final classification head.
- projection_head: maps features to a smaller contrastive space (128-d).
- Output z is used to compute the contrastive loss.
Python
class SimCLRModel(nn.Module):
def __init__(self, projection_dim=128):
super().__init__()
base_model = models.resnet18(weights=None)
num_ftrs = base_model.fc.in_features
base_model.fc = nn.Identity()
self.encoder = base_model
self.projection_head = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Linear(512, projection_dim)
)
def forward(self, x):
h = self.encoder(x)
z = self.projection_head(h)
return z
5. NT-Xent Loss
- Combine positive and negative pairs.
- Cosine similarity is scaled by temperature.
- Use NT-Xent (normalized temperature-scaled cross-entropy) to compute the loss
Python
def nt_xent_loss(z_i, z_j, temperature=0.5):
z = torch.cat([z_i, z_j], dim=0)
z = F.normalize(z, dim=1)
similarity = torch.matmul(z, z.T)
N = z_i.shape[0]
mask = (~torch.eye(2*N, dtype=bool)).to(z.device)
sim = similarity / temperature
exp_sim = torch.exp(sim) * mask
positive_sim = torch.exp(F.cosine_similarity(z_i, z_j) / temperature)
positives = torch.cat([positive_sim, positive_sim], dim=0)
denominator = exp_sim.sum(dim=1)
loss = -torch.log(positives / denominator)
return loss.mean()
6. Training Loop
- Load data and send to GPU.
- Run through encoder + projection.
- Compute contrastive loss.
- Backpropagate and optimize.
Python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = CIFAR10(root='./data', train=True, download=True)
contrastive_dataset = SimCLRDataset(train_dataset, simclr_transform)
train_loader = DataLoader(contrastive_dataset, batch_size=256, shuffle=True, num_workers=2)
model = SimCLRModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for epoch in range(10):
model.train()
total_loss = 0
for x_i, x_j in tqdm(train_loader):
x_i, x_j = x_i.to(device), x_j.to(device)
z_i = model(x_i)
z_j = model(x_j)
loss = nt_xent_loss(z_i, z_j)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1} | Loss: {total_loss / len(train_loader):.4f}")
Output
Training Loop7. Evaluation (Linear Probe)
- Freeze the encoder to use it as a feature extractor.
- Train a linear classifier on top of these features.
- This checks how useful the learned representations are for actual classification.
Python
for param in model.encoder.parameters():
param.requires_grad = False
classifier = nn.Linear(512, 10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
def get_features_and_labels(loader):
features, labels = [], []
with torch.no_grad():
for x, y in loader:
x = x.to(device)
h = model.encoder(x)
features.append(h.cpu())
labels.append(y)
return torch.cat(features), torch.cat(labels)
You can download the complete code from here: Contrastive Learning using SimCLR
Practical Considerations
- Batch Size: Large batch sizes improve performance by providing more negative samples.
- Projection Head: Used only during pretraining; final representations are taken from the encoder, not the projection head.
- Multi-GPU Training: Supported for scaling to large datasets.
Training and Evaluation
- Pretraining: Train SimCLR on unlabeled data using the contrastive loss.
- Fine-tuning: After pretraining, remove the projection head and fine-tune or evaluate the encoder on downstream tasks (e.g., classification with a linear head).
Similar Reads
Deep Learning with PyTorch | An Introduction PyTorch in a lot of ways behaves like the arrays we love from Numpy. These Numpy arrays, after all, are just tensors. PyTorch takes these tensors and makes it simple to move them to GPUs for the faster processing needed when training neural networks. It also provides a module that automatically calc
7 min read
Start learning PyTorch for Beginners Machine Learning helps us to extract meaningful insights from the data. But now, it is capable of mimicking the human brain. This is done using neural networks, which contain the various interconnected layers of nodes containing the data. This data is passed to forward layers. Subsequently, the mode
15+ min read
PyTorch-Lightning Conda Setup Guide PyTorch-Lightning is a popular deep learning framework and is more simple version of PyTorch. It is easy to use as one does not need to define the training loops and the testing loops. We can perform distributed training easily without making the code complex. Some other features include more focus
7 min read
How to implement transfer learning in PyTorch? What is Transfer Learning?Transfer learning is a technique in deep learning where a pre-trained model on a large dataset is reused as a starting point for a new task. This approach significantly reduces training time and improves performance, especially when dealing with limited datasets. It is very
15+ min read
Saving and Loading Weights in PyTorch Lightning In Machine learning models, it is important to save and load weights efficiently. This helps us preserve the state of our model during training, so we can resume later without starting from scratch. In this article, we are going to discuss how to save and load weights in PyTorch Lightning. PyTorch L
8 min read
Load a Computer Vision Dataset in PyTorch Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process. There are several ways to load a computer vision da
3 min read