import math
import warnings
import cv2
import mmcv
import numpy as np
from mmcv.image import imwrite
from mmcv.visualization.image import imshow
from .. import builder
from ..registry import POSENETS
from .base import BasePose
try:
from mmcv.runner import auto_fp16
except ImportError:
warnings.warn('auto_fp16 from mmpose will be deprecated from v0.15.0'
'Please install mmcv>=1.1.4')
from mmpose.core import auto_fp16
[docs]@POSENETS.register_module()
class TopDown(BasePose):
"""Top-down pose detectors.
Args:
backbone (dict): Backbone modules to extract feature.
keypoint_head (dict): Keypoint head to process feature.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path to the pretrained models.
loss_pose (None): Deprecated arguments. Please use
`loss_keypoint` for heads instead.
"""
def __init__(self,
backbone,
neck=None,
keypoint_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
loss_pose=None):
super().__init__()
self.fp16_enabled = False
self.backbone = builder.build_backbone(backbone)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if neck is not None:
self.neck = builder.build_neck(neck)
if keypoint_head is not None:
keypoint_head['train_cfg'] = train_cfg
keypoint_head['test_cfg'] = test_cfg
if 'loss_keypoint' not in keypoint_head and loss_pose is not None:
warnings.warn(
'`loss_pose` for TopDown is deprecated, '
'use `loss_keypoint` for heads instead. See '
'https://github.com/open-mmlab/mmpose/pull/382'
' for more information.', DeprecationWarning)
keypoint_head['loss_keypoint'] = loss_pose
self.keypoint_head = builder.build_head(keypoint_head)
self.init_weights(pretrained=pretrained)
@property
def with_neck(self):
"""Check if has keypoint_head."""
return hasattr(self, 'neck')
@property
def with_keypoint(self):
"""Check if has keypoint_head."""
return hasattr(self, 'keypoint_head')
[docs] def init_weights(self, pretrained=None):
"""Weight initialization for model."""
self.backbone.init_weights(pretrained)
if self.with_neck:
self.neck.init_weights()
if self.with_keypoint:
self.keypoint_head.init_weights()
[docs] @auto_fp16(apply_to=('img', ))
def forward(self,
img,
target=None,
target_weight=None,
img_metas=None,
return_loss=True,
return_heatmap=False,
**kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True. Note this setting will change the expected inputs.
When `return_loss=True`, img and img_meta are single-nested (i.e.
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
Note:
batch_size: N
num_keypoints: K
num_img_channel: C (Default: 3)
img height: imgH
img width: imgW
heatmaps height: H
heatmaps weight: W
Args:
img (torch.Tensor[NxCximgHximgW]): Input images.
target (torch.Tensor[NxKxHxW]): Target heatmaps.
target_weight (torch.Tensor[NxKx1]): Weights across
different joint types.
img_metas (list(dict)): Information about data augmentation
By default this includes:
- "image_file: path to the image file
- "center": center of the bbox
- "scale": scale of the bbox
- "rotation": rotation of the bbox
- "bbox_score": score of bbox
return_loss (bool): Option to `return loss`. `return loss=True`
for training, `return loss=False` for validation & test.
return_heatmap (bool) : Option to return heatmap.
Returns:
dict|tuple: if `return loss` is true, then return losses.
Otherwise, return predicted poses, boxes, image paths
and heatmaps.
"""
if return_loss:
return self.forward_train(img, target, target_weight, img_metas,
**kwargs)
return self.forward_test(
img, img_metas, return_heatmap=return_heatmap, **kwargs)
[docs] def forward_train(self, img, target, target_weight, img_metas, **kwargs):
"""Defines the computation performed at every call when training."""
output = self.backbone(img)
if self.with_neck:
output = self.neck(output)
if self.with_keypoint:
output = self.keypoint_head(output)
# if return loss
losses = dict()
if self.with_keypoint:
keypoint_losses = self.keypoint_head.get_loss(
output, target, target_weight)
losses.update(keypoint_losses)
keypoint_accuracy = self.keypoint_head.get_accuracy(
output, target, target_weight)
losses.update(keypoint_accuracy)
return losses
[docs] def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
"""Defines the computation performed at every call when testing."""
assert img.size(0) == len(img_metas)
batch_size, _, img_height, img_width = img.shape
if batch_size > 1:
assert 'bbox_id' in img_metas[0]
result = {}
features = self.backbone(img)
if self.with_neck:
features = self.neck(features)
if self.with_keypoint:
output_heatmap = self.keypoint_head.inference_model(
features, flip_pairs=None)
if self.test_cfg.get('flip_test', True):
img_flipped = img.flip(3)
features_flipped = self.backbone(img_flipped)
if self.with_neck:
features_flipped = self.neck(features_flipped)
if self.with_keypoint:
output_flipped_heatmap = self.keypoint_head.inference_model(
features_flipped, img_metas[0]['flip_pairs'])
output_heatmap = (output_heatmap +
output_flipped_heatmap) * 0.5
if self.with_keypoint:
keypoint_result = self.keypoint_head.decode(
img_metas, output_heatmap, img_size=[img_width, img_height])
result.update(keypoint_result)
if not return_heatmap:
output_heatmap = None
result['output_heatmap'] = output_heatmap
return result
[docs] def forward_dummy(self, img):
"""Used for computing network FLOPs.
See ``tools/get_flops.py``.
Args:
img (torch.Tensor): Input image.
Returns:
Tensor: Output heatmaps.
"""
output = self.backbone(img)
if self.with_neck:
output = self.neck(output)
if self.with_keypoint:
output = self.keypoint_head(output)
return output
[docs] def show_result(self,
img,
result,
skeleton=None,
kpt_score_thr=0.3,
bbox_color='green',
pose_kpt_color=None,
pose_limb_color=None,
text_color=(255, 0, 0),
radius=4,
thickness=1,
font_scale=0.5,
win_name='',
show=False,
show_keypoint_weight=False,
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (list[dict]): The results to draw over `img`
(bbox_result, pose_result).
skeleton (list[list]): The connection of keypoints.
kpt_score_thr (float, optional): Minimum score of keypoints
to be shown. Default: 0.3.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
If None, do not draw keypoints.
pose_limb_color (np.array[Mx3]): Color of M limbs.
If None, do not draw limbs.
text_color (str or tuple or :obj:`Color`): Color of texts.
radius (int): Radius of circles.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
win_name (str): The window name.
show (bool): Whether to show the image. Default: False.
show_keypoint_weight (bool): Whether to change the transparency
using the predicted confidence scores of keypoints.
wait_time (int): Value of waitKey param.
Default: 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
Tensor: Visualized img, only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
img_h, img_w, _ = img.shape
bbox_result = []
pose_result = []
for res in result:
bbox_result.append(res['bbox'])
pose_result.append(res['keypoints'])
if len(bbox_result) > 0:
bboxes = np.vstack(bbox_result)
# draw bounding boxes
mmcv.imshow_bboxes(
img,
bboxes,
colors=bbox_color,
top_k=-1,
thickness=thickness,
show=False,
win_name=win_name,
wait_time=wait_time,
out_file=None)
for _, kpts in enumerate(pose_result):
# draw each point on image
if pose_kpt_color is not None:
assert len(pose_kpt_color) == len(kpts)
for kid, kpt in enumerate(kpts):
x_coord, y_coord, kpt_score = int(kpt[0]), int(
kpt[1]), kpt[2]
if kpt_score > kpt_score_thr:
if show_keypoint_weight:
img_copy = img.copy()
r, g, b = pose_kpt_color[kid]
cv2.circle(img_copy,
(int(x_coord), int(y_coord)),
radius, (int(r), int(g), int(b)),
-1)
transparency = max(0, min(1, kpt_score))
cv2.addWeighted(
img_copy,
transparency,
img,
1 - transparency,
0,
dst=img)
else:
r, g, b = pose_kpt_color[kid]
cv2.circle(img, (int(x_coord), int(y_coord)),
radius, (int(r), int(g), int(b)),
-1)
# draw limbs
if skeleton is not None and pose_limb_color is not None:
assert len(pose_limb_color) == len(skeleton)
for sk_id, sk in enumerate(skeleton):
pos1 = (int(kpts[sk[0] - 1, 0]), int(kpts[sk[0] - 1,
1]))
pos2 = (int(kpts[sk[1] - 1, 0]), int(kpts[sk[1] - 1,
1]))
if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0
and pos1[1] < img_h and pos2[0] > 0
and pos2[0] < img_w and pos2[1] > 0
and pos2[1] < img_h
and kpts[sk[0] - 1, 2] > kpt_score_thr
and kpts[sk[1] - 1, 2] > kpt_score_thr):
r, g, b = pose_limb_color[sk_id]
if show_keypoint_weight:
img_copy = img.copy()
X = (pos1[0], pos2[0])
Y = (pos1[1], pos2[1])
mX = np.mean(X)
mY = np.mean(Y)
length = ((Y[0] - Y[1])**2 +
(X[0] - X[1])**2)**0.5
angle = math.degrees(
math.atan2(Y[0] - Y[1], X[0] - X[1]))
stickwidth = 2
polygon = cv2.ellipse2Poly(
(int(mX), int(mY)),
(int(length / 2), int(stickwidth)),
int(angle), 0, 360, 1)
cv2.fillConvexPoly(img_copy, polygon,
(int(r), int(g), int(b)))
transparency = max(
0,
min(
1, 0.5 * (kpts[sk[0] - 1, 2] +
kpts[sk[1] - 1, 2])))
cv2.addWeighted(
img_copy,
transparency,
img,
1 - transparency,
0,
dst=img)
else:
cv2.line(
img,
pos1,
pos2, (int(r), int(g), int(b)),
thickness=thickness)
if show:
imshow(img, win_name, wait_time)
if out_file is not None:
imwrite(img, out_file)
return img