Source code for dgs.models.dataset

r"""
Modules for loading data, including datasets and data loaders.
The modules are a combination of my custom BaseModule and a regular torch Dataset.
Additionally, I implemented a Dataset for the |PT21|_ dataset that can be loaded within |torchreid|_.
"""

import os
from glob import glob
from typing import Type, Union

from torch.utils.data import ConcatDataset as TConcatDataset
from tqdm import tqdm

from dgs.utils.config import get_sub_config, insert_into_config
from dgs.utils.loader import get_instance, register_instance
from dgs.utils.types import Config, NodePath
from .alphapose import AlphaPoseLoader
from .dataset import BaseDataset
from .keypoint_rcnn import KeypointRCNNBackbone, KeypointRCNNImageBackbone, KeypointRCNNVideoBackbone
from .MOT import MOTImage, MOTImageHistory
from .posetrack21 import PoseTrack21_BBox, PoseTrack21_Image, PoseTrack21_ImageHistory

__all__ = ["DATASETS", "get_dataset", "register_dataset", "get_multi_dataset"]

DATASETS: dict[str, Type[BaseDataset]] = {
    "PoseTrack21_BBox": PoseTrack21_BBox,
    "PT21_BBox": PoseTrack21_BBox,  # alias
    "PoseTrack21_Image": PoseTrack21_Image,
    "PT21_Image": PoseTrack21_Image,  # alias
    "PoseTrack21_ImageHistory": PoseTrack21_ImageHistory,
    "PT21_ImageHistory": PoseTrack21_ImageHistory,  # alias
    "AlphaPoseLoader": AlphaPoseLoader,
    "KeypointRCNNBackbone": KeypointRCNNBackbone,  # shouldn't be used directly, only as wrapper
    "KeypointRCNNImageBackbone": KeypointRCNNImageBackbone,
    "KeypointRCNNVideoBackbone": KeypointRCNNVideoBackbone,
    "MOTImage": MOTImage,
    "MOTI": MOTImage,  # alias
    "MOTImageHistory": MOTImageHistory,
    "MOTIH": MOTImageHistory,  # alias
}


[docs] def get_dataset(name: str) -> Type[BaseDataset]: """Given the name of one dataset, return an instance.""" return get_instance(instance=name, instances=DATASETS, inst_class=BaseDataset)
[docs] def register_dataset(name: str, new_ds: Type[BaseDataset]) -> None: """Register a new dataset module in :data:``DATASETS``, to be able to use it from configuration files.""" register_instance(name=name, instance=new_ds, instances=DATASETS, inst_class=BaseDataset)
[docs] def get_multi_dataset( config: Config, path: NodePath, ds_name: str, concat: bool = True ) -> Union[TConcatDataset[BaseDataset], list[BaseDataset]]: """Create a concatenated dataset from the given configuration and path. Args: config: The overall configuration for the tracker. path: The path to the dataset-specific parameters. ds_name: The type of dataset to create as a string from all the available datasets. concat: Whether to concatenate the list of datasets in the end. Default: True. """ # get the dataset type to instantiate it faster ds_type = get_dataset(name=ds_name) sub_cfg = get_sub_config(config=config, path=path) if "paths" not in sub_cfg: raise ValueError(f"No paths given in config. Got: {sub_cfg}") # get all the data paths paths = sub_cfg["paths"] if isinstance(paths, (list, tuple)): pass elif isinstance(paths, str) and "*" in paths: paths = glob(paths) elif isinstance(paths, str) and os.path.exists(paths): paths = [paths] elif isinstance(paths, str) and os.path.exists(dir_path := str(os.path.join(sub_cfg["dataset_path"], paths))): paths = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))] else: raise ValueError(f"The given 'paths' ({paths}) is neither an iterable, a string, nor a valid file path.") assert len(paths) > 0, f"No paths found with paths: {paths}" # for every dataset, insert the right data_path into the config and initialize the datasets datasets = [] for data_path in tqdm(paths, desc="Loading datasets", leave=False): ds_cfg = insert_into_config( path=path, value={ "data_path": os.path.normpath(data_path), "module_name": str(ds_name).lstrip("Concat_"), }, original=config, copy=True, ) datasets.append(ds_type(config=ds_cfg, path=path)) if concat: return TConcatDataset(datasets) return datasets