Source code for mmpose.models.detectors.top_down

import math
import warnings

import cv2
import mmcv
import numpy as np
from mmcv.image import imwrite
from mmcv.visualization.image import imshow

from .. import builder
from ..registry import POSENETS
from .base import BasePose

try:
    from mmcv.runner import auto_fp16
except ImportError:
    warnings.warn('auto_fp16 from mmpose will be deprecated from v0.15.0'
                  'Please install mmcv>=1.1.4')
    from mmpose.core import auto_fp16


[docs]@POSENETS.register_module() class TopDown(BasePose): """Top-down pose detectors. Args: backbone (dict): Backbone modules to extract feature. keypoint_head (dict): Keypoint head to process feature. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path to the pretrained models. loss_pose (None): Deprecated arguments. Please use `loss_keypoint` for heads instead. """ def __init__(self, backbone, neck=None, keypoint_head=None, train_cfg=None, test_cfg=None, pretrained=None, loss_pose=None): super().__init__() self.fp16_enabled = False self.backbone = builder.build_backbone(backbone) self.train_cfg = train_cfg self.test_cfg = test_cfg if neck is not None: self.neck = builder.build_neck(neck) if keypoint_head is not None: keypoint_head['train_cfg'] = train_cfg keypoint_head['test_cfg'] = test_cfg if 'loss_keypoint' not in keypoint_head and loss_pose is not None: warnings.warn( '`loss_pose` for TopDown is deprecated, ' 'use `loss_keypoint` for heads instead. See ' 'https://github.com/open-mmlab/mmpose/pull/382' ' for more information.', DeprecationWarning) keypoint_head['loss_keypoint'] = loss_pose self.keypoint_head = builder.build_head(keypoint_head) self.init_weights(pretrained=pretrained) @property def with_neck(self): """Check if has keypoint_head.""" return hasattr(self, 'neck') @property def with_keypoint(self): """Check if has keypoint_head.""" return hasattr(self, 'keypoint_head')
[docs] def init_weights(self, pretrained=None): """Weight initialization for model.""" self.backbone.init_weights(pretrained) if self.with_neck: self.neck.init_weights() if self.with_keypoint: self.keypoint_head.init_weights()
[docs] @auto_fp16(apply_to=('img', )) def forward(self, img, target=None, target_weight=None, img_metas=None, return_loss=True, return_heatmap=False, **kwargs): """Calls either forward_train or forward_test depending on whether return_loss=True. Note this setting will change the expected inputs. When `return_loss=True`, img and img_meta are single-nested (i.e. Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta should be double nested (i.e. List[Tensor], List[List[dict]]), with the outer list indicating test time augmentations. Note: batch_size: N num_keypoints: K num_img_channel: C (Default: 3) img height: imgH img width: imgW heatmaps height: H heatmaps weight: W Args: img (torch.Tensor[NxCximgHximgW]): Input images. target (torch.Tensor[NxKxHxW]): Target heatmaps. target_weight (torch.Tensor[NxKx1]): Weights across different joint types. img_metas (list(dict)): Information about data augmentation By default this includes: - "image_file: path to the image file - "center": center of the bbox - "scale": scale of the bbox - "rotation": rotation of the bbox - "bbox_score": score of bbox return_loss (bool): Option to `return loss`. `return loss=True` for training, `return loss=False` for validation & test. return_heatmap (bool) : Option to return heatmap. Returns: dict|tuple: if `return loss` is true, then return losses. Otherwise, return predicted poses, boxes, image paths and heatmaps. """ if return_loss: return self.forward_train(img, target, target_weight, img_metas, **kwargs) return self.forward_test( img, img_metas, return_heatmap=return_heatmap, **kwargs)
[docs] def forward_train(self, img, target, target_weight, img_metas, **kwargs): """Defines the computation performed at every call when training.""" output = self.backbone(img) if self.with_neck: output = self.neck(output) if self.with_keypoint: output = self.keypoint_head(output) # if return loss losses = dict() if self.with_keypoint: keypoint_losses = self.keypoint_head.get_loss( output, target, target_weight) losses.update(keypoint_losses) keypoint_accuracy = self.keypoint_head.get_accuracy( output, target, target_weight) losses.update(keypoint_accuracy) return losses
[docs] def forward_test(self, img, img_metas, return_heatmap=False, **kwargs): """Defines the computation performed at every call when testing.""" assert img.size(0) == len(img_metas) batch_size, _, img_height, img_width = img.shape if batch_size > 1: assert 'bbox_id' in img_metas[0] result = {} features = self.backbone(img) if self.with_neck: features = self.neck(features) if self.with_keypoint: output_heatmap = self.keypoint_head.inference_model( features, flip_pairs=None) if self.test_cfg.get('flip_test', True): img_flipped = img.flip(3) features_flipped = self.backbone(img_flipped) if self.with_neck: features_flipped = self.neck(features_flipped) if self.with_keypoint: output_flipped_heatmap = self.keypoint_head.inference_model( features_flipped, img_metas[0]['flip_pairs']) output_heatmap = (output_heatmap + output_flipped_heatmap) * 0.5 if self.with_keypoint: keypoint_result = self.keypoint_head.decode( img_metas, output_heatmap, img_size=[img_width, img_height]) result.update(keypoint_result) if not return_heatmap: output_heatmap = None result['output_heatmap'] = output_heatmap return result
[docs] def forward_dummy(self, img): """Used for computing network FLOPs. See ``tools/get_flops.py``. Args: img (torch.Tensor): Input image. Returns: Tensor: Output heatmaps. """ output = self.backbone(img) if self.with_neck: output = self.neck(output) if self.with_keypoint: output = self.keypoint_head(output) return output
[docs] def show_result(self, img, result, skeleton=None, kpt_score_thr=0.3, bbox_color='green', pose_kpt_color=None, pose_limb_color=None, text_color=(255, 0, 0), radius=4, thickness=1, font_scale=0.5, win_name='', show=False, show_keypoint_weight=False, wait_time=0, out_file=None): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (list[dict]): The results to draw over `img` (bbox_result, pose_result). skeleton (list[list]): The connection of keypoints. kpt_score_thr (float, optional): Minimum score of keypoints to be shown. Default: 0.3. bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, do not draw keypoints. pose_limb_color (np.array[Mx3]): Color of M limbs. If None, do not draw limbs. text_color (str or tuple or :obj:`Color`): Color of texts. radius (int): Radius of circles. thickness (int): Thickness of lines. font_scale (float): Font scales of texts. win_name (str): The window name. show (bool): Whether to show the image. Default: False. show_keypoint_weight (bool): Whether to change the transparency using the predicted confidence scores of keypoints. wait_time (int): Value of waitKey param. Default: 0. out_file (str or None): The filename to write the image. Default: None. Returns: Tensor: Visualized img, only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() img_h, img_w, _ = img.shape bbox_result = [] pose_result = [] for res in result: bbox_result.append(res['bbox']) pose_result.append(res['keypoints']) if len(bbox_result) > 0: bboxes = np.vstack(bbox_result) # draw bounding boxes mmcv.imshow_bboxes( img, bboxes, colors=bbox_color, top_k=-1, thickness=thickness, show=False, win_name=win_name, wait_time=wait_time, out_file=None) for _, kpts in enumerate(pose_result): # draw each point on image if pose_kpt_color is not None: assert len(pose_kpt_color) == len(kpts) for kid, kpt in enumerate(kpts): x_coord, y_coord, kpt_score = int(kpt[0]), int( kpt[1]), kpt[2] if kpt_score > kpt_score_thr: if show_keypoint_weight: img_copy = img.copy() r, g, b = pose_kpt_color[kid] cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius, (int(r), int(g), int(b)), -1) transparency = max(0, min(1, kpt_score)) cv2.addWeighted( img_copy, transparency, img, 1 - transparency, 0, dst=img) else: r, g, b = pose_kpt_color[kid] cv2.circle(img, (int(x_coord), int(y_coord)), radius, (int(r), int(g), int(b)), -1) # draw limbs if skeleton is not None and pose_limb_color is not None: assert len(pose_limb_color) == len(skeleton) for sk_id, sk in enumerate(skeleton): pos1 = (int(kpts[sk[0] - 1, 0]), int(kpts[sk[0] - 1, 1])) pos2 = (int(kpts[sk[1] - 1, 0]), int(kpts[sk[1] - 1, 1])) if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0 and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w and pos2[1] > 0 and pos2[1] < img_h and kpts[sk[0] - 1, 2] > kpt_score_thr and kpts[sk[1] - 1, 2] > kpt_score_thr): r, g, b = pose_limb_color[sk_id] if show_keypoint_weight: img_copy = img.copy() X = (pos1[0], pos2[0]) Y = (pos1[1], pos2[1]) mX = np.mean(X) mY = np.mean(Y) length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 angle = math.degrees( math.atan2(Y[0] - Y[1], X[0] - X[1])) stickwidth = 2 polygon = cv2.ellipse2Poly( (int(mX), int(mY)), (int(length / 2), int(stickwidth)), int(angle), 0, 360, 1) cv2.fillConvexPoly(img_copy, polygon, (int(r), int(g), int(b))) transparency = max( 0, min( 1, 0.5 * (kpts[sk[0] - 1, 2] + kpts[sk[1] - 1, 2]))) cv2.addWeighted( img_copy, transparency, img, 1 - transparency, 0, dst=img) else: cv2.line( img, pos1, pos2, (int(r), int(g), int(b)), thickness=thickness) if show: imshow(img, win_name, wait_time) if out_file is not None: imwrite(img, out_file) return img