Shortcuts

mmpose.models.losses.heatmap_loss 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from ..builder import LOSSES


[文档]@LOSSES.register_module() class AdaptiveWingLoss(nn.Module): """Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression' Wang et al. ICCV'2019. Args: alpha (float), omega (float), epsilon (float), theta (float) are hyper-parameters. use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. """ def __init__(self, alpha=2.1, omega=14, epsilon=1, theta=0.5, use_target_weight=False, loss_weight=1.): super().__init__() self.alpha = float(alpha) self.omega = float(omega) self.epsilon = float(epsilon) self.theta = float(theta) self.use_target_weight = use_target_weight self.loss_weight = loss_weight
[文档] def criterion(self, pred, target): """Criterion of wingloss. Note: batch_size: N num_keypoints: K Args: pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. target (torch.Tensor[NxKxHxW]): Target heatmaps. """ H, W = pred.shape[2:4] delta = (target - pred).abs() A = self.omega * ( 1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) ) * (self.alpha - target) * (torch.pow( self.theta / self.epsilon, self.alpha - target - 1)) * (1 / self.epsilon) C = self.theta * A - self.omega * torch.log( 1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) losses = torch.where( delta < self.theta, self.omega * torch.log(1 + torch.pow(delta / self.epsilon, self.alpha - target)), A * delta - C) return torch.mean(losses)
[文档] def forward(self, output, target, target_weight): """Forward function. Note: batch_size: N num_keypoints: K Args: output (torch.Tensor[NxKxHxW]): Output heatmaps. target (torch.Tensor[NxKxHxW]): Target heatmaps. target_weight (torch.Tensor[NxKx1]): Weights across different joint types. """ if self.use_target_weight: loss = self.criterion(output * target_weight.unsqueeze(-1), target * target_weight.unsqueeze(-1)) else: loss = self.criterion(output, target) return loss * self.loss_weight
Read the Docs v: stable
Versions
latest
stable
cn_doc
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.