Shortcuts

Source code for mmpose.models.losses.heatmap_loss

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from mmpose.registry import MODELS


[docs]@MODELS.register_module() class KeypointMSELoss(nn.Module): """MSE loss for heatmaps. Args: use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. Defaults to ``False`` skip_empty_channel (bool): If ``True``, heatmap channels with no non-zero value (which means no visible ground-truth keypoint in the image) will not be used to calculate the loss. Defaults to ``False`` loss_weight (float): Weight of the loss. Defaults to 1.0 """ def __init__(self, use_target_weight: bool = False, skip_empty_channel: bool = False, loss_weight: float = 1.): super().__init__() self.use_target_weight = use_target_weight self.skip_empty_channel = skip_empty_channel self.loss_weight = loss_weight
[docs] def forward(self, output: Tensor, target: Tensor, target_weights: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor: """Forward function of loss. Note: - batch_size: B - num_keypoints: K - heatmaps height: H - heatmaps weight: W Args: output (Tensor): The output heatmaps with shape [B, K, H, W] target (Tensor): The target heatmaps with shape [B, K, H, W] target_weights (Tensor, optional): The target weights of differet keypoints, with shape [B, K] (keypoint-wise) or [B, K, H, W] (pixel-wise). mask (Tensor, optional): The masks of valid heatmap pixels in shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will be applied. Defaults to ``None`` Returns: Tensor: The calculated loss. """ _mask = self._get_mask(target, target_weights, mask) if _mask is None: loss = F.mse_loss(output, target) else: _loss = F.mse_loss(output, target, reduction='none') loss = (_loss * _mask).mean() return loss * self.loss_weight
def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], mask: Optional[Tensor]) -> Optional[Tensor]: """Generate the heatmap mask w.r.t. the given mask, target weight and `skip_empty_channel` setting. Returns: Tensor: The mask in shape (B, K, *) or ``None`` if no mask is needed. """ # Given spatial mask if mask is not None: # check mask has matching type with target assert (mask.ndim == target.ndim and all( d_m == d_t or d_m == 1 for d_m, d_t in zip(mask.shape, target.shape))), ( f'mask and target have mismatched shapes {mask.shape} v.s.' f'{target.shape}') # Mask by target weights (keypoint-wise mask) if target_weights is not None: # check target weight has matching shape with target assert (target_weights.ndim in (2, 4) and target_weights.shape == target.shape[:target_weights.ndim]), ( 'target_weights and target have mismatched shapes ' f'{target_weights.shape} v.s. {target.shape}') ndim_pad = target.ndim - target_weights.ndim _mask = target_weights.view(target_weights.shape + (1, ) * ndim_pad) if mask is None: mask = _mask else: mask = mask * _mask # Mask by ``skip_empty_channel`` if self.skip_empty_channel: _mask = (target != 0).flatten(2).any() ndim_pad = target.ndim - _mask.ndim _mask = _mask.view(_mask.shape + (1, ) * ndim_pad) if mask is None: mask = _mask else: mask = mask * _mask return mask
@MODELS.register_module() class CombinedTargetMSELoss(nn.Module): """MSE loss for combined target. CombinedTarget: The combination of classification target (response map) and regression target (offset map). Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020). Args: use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. Defaults to ``False`` loss_weight (float): Weight of the loss. Defaults to 1.0 """ def __init__(self, use_target_weight: bool = False, loss_weight: float = 1.): super().__init__() self.criterion = nn.MSELoss(reduction='mean') self.use_target_weight = use_target_weight self.loss_weight = loss_weight def forward(self, output: Tensor, target: Tensor, target_weights: Tensor) -> Tensor: """Forward function of loss. Note: - batch_size: B - num_channels: C - heatmaps height: H - heatmaps weight: W - num_keypoints: K Here, C = 3 * K Args: output (Tensor): The output feature maps with shape [B, C, H, W]. target (Tensor): The target feature maps with shape [B, C, H, W]. target_weights (Tensor): The target weights of differet keypoints, with shape [B, K]. Returns: Tensor: The calculated loss. """ batch_size = output.size(0) num_channels = output.size(1) heatmaps_pred = output.reshape( (batch_size, num_channels, -1)).split(1, 1) heatmaps_gt = target.reshape( (batch_size, num_channels, -1)).split(1, 1) loss = 0. num_joints = num_channels // 3 for idx in range(num_joints): heatmap_pred = heatmaps_pred[idx * 3].squeeze() heatmap_gt = heatmaps_gt[idx * 3].squeeze() offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze() offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze() offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze() offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze() if self.use_target_weight: target_weight = target_weights[:, idx, None] heatmap_pred = heatmap_pred * target_weight heatmap_gt = heatmap_gt * target_weight # classification loss loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) # regression loss loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred, heatmap_gt * offset_x_gt) loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred, heatmap_gt * offset_y_gt) return loss / num_joints * self.loss_weight
[docs]@MODELS.register_module() class KeypointOHKMMSELoss(nn.Module): """MSE loss with online hard keypoint mining. Args: use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. Defaults to ``False`` topk (int): Only top k joint losses are kept. Defaults to 8 loss_weight (float): Weight of the loss. Defaults to 1.0 """ def __init__(self, use_target_weight: bool = False, topk: int = 8, loss_weight: float = 1.): super().__init__() assert topk > 0 self.criterion = nn.MSELoss(reduction='none') self.use_target_weight = use_target_weight self.topk = topk self.loss_weight = loss_weight def _ohkm(self, losses: Tensor) -> Tensor: """Online hard keypoint mining. Note: - batch_size: B - num_keypoints: K Args: loss (Tensor): The losses with shape [B, K] Returns: Tensor: The calculated loss. """ ohkm_loss = 0. B = losses.shape[0] for i in range(B): sub_loss = losses[i] _, topk_idx = torch.topk( sub_loss, k=self.topk, dim=0, sorted=False) tmp_loss = torch.gather(sub_loss, 0, topk_idx) ohkm_loss += torch.sum(tmp_loss) / self.topk ohkm_loss /= B return ohkm_loss
[docs] def forward(self, output: Tensor, target: Tensor, target_weights: Tensor) -> Tensor: """Forward function of loss. Note: - batch_size: B - num_keypoints: K - heatmaps height: H - heatmaps weight: W Args: output (Tensor): The output heatmaps with shape [B, K, H, W]. target (Tensor): The target heatmaps with shape [B, K, H, W]. target_weights (Tensor): The target weights of differet keypoints, with shape [B, K]. Returns: Tensor: The calculated loss. """ num_keypoints = output.size(1) if num_keypoints < self.topk: raise ValueError(f'topk ({self.topk}) should not be ' f'larger than num_keypoints ({num_keypoints}).') losses = [] for idx in range(num_keypoints): if self.use_target_weight: target_weight = target_weights[:, idx, None, None] losses.append( self.criterion(output[:, idx] * target_weight, target[:, idx] * target_weight)) else: losses.append(self.criterion(output[:, idx], target[:, idx])) losses = [loss.mean(dim=(1, 2)).unsqueeze(dim=1) for loss in losses] losses = torch.cat(losses, dim=1) return self._ohkm(losses) * self.loss_weight
[docs]@MODELS.register_module() class AdaptiveWingLoss(nn.Module): """Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression' Wang et al. ICCV'2019. Args: alpha (float), omega (float), epsilon (float), theta (float) are hyper-parameters. use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. """ def __init__(self, alpha=2.1, omega=14, epsilon=1, theta=0.5, use_target_weight=False, loss_weight=1.): super().__init__() self.alpha = float(alpha) self.omega = float(omega) self.epsilon = float(epsilon) self.theta = float(theta) self.use_target_weight = use_target_weight self.loss_weight = loss_weight
[docs] def criterion(self, pred, target): """Criterion of wingloss. Note: batch_size: N num_keypoints: K Args: pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. target (torch.Tensor[NxKxHxW]): Target heatmaps. """ H, W = pred.shape[2:4] delta = (target - pred).abs() A = self.omega * ( 1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) ) * (self.alpha - target) * (torch.pow( self.theta / self.epsilon, self.alpha - target - 1)) * (1 / self.epsilon) C = self.theta * A - self.omega * torch.log( 1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) losses = torch.where( delta < self.theta, self.omega * torch.log(1 + torch.pow(delta / self.epsilon, self.alpha - target)), A * delta - C) return torch.mean(losses)
[docs] def forward(self, output: Tensor, target: Tensor, target_weights: Optional[Tensor] = None): """Forward function. Note: batch_size: N num_keypoints: K Args: output (torch.Tensor[N, K, H, W]): Output heatmaps. target (torch.Tensor[N, K, H, W]): Target heatmaps. target_weight (torch.Tensor[N, K]): Weights across different joint types. """ if self.use_target_weight: assert (target_weights.ndim in (2, 4) and target_weights.shape == target.shape[:target_weights.ndim]), ( 'target_weights and target have mismatched shapes ' f'{target_weights.shape} v.s. {target.shape}') ndim_pad = target.ndim - target_weights.ndim target_weights = target_weights.view(target_weights.shape + (1, ) * ndim_pad) loss = self.criterion(output * target_weights, target * target_weights) else: loss = self.criterion(output, target) return loss * self.loss_weight
@MODELS.register_module() class FocalHeatmapLoss(KeypointMSELoss): """A class for calculating the modified focal loss for heatmap prediction. This loss function is exactly the same as the one used in CornerNet. It runs faster and costs a little bit more memory. `CornerNet: Detecting Objects as Paired Keypoints arXiv: <https://arxiv.org/abs/1808.01244>`_. Arguments: alpha (int): The alpha parameter in the focal loss equation. beta (int): The beta parameter in the focal loss equation. use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. Defaults to ``False`` skip_empty_channel (bool): If ``True``, heatmap channels with no non-zero value (which means no visible ground-truth keypoint in the image) will not be used to calculate the loss. Defaults to ``False`` loss_weight (float): Weight of the loss. Defaults to 1.0 """ def __init__(self, alpha: int = 2, beta: int = 4, use_target_weight: bool = False, skip_empty_channel: bool = False, loss_weight: float = 1.0): super(FocalHeatmapLoss, self).__init__(use_target_weight, skip_empty_channel, loss_weight) self.alpha = alpha self.beta = beta def forward(self, output: Tensor, target: Tensor, target_weights: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor: """Calculate the modified focal loss for heatmap prediction. Note: - batch_size: B - num_keypoints: K - heatmaps height: H - heatmaps weight: W Args: output (Tensor): The output heatmaps with shape [B, K, H, W] target (Tensor): The target heatmaps with shape [B, K, H, W] target_weights (Tensor, optional): The target weights of differet keypoints, with shape [B, K] (keypoint-wise) or [B, K, H, W] (pixel-wise). mask (Tensor, optional): The masks of valid heatmap pixels in shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will be applied. Defaults to ``None`` Returns: Tensor: The calculated loss. """ _mask = self._get_mask(target, target_weights, mask) pos_inds = target.eq(1).float() neg_inds = target.lt(1).float() if _mask is not None: pos_inds = pos_inds * _mask neg_inds = neg_inds * _mask neg_weights = torch.pow(1 - target, self.beta) pos_loss = torch.log(output) * torch.pow(1 - output, self.alpha) * pos_inds neg_loss = torch.log(1 - output) * torch.pow( output, self.alpha) * neg_weights * neg_inds num_pos = pos_inds.float().sum() if num_pos == 0: loss = -neg_loss.sum() else: loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos return loss * self.loss_weight
Read the Docs v: fix-doc
Versions
latest
1.x
v0.14.0
fix-doc
cn_doc
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.