Open In App

Contrastive Learning with SimCLR in PyTorch

Last Updated : 25 Jun, 2025
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

SimCLR (Simple Framework for Contrastive Learning of Visual Representations) is a self-supervised learning approach that learns powerful image representations without labeled data. It does so by maximizing agreement between differently augmented views of the same image via a contrastive loss in the latent space.By maximizing the similarity between different augmented views of the same image and minimizing similarity with other images, SimCLR enables models to learn powerful visual representations. Implementing SimCLR in PyTorch allows for flexible experimentation and strong performance on image tasks using only unlabeled data.

Core Ideas of SimCLR

  • Data Augmentation: Each input image is randomly augmented twice to create two correlated views (positive pair). Common augmentations include random cropping, flipping, color jittering, and Gaussian blur.
  • Encoder Network: A deep neural network (often ResNet-18/50) encodes each augmented image into a feature vector. The final classification layer is removed and replaced with a projection head.
  • Projection Head: An MLP (multi-layer perceptron) maps the encoder’s output to a lower-dimensional embedding space where the contrastive loss is applied.

SimCLR in PyTorch: Main Components

1. Data Augmentation: Define a set of strong augmentations to generate two different views of each image.

2. Encoder and Projection Head

  • Use a backbone (e.g., ResNet-18/50) without the final classification layer.
  • Add a projection head (typically a 2-layer MLP) to map features to the embedding space.
  • Contrastive Loss (NT-Xent): The normalized temperature-scaled cross-entropy loss encourages embeddings of positive pairs to be similar and those of different images (negatives) to be dissimilar.

3. Contrastive Loss Implementation: A custom loss function (NT-Xent) computes the contrastive loss for each positive pair in the batch.

4. Training Loop

  • For each batch, generate two augmented views per image.
  • Pass both views through the encoder and projection head.
  • Compute the contrastive loss and update the model.

PyTorch Implementation

1. Install Libraries

Installs the required PyTorch packages to run the model.

Python
!pip install torch torchvision pytorch-lightning --quiet

Output

Installation
Install Libraries

Setup: Install Libraries

The standard imports models, datasets, transformations, and helper utilities (like tqdm for progress bars).

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.datasets import CIFAR10
import torchvision.models as models
import numpy as np
import random
from tqdm import tqdm

2. Data Augmentation

The augmentations create two different views of the same image to learn invariant features.

Python
simclr_transform = T.Compose([
    T.RandomResizedCrop(32),
    T.RandomHorizontalFlip(),
    T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.GaussianBlur(kernel_size=3),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

3. Dataset for Two Views

For each image, return two augmented versions: xi and xj, used as positive pairs in contrastive learning.

Python
class SimCLRDataset(Dataset):
    def __init__(self, base_dataset, transform):
        self.dataset = base_dataset
        self.transform = transform

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        xi = self.transform(image)
        xj = self.transform(image)
        return xi, xj

    def __len__(self):
        return len(self.dataset)

4. SimCLR Model = Encoder + Projection Head

  • encoder: a ResNet18 without its final classification head.
  • projection_head: maps features to a smaller contrastive space (128-d).
  • Output z is used to compute the contrastive loss.
Python
class SimCLRModel(nn.Module):
    def __init__(self, projection_dim=128):
        super().__init__()
        base_model = models.resnet18(weights=None)
        num_ftrs = base_model.fc.in_features
        base_model.fc = nn.Identity()
        self.encoder = base_model
        self.projection_head = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return z

5. NT-Xent Loss

  • Combine positive and negative pairs.
  • Cosine similarity is scaled by temperature.
  • Use NT-Xent (normalized temperature-scaled cross-entropy) to compute the loss
Python
def nt_xent_loss(z_i, z_j, temperature=0.5):
    z = torch.cat([z_i, z_j], dim=0)
    z = F.normalize(z, dim=1)

    similarity = torch.matmul(z, z.T)
    N = z_i.shape[0]

    mask = (~torch.eye(2*N, dtype=bool)).to(z.device)
    sim = similarity / temperature
    exp_sim = torch.exp(sim) * mask

    positive_sim = torch.exp(F.cosine_similarity(z_i, z_j) / temperature)
    positives = torch.cat([positive_sim, positive_sim], dim=0)

    denominator = exp_sim.sum(dim=1)
    loss = -torch.log(positives / denominator)
    return loss.mean()

6. Training Loop

  • Load data and send to GPU.
  • Run through encoder + projection.
  • Compute contrastive loss.
  • Backpropagate and optimize.
Python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = CIFAR10(root='./data', train=True, download=True)
contrastive_dataset = SimCLRDataset(train_dataset, simclr_transform)
train_loader = DataLoader(contrastive_dataset, batch_size=256, shuffle=True, num_workers=2)

model = SimCLRModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

for epoch in range(10):
    model.train()
    total_loss = 0
    for x_i, x_j in tqdm(train_loader):
        x_i, x_j = x_i.to(device), x_j.to(device)
        z_i = model(x_i)
        z_j = model(x_j)

        loss = nt_xent_loss(z_i, z_j)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss / len(train_loader):.4f}")

Output

Training
Training Loop

7. Evaluation (Linear Probe)

  • Freeze the encoder to use it as a feature extractor.
  • Train a linear classifier on top of these features.
  • This checks how useful the learned representations are for actual classification.
Python
for param in model.encoder.parameters():
    param.requires_grad = False

classifier = nn.Linear(512, 10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

def get_features_and_labels(loader):
    features, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            h = model.encoder(x)
            features.append(h.cpu())
            labels.append(y)
    return torch.cat(features), torch.cat(labels)

You can download the complete code from here: Contrastive Learning using SimCLR

Practical Considerations

  • Batch Size: Large batch sizes improve performance by providing more negative samples.
  • Projection Head: Used only during pretraining; final representations are taken from the encoder, not the projection head.
  • Multi-GPU Training: Supported for scaling to large datasets.

Training and Evaluation

  • Pretraining: Train SimCLR on unlabeled data using the contrastive loss.
  • Fine-tuning: After pretraining, remove the projection head and fine-tune or evaluate the encoder on downstream tasks (e.g., classification with a linear head).

Similar Reads