Shortcuts

Source code for mmpose.models.heads.ae_multi_stage_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_upsample_layer, constant_init,
                      normal_init)

from mmpose.models.builder import build_loss
from ..builder import HEADS


[docs]@HEADS.register_module() class AEMultiStageHead(nn.Module): """Associative embedding multi-stage head. paper ref: Alejandro Newell et al. "Associative Embedding: End-to-end Learning for Joint Detection and Grouping" Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. num_deconv_layers (int): Number of deconv layers. num_deconv_layers should >= 0. Note that 0 means no deconv layers. num_deconv_filters (list|tuple): Number of filters. If num_deconv_layers > 0, the length of num_deconv_kernels (list|tuple): Kernel sizes. loss_keypoint (dict): Config for loss. Default: None. """ def __init__(self, in_channels, out_channels, num_stages=1, num_deconv_layers=3, num_deconv_filters=(256, 256, 256), num_deconv_kernels=(4, 4, 4), extra=None, loss_keypoint=None): super().__init__() self.loss = build_loss(loss_keypoint) self.in_channels = in_channels self.num_stages = num_stages if extra is not None and not isinstance(extra, dict): raise TypeError('extra should be dict or None.') # build multi-stage deconv layers self.multi_deconv_layers = nn.ModuleList([]) for _ in range(self.num_stages): if num_deconv_layers > 0: deconv_layers = self._make_deconv_layer( num_deconv_layers, num_deconv_filters, num_deconv_kernels, ) elif num_deconv_layers == 0: deconv_layers = nn.Identity() else: raise ValueError( f'num_deconv_layers ({num_deconv_layers}) should >= 0.') self.multi_deconv_layers.append(deconv_layers) identity_final_layer = False if extra is not None and 'final_conv_kernel' in extra: assert extra['final_conv_kernel'] in [0, 1, 3] if extra['final_conv_kernel'] == 3: padding = 1 elif extra['final_conv_kernel'] == 1: padding = 0 else: # 0 for Identity mapping. identity_final_layer = True kernel_size = extra['final_conv_kernel'] else: kernel_size = 1 padding = 0 # build multi-stage final layers self.multi_final_layers = nn.ModuleList([]) for i in range(self.num_stages): if identity_final_layer: final_layer = nn.Identity() else: final_layer = build_conv_layer( cfg=dict(type='Conv2d'), in_channels=num_deconv_filters[-1] if num_deconv_layers > 0 else in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding) self.multi_final_layers.append(final_layer)
[docs] def get_loss(self, output, targets, masks, joints): """Calculate bottom-up keypoint loss. Note: - batch_size: N - num_keypoints: K - heatmaps height: H - heatmaps weight: W Args: output (List(torch.Tensor[NxKxHxW])): Output heatmaps. targets(List(List(torch.Tensor[NxKxHxW]))): Multi-stage and multi-scale target heatmaps. masks(List(List(torch.Tensor[NxHxW]))): Masks of multi-stage and multi-scale target heatmaps joints(List(List(torch.Tensor[NxMxKx2]))): Joints of multi-stage multi-scale target heatmaps for ae loss """ losses = dict() # Flatten list: # [stage_1_scale_1, stage_1_scale_2, ... , stage_1_scale_m, # ... # stage_n_scale_1, stage_n_scale_2, ... , stage_n_scale_m] targets = [target for _targets in targets for target in _targets] masks = [mask for _masks in masks for mask in _masks] joints = [joint for _joints in joints for joint in _joints] heatmaps_losses, push_losses, pull_losses = self.loss( output, targets, masks, joints) for idx in range(len(targets)): if heatmaps_losses[idx] is not None: heatmaps_loss = heatmaps_losses[idx].mean(dim=0) if 'heatmap_loss' not in losses: losses['heatmap_loss'] = heatmaps_loss else: losses['heatmap_loss'] += heatmaps_loss if push_losses[idx] is not None: push_loss = push_losses[idx].mean(dim=0) if 'push_loss' not in losses: losses['push_loss'] = push_loss else: losses['push_loss'] += push_loss if pull_losses[idx] is not None: pull_loss = pull_losses[idx].mean(dim=0) if 'pull_loss' not in losses: losses['pull_loss'] = pull_loss else: losses['pull_loss'] += pull_loss return losses
[docs] def forward(self, x): """Forward function. Returns: out (list[Tensor]): a list of heatmaps from multiple stages. """ out = [] assert isinstance(x, list) for i in range(self.num_stages): y = self.multi_deconv_layers[i](x[i]) y = self.multi_final_layers[i](y) out.append(y) return out
def _make_deconv_layer(self, num_layers, num_filters, num_kernels): """Make deconv layers.""" if num_layers != len(num_filters): error_msg = f'num_layers({num_layers}) ' \ f'!= length of num_filters({len(num_filters)})' raise ValueError(error_msg) if num_layers != len(num_kernels): error_msg = f'num_layers({num_layers}) ' \ f'!= length of num_kernels({len(num_kernels)})' raise ValueError(error_msg) layers = [] for i in range(num_layers): kernel, padding, output_padding = \ self._get_deconv_cfg(num_kernels[i]) planes = num_filters[i] layers.append( build_upsample_layer( dict(type='deconv'), in_channels=self.in_channels, out_channels=planes, kernel_size=kernel, stride=2, padding=padding, output_padding=output_padding, bias=False)) layers.append(nn.BatchNorm2d(planes)) layers.append(nn.ReLU(inplace=True)) self.in_channels = planes return nn.Sequential(*layers) @staticmethod def _get_deconv_cfg(deconv_kernel): """Get configurations for deconv layers.""" if deconv_kernel == 4: padding = 1 output_padding = 0 elif deconv_kernel == 3: padding = 1 output_padding = 1 elif deconv_kernel == 2: padding = 0 output_padding = 0 else: raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') return deconv_kernel, padding, output_padding
[docs] def init_weights(self): """Initialize model weights.""" for _, m in self.multi_deconv_layers.named_modules(): if isinstance(m, nn.ConvTranspose2d): normal_init(m, std=0.001) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) for m in self.multi_final_layers.modules(): if isinstance(m, nn.Conv2d): normal_init(m, std=0.001, bias=0)
Read the Docs v: latest
Versions
latest
1.x
v0.14.0
fix-doc
cn_doc
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.