Shortcuts

Source code for mmpose.apis.train

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
                         get_dist_info)
from mmcv.utils import digit_version

from mmpose.core import DistEvalHook, EvalHook, build_optimizers
from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper
from mmpose.datasets import build_dataloader, build_dataset
from mmpose.utils import get_root_logger

try:
    from mmcv.runner import Fp16OptimizerHook
except ImportError:
    warnings.warn(
        'Fp16OptimizerHook from mmpose will be deprecated from '
        'v0.15.0. Please install mmcv>=1.1.4', DeprecationWarning)
    from mmpose.core import Fp16OptimizerHook


[docs]def init_random_seed(seed=None, device='cuda'): """Initialize random seed. If the seed is not set, the seed will be automatically randomized, and then broadcast to all processes to prevent some potential bugs. Args: seed (int, Optional): The seed. Default to None. device (str): The device where the seed will be put on. Default to 'cuda'. Returns: int: Seed to be used. """ if seed is not None: return seed # Make sure all ranks share the same random seed to prevent # some potential bugs. Please refer to # https://github.com/open-mmlab/mmdetection/issues/6339 rank, world_size = get_dist_info() seed = np.random.randint(2**31) if world_size == 1: return seed if rank == 0: random_num = torch.tensor(seed, dtype=torch.int32, device=device) else: random_num = torch.tensor(0, dtype=torch.int32, device=device) dist.broadcast(random_num, src=0) return random_num.item()
[docs]def train_model(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): """Train model entry function. Args: model (nn.Module): The model to be trained. dataset (Dataset): Train dataset. cfg (dict): The config dict for training. distributed (bool): Whether to use distributed training. Default: False. validate (bool): Whether to do evaluation. Default: False. timestamp (str | None): Local time for runner. Default: None. meta (dict | None): Meta dict to record some important information. Default: None """ logger = get_root_logger(cfg.log_level) # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] # step 1: give default values and override (if exist) from cfg.data loader_cfg = { **dict( seed=cfg.get('seed'), drop_last=False, dist=distributed, num_gpus=len(cfg.gpu_ids)), **({} if torch.__version__ != 'parrots' else dict( prefetch_num=2, pin_memory=False, )), **dict((k, cfg.data[k]) for k in [ 'samples_per_gpu', 'workers_per_gpu', 'shuffle', 'seed', 'drop_last', 'prefetch_num', 'pin_memory', 'persistent_workers', ] if k in cfg.data) } # step 2: cfg.data.train_dataloader has highest priority train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] # determine whether use adversarial training precess or not use_adverserial_train = cfg.get('use_adversarial_train', False) # put model on gpus if distributed: find_unused_parameters = cfg.get('find_unused_parameters', True) # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel if use_adverserial_train: # Use DistributedDataParallelWrapper for adversarial training model = DistributedDataParallelWrapper( model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: if digit_version(mmcv.__version__) >= digit_version( '1.4.4') or torch.cuda.is_available(): model = MMDataParallel(model, device_ids=cfg.gpu_ids) else: warnings.warn( 'We recommend to use MMCV >= 1.4.4 for CPU training. ' 'See https://github.com/open-mmlab/mmpose/pull/1157 for ' 'details.') # build runner optimizer = build_optimizers(model, cfg.optimizer) runner = EpochBasedRunner( model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta) # an ugly workaround to make .log and .log.json filenames the same runner.timestamp = timestamp if use_adverserial_train: # The optimizer step process is included in the train_step function # of the model, so the runner should NOT include optimizer hook. optimizer_config = None else: # fp16 setting fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: optimizer_config = Fp16OptimizerHook( **cfg.optimizer_config, **fp16_cfg, distributed=distributed) elif distributed and 'type' not in cfg.optimizer_config: optimizer_config = OptimizerHook(**cfg.optimizer_config) else: optimizer_config = cfg.optimizer_config custom_hooks_cfg = cfg.get('custom_hooks', None) if custom_hooks_cfg is None: custom_hooks_cfg = cfg.get('custom_hooks_config', None) if custom_hooks_cfg is not None: warnings.warn( '"custom_hooks_config" is deprecated, please use ' '"custom_hooks" instead.', DeprecationWarning) # register hooks runner.register_training_hooks( cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None), custom_hooks_config=custom_hooks_cfg) if distributed: runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: eval_cfg = cfg.get('evaluation', {}) val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) dataloader_setting = dict( samples_per_gpu=1, workers_per_gpu=cfg.data.get('workers_per_gpu', 1), # cfg.gpus will be ignored if distributed num_gpus=len(cfg.gpu_ids), dist=distributed, drop_last=False, shuffle=False) dataloader_setting = dict(dataloader_setting, **cfg.data.get('val_dataloader', {})) val_dataloader = build_dataloader(val_dataset, **dataloader_setting) eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
Read the Docs v: latest
Versions
latest
1.x
v0.14.0
fix-doc
cn_doc
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.