
Source code for mmpose.models.heads.heatmap_heads.cid_head

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_conv_layer
from mmengine.model import BaseModule, ModuleDict
from mmengine.structures import InstanceData, PixelData
from torch import Tensor

from mmpose.models.utils.tta import flip_heatmaps
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.typing import (ConfigType, Features, OptConfigType,
                                 OptSampleList, Predictions)
from ..base_head import BaseHead

def smooth_heatmaps(heatmaps: Tensor, blur_kernel_size: int) -> Tensor:
    """Smooth the heatmaps by blurring and averaging.

        heatmaps (Tensor): The heatmaps to smooth.
        blur_kernel_size (int): The kernel size for blurring the heatmaps.

        Tensor: The smoothed heatmaps.
    smoothed_heatmaps = torch.nn.functional.avg_pool2d(
        heatmaps, blur_kernel_size, 1, (blur_kernel_size - 1) // 2)
    smoothed_heatmaps = (heatmaps + smoothed_heatmaps) / 2.0
    return smoothed_heatmaps

class TruncSigmoid(nn.Sigmoid):
    """A sigmoid activation function that truncates the output to the given

        min (float, optional): The minimum value to clamp the output to.
            Defaults to 0.0
        max (float, optional): The maximum value to clamp the output to.
            Defaults to 1.0

    def __init__(self, min: float = 0.0, max: float = 1.0):
        super(TruncSigmoid, self).__init__()
        self.min = min
        self.max = max

    def forward(self, input: Tensor) -> Tensor:
        """Computes the truncated sigmoid activation of the input tensor."""
        output = torch.sigmoid(input)
        output = output.clamp(min=self.min, max=self.max)
        return output

