mmpose.models.losses.classfication_loss 源代码

import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES


[文档]@LOSSES.register_module() class BCELoss(nn.Module): """Binary Cross Entropy loss.""" def __init__(self, use_target_weight=False, loss_weight=1.): super().__init__() self.criterion = F.binary_cross_entropy self.use_target_weight = use_target_weight self.loss_weight = loss_weight
[文档] def forward(self, output, target, target_weight): """Forward function. Note: batch_size: N num_labels: K Args: output (torch.Tensor[N, K]): Output classification. target (torch.Tensor[N, K]): Target classification. target_weight (torch.Tensor[N, K] or torch.Tensor[N]): Weights across different labels. """ if self.use_target_weight: loss = self.criterion(output, target, reduction='none') if target_weight.dim() == 1: target_weight = target_weight[:, None] loss = (loss * target_weight).mean() else: loss = self.criterion(output, target) return loss * self.loss_weight