Pytorch Dataloader 模块源码分析(二):Sampler / Fetcher 组件及 Dataloader 核心代码

本文深入剖析Pytorch Dataloader的内部组件,包括Sampler(SequentialSampler, RandomSampler, BatchSampler)和Fetcher的工作原理。Sampler负责生成访问Dataset的index,Fetcher对Dataset做封装并转换为Pytorch Tensor。Dataloader的单线程和多线程场景中,Fetcher的作用是减少I/O瓶颈,提高数据加载效率。通过理解和优化这些组件,能提升Pytorch模型训练的效率。" 51906253,1841921,Git 进阶:掌握Submodule使用,"['Git', '版本管理', 'Submodule', '团队协作', '代码管理']

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Dataloader 组件

Sampler 类

在看 Sampler 的具体实现之前,我们先看看 Dataloader 在什么时候产生 Sampler 对象:

class DataLoader(object):
    def __init__(self, ...):
        ...
        if sampler is None:  
            ...
             # 如果指定shuffle就使用随机采样,否则使用顺序采样
                if shuffle: 
                    sampler = RandomSampler(dataset, generator=generator)
                else:
                    sampler = SequentialSampler(dataset)

        if batch_size is not None and batch_sampler is None:
            # 如果指定了batch_size又没有指定自定义的batch_sampler,就开启自动批采样
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        ...

我们可以看到 Sampler 对象的主要职责就是生成用于访问 Dataset 的 index。其中 Sampler 的子类如下:

  • SequentialSampler 顺序采样
  • RandomSampler 随机采样
  • BatchSampler 批采样

实际上还有其他的采样方法,但是因为使用的不多,本文主要讲解上述的三种 Sampler。上述提到的几种采样类都是 Sampler 的子类,Sampler 中的__iter__方法定义为 raise NotImplementedError:

class Sampler(Generic[T_co]):
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

SequentialSampler

SequentialSampler 实现:

class SequentialSampler(Sampler[int]):
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
    	# 创建一个迭代器
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

这里主要关注__Iter__方法,实际上返回的 index 就是 range(len(self.data_source)) 顺序递增的结果:len(data_source) 实际上就是 Dataset 返回的 samples 的长度。创建迭代器之后,当对这个迭代器调用__next__方法,就会返回 0, 1, 2, 3, 4, … 顺序递增的 index。

RandomSampler

RandomSampler 实现:

class RandomSampler(Sampler[int]):
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator    
		...
    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator
		# replacement 表示是否可以生成重复 index
        if self.replacement:
        	# num_samples 表示一次性采样的数据量
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            for _ in range(self.num_samples // n):
                yield from torch.randperm(n, generator=generator).tolist()
            yield
### PyTorch DataLoader 源码解析 #### 1. 整体架构概述 `DataLoader` 是 PyTorch 中用于加载数据的核心模块之一。该类的设计旨在简化大规模机器学习训练过程中数据预处理和批量读取的任务[^2]。 #### 2. 迭代器机制 当 `for` 循环遍历 `DataLoader` 对象时,实际上是在调用 `_SingleProcessDataLoaderIter` 或者 `_MultiProcessDataLoaderIter` 类实例的 `__iter__()` 方法来获取迭代器对象。随后,在每一次循环中,程序会执行此迭代器上的 `__next__()` 函数以取得下一个批次的数据直至所有样本被访问完毕[^4]。 #### 3. 关键组件说明 为了实现高效并行化的数据加载过程,`DataLoader` 使用了几种重要的内部组件: - **Dataset**: 定义了如何从磁盘或其他存储介质中读入单个样本的方法;支持两种形式——基于索引随机存取的标准版 (`Map-style`) 和顺序流式的可迭代版本 (`IterableStyle`) [^3]。 - **Sampler/BatchSampler**: 负责决定哪些样本应该被打包成一批次传给模型训练环节。前者负责生成单一样本ID序列而后者则进一步将其划分为多个子集作为输入批次。 - **Fetcher**: 实现多线程或多进程环境下安全地抓取由 Sampler 提供 ID 所指向的具体样本内容的功能。 - **Collate Function**: 用户自定义或默认提供的函数用来组合若干单独取出的小样成本批数据结构的一部分。 - **Pin Memory Option**: 如果 GPU 可用,则开启这项特性可以让 CPU 上准备好的张量直接复制到固定内存区域以便快速传输至显卡端参与计算。 #### 4. 数据流动路径 整个流程始于用户配置好必要的参数后初始化了一个新的 `DataLoader` 实例。接着这个实例会在后台默默构建起上述提到的各种辅助工具,并最终产出一个合适的迭代器供外部使用。每当请求新一批数据时,迭代器就会协调各部分协同工作:先让 Sampler 生产出待处理项列表交给 Fetcher 去实际提取文件内容,再经 Collate Fn 加工整理成为标准格式最后交付给使用者。 ```python from torch.utils.data import DataLoader, TensorDataset import numpy as np # 创建简单的TensorDataset data = TensorDataset( *[np.random.rand(8), np.random.rand(8)] ) loader = DataLoader(data, batch_size=2) for inputs, targets in loader: print(inputs.shape) # 输出每轮得到的batch大小 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值