Source code for mmpose.models.keypoint_heads.multilabel_classification_head

import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from mmpose.models.builder import build_loss
from ..registry import HEADS


[docs]@HEADS.register_module() class MultilabelClassificationHead(nn.Module): """Multi-label classification head. Paper ref: Gyeongsik Moon. "InterHand2.6M: A Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single RGB Image". Args: in_channels (int): Number of input channels num_labels (int): Number of labels hidden_dims (list|tuple): Number of hidden dimension of FC layers. loss_classification (dict): Config for classification loss. Default: None. """ def __init__(self, in_channels=2048, num_labels=2, hidden_dims=(512, ), loss_classification=None, train_cfg=None, test_cfg=None): super().__init__() self.loss = build_loss(loss_classification) self.in_channels = in_channels self.num_labesl = num_labels self.train_cfg = {} if train_cfg is None else train_cfg self.test_cfg = {} if test_cfg is None else test_cfg feature_dims = [in_channels] + \ [dim for dim in hidden_dims] + \ [num_labels] self.fc = self._make_linear_layers(feature_dims) def _make_linear_layers(self, feat_dims, relu_final=True): """Make linear layers.""" layers = [] for i in range(len(feat_dims) - 1): layers.append(nn.Linear(feat_dims[i], feat_dims[i + 1])) if i < len(feat_dims) - 2 or (i == len(feat_dims) - 2 and relu_final): layers.append(nn.ReLU(inplace=True)) return nn.Sequential(*layers)
[docs] def forward(self, x): """Forward function.""" labels = torch.sigmoid(self.fc(x)) return labels
[docs] def get_loss(self, output, target, target_weight): """Calculate regression loss of root depth. Note: batch_size: N Args: output (torch.Tensor[N, 1]): Output depth. target (torch.Tensor[N, 1]): Target depth. target_weight (torch.Tensor[N, 1]): Weights across different data. """ losses = dict() assert not isinstance(self.loss, nn.Sequential) assert target.dim() == 2 and target_weight.dim() == 2 losses['classification_loss'] = self.loss(output, target, target_weight) return losses
[docs] def get_accuracy(self, output, target, target_weight): """Calculate accuracy for classification. Note: batch size: N number labels: L Args: output (torch.Tensor[N, L]): Output hand visibility. target (torch.Tensor[N, L]): Target hand visibility. target_weight (torch.Tensor[N, L]): Weights across different labels. """ accuracy = dict() # only calculate accuracy on the samples with ground truth labels valid = (target_weight > 0).min(dim=1)[0] output, target = output[valid], target[valid] if output.shape[0] == 0: # when no samples with gt labels, set acc to 0. acc = output.new_zeros(1) else: acc = (((output - 0.5) * (target - 0.5)).min(dim=1)[0] > 0).float().mean() accuracy['acc_classification'] = acc return accuracy
[docs] def inference_model(self, x, flip_pairs=None): """Inference function. Returns: output_labels (np.ndarray): Output labels. Args: x (torch.Tensor[NxC]): Input features vector. flip_pairs (None | list[tuple()]): Pairs of labels which are mirrored. """ labels = self.forward(x).detach().cpu().numpy() if flip_pairs is not None: labels_flipped_back = labels.copy() for left, right in flip_pairs: labels_flipped_back[:, left, ...] = labels[:, right, ...] labels_flipped_back[:, right, ...] = labels[:, left, ...] return labels_flipped_back return labels
[docs] def decode(self, img_metas, output, **kwargs): """Decode keypoints from heatmaps. Args: img_metas (list(dict)): Information about data augmentation By default this includes: - "image_file": path to the image file output (np.ndarray[N, L]): model predicted labels. """ return dict(labels=output)
def init_weights(self): for m in self.fc.modules(): if isinstance(m, nn.Linear): normal_init(m, mean=0, std=0.01, bias=0)