mmpose.datasets.datasets.top_down.topdown_h36m_dataset 源代码

import os
from collections import OrderedDict

import json_tricks as json
import numpy as np
from xtcocotools.coco import COCO

from mmpose.core.evaluation.top_down_eval import (keypoint_epe,
                                                  keypoint_pck_accuracy)
from ...builder import DATASETS
from .topdown_base_dataset import TopDownBaseDataset


[文档]@DATASETS.register_module() class TopDownH36MDataset(TopDownBaseDataset): """Human3.6M dataset for top-down 2D pose estimation. `Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments' TPAMI`2014 More details can be found in the `paper <http://vision.imar.ro/human3.6m/pami-h36m.pdf>`__. Human3.6M keypoint indexes:: 0: 'root (pelvis)', 1: 'right_hip', 2: 'right_knee', 3: 'right_foot', 4: 'left_hip', 5: 'left_knee', 6: 'left_foot', 7: 'spine', 8: 'thorax', 9: 'neck_base', 10: 'head', 11: 'left_shoulder', 12: 'left_elbow', 13: 'left_wrist', 14: 'right_shoulder', 15: 'right_elbow', 16: 'right_wrist' Args: ann_file (str): Path to the annotation file. img_prefix (str): Path to a directory where images are held. Default: None. data_cfg (dict): config pipeline (list[dict | callable]): A sequence of data transforms. test_mode (bool): Store True when building test or validation dataset. Default: False. """ def __init__(self, ann_file, img_prefix, data_cfg, pipeline, test_mode=False): super(TopDownH36MDataset, self).__init__( ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode) assert self.ann_info['num_joints'] == 17 self.ann_info['flip_pairs'] = [[1, 4], [2, 5], [3, 6], [11, 14], [12, 15], [13, 16]] self.ann_info['upper_body_ids'] = (0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) self.ann_info['lower_body_ids'] = (1, 2, 3, 4, 5, 6) self.ann_info['use_different_joint_weights'] = False self.ann_info['joint_weights'] = np.ones( (self.ann_info['num_joints'], 1), dtype=np.float32) self.coco = COCO(ann_file) self.img_ids = self.coco.getImgIds() self.num_images = len(self.img_ids) self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs) self.dataset_name = 'h36m' self.db = self._get_db() print(f'=> num_images: {self.num_images}') print(f'=> load {len(self.db)} samples') @staticmethod def _get_mapping_id_name(imgs): """ Args: imgs (dict): dict of image info. Returns: tuple: Image name & id mapping dicts. - id2name (dict): Mapping image id to name. - name2id (dict): Mapping image name to id. """ id2name = {} name2id = {} for image_id, image in imgs.items(): file_name = image['file_name'] id2name[image_id] = file_name name2id[file_name] = image_id return id2name, name2id def _xywh2cs(self, x, y, w, h, padding=1.): """This encodes bbox(x,y,w,h) into (center, scale) Args: x, y, w, h Returns: center (np.ndarray[float32](2,)): center of the bbox (x, y). scale (np.ndarray[float32](2,)): scale of the bbox w & h. """ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[ 'image_size'][1] center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32) if (not self.test_mode) and np.random.rand() < 0.3: center += 0.4 * (np.random.rand(2) - 0.5) * [w, h] if w > aspect_ratio * h: h = w * 1.0 / aspect_ratio elif w < aspect_ratio * h: w = h * aspect_ratio # pixel std is 200.0 scale = np.array([w / 200.0, h / 200.0], dtype=np.float32) # padding to include proper amount of context scale = scale * padding return center, scale def _get_db(self): """Load dataset.""" gt_db = [] bbox_id = 0 num_joints = self.ann_info['num_joints'] for img_id in self.img_ids: ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False) objs = self.coco.loadAnns(ann_ids) for obj in objs: if max(obj['keypoints']) == 0: continue joints_3d = np.zeros((num_joints, 3), dtype=np.float32) joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32) keypoints = np.array(obj['keypoints']).reshape(-1, 3) joints_3d[:, :2] = keypoints[:, :2] joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3]) # use 1.25 padded bbox as input center, scale = self._xywh2cs(*obj['bbox'][:4]) image_file = os.path.join(self.img_prefix, self.id2name[img_id]) gt_db.append({ 'image_file': image_file, 'center': center, 'scale': scale, 'rotation': 0, 'joints_3d': joints_3d, 'joints_3d_visible': joints_3d_visible, 'dataset': self.dataset_name, 'bbox': obj['bbox'], 'bbox_score': 1, 'bbox_id': bbox_id }) bbox_id = bbox_id + 1 gt_db = sorted(gt_db, key=lambda x: x['bbox_id']) return gt_db
[文档] def evaluate(self, outputs, res_folder, metric, **kwargs): """Evaluate human3.6m 2d keypoint results. The pose prediction results will be saved in `${res_folder}/result_keypoints.json`. Note: batch_size: N num_keypoints: K heatmap height: H heatmap width: W Args: outputs (list(dict)) :preds (np.ndarray[N,K,3]): The first two dimensions are coordinates, score is the third dimension of the array. :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0] , scale[1],area, score] :image_paths (list[str]): For example, ['data/coco/val2017 /000000393226.jpg'] :heatmap (np.ndarray[N, K, H, W]): model output heatmap :bbox_id (list(int)). res_folder (str): Path of directory to save the results. metric (str | list[str]): Metric to be performed. Defaults: 'mAP'. Returns: dict: Evaluation results for evaluation metric. """ metrics = metric if isinstance(metric, list) else [metric] allowed_metrics = ['PCK', 'EPE'] for metric in metrics: if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') res_file = os.path.join(res_folder, 'result_keypoints.json') kpts = [] for output in outputs: preds = output['preds'] boxes = output['boxes'] image_paths = output['image_paths'] bbox_ids = output['bbox_ids'] batch_size = len(image_paths) for i in range(batch_size): image_id = self.name2id[image_paths[i][len(self.img_prefix):]] kpts.append({ 'keypoints': preds[i].tolist(), 'center': boxes[i][0:2].tolist(), 'scale': boxes[i][2:4].tolist(), 'area': float(boxes[i][4]), 'score': float(boxes[i][5]), 'image_id': image_id, 'bbox_id': bbox_ids[i] }) kpts = self._sort_and_unique_bboxes(kpts) self._write_keypoint_results(kpts, res_file) info_str = self._report_metric(res_file, metrics) name_value = OrderedDict(info_str) return name_value
def _report_metric(self, res_file, metrics, pck_thr=0.05): """Keypoint evaluation. Args: res_file (str): Json file stored prediction results. metrics (str | list[str]): Metric to be performed. Options: 'PCK', 'PCKh', 'AUC', 'EPE'. pck_thr (float): PCK threshold, default as 0.05. auc_nor (float): AUC normalization factor, default as 30 pixel. Returns: List: Evaluation results for evaluation metric. """ info_str = [] with open(res_file, 'r') as fin: preds = json.load(fin) assert len(preds) == len(self.db) outputs = [] gts = [] masks = [] threshold_bbox = [] for pred, item in zip(preds, self.db): outputs.append(np.array(pred['keypoints'])[:, :-1]) gts.append(np.array(item['joints_3d'])[:, :-1]) masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0) if 'PCK' in metrics: bbox = np.array(item['bbox']) bbox_thr = np.max(bbox[2:]) threshold_bbox.append(np.array([bbox_thr, bbox_thr])) outputs = np.array(outputs) gts = np.array(gts) masks = np.array(masks) threshold_bbox = np.array(threshold_bbox) if 'PCK' in metrics: _, pck, _ = keypoint_pck_accuracy(outputs, gts, masks, pck_thr, threshold_bbox) info_str.append(('PCK', pck)) if 'EPE' in metrics: info_str.append(('EPE', keypoint_epe(outputs, gts, masks))) return info_str def _sort_and_unique_bboxes(self, kpts, key='bbox_id'): """sort kpts and remove the repeated ones.""" kpts = sorted(kpts, key=lambda x: x[key]) num = len(kpts) for i in range(num - 1, 0, -1): if kpts[i][key] == kpts[i - 1][key]: del kpts[i] return kpts @staticmethod def _write_keypoint_results(keypoints, res_file): """Write results into a json file.""" with open(res_file, 'w') as f: json.dump(keypoints, f, sort_keys=True, indent=4)