Source code for mmpose.models.losses.heatmap_loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from ..builder import LOSSES
[docs]@LOSSES.register_module()
class AdaptiveWingLoss(nn.Module):
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face
Alignment via Heatmap Regression' Wang et al. ICCV'2019.
Args:
alpha (float), omega (float), epsilon (float), theta (float)
are hyper-parameters.
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self,
alpha=2.1,
omega=14,
epsilon=1,
theta=0.5,
use_target_weight=False,
loss_weight=1.):
super().__init__()
self.alpha = float(alpha)
self.omega = float(omega)
self.epsilon = float(epsilon)
self.theta = float(theta)
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
[docs] def criterion(self, pred, target):
"""Criterion of wingloss.
Note:
batch_size: N
num_keypoints: K
Args:
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps.
target (torch.Tensor[NxKxHxW]): Target heatmaps.
"""
H, W = pred.shape[2:4]
delta = (target - pred).abs()
A = self.omega * (
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
) * (self.alpha - target) * (torch.pow(
self.theta / self.epsilon,
self.alpha - target - 1)) * (1 / self.epsilon)
C = self.theta * A - self.omega * torch.log(
1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
losses = torch.where(
delta < self.theta,
self.omega *
torch.log(1 +
torch.pow(delta / self.epsilon, self.alpha - target)),
A * delta - C)
return torch.mean(losses)
[docs] def forward(self, output, target, target_weight):
"""Forward function.
Note:
batch_size: N
num_keypoints: K
Args:
output (torch.Tensor[NxKxHxW]): Output heatmaps.
target (torch.Tensor[NxKxHxW]): Target heatmaps.
target_weight (torch.Tensor[NxKx1]):
Weights across different joint types.
"""
if self.use_target_weight:
loss = self.criterion(output * target_weight.unsqueeze(-1),
target * target_weight.unsqueeze(-1))
else:
loss = self.criterion(output, target)
return loss * self.loss_weight
[docs]@LOSSES.register_module()
class FocalHeatmapLoss(nn.Module):
def __init__(self, alpha=2, beta=4):
super(FocalHeatmapLoss, self).__init__()
self.alpha = alpha
self.beta = beta
[docs] def forward(self, pred, gt, mask=None):
"""Modified focal loss.
Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
"""
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
if mask is not None:
pos_inds = pos_inds * mask
neg_inds = neg_inds * mask
neg_weights = torch.pow(1 - gt, self.beta)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, self.alpha) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(
pred, self.alpha) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss