Skip to content

[Targeting 2024 Q2] Dataloader crashes after enabling persistent_workers=True  #48964

Closed
@Wong4j

Description

@Wong4j

bug描述 Describe the Bug

As a benchmark, I only need to train a few steps per epoch. So, I add a break in the loop. For example:

train_dataloader = paddle.io.DataLoader(dataset, batch_size=16, num_workers=4, persistent_workers=True)
bench_epochs = 3
bench_steps = 10
for epoch in range(bench_epochs):
      for i, batch in enumerate(train_dataloader):
          if i > bench_steps:
              break
          do_training_process()

It works fine if I set persistent_workers=False. But after setting persistent_workers=True, I got this error:

$ python test_dataloader.py 
W1209 00:24:39.929682 3970167 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 11.8, Runtime API Version: 11.7
W1209 00:24:39.943722 3970167 gpu_resources.cc:91] device: 0, cuDNN Version: 8.7.
Epoch 0 batch 0: loss = 2.582632303237915
Epoch 0 batch 1: loss = 2.553558588027954
Epoch 0 batch 2: loss = 2.5804834365844727
Epoch 0 batch 3: loss = 2.531757354736328
Epoch 0 batch 4: loss = 2.3217196464538574
Epoch 0 batch 5: loss = 2.3962247371673584
Epoch 0 batch 6: loss = 2.3609089851379395
Epoch 0 batch 7: loss = 2.398348808288574
Epoch 0 batch 8: loss = 2.594115734100342
Epoch 0 batch 9: loss = 2.648672342300415
Epoch 0 batch 10: loss = 2.4073853492736816
Traceback (most recent call last):
  File "test_dataloader.py", line 51, in <module>
    for i, (image, label) in enumerate(loader()):
  File "/usr/local/lib/python3.8/dist-packages/paddle/fluid/dataloader/dataloader_iter.py", line 746, in __next__
    data = _restore_batch(data, self._structure_infos.pop(0))
IndexError: pop from empty list

Here is the complete code to reproduce:

# cat test_dataloader.py 
import numpy as np

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset, BatchSampler, DataLoader

BATCH_NUM = 20
BATCH_SIZE = 16
EPOCH_NUM = 100
STEPS_PER_EPOCH = 10

IMAGE_SIZE = 784
CLASS_NUM = 10

USE_GPU = False # whether use GPU to run model

# define a random dataset
class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([IMAGE_SIZE]).astype('float32')
        label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)

class SimpleNet(nn.Layer):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(IMAGE_SIZE, CLASS_NUM)

    def forward(self, image, label=None):
        return self.fc(image)

simple_net = SimpleNet()
opt = paddle.optimizer.SGD(learning_rate=1e-3,
                          parameters=simple_net.parameters())

loader = DataLoader(dataset,
                    batch_size=16,
                    num_workers=4,
                    persistent_workers=True)

for e in range(EPOCH_NUM):
    for i, (image, label) in enumerate(loader()):
        if i > STEPS_PER_EPOCH:
            break
        out = simple_net(image)
        loss = F.cross_entropy(out, label)
        avg_loss = paddle.mean(loss)
        avg_loss.backward()
        opt.minimize(avg_loss)
        simple_net.clear_gradients()
        print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))

其他补充信息 Additional Supplementary Information

No response

Metadata

Metadata

Assignees

Labels

NVIDIAPFCCPaddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfccstatus/close已关闭type/bug-report报bug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions