Source code for mmpose.models.detectors.pose_lifter

import warnings

from .. import builder
from ..registry import POSENETS
from .base import BasePose

try:
    from mmcv.runner import auto_fp16
except ImportError:
    warnings.warn('auto_fp16 from mmpose will be deprecated from v0.15.0'
                  'Please install mmcv>=1.1.4')
    from mmpose.core import auto_fp16


[docs]@POSENETS.register_module() class PoseLifter(BasePose): """Pose lifter that lifts 2D pose to 3D pose.""" def __init__(self, backbone, neck=None, keypoint_head=None, train_cfg=None, test_cfg=None, pretrained=None): super().__init__() self.fp16_enabled = False self.train_cfg = train_cfg self.test_cfg = test_cfg self.backbone = builder.build_backbone(backbone) if neck is not None: self.neck = builder.build_neck(neck) if keypoint_head is not None: keypoint_head['train_cfg'] = train_cfg keypoint_head['test_cfg'] = test_cfg self.keypoint_head = builder.build_head(keypoint_head) self.init_weights(pretrained=pretrained) @property def with_neck(self): """Check if has keypoint_head.""" return hasattr(self, 'neck') @property def with_keypoint(self): """Check if has keypoint_head.""" return hasattr(self, 'keypoint_head')
[docs] def init_weights(self, pretrained=None): """Weight initialization for model.""" self.backbone.init_weights(pretrained) if self.with_neck: self.neck.init_weights() if self.with_keypoint: self.keypoint_head.init_weights()
[docs] @auto_fp16(apply_to=('input', )) def forward(self, input, target=None, target_weight=None, metas=None, return_loss=True, **kwargs): """Calls either forward_train or forward_test depending on whether return_loss=True. Note: Note: batch_size: N num_input_keypoints: Ki input_keypoint_dim: Ci input_sequence_len: Ti num_output_keypoints: Ko output_keypoint_dim: Co input_sequence_len: To Args: input (torch.Tensor[NxKixCixTi]): Input keypoint coordinates. target (torch.Tensor[NxKoxCoxTo]): Output keypoint coordinates. Defaults to None. target_weight (torch.Tensor[NxKox1]): Weights across different joint types. Defaults to None. metas (list(dict)): Information about data augmentation return_loss (bool): Option to `return loss`. `return loss=True` for training, `return loss=False` for validation & test. Returns: dict|Tensor: if `reutrn_loss` is true, return losses. Otherwise return predicted poses """ if return_loss: return self.forward_train(input, target, target_weight, metas, **kwargs) else: return self.forward_test(input, metas, **kwargs)
[docs] def forward_train(self, input, target, target_weight, metas, **kwargs): """Defines the computation performed at every call when training.""" assert input.size(0) == len(metas) features = self.backbone(input) if self.with_neck: features = self.neck(features) if self.with_keypoint: output = self.keypoint_head(features) losses = dict() if self.with_keypoint: keypoint_losses = self.keypoint_head.get_loss( output, target, target_weight) keypoint_accuracy = self.keypoint_head.get_accuracy( output, target, target_weight, metas) losses.update(keypoint_losses) losses.update(keypoint_accuracy) return losses
[docs] def forward_test(self, input, metas, **kwargs): """Defines the computation performed at every call when training.""" assert input.size(0) == len(metas) results = {} features = self.backbone(input) if self.with_neck: features = self.neck(features) if self.with_keypoint: output = self.keypoint_head.inference_model(features) keypoint_result = self.keypoint_head.decode(metas, output) results.update(keypoint_result) return results
[docs] def forward_dummy(self, input): """Used for computing network FLOPs. See ``tools/get_flops.py``. Args: input (torch.Tensor): Input pose Returns: Tensor: Model output """ features = self.backbone(input) if self.with_neck: features = self.neck(features) if self.with_keypoint: output = self.keypoint_head(features) return output
[docs] def show_result(self, **kwargs): pass