Open In App

How to handle overfitting in PyTorch models using Early Stopping

Last Updated : 20 Jun, 2025
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

When we train a machine learning model, sometimes it learns the training data too well. It even learns the small details and random noise that don’t really matter. This problem is called overfitting. A model that overfits does very well on the training data but performs badly on new data.

What is Early Stopping?

Early stopping means we stop training the model before it starts overfitting. While training, we check how the model is doing on a separate dataset called the validation set. If the model stops getting better on this validation set after a few rounds (epochs) we stop the training.

For example, imagine you are training your model and after each round of training, you check its performance on the validation set. If you see that the model has not improved for 5 rounds in a row, you stop the training. These 5 rounds without improvement are called patience.

Benefits of Early Stopping

  • Prevents Overfitting: By halting training at the right time it ensures the model does not overfit.
  • Saves Time and Resources: It reduces unnecessary training time and computational resources by stopping the training early.
  • Optimizes Model Performance: Helps in selecting the version of the model that performs best on unseen data.

Implementing Early Stopping in PyTorch

In this section, we are going to walk through the process of creating, training and evaluating a simple neural network using PyTorch mainly focusing on the implementation of early stopping to prevent overfitting.

Step 1: Import Libraries

First, we import the necessary libraries like numpy and pytorch.

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

Step 2: Define the Neural Network Architecture

Next, we define a simple neural network class using PyTorch's nn.Module. The neural network has:

  • fc1, fc2, fc3: Fully connected layers with ReLU activations.
  • forward method: Defines the forward pass of the network.
Python
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Step 3: Implement Early Stopping

We implement an EarlyStopping class to halt training if the validation loss stops improving. Here the parameters are:

  • patience: Number of epochs to wait before stopping if no improvement.
  • delta: Minimum change in the monitored quantity to qualify as an improvement.
  • best_score, best_model_state: Track the best validation score and model state.
  • call method: Updates the early stopping logic.
Python
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

Step 4: Load the Data

We load and transform the MNIST dataset.

Python
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_size = int(0.8 * len(train_dataset))  
val_size = len(train_dataset) - train_size  
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Step 5: Initialize the Model, Loss Function and Optimizer

We set up the model with cross entropy, adam optimizer and patience of 5.

Python
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
early_stopping = EarlyStopping(patience=5, delta=0.01)

Step 6: Train the Model with Early Stopping

We train the model, incorporating early stopping. Here:

  • Train loop: Train the model, update weights and calculate training loss.
  • Validation loop: Evaluate the model on validation data and calculate validation loss.
  • Early stopping check: Apply early stopping logic after each epoch.
Python
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    train_loss = 
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)

    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)

    val_loss /= len(val_loader.dataset)

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

early_stopping.load_best_model(model)

Step 7: Evaluate the Model

Finally, we evaluate the model's accuracy on the test dataset. The evaluation loop computes the accuracy by comparing predicted labels with true labels.

Python
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Output:

Screenshot-2025-06-20-152803

The complete code of the above model is here.

Similar Reads


Similar Reads