1 训练整体流程
主要核心代码如下:
1.1 train.py
# 1.加载配置文件中的参数
cfg = Config.fromfile(args.config)
# 2. 实例化runner
runner = Runner.from_cfg(cfg)
# 3. 开始训练
runner.train()
1.2 mmengine/runner/runner.py -- class Runner()
# 1.构建train_loop
self._train_loop = self.build_train_loop(
self._train_loop)
# 2. 运行训练代码
model = self.train_loop.run()
其中self.train_loop是EpochBasedTrainLoop的一个实例,详细的介绍可以看第二部分Runner类
1.3 mmengine/runner/loops.py -- class EpochBasedTrainLoop(BaseLoop)
# 1.self.run()
def run(self) -> torch.nn.Module:
"""Launch training."""
self.runner.call_hook('before_train')
while self._epoch < self._max_epochs and not self.stop_training:
self.run_epoch()
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and (self._epoch % self.val_interval == 0
or self._epoch == self._max_epochs)):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
return self.runner.model
# 2.self.run_epoch()
def run_epoch(self) -> None:
"""Iterate one epoch."""
self.runner.call_hook('before_train_epoch')
self.runner.model.train()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch')
self._epoch += 1
# 3. self.run_iter()
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
self.runner.call_hook(
'before_train_iter', batch_idx=idx, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.call_hook(
'after_train_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
self._iter += 1
其中核心的代码为 self.runner.model.train_step(),将在3 Runner.model里解析。
2 Runner类 (runner.py)
Runner类的一些核心属性和方法
1. self._work_dir 保存结果的文件夹
self._work_dir = osp.abspath(work_dir)
2. self.cfg
if cfg is not None:
if isinstance(cfg, Config):
self.cfg = copy.deepcopy(cfg)
elif isinstance(cfg, dict):
self.cfg = Config(cfg)
else:
self.cfg = Config(dict())
3. self._train_loop
# 1. self.__init__():
self._train_loop = train_cfg
# 2. self.train():
self._train_loop = self.build_train_loop(self._train_loop)
# 3. self.build_train_loop():
def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop:
if isinstance(loop, BaseLoop):
return loop
elif not isinstance(loop, dict):
raise TypeError(
f'train_loop should be a Loop object or dict, but got {loop}')
loop_cfg = copy.deepcopy(loop)
if 'type' in loop_cfg and 'by_epoch' in loop_cfg:
raise RuntimeError(
'Only one of `type` or `by_epoch` can exist in `loop_cfg`.')
if 'type' in loop_cfg:
loop = LOOPS.build(
loop_cfg,
default_args=dict(
runner=self, dataloader=self._train_dataloader))
else:
by_epoch = loop_cfg.pop('by_epoch')
if by_epoch:
loop = EpochBasedTrainLoop(
**loop_cfg, runner=self, dataloader=self._train_dataloader)
else:
loop = IterBasedTrainLoop(
**loop_cfg, runner=self, dataloader=self._train_dataloader)
return loop # type: ignore