Shortcuts

Source code for mmpose.models.losses.classfication_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES


[docs]@LOSSES.register_module() class BCELoss(nn.Module): """Binary Cross Entropy loss.""" def __init__(self, use_target_weight=False, loss_weight=1.): super().__init__() self.criterion = F.binary_cross_entropy self.use_target_weight = use_target_weight self.loss_weight = loss_weight
[docs] def forward(self, output, target, target_weight=None): """Forward function. Note: - batch_size: N - num_labels: K Args: output (torch.Tensor[N, K]): Output classification. target (torch.Tensor[N, K]): Target classification. target_weight (torch.Tensor[N, K] or torch.Tensor[N]): Weights across different labels. """ if self.use_target_weight: assert target_weight is not None loss = self.criterion(output, target, reduction='none') if target_weight.dim() == 1: target_weight = target_weight[:, None] loss = (loss * target_weight).mean() else: loss = self.criterion(output, target) return loss * self.loss_weight
Read the Docs v: latest
Versions
latest
1.x
v0.14.0
fix-doc
cn_doc
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.