Shortcuts

mmpose.datasets.datasets.body3d.body3d_h36m_dataset 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from collections import OrderedDict, defaultdict

import mmcv
import numpy as np
from mmcv import Config

from mmpose.core.evaluation import keypoint_mpjpe
from mmpose.datasets.datasets.base import Kpt3dSviewKpt2dDataset
from ...builder import DATASETS


[文档]@DATASETS.register_module() class Body3DH36MDataset(Kpt3dSviewKpt2dDataset): """Human3.6M dataset for 3D human pose estimation. "Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments", TPAMI`2014. More details can be found in the `paper <http://vision.imar.ro/human3.6m/pami-h36m.pdf>`__. Human3.6M keypoint indexes:: 0: 'root (pelvis)', 1: 'right_hip', 2: 'right_knee', 3: 'right_foot', 4: 'left_hip', 5: 'left_knee', 6: 'left_foot', 7: 'spine', 8: 'thorax', 9: 'neck_base', 10: 'head', 11: 'left_shoulder', 12: 'left_elbow', 13: 'left_wrist', 14: 'right_shoulder', 15: 'right_elbow', 16: 'right_wrist' Args: ann_file (str): Path to the annotation file. img_prefix (str): Path to a directory where images are held. Default: None. data_cfg (dict): config pipeline (list[dict | callable]): A sequence of data transforms. dataset_info (DatasetInfo): A class containing all dataset info. test_mode (bool): Store True when building test or validation dataset. Default: False. """ JOINT_NAMES = [ 'Root', 'RHip', 'RKnee', 'RFoot', 'LHip', 'LKnee', 'LFoot', 'Spine', 'Thorax', 'NeckBase', 'Head', 'LShoulder', 'LElbow', 'LWrist', 'RShoulder', 'RElbow', 'RWrist' ] # 2D joint source options: # "gt": from the annotation file # "detection": from a detection result file of 2D keypoint # "pipeline": will be generate by the pipeline SUPPORTED_JOINT_2D_SRC = {'gt', 'detection', 'pipeline'} # metric ALLOWED_METRICS = {'mpjpe', 'p-mpjpe', 'n-mpjpe'} def __init__(self, ann_file, img_prefix, data_cfg, pipeline, dataset_info=None, test_mode=False): if dataset_info is None: warnings.warn( 'dataset_info is missing. ' 'Check https://github.com/open-mmlab/mmpose/pull/663 ' 'for details.', DeprecationWarning) cfg = Config.fromfile('configs/_base_/datasets/h36m.py') dataset_info = cfg._cfg_dict['dataset_info'] super().__init__( ann_file, img_prefix, data_cfg, pipeline, dataset_info=dataset_info, test_mode=test_mode)
[文档] def load_config(self, data_cfg): super().load_config(data_cfg) # h36m specific attributes self.joint_2d_src = data_cfg.get('joint_2d_src', 'gt') if self.joint_2d_src not in self.SUPPORTED_JOINT_2D_SRC: raise ValueError( f'Unsupported joint_2d_src "{self.joint_2d_src}". ' f'Supported options are {self.SUPPORTED_JOINT_2D_SRC}') self.joint_2d_det_file = data_cfg.get('joint_2d_det_file', None) self.need_camera_param = data_cfg.get('need_camera_param', False) if self.need_camera_param: assert 'camera_param_file' in data_cfg self.camera_param = self._load_camera_param( data_cfg['camera_param_file']) # h36m specific annotation info ann_info = {} ann_info['use_different_joint_weights'] = False # action filter actions = data_cfg.get('actions', '_all_') self.actions = set( actions if isinstance(actions, (list, tuple)) else [actions]) # subject filter subjects = data_cfg.get('subjects', '_all_') self.subjects = set( subjects if isinstance(subjects, (list, tuple)) else [subjects]) self.ann_info.update(ann_info)
[文档] def load_annotations(self): data_info = super().load_annotations() # get 2D joints if self.joint_2d_src == 'gt': data_info['joints_2d'] = data_info['joints_2d'] elif self.joint_2d_src == 'detection': data_info['joints_2d'] = self._load_joint_2d_detection( self.joint_2d_det_file) assert data_info['joints_2d'].shape[0] == data_info[ 'joints_3d'].shape[0] assert data_info['joints_2d'].shape[2] == 3 elif self.joint_2d_src == 'pipeline': # joint_2d will be generated in the pipeline pass else: raise NotImplementedError( f'Unhandled joint_2d_src option {self.joint_2d_src}') return data_info
@staticmethod def _parse_h36m_imgname(imgname): """Parse imgname to get information of subject, action and camera. A typical h36m image filename is like: S1_Directions_1.54138969_000001.jpg """ subj, rest = osp.basename(imgname).split('_', 1) action, rest = rest.split('.', 1) camera, rest = rest.split('_', 1) return subj, action, camera
[文档] def build_sample_indices(self): """Split original videos into sequences and build frame indices. This method overrides the default one in the base class. """ # Group frames into videos. Assume that self.data_info is # chronological. video_frames = defaultdict(list) for idx, imgname in enumerate(self.data_info['imgnames']): subj, action, camera = self._parse_h36m_imgname(imgname) if '_all_' not in self.actions and action not in self.actions: continue if '_all_' not in self.subjects and subj not in self.subjects: continue video_frames[(subj, action, camera)].append(idx) # build sample indices sample_indices = [] _len = (self.seq_len - 1) * self.seq_frame_interval + 1 _step = self.seq_frame_interval for _, _indices in sorted(video_frames.items()): n_frame = len(_indices) if self.temporal_padding: # Pad the sequence so that every frame in the sequence will be # predicted. if self.causal: frames_left = self.seq_len - 1 frames_right = 0 else: frames_left = (self.seq_len - 1) // 2 frames_right = frames_left for i in range(n_frame): pad_left = max(0, frames_left - i // _step) pad_right = max(0, frames_right - (n_frame - 1 - i) // _step) start = max(i % _step, i - frames_left * _step) end = min(n_frame - (n_frame - 1 - i) % _step, i + frames_right * _step + 1) sample_indices.append([_indices[0]] * pad_left + _indices[start:end:_step] + [_indices[-1]] * pad_right) else: seqs_from_video = [ _indices[i:(i + _len):_step] for i in range(0, n_frame - _len + 1) ] sample_indices.extend(seqs_from_video) # reduce dataset size if self.subset < 1 assert 0 < self.subset <= 1 subset_size = int(len(sample_indices) * self.subset) start = np.random.randint(0, len(sample_indices) - subset_size + 1) end = start + subset_size return sample_indices[start:end]
def _load_joint_2d_detection(self, det_file): """"Load 2D joint detection results from file.""" joints_2d = np.load(det_file).astype(np.float32) return joints_2d
[文档] def evaluate(self, outputs, res_folder, metric='mpjpe', logger=None, **kwargs): metrics = metric if isinstance(metric, list) else [metric] for _metric in metrics: if _metric not in self.ALLOWED_METRICS: raise ValueError( f'Unsupported metric "{_metric}" for human3.6 dataset.' f'Supported metrics are {self.ALLOWED_METRICS}') res_file = osp.join(res_folder, 'result_keypoints.json') kpts = [] for output in outputs: preds = output['preds'] image_paths = output['target_image_paths'] batch_size = len(image_paths) for i in range(batch_size): target_id = self.name2id[image_paths[i]] kpts.append({ 'keypoints': preds[i], 'target_id': target_id, }) mmcv.dump(kpts, res_file) name_value_tuples = [] for _metric in metrics: if _metric == 'mpjpe': _nv_tuples = self._report_mpjpe(kpts) elif _metric == 'p-mpjpe': _nv_tuples = self._report_mpjpe(kpts, mode='p-mpjpe') elif _metric == 'n-mpjpe': _nv_tuples = self._report_mpjpe(kpts, mode='n-mpjpe') else: raise NotImplementedError name_value_tuples.extend(_nv_tuples) return OrderedDict(name_value_tuples)
def _report_mpjpe(self, keypoint_results, mode='mpjpe'): """Cauculate mean per joint position error (MPJPE) or its variants like P-MPJPE or N-MPJPE. Args: keypoint_results (list): Keypoint predictions. See 'Body3DH36MDataset.evaluate' for details. mode (str): Specify mpjpe variants. Supported options are: - ``'mpjpe'``: Standard MPJPE. - ``'p-mpjpe'``: MPJPE after aligning prediction to groundtruth via a rigid transformation (scale, rotation and translation). - ``'n-mpjpe'``: MPJPE after aligning prediction to groundtruth in scale only. """ preds = [] gts = [] masks = [] action_category_indices = defaultdict(list) for idx, result in enumerate(keypoint_results): pred = result['keypoints'] target_id = result['target_id'] gt, gt_visible = np.split( self.data_info['joints_3d'][target_id], [3], axis=-1) preds.append(pred) gts.append(gt) masks.append(gt_visible) action = self._parse_h36m_imgname( self.data_info['imgnames'][target_id])[1] action_category = action.split('_')[0] action_category_indices[action_category].append(idx) preds = np.stack(preds) gts = np.stack(gts) masks = np.stack(masks).squeeze(-1) > 0 err_name = mode.upper() if mode == 'mpjpe': alignment = 'none' elif mode == 'p-mpjpe': alignment = 'procrustes' elif mode == 'n-mpjpe': alignment = 'scale' else: raise ValueError(f'Invalid mode: {mode}') error = keypoint_mpjpe(preds, gts, masks, alignment) name_value_tuples = [(err_name, error)] for action_category, indices in action_category_indices.items(): _error = keypoint_mpjpe(preds[indices], gts[indices], masks[indices]) name_value_tuples.append((f'{err_name}_{action_category}', _error)) return name_value_tuples def _load_camera_param(self, camera_param_file): """Load camera parameters from file.""" return mmcv.load(camera_param_file)
[文档] def get_camera_param(self, imgname): """Get camera parameters of a frame by its image name.""" assert hasattr(self, 'camera_param') subj, _, camera = self._parse_h36m_imgname(imgname) return self.camera_param[(subj, camera)]
Read the Docs v: stable
Versions
latest
stable
cn_doc
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.