mmpose.datasets.datasets.mesh.mesh_mix_dataset 源代码

from abc import ABCMeta

import numpy as np
from torch.utils.data import ConcatDataset, Dataset, WeightedRandomSampler

from mmpose.datasets.builder import DATASETS
from .mesh_base_dataset import MeshBaseDataset


[文档]@DATASETS.register_module() class MeshMixDataset(Dataset, metaclass=ABCMeta): """Mix Dataset for 3D human mesh estimation. The dataset combines data from multiple datasets (MeshBaseDataset) and sample the data from different datasets with the provided proportions. The dataset loads raw features and apply specified transforms to return a dict containing the image tensors and other information. Args: configs (list): List of configs for multiple datasets. partition (list): Sample proportion of multiple datasets. The length of partition should be same with that of configs. The elements of it should be non-negative and is not necessary summing up to one. Example: >>> from mmpose.datasets import MeshMixDataset >>> data_cfg = dict( >>> image_size=[256, 256], >>> iuv_size=[64, 64], >>> num_joints=24, >>> use_IUV=True, >>> uv_type='BF') >>> >>> mix_dataset = MeshMixDataset( >>> configs=[ >>> dict( >>> ann_file='tests/data/h36m/test_h36m.npz', >>> img_prefix='tests/data/h36m', >>> data_cfg=data_cfg, >>> pipeline=[]), >>> dict( >>> ann_file='tests/data/h36m/test_h36m.npz', >>> img_prefix='tests/data/h36m', >>> data_cfg=data_cfg, >>> pipeline=[]), >>> ], >>> partition=[0.6, 0.4]) """ def __init__(self, configs, partition): """Load data from multiple datasets.""" assert min(partition) >= 0 datasets = [MeshBaseDataset(**cfg) for cfg in configs] self.dataset = ConcatDataset(datasets) self.length = max(len(ds) for ds in datasets) weights = [ np.ones(len(ds)) * p / len(ds) for (p, ds) in zip(partition, datasets) ] weights = np.concatenate(weights, axis=0) self.sampler = WeightedRandomSampler(weights, 1) def __len__(self): """Get the size of the dataset.""" return self.length def __getitem__(self, idx): """Given index, sample the data from multiple datasets with the given proportion.""" idx_new = list(self.sampler)[0] return self.dataset[idx_new]