How to handle overfitting in PyTorch models using Early Stopping
Last Updated :
20 Jun, 2025
When we train a machine learning model, sometimes it learns the training data too well. It even learns the small details and random noise that don’t really matter. This problem is called overfitting. A model that overfits does very well on the training data but performs badly on new data.
What is Early Stopping?
Early stopping means we stop training the model before it starts overfitting. While training, we check how the model is doing on a separate dataset called the validation set. If the model stops getting better on this validation set after a few rounds (epochs) we stop the training.
For example, imagine you are training your model and after each round of training, you check its performance on the validation set. If you see that the model has not improved for 5 rounds in a row, you stop the training. These 5 rounds without improvement are called patience.
Benefits of Early Stopping
- Prevents Overfitting: By halting training at the right time it ensures the model does not overfit.
- Saves Time and Resources: It reduces unnecessary training time and computational resources by stopping the training early.
- Optimizes Model Performance: Helps in selecting the version of the model that performs best on unseen data.
Implementing Early Stopping in PyTorch
In this section, we are going to walk through the process of creating, training and evaluating a simple neural network using PyTorch mainly focusing on the implementation of early stopping to prevent overfitting.
Step 1: Import Libraries
First, we import the necessary libraries like numpy and pytorch.
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
Step 2: Define the Neural Network Architecture
Next, we define a simple neural network class using PyTorch's nn.Module
. The neural network has:
- fc1, fc2, fc3: Fully connected layers with ReLU activations.
- forward method: Defines the forward pass of the network.
Python
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
Step 3: Implement Early Stopping
We implement an EarlyStopping
class to halt training if the validation loss stops improving. Here the parameters are:
- patience: Number of epochs to wait before stopping if no improvement.
- delta: Minimum change in the monitored quantity to qualify as an improvement.
- best_score, best_model_state: Track the best validation score and model state.
- call method: Updates the early stopping logic.
Python
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.best_score = None
self.early_stop = False
self.counter = 0
self.best_model_state = None
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.best_model_state = model.state_dict()
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_model_state = model.state_dict()
self.counter = 0
def load_best_model(self, model):
model.load_state_dict(self.best_model_state)
Step 4: Load the Data
We load and transform the MNIST dataset.
Python
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
Step 5: Initialize the Model, Loss Function and Optimizer
We set up the model with cross entropy, adam optimizer and patience of 5.
Python
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
early_stopping = EarlyStopping(patience=5, delta=0.01)
Step 6: Train the Model with Early Stopping
We train the model, incorporating early stopping. Here:
- Train loop: Train the model, update weights and calculate training loss.
- Validation loop: Evaluate the model on validation data and calculate validation loss.
- Early stopping check: Apply early stopping logic after each epoch.
Python
num_epochs = 100
for epoch in range(num_epochs):
model.train()
train_loss =
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
train_loss /= len(train_loader.dataset)
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
val_loss /= len(val_loader.dataset)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
early_stopping.load_best_model(model)
Step 7: Evaluate the Model
Finally, we evaluate the model's accuracy on the test dataset. The evaluation loop computes the accuracy by comparing predicted labels with true labels.
Python
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')
Output:
The complete code of the above model is here.
Similar Reads
Similar Reads
How to handle overfitting in computer vision models? Overfitting is a common problem in machine learning, especially in computer vision tasks where models can easily memorize training data instead of learning to generalize from it. Handling overfitting is crucial to ensure that the model performs well on unseen data. In this article, we are going to e
7 min read
Using Early Stopping to Reduce Overfitting in Neural Networks Overfitting is a common challenge in training neural networks. It occurs when a model learns to memorize the training data rather than generalize patterns from it, leading to poor performance on unseen data. While various regularization techniques like dropout and weight decay can help combat overfi
7 min read
How to handle overfitting in TensorFlow models? Overfitting occurs when a machine learning model learns to perform well on the training data but fails to generalize to new, unseen data. In TensorFlow models, overfitting typically manifests as high accuracy on the training dataset but lower accuracy on the validation or test datasets. This phenome
10 min read
Identifying Overfitting in Machine Learning Models Using Scikit-Learn Overfitting is a critical issue in machine learning that can significantly impact the performance of models when applied to new, unseen data. Identifying overfitting in machine learning models is crucial to ensuring their performance generalizes well to unseen data. In this article, we'll explore ho
7 min read
How to Split a Dataset Using PyTorch Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another. In this article, we are going to discuss the process of s
6 min read
Create Model using Custom Module in Pytorch Custom module in Pytorch A custom module in PyTorch is a user-defined module that is built using the PyTorch library's built-in neural network module, torch.nn.Module. It's a way of creating new modules by combining and extending the functionality provided by existing PyTorch modules. The torch.nn.M
8 min read