Source code for mmpose.models.keypoint_heads.heatmap_3d_head

import numpy as np
import torch.nn as nn

from mmpose.core.post_processing import flip_back, transform_preds
from ..registry import HEADS
from .top_down_simple_head import TopDownSimpleHead


def _get_max_preds_3d(heatmaps):
    """Get keypoint predictions from 3D score maps.

    Note:
        batch size: N
        num keypoints: K
        heatmap depth size: D
        heatmap height: H
        heatmap width: W

    Args:
        heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.

    Returns:
        tuple: A tuple containing aggregated results.
        - preds (np.ndarray[N, K, 3]): Predicted keypoint location.
        - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
    """
    assert isinstance(heatmaps, np.ndarray), \
        ('heatmaps should be numpy.ndarray')
    assert heatmaps.ndim == 5, 'heatmaps should be 5-ndim'

    N, K, D, H, W = heatmaps.shape
    heatmaps_reshaped = heatmaps.reshape((N, K, -1))
    idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))
    maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))

    preds = np.zeros((N, K, 3), dtype=np.float32)
    _idx = idx[..., 0]
    preds[:, :, 2], _idx = _idx // (H * W), _idx % (H * W)
    preds[:, :, 1], _idx = _idx // W, _idx % W
    preds[:, :, 0] = _idx

    preds = np.where(np.tile(maxvals, (1, 1, 3)) > 0.0, preds, -1)
    return preds, maxvals


[docs]@HEADS.register_module() class HeatMap3DHead(TopDownSimpleHead): """3D heatmap head of paper ref: Gyeongsik Moon. "InterHand2.6M: A Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single RGB Image" HeatMap3DHead is a variant of TopDownSimpleHead, and is composed of (>=0) number of deconv layers and a simple conv2d layer. Args: in_channels (int): Number of input channels out_channels (int): Number of output channels depth_size (int): Number of depth discretization size num_deconv_layers (int): Number of deconv layers. num_deconv_layers should >= 0. Note that 0 means no deconv layers. num_deconv_filters (list|tuple): Number of filters. If num_deconv_layers > 0, the length of num_deconv_kernels (list|tuple): Kernel sizes. in_index (int|Sequence[int]): Input feature index. Default: -1 input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. Default: None. align_corners (bool): align_corners argument of F.interpolate. Default: False. loss_keypoint (dict): Config for keypoint loss. Default: None. """ def __init__(self, in_channels, out_channels, depth_size=64, num_deconv_layers=3, num_deconv_filters=(256, 256, 256), num_deconv_kernels=(4, 4, 4), extra=None, in_index=0, input_transform=None, align_corners=False, loss_keypoint=None, train_cfg=None, test_cfg=None): super().__init__(in_channels, out_channels, num_deconv_layers, num_deconv_filters, num_deconv_kernels, extra, in_index, input_transform, align_corners, loss_keypoint, train_cfg, test_cfg) assert out_channels % depth_size == 0 self.depth_size = depth_size
[docs] def get_loss(self, output, target, target_weight): """Calculate 3D heatmap loss. Note: batch size: N num keypoints: K heatmaps depth size: D heatmaps height: H heatmaps weight: W Args: output (torch.Tensor[NxKxDxHxW]): Output heatmaps. target (torch.Tensor[NxKxDxHxW]): Target heatmaps. target_weight (torch.Tensor[NxKx1]): Weights across different joint types. """ losses = dict() assert not isinstance(self.loss, nn.Sequential) assert target.dim() == 5 and target_weight.dim() == 3 losses['heatmap_loss'] = self.loss(output, target, target_weight) return losses
# The accuracy of interhand keypoints is not well defined.
[docs] def get_accuracy(self, output, target, target_weight): return {}
[docs] def forward(self, x): """Forward function.""" x = super().forward(x) N, C, H, W = x.shape # reshape the 2D heatmap to 3D heatmap x = x.reshape(N, C // self.depth_size, self.depth_size, H, W) return x
[docs] def decode(self, img_metas, output, **kwargs): """Decode keypoints from heatmaps. Args: img_metas (list(dict)): Information about data augmentation By default this includes: - "image_file: path to the image file - "center": center of the bbox - "scale": scale of the bbox - "rotation": rotation of the bbox - "bbox_score": score of bbox output (np.ndarray[N, K, D, H, W]): model predicted 3D heatmaps. """ batch_size = len(img_metas) N, K, D, H, W = output.shape if 'bbox_id' in img_metas[0]: bbox_ids = [] else: bbox_ids = None center = np.zeros((batch_size, 2), dtype=np.float32) scale = np.zeros((batch_size, 2), dtype=np.float32) image_paths = [] score = np.ones(batch_size, dtype=np.float32) for i in range(batch_size): center[i, :] = img_metas[i]['center'] scale[i, :] = img_metas[i]['scale'] image_paths.append(img_metas[i]['image_file']) if 'bbox_score' in img_metas[i]: score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) if bbox_ids is not None: bbox_ids.append(img_metas[i]['bbox_id']) preds, maxvals = _get_max_preds_3d(output) # Transform back to the image for i in range(N): preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i], [W, H]) all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32) all_boxes = np.zeros((batch_size, 6), dtype=np.float32) all_preds[:, :, 0:2] = preds[:, :, 0:2] all_preds[:, :, 2:3] = maxvals all_boxes[:, 0:2] = center[:, 0:2] all_boxes[:, 2:4] = scale[:, 0:2] # scale is defined as: bbox_size / 200.0, # so we need multiply 200.0 to get bbox size all_boxes[:, 4] = np.prod(scale * 200.0, axis=1) all_boxes[:, 5] = score result = {} result['preds'] = all_preds result['boxes'] = all_boxes result['image_paths'] = image_paths result['bbox_ids'] = bbox_ids return result
[docs] def inference_model(self, x, flip_pairs=None): """Inference function. Returns: output_heatmap (np.ndarray): Output heatmaps. Args: x (torch.Tensor[NxKxHxW]): Input features. flip_pairs (None | list[tuple()): Pairs of keypoints which are mirrored. """ output = self.forward(x) if flip_pairs is not None: N, K, D, H, W = output.shape # reshape 3D heatmap to 2D heatmap output = output.reshape(N, K * D, H, W) # 2D heatmap flip output_heatmap = flip_back( output.detach().cpu().numpy(), flip_pairs, target_type=self.target_type) # reshape back to 3D heatmap output_heatmap = output_heatmap.reshape(N, K, D, H, W) # feature is not aligned, shift flipped heatmap for higher accuracy if self.test_cfg.get('shift_heatmap', False): output_heatmap[..., 1:] = output_heatmap[..., :-1] else: output_heatmap = output.detach().cpu().numpy() return output_heatmap