class IIAModule(BaseModule):
    """Instance Information Abstraction module introduced in `CID`. This module
    extracts the feature representation vectors for each instance.

        in_channels (int): Number of channels in the input feature tensor
        out_channels (int): Number of channels of the output heatmaps
        clamp_delta (float, optional): A small value that prevents the sigmoid
            activation from becoming saturated. Defaults to 1e-4.
        init_cfg (Config, optional): Config to control the initialization. See
            :attr:`default_init_cfg` for default settings

    def __init__(
        in_channels: int,
        out_channels: int,
        clamp_delta: float = 1e-4,
        init_cfg: OptConfigType = None,

        self.keypoint_root_conv = build_conv_layer(
        self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta)

    def forward(self, feats: Tensor):
        heatmaps = self.keypoint_root_conv(feats)
        heatmaps = self.sigmoid(heatmaps)
        return heatmaps

    def _sample_feats(self, feats: Tensor, indices: Tensor) -> Tensor:
        """Extract feature vectors at the specified indices from the input
        feature map.

            feats (Tensor): Input feature map.
            indices (Tensor): Indices of the feature vectors to extract.

            Tensor: Extracted feature vectors.
        assert indices.dtype == torch.long
        if indices.shape[1] == 3:
            b, w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)]
            instance_feats = feats[b, :, h, w]
        elif indices.shape[1] == 2:
            w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)]
            instance_feats = feats[:, :, h, w]
            instance_feats = instance_feats.permute(0, 2, 1)
            instance_feats = instance_feats.reshape(-1,

            raise ValueError(f'`indices` should have 2 or 3 channels, '
                             f'but got f{indices.shape[1]}')
        return instance_feats

    def _hierarchical_pool(self, heatmaps: Tensor) -> Tensor:
        """Conduct max pooling on the input heatmaps with different kernel size
        according to the input size.

            heatmaps (Tensor): Input heatmaps.

            Tensor: Result of hierarchical pooling.
        map_size = (heatmaps.shape[-1] + heatmaps.shape[-2]) / 2.0
        if map_size > 300:
            maxm = torch.nn.functional.max_pool2d(heatmaps, 7, 1, 3)
        elif map_size > 200:
            maxm = torch.nn.functional.max_pool2d(heatmaps, 5, 1, 2)
            maxm = torch.nn.functional.max_pool2d(heatmaps, 3, 1, 1)
        return maxm

    def forward_train(self, feats: Tensor, instance_coords: Tensor,
                      instance_imgids: Tensor) -> Tuple[Tensor, Tensor]:
        """Forward pass during training.

            feats (Tensor): Input feature tensor.
            instance_coords (Tensor): Coordinates of the instance roots.
            instance_imgids (Tensor): Sample indices of each instances
                in the batch.

            Tuple[Tensor, Tensor]: Extracted feature vectors and heatmaps
                for the instances.
        heatmaps = self.forward(feats)
        indices =[:, None], instance_coords), dim=1)
        instance_feats = self._sample_feats(feats, indices)

        return instance_feats, heatmaps

    def forward_test(
        self, feats: Tensor, test_cfg: Dict
    ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
        """Forward pass during testing.

            feats (Tensor): Input feature tensor.
            test_cfg (Dict): Testing configuration, including:
                - blur_kernel_size (int, optional): Kernel size for blurring
                    the heatmaps. Defaults to 3.
                - max_instances (int, optional): Maximum number of instances
                    to extract. Defaults to 30.
                - score_threshold (float, optional): Minimum score for
                    extracting an instance. Defaults to 0.01.
                - flip_test (bool, optional): Whether to compute the average
                    of the heatmaps across the batch dimension.
                    Defaults to False.

            A tuple of Tensor including extracted feature vectors,
            coordinates, and scores of the instances. Any of these can be
            empty Tensor if no instances are extracted.
        blur_kernel_size = test_cfg.get('blur_kernel_size', 3)
        max_instances = test_cfg.get('max_instances', 30)
        score_threshold = test_cfg.get('score_threshold', 0.01)
        H, W = feats.shape[-2:]

        # compute heatmaps
        heatmaps = self.forward(feats).narrow(1, -1, 1)
        if test_cfg.get('flip_test', False):
            heatmaps = heatmaps.mean(dim=0, keepdims=True)
        smoothed_heatmaps = smooth_heatmaps(heatmaps, blur_kernel_size)

        # decode heatmaps
        maximums = self._hierarchical_pool(smoothed_heatmaps)
        maximums = torch.eq(maximums, smoothed_heatmaps).float()
        maximums = (smoothed_heatmaps * maximums).reshape(-1)
        scores, pos_ind = maximums.topk(max_instances, dim=0)
        select_ind = (scores > (score_threshold)).nonzero().squeeze(1)
        scores, pos_ind = scores[select_ind], pos_ind[select_ind]

        # sample feature vectors from feature map
        instance_coords = torch.stack((pos_ind % W, pos_ind // W), dim=1)
        instance_feats = self._sample_feats(feats, instance_coords)

        return instance_feats, instance_coords, scores

class ChannelAttention(nn.Module):
    """Channel-wise attention module introduced in `CID`.

        in_channels (int): The number of channels of the input instance
        out_channels (int): The number of channels of the transformed instance

    def __init__(self, in_channels: int, out_channels: int):
        super(ChannelAttention, self).__init__()
        self.atn = nn.Linear(in_channels, out_channels)

    def forward(self, global_feats: Tensor, instance_feats: Tensor) -> Tensor:
        """Applies attention to the channel dimension of the input tensor."""

        instance_feats = self.atn(instance_feats).unsqueeze(2).unsqueeze(3)
        return global_feats * instance_feats

class SpatialAttention(nn.Module):
    """Spatial-wise attention module introduced in `CID`.

        in_channels (int): The number of channels of the input instance
        out_channels (int): The number of channels of the transformed instance

    def __init__(self, in_channels, out_channels):
        super(SpatialAttention, self).__init__()
        self.atn = nn.Linear(in_channels, out_channels)
        self.feat_stride = 4
        self.conv = nn.Conv2d(3, 1, 5, 1, 2)

    def _get_pixel_coords(self, heatmap_size: Tuple, device: str = 'cpu'):
        """Get pixel coordinates for each element in the heatmap.

            heatmap_size (tuple): Size of the heatmap in (W, H) format.
            device (str): Device to put the resulting tensor on.

            Tensor of shape (batch_size, num_pixels, 2) containing the pixel
            coordinates for each element in the heatmap.
        w, h = heatmap_size
        y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
        pixel_coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
        pixel_coords = pixel_coords.float().to(device) + 0.5
        return pixel_coords

    def forward(self, global_feats: Tensor, instance_feats: Tensor,
                instance_coords: Tensor) -> Tensor:
        """Perform spatial attention.

            global_feats (Tensor): Tensor containing the global features.
            instance_feats (Tensor): Tensor containing the instance feature
            instance_coords (Tensor): Tensor containing the root coordinates
                of the instances.

            Tensor containing the modulated global features.
        B, C, H, W = global_feats.size()

        instance_feats = self.atn(instance_feats).reshape(B, C, 1, 1)
        feats = global_feats * instance_feats.expand_as(global_feats)
        fsum = torch.sum(feats, dim=1, keepdim=True)

        pixel_coords = self._get_pixel_coords((W, H), feats.device)
        relative_coords = instance_coords.reshape(
            -1, 1, 2) - pixel_coords.reshape(1, -1, 2)
        relative_coords = relative_coords.permute(0, 2, 1) / 32.0
        relative_coords = relative_coords.reshape(B, 2, H, W)

        input_feats =, relative_coords), dim=1)
        mask = self.conv(input_feats).sigmoid()
        return global_feats * mask

class GFDModule(BaseModule):
    """Global Feature Decoupling module introduced in `CID`. This module
    extracts the decoupled heatmaps for each instance.

        in_channels (int): Number of channels in the input feature map
        out_channels (int): Number of channels of the output heatmaps
            for each instance
        gfd_channels (int): Number of channels in the transformed feature map
        clamp_delta (float, optional): A small value that prevents the sigmoid
            activation from becoming saturated. Defaults to 1e-4.
        init_cfg (Config, optional): Config to control the initialization. See
            :attr:`default_init_cfg` for default settings

    def __init__(
        in_channels: int,
        out_channels: int,
        gfd_channels: int,
        clamp_delta: float = 1e-4,
        init_cfg: OptConfigType = None,

        self.conv_down = build_conv_layer(

        self.channel_attention = ChannelAttention(in_channels, gfd_channels)
        self.spatial_attention = SpatialAttention(in_channels, gfd_channels)
        self.fuse_attention = build_conv_layer(
                in_channels=gfd_channels * 2,
        self.heatmap_conv = build_conv_layer(
        self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta)

    def forward(
        feats: Tensor,
        instance_feats: Tensor,
        instance_coords: Tensor,
        instance_imgids: Tensor,
    ) -> Tensor:
        """Extract decoupled heatmaps for each instance.

            feats (Tensor): Input feature maps.
            instance_feats (Tensor): Tensor containing the instance feature
            instance_coords (Tensor): Tensor containing the root coordinates
                of the instances.
            instance_imgids (Tensor): Sample indices of each instances
                in the batch.

            A tensor containing decoupled heatmaps.

        global_feats = self.conv_down(feats)
        global_feats = global_feats[instance_imgids]
        cond_instance_feats =
            (self.channel_attention(global_feats, instance_feats),
             self.spatial_attention(global_feats, instance_feats,

        cond_instance_feats = self.fuse_attention(cond_instance_feats)
        cond_instance_feats = torch.nn.functional.relu(cond_instance_feats)
        cond_instance_feats = self.heatmap_conv(cond_instance_feats)
        heatmaps = self.sigmoid(cond_instance_feats)

        return heatmaps

[docs]@MODELS.register_module() class CIDHead(BaseHead): """Contextual Instance Decoupling head introduced in `Contextual Instance Decoupling for Robust Multi-Person Pose Estimation (CID)`_ by Wang et al (2022). The head is composed of an Instance Information Abstraction (IIA) module and a Global Feature Decoupling (GFD) module. Args: in_channels (int | Sequence[int]): Number of channels in the input feature map num_keypoints (int): Number of keypoints gfd_channels (int): Number of filters in GFD module max_train_instances (int): Maximum number of instances in a batch during training. Defaults to 200 input_transform (str): Transformation of input features which should be one of the following options: - ``'resize_concat'``: Resize multiple feature maps specified by ``input_index`` to the same size as the first one and concat these feature maps - ``'select'``: Select feature map(s) specified by ``input_index``. Multiple selected features will be bundled into a tuple Defaults to ``'select'`` input_index (int | Sequence[int]): The feature map index used in the input transformation. See also ``input_transform``. Defaults to -1 align_corners (bool): `align_corners` argument of :func:`torch.nn.functional.interpolate` used in the input transformation. Defaults to ``False`` heatmap_loss (Config): Config of the heatmap loss. Defaults to use :class:`KeypointMSELoss` coupled_heatmap_loss (Config): Config of the loss for coupled heatmaps. Defaults to use :class:`SoftWeightSmoothL1Loss` decoupled_heatmap_loss (Config): Config of the loss for decoupled heatmaps. Defaults to use :class:`SoftWeightSmoothL1Loss` contrastive_loss (Config): Config of the contrastive loss for representation vectors of instances. Defaults to use :class:`InfoNCELoss` decoder (Config, optional): The decoder config that controls decoding keypoint coordinates from the network output. Defaults to ``None`` init_cfg (Config, optional): Config to control the initialization. See :attr:`default_init_cfg` for default settings .. _`CID`: Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_ CVPR_2022_paper.html """ _version = 2 def __init__(self, in_channels: Union[int, Sequence[int]], gfd_channels: int, num_keypoints: int, prior_prob: float = 0.01, input_transform: str = 'select', input_index: Union[int, Sequence[int]] = -1, align_corners: bool = False, coupled_heatmap_loss: OptConfigType = dict( type='FocalHeatmapLoss'), decoupled_heatmap_loss: OptConfigType = dict( type='FocalHeatmapLoss'), contrastive_loss: OptConfigType = dict(type='InfoNCELoss'), decoder: OptConfigType = None, init_cfg: OptConfigType = None): if init_cfg is None: init_cfg = self.default_init_cfg super().__init__(init_cfg) self.in_channels = in_channels self.num_keypoints = num_keypoints self.align_corners = align_corners self.input_transform = input_transform self.input_index = input_index if decoder is not None: self.decoder = else: self.decoder = None # Get model input channels according to feature in_channels = self._get_in_channels() if isinstance(in_channels, list): raise ValueError( f'{self.__class__.__name__} does not support selecting ' 'multiple input features.') # build sub-modules bias_value = -math.log((1 - prior_prob) / prior_prob) self.iia_module = IIAModule( in_channels, num_keypoints + 1, init_cfg=init_cfg + [ dict( type='Normal', layer=['Conv2d', 'Linear'], std=0.001, override=dict( name='keypoint_root_conv', type='Normal', std=0.001, bias=bias_value)) ]) self.gfd_module = GFDModule( in_channels, num_keypoints, gfd_channels, init_cfg=init_cfg + [ dict( type='Normal', layer=['Conv2d', 'Linear'], std=0.001, override=dict( name='heatmap_conv', type='Normal', std=0.001, bias=bias_value)) ]) # build losses self.loss_module = ModuleDict( dict(,,, )) # Register the hook to automatically convert old version state dicts self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) @property def default_init_cfg(self): init_cfg = [ dict(type='Normal', layer=['Conv2d', 'Linear'], std=0.001), dict(type='Constant', layer='BatchNorm2d', val=1) ] return init_cfg
[docs] def forward(self, feats: Tuple[Tensor]) -> Tensor: """Forward the network. The input is multi scale feature maps and the output is the heatmap. Args: feats (Tuple[Tensor]): Multi scale feature maps. Returns: Tensor: output heatmap. """ feats = self._transform_inputs(feats) instance_info = self.iia_module.forward_test(feats, {}) instance_feats, instance_coords, instance_scores = instance_info instance_imgids = torch.zeros( instance_coords.size(0), dtype=torch.long, device=feats.device) instance_heatmaps = self.gfd_module(feats, instance_feats, instance_coords, instance_imgids) return instance_heatmaps
[docs] def predict(self, feats: Features, batch_data_samples: OptSampleList, test_cfg: ConfigType = {}) -> Predictions: """Predict results from features. Args: feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage features (or multiple multi-stage features in TTA) batch_data_samples (List[:obj:`PoseDataSample`]): The batch data samples test_cfg (dict): The runtime config for testing process. Defaults to {} Returns: Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If ``test_cfg['output_heatmap']==True``, return both pose and heatmap prediction; otherwise only return the pose prediction. The pose prediction is a list of ``InstanceData``, each contains the following fields: - keypoints (np.ndarray): predicted keypoint coordinates in shape (num_instances, K, D) where K is the keypoint number and D is the keypoint dimension - keypoint_scores (np.ndarray): predicted keypoint scores in shape (num_instances, K) The heatmap prediction is a list of ``PixelData``, each contains the following fields: - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) """ metainfo = batch_data_samples[0].metainfo if test_cfg.get('flip_test', False): assert isinstance(feats, list) and len(feats) == 2 feats_flipped = flip_heatmaps( self._transform_inputs(feats[1]), shift_heatmap=False) feats = (self._transform_inputs(feats[0]), feats_flipped)) else: feats = self._transform_inputs(feats) instance_info = self.iia_module.forward_test(feats, test_cfg) instance_feats, instance_coords, instance_scores = instance_info if len(instance_coords) > 0: instance_imgids = torch.zeros( instance_coords.size(0), dtype=torch.long, device=feats.device) if test_cfg.get('flip_test', False): instance_coords =, instance_coords)) instance_imgids = (instance_imgids, instance_imgids + 1)) instance_heatmaps = self.gfd_module(feats, instance_feats, instance_coords, instance_imgids) if test_cfg.get('flip_test', False): flip_indices = batch_data_samples[0].metainfo['flip_indices'] instance_heatmaps, instance_heatmaps_flip = torch.chunk( instance_heatmaps, 2, dim=0) instance_heatmaps_flip = \ instance_heatmaps_flip[:, flip_indices, :, :] instance_heatmaps = (instance_heatmaps + instance_heatmaps_flip) / 2.0 instance_heatmaps = smooth_heatmaps( instance_heatmaps, test_cfg.get('blur_kernel_size', 3)) preds = self.decode((instance_heatmaps, instance_scores[:, None])) preds = preds.keypoints[..., 0] += metainfo['input_size'][ 0] / instance_heatmaps.shape[-1] / 2.0 preds.keypoints[..., 1] += metainfo['input_size'][ 1] / instance_heatmaps.shape[-2] / 2.0 preds = [preds] else: preds = [ InstanceData( keypoints=np.empty((0, self.num_keypoints, 2)), keypoint_scores=np.empty((0, self.num_keypoints))) ] instance_heatmaps = torch.empty(0, self.num_keypoints, *feats.shape[-2:]) if test_cfg.get('output_heatmaps', False): pred_fields = [ PixelData( heatmaps=instance_heatmaps.reshape( -1, *instance_heatmaps.shape[-2:])) ] return preds, pred_fields else: return preds
[docs] def loss(self, feats: Tuple[Tensor], batch_data_samples: OptSampleList, train_cfg: ConfigType = {}) -> dict: """Calculate losses from a batch of inputs and data samples. Args: feats (Tuple[Tensor]): The multi-stage features batch_data_samples (List[:obj:`PoseDataSample`]): The batch data samples train_cfg (dict): The runtime config for training process. Defaults to {} Returns: dict: A dictionary of losses. """ # load targets gt_heatmaps, gt_instance_coords, keypoint_weights = [], [], [] heatmap_mask = [] instance_imgids, gt_instance_heatmaps = [], [] for i, d in enumerate(batch_data_samples): gt_heatmaps.append(d.gt_fields.heatmaps) gt_instance_coords.append(d.gt_instance_labels.instance_coords) keypoint_weights.append(d.gt_instance_labels.keypoint_weights) instance_imgids.append( torch.ones( len(d.gt_instance_labels.instance_coords), dtype=torch.long) * i) instance_heatmaps = d.gt_fields.instance_heatmaps.reshape( -1, self.num_keypoints, *d.gt_fields.instance_heatmaps.shape[1:]) gt_instance_heatmaps.append(instance_heatmaps) if 'heatmap_mask' in d.gt_fields: heatmap_mask.append(d.gt_fields.heatmap_mask) gt_heatmaps = torch.stack(gt_heatmaps) heatmap_mask = torch.stack(heatmap_mask) if heatmap_mask else None gt_instance_coords =, dim=0) gt_instance_heatmaps =, dim=0) keypoint_weights =, dim=0) instance_imgids = # feed-forward feats = self._transform_inputs(feats) pred_instance_feats, pred_heatmaps = self.iia_module.forward_train( feats, gt_instance_coords, instance_imgids) # conpute contrastive loss contrastive_loss = 0 for i in range(len(batch_data_samples)): pred_instance_feat = pred_instance_feats[instance_imgids == i] contrastive_loss += self.loss_module['contrastive']( pred_instance_feat) contrastive_loss = contrastive_loss / max(1, len(instance_imgids)) # limit the number of instances max_train_instances = train_cfg.get('max_train_instances', -1) if (max_train_instances > 0 and len(instance_imgids) > max_train_instances): selected_indices = torch.randperm( len(instance_imgids), device=gt_heatmaps.device, dtype=torch.long)[:max_train_instances] gt_instance_coords = gt_instance_coords[selected_indices] keypoint_weights = keypoint_weights[selected_indices] gt_instance_heatmaps = gt_instance_heatmaps[selected_indices] instance_imgids = instance_imgids[selected_indices] pred_instance_feats = pred_instance_feats[selected_indices] # calculate the decoupled heatmaps for each instance pred_instance_heatmaps = self.gfd_module(feats, pred_instance_feats, gt_instance_coords, instance_imgids) # calculate losses losses = { 'loss/heatmap_coupled': self.loss_module['heatmap_coupled'](pred_heatmaps, gt_heatmaps, None, heatmap_mask) } if len(instance_imgids) > 0: losses.update({ 'loss/heatmap_decoupled': self.loss_module['heatmap_decoupled'](pred_instance_heatmaps, gt_instance_heatmaps, keypoint_weights), 'loss/contrastive': contrastive_loss }) return losses
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, **kwargs): """A hook function to convert old-version state dict of :class:`CIDHead` (before MMPose v1.0.0) to a compatible format of :class:`CIDHead`. The hook will be automatically registered during initialization. """ version = local_meta.get('version', None) if version and version >= self._version: return # convert old-version state dict keys = list(state_dict.keys()) for k in keys: if 'keypoint_center_conv' in k: v = state_dict.pop(k) k = k.replace('keypoint_center_conv', 'iia_module.keypoint_root_conv') state_dict[k] = v if 'conv_down' in k: v = state_dict.pop(k) k = k.replace('conv_down', 'gfd_module.conv_down') state_dict[k] = v if 'c_attn' in k: v = state_dict.pop(k) k = k.replace('c_attn', 'gfd_module.channel_attention') state_dict[k] = v if 's_attn' in k: v = state_dict.pop(k) k = k.replace('s_attn', 'gfd_module.spatial_attention') state_dict[k] = v if 'fuse_attn' in k: v = state_dict.pop(k) k = k.replace('fuse_attn', 'gfd_module.fuse_attention') state_dict[k] = v if 'heatmap_conv' in k: v = state_dict.pop(k) k = k.replace('heatmap_conv', 'gfd_module.heatmap_conv') state_dict[k] = v
Read the Docs v: fix-doc
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.