mmdetection中FCOS代码精读

1  训练整体流程

主要核心代码如下:

1.1 train.py

# 1.加载配置文件中的参数
cfg = Config.fromfile(args.config)

# 2. 实例化runner
runner = Runner.from_cfg(cfg)

# 3. 开始训练
runner.train()

1.2 mmengine/runner/runner.py -- class Runner()

# 1.构建train_loop
self._train_loop = self.build_train_loop(
            self._train_loop)

# 2. 运行训练代码
model = self.train_loop.run()

其中self.train_loop是EpochBasedTrainLoop的一个实例,详细的介绍可以看第二部分Runner类

1.3 mmengine/runner/loops.py -- class EpochBasedTrainLoop(BaseLoop)

# 1.self.run()
  def run(self) -> torch.nn.Module:
      """Launch training."""
      self.runner.call_hook('before_train')

      while self._epoch < self._max_epochs and not self.stop_training:
          self.run_epoch()

          self._decide_current_val_interval()
          if (self.runner.val_loop is not None
                  and self._epoch >= self.val_begin
                  and (self._epoch % self.val_interval == 0
                       or self._epoch == self._max_epochs)):
              self.runner.val_loop.run()

      self.runner.call_hook('after_train')
      return self.runner.model

# 2.self.run_epoch()
  def run_epoch(self) -> None:
      """Iterate one epoch."""
      self.runner.call_hook('before_train_epoch')
      self.runner.model.train()
      for idx, data_batch in enumerate(self.dataloader):
          self.run_iter(idx, data_batch)

      self.runner.call_hook('after_train_epoch')
      self._epoch += 1
# 3. self.run_iter()
  def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
      self.runner.call_hook(
          'before_train_iter', batch_idx=idx, data_batch=data_batch)
      # Enable gradient accumulation mode and avoid unnecessary gradient
      # synchronization during gradient accumulation process.
      # outputs should be a dict of loss.
      outputs = self.runner.model.train_step(
          data_batch, optim_wrapper=self.runner.optim_wrapper)

      self.runner.call_hook(
          'after_train_iter',
          batch_idx=idx,
          data_batch=data_batch,
          outputs=outputs)
      self._iter += 1

其中核心的代码为 self.runner.model.train_step(),将在3 Runner.model里解析。

2 Runner类 (runner.py)

Runner类的一些核心属性和方法

1. self._work_dir 保存结果的文件夹

self._work_dir = osp.abspath(work_dir)

2. self.cfg

  if cfg is not None:
       if isinstance(cfg, Config):
          self.cfg = copy.deepcopy(cfg)
      elif isinstance(cfg, dict):
          self.cfg = Config(cfg)
  else:
      self.cfg = Config(dict())

3. self._train_loop

#  1. self.__init__():
self._train_loop = train_cfg

 

# 2. self.train():
self._train_loop = self.build_train_loop(self._train_loop)

# 3. self.build_train_loop():
def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop:
  if isinstance(loop, BaseLoop):
      return loop
  elif not isinstance(loop, dict):
      raise TypeError(
          f'train_loop should be a Loop object or dict, but got {loop}')

  loop_cfg = copy.deepcopy(loop)

  if 'type' in loop_cfg and 'by_epoch' in loop_cfg:
      raise RuntimeError(
          'Only one of `type` or `by_epoch` can exist in `loop_cfg`.')

  if 'type' in loop_cfg:
      loop = LOOPS.build(
          loop_cfg,
          default_args=dict(
              runner=self, dataloader=self._train_dataloader))
  else:
      by_epoch = loop_cfg.pop('by_epoch')
      if by_epoch:
          loop = EpochBasedTrainLoop(
              **loop_cfg, runner=self, dataloader=self._train_dataloader)
      else:
          loop = IterBasedTrainLoop(
              **loop_cfg, runner=self, dataloader=self._train_dataloader)
  return loop  # type: ignore

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值