mmdetection 中的 EpochBasedRunner
EpochBasedRunner 阅读记录
这个类有很多方法,这里只记录当前我所用到的方法。
一、首先是实例化这个类的初始化方法:
"""
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN).
work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None. (The default value is just for backward
compatibility)
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
max_epochs (int, optional): Total training epochs.
max_iters (int, optional): Total training iterations.
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
if batch_processor is not None:
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
f'but got {type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.')
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if is_module_wrapper(model):
_model = model.module
else:
_model = model
if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
raise RuntimeError(
'batch_processor and model.train_step()/model.val_step() '
'cannot be both available.')
else:
assert hasattr(model, 'train_step')
# check the type of `optimizer`
if isinstance(optimizer, dict):
for name, optim in optimizer.items():
if not isinstance(optim, Optimizer):
raise TypeError(
f'optimizer must be a dict of torch.optim.Optimizers, '
f'but optimizer["{name}"] is a {type(optim)}')
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
raise TypeError(
f'optimizer must be a torch.optim.Optimizer object '
f'or dict or None, but got {type(optimizer)}')
# check the type of `logger`
if not isinstance(logger, logging.Logger):
raise TypeError(f'logger must be a logging.Logger object, '
f'but got {type(logger)}')
# check the type of `meta`
if meta is not None and not isinstance(meta, dict):
raise TypeError(
f'meta must be a dict or None, but got {type(meta)}')
self.model = model
self.batch_processor = batch_processor
self.optimizer = optimizer
self.logger = logger
self.meta = meta
# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None:
self.work_dir = None
else:
raise TypeError('"work_dir" must be a str or None')
# get model name from the model class
if hasattr(self.model, 'module'):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode = None
self._hooks = []
self._epoch = 0
self._iter = 0
self._inner_iter = 0
if max_epochs is not None and max_iters is not None:
raise ValueError(
'Only one of `max_epochs` or `max_iters` can be set.')
self._max_epochs = max_epochs
self._max_iters = max_iters
# TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer = LogBuffer()
实例化BaseRunner所传入的参数见源码的注释部分,这里需要注意几个点:
- model 中要有 train_step() 这个方法;
- batch_processor 这个参数默认为 None ,本人使用过程在它也一直是 None ,这里暂时不考虑这个参数;
- optimizer 这个参数是可以为字典的,也就是说 BaseRunner 允许使用多个优化器优化网络不同位置的参数;
- max_epochs 和 max_iters 不能同时赋值,这两个参数我们给一个就行。
然后看一下初始化方法都做了哪些事:
- 检查传入的参数是否合法;
- 将相关参数保存在BaseRunner类的属性中;
- 创建工作目录以及获取设备的相关信息;
- 创建相关属性,用于记录训练过程中用到的一些参数。
二、注册训练Hook
一般情况下,当我们实例化一个BaseRunner类以后,都要执行 register_training_hooks() 方法,这个方法会调用BaseRunner中注册Hook的7个方法。关于mmdetection中的Hook,大家可以自行百度,个人感觉他的工作机制类似于单片机中的中断。
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
"""Register default and custom hooks for training.
Default and custom hooks include:
+----------------------+-------------------------+
| Hooks | Priority |
+======================+=========================+
| LrUpdaterHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
这里传入的参数都是配置文件中定义的字典,比如在 schedule_1x.py 文件中定义了:
lr_config = dict(policy=‘step’, warmup=‘linear’, warmup_iters=500, warmup_ratio=0.001, step=[16, 21])
具体各个键值对的功能我们用到的时候再解释。
下面分别介绍以下各个类型Hook的注册过程。
1、register_lr_hook
def register_lr_hook(self, lr_config):
if lr_config is None:
return
elif isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook, priority='VERY_HIGH')
这个方法首先会判断传入的参数,如果没有传入参数,就不会注册管理 lr 的 Hook;如果是一个字典就根据字典中的信息注册一个 Hook (其实就是更加字典实例化一个相应的对象),关于 mmcv.build_from_cfg 方法,请大家自己百度;如果传入了参数但是参数不是字典,就默认传入的是一个已经注册好的 Hook 。
根据本人实验过程中传入的参数:lr_config = dict(policy=‘step’, warmup=‘linear’, warmup_iters=500, warmup_ratio=0.001, step=[16, 21]), 其实就是将字典中的items作为参数实例化一个名为 StepLrUpdaterHook 的对象, StepLrUpdaterHook 这个类的位置在mmcv/runner/hooks/lr_updater.py 中。
最后调用了 self.register_hook 方法,将实注册的 Hook 保存到这个类的 _hooks 属性中,方便后面的使用。
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
简单说一下 register_hook 这个方法,参数 hook 表示要保存的 Hook ,参数 priority 表示这个 Hook 的优先级,是不是越来越小单片机中的中断了。这里的优先级就是将Hook 放在 self._hooks 的什么位置,靠前的Hook在训练的时候会先执行。
2、register_momentum_hook
与register_lr_hook一样。
def register_momentum_hook(self, momentum_config):
if momentum_config is None:
return
if isinstance(momentum_config, dict):
assert 'policy' in momentum_config
policy_type = momentum_config.pop('policy')
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'MomentumUpdaterHook'
momentum_config['type'] = hook_type
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook, priority='HIGH')
3、register_optimizer_hook
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook, priority='ABOVE_NORMAL')
4、register_checkpoint_hook
def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook, priority='NORMAL')
5、register_logger_hooks
def register_timer_hook(self, timer_config):
if timer_config is None:
return
if isinstance(timer_config, dict):
timer_config_ = copy.deepcopy(timer_config)
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook, priority='LOW')
6、register_logger_hooks
def register_logger_hooks(self, log_config):
if log_config is None:
return
log_interval = log_config['interval']
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
7、register_logger_hooks
def register_custom_hooks(self, custom_config):
if custom_config is None:
return
if not isinstance(custom_config, list):
custom_config = [custom_config]
for item in custom_config:
if isinstance(item, dict):
self.register_hook_from_cfg(item)
else:
self.register_hook(item, priority='NORMAL')
也就是说每个 Hook 都是一个类,每个类都有自己的功能,这些类的共同之处是都会有before_run、after_run、before_epoch、after_epoch、before_iter、after_iter 这些方法。
三、训练模型
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn('setting max_epochs in run is deprecated, please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, ('max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s', self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow, self._max_epochs)
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
- 传入的参数如注释;
- 进入方法后还是先检查传入参数的格式;
- 然后根据 self._max_epochs 的值设置 self._max_iters 的值;
- 打印一些信息;
- 调用 call_hook 方法并传入 ‘before_run’,这个过程就是遍历我们之前注册的Hook,执行所有Hook的before_run方法;
- 开始循环训练,这个循环的功能是根据各种参数判断当前因该执行那个方法, self.train 还是 self.val 。
epoch_runner = getattr(self, mode),这里的mode就是就是train、val 。 - 最后执行
self.call_hook('after_run')。
这里补充以下第5、7两点,至于调用self.call_hook('before_run') 和 self.call_hook('after_run')到底会干嘛,这个取决于我们刚刚定义的每一个Hook以及每个Hook对相关方法的实现。
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
call_hook 会比遍历我们呢注册的所有 Hook 并调用所有 Hook 中的 某个方法。比如self.call_hook('before_run') 就会调用所有 Hook 中的 before_run 方法。
最后还有三个方法,就是上面第6点 run 中可能选择的执行方法。
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()" and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
这三个方法就是比较常见的训练过程了,只不过大部分的代码都封装在 Hook 中了,如果想搞懂脚本每一步在干嘛还需要参考各个 Hook 的各个方法。