(一)开始训练
def run(self):
assert self.mode in ["train", "train_only", "val"]
if self.mode == "train":
if self.epoch > 0:
logging.info(f"Resuming training from epoch: {
self.epoch}")
# resuming from a checkpoint
if self.is_intermediate_val_epoch(self.epoch - 1):
logging.info("Running previous val epoch")
self.epoch -= 1
self.run_val()
self.epoch += 1
self.run_train()
self.run_val()
elif self.mode == "val":
self.run_val()
elif self.mode == "train_only":
self.run_train()
整个训练过程由trainer.run()函数开始,并调用run_train()和train_val()函数
run_train()函数
def run_train(self):
while self.epoch < self.max_epochs:
dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
barrier()#用于同步分布式训练中各个进程,
#确保所有进程在继续下一步前都已完成当前步骤
outs = self.train_epoch(dataloader)#训练一个epoch
self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
# log train to text file.
if self.distributed_rank == 0:#再次确认只有主进程进行保存操作
with g_pathmgr.open(
os.path.join(self.logging_conf.log_dir, "train_stats.json"),
"a",
) as f:
f.write(json.dumps(outs) + "\n")
# Save checkpoint before validating
self.save_checkpoint(self.epoch + 1)
del dataloader#删除dataloader,
gc.collect()#并进行垃圾回收释放内存
# Run val, not running on last epoch since will run after the
# loop anyway
if self.is_intermediate_val_epoch(self.epoch):
self.run_val()
if self.distributed_rank == 0:
self.best_meter_values.update(self._get_trainer_state("train"))
with g_pathmgr.open(
os.path.join(self.logging_conf.log_dir, "best_stats.json"),
"a",
) as f:
f.write(json.dumps(self.best_meter_values) + "\n")
self.epoch += 1
# epoch was incremented in the loop but the val step runs out of the loop
self.epoch -= 1
这段代码包含了两个主要函数
self.train_dataset.get_loader(epoch=int(self.epoch))
self.train_epoch(dataloader)
下面我们将进行解读
(二)dataloader部分
(1)data部分的configure
data:
train:
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
phases_per_epoch: ${
scratch.phases_per_epoch}
batch_sizes:
- ${
scratch.train_batch_size}
datasets:
- _target_: training.dataset.utils.RepeatFactorWrapper
dataset:
_target_: training.dataset.utils.ConcatDataset
datasets:
- _target_: training.dataset.vos_dataset.VOSDataset
transforms: ${
vos.train_transforms}
training: true
video_dataset:
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
img_folder: ${
dataset.img_folder}
gt_folder: ${
dataset.gt_folder}
file_list_txt: ${
dataset.file_list_txt}
sampler:
_target_: training.dataset.vos_sampler.RandomUniformSampler
num_frames: ${
scratch.num_frames}
max_num_objects: ${
scratch.max_num_objects}
multiplier: ${
dataset.multiplier}
shuffle: True
num_workers: ${
scratch.num_train_workers}
pin_memory: True
drop_last: True
collate_fn:
_target_: training.utils.data_utils.collate_fn
_partial_: true
dict_key: all
结合configure可知self.train_dataset.get_loader(epoch=int(self.epoch))中调用的是TorchTrainMixedDataset,接下来我们看一下他的实现
(2)TorchTrainMixedDataset
class TorchTrainMixedDataset:
def __init__(
self,
datasets: List[Dataset],
#根据configure,这里是training.dataset.utils.RepeatFactorWrapper
batch_sizes: List[int],
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
phases_per_epoch: int = 1,
dataset_prob: Optional[List[float]] = None,
) -> None:
"""
Args:
datasets (List[Dataset]): List of Datasets to be mixed.
batch_sizes (List[int]): Batch sizes for each dataset in the list.
num_workers (int): Number of workers per dataloader.
shuffle (bool): Whether or not to shuffle data.
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
drop_last (bool): Whether or not to drop the last batch of data.
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
worker_init_fn (Callable): Function to init each dataloader worker.
phases_per_epoch (int): Number of phases per epoch.
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
"""
self.datasets = datasets
self.batch_sizes = batch_sizes
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
assert len(self.datasets) > 0
for dataset in self.datasets:
assert not isinstance(dataset, IterableDataset), "Not supported"
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
self._set_dataset_epoch(dataset, 0)
self.phases_per_epoch = phases_per_epoch
self.chunks = [None] * len(datasets)
if dataset_prob is None:
# If not provided, assign each dataset a probability proportional to its length.
dataset_lens = [
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
for d, bs in zip(datasets, batch_sizes)
]
total_len = sum(dataset_lens)
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
else:
assert len(dataset_prob) == len(datasets)
dataset_prob = torch.tensor(dataset_prob)
logging.info(f"Dataset mixing probabilities: {
dataset_prob.tolist()}")
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
self.dataset_prob = dataset_prob
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
#检查对象是否有epoch属性或set_epoch方法
if hasattr(dataset, "epoch"):
dataset.epoch = epoch
if hasattr(dataset, "set_epoch"):
dataset.set_epoch(epoch)
def get_loader(self, epoch) -> Iterable:
dataloaders = []
for d_idx, (dataset, batch_size) in enumerate(
zip(self.datasets, self.batch_sizes)
):
if self.phases_per_epoch > 1:
#将每个epoch划分成多个阶段
# Major epoch that looops over entire dataset
# len(main_epoch) == phases_per_epoch * len(epoch)
main_epoch = epoch // self.phases_per_epoch
# Phase with in the main epoch
local_phase = epoch % self.phases_per_epoch
# Start of new data-epoch or job is resumed after preemtion.
if local_phase == 0 or self.chunks[d_idx] is None:
# set seed for dataset epoch
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
self._set_dataset_epoch(dataset, main_epoch)
# Separate random generator for subset sampling
g = torch.Generator()
g.manual_seed(main_epoch)
self.chunks[d_idx] = torch.chunk(
torch.randperm(len(dataset), generator=g),
self.phases_per_epoch,
)
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
else:#sam2中设为1,不需进行划分
self._set_dataset_epoch(dataset, epoch)
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
sampler.set_epoch(epoch)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
dataloaders.append(
DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
worker_init_fn=self.worker_init_fn,
)
)
return MixedDataLoader(dataloaders, self.dataset_prob)
这段代码定义了一个名为 MixedDataLoader 的类,用于从多个dataloader中采样数据。它允许你根据指定的概率从不同的dataloader中获取数据,从而实现混合数据集的训练。由于 sam2 中仅提供了一个 RepeatFactorWrapper dataset作为 datasets 的输入,MixedDataLoader 的功能实际上退化成了一个仅包含对应 RepeatFactorWrapper 数据加载器的 DataLoader
(3)RepeatFactorWrapper
class RepeatFactorWrapper(Dataset):
"""
Thin wrapper around a dataset to implement repeat factor sampling.
The underlying dataset must have a repeat_factors member to indicate the per-image factor.
Set it to uniformly ones to disable repeat factor sampling
"""
def __init__(self, dataset, seed: int = 0):
self.dataset = dataset
self.epoch_ids = None
self._seed = seed
# Split into whole number (_int_part) and fractional (_frac_part) parts.
self._int_part = torch.trunc(dataset.repeat_factors)
self._frac_part = dataset.repeat_factors - self._int_part
def _get_epoch_indices(self, generator):
"""
Create a list of dataset indices (with repeats) to use for one epoch.
Args:
generator (torch.Generator): pseudo random number generator used for
stochastic rounding.
Returns:
torch.Tensor: list of dataset indices to use in one epoch. Each index
is repeated based on its calculated repeat factor.
"""
# Since repeat factors are fractional, we use stochastic rounding so
# that the target repeat factor is achieved in expectation over the
# course of training
rands = torch.rand(len(self._frac_part), generator=generator)
rep_factors = self._int_part + (rands < self._frac_part).float()
# Construct a list of indices in which we repeat images as specified
indices = [