"""
Default Datasets for pose-based data.
TorchreidPoseDataset and TorchreidPoseDataManager are custom models for torchreid.
"""
import warnings
from typing import Callable, Type, Union
import torch as t
import torchvision.transforms.v2 as tvt
from torch.utils.data import DataLoader as TorchDataLoader, Dataset as TorchDataset
from dgs.utils.types import FilePath
from dgs.utils.utils import HidePrint
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Cython evaluation.*is unavailable", category=UserWarning)
try:
# If torchreid is installed using `./dependencies/torchreid`
# noinspection PyUnresolvedReferences
from torchreid.data import Dataset as TorchreidDataset
# noinspection PyUnresolvedReferences
from torchreid.data.datamanager import DataManager as TorchreidDM
# noinspection PyUnresolvedReferences
from torchreid.data.sampler import build_train_sampler
except ModuleNotFoundError:
# if torchreid is installed using `pip install torchreid`
# noinspection PyUnresolvedReferences
from torchreid.reid.data import Dataset as TorchreidDataset
# noinspection PyUnresolvedReferences
from torchreid.reid.data.datamanager import DataManager as TorchreidDM
# noinspection PyUnresolvedReferences
from torchreid.reid.data.sampler import build_train_sampler
[docs]
class TorchreidPoseDataset(TorchreidDataset):
"""Custom torchreid Dataset for pose-based data."""
def __getitem__(self, index: int) -> dict[str, any]:
pose_path, pid, camid, dsetid = self.data[index]
pose = t.load(pose_path)
return {"img": pose, "pid": pid, "camid": camid, "dsetid": dsetid}
[docs]
def show_summary(self) -> None:
"""Print dataset summary."""
num_train_pids = self.get_num_pids(self.train)
num_train_cams = self.get_num_cams(self.train)
num_query_pids = self.get_num_pids(self.query)
num_query_cams = self.get_num_cams(self.query)
num_gallery_pids = self.get_num_pids(self.gallery)
num_gallery_cams = self.get_num_cams(self.gallery)
print(f" => Loaded {self.__class__.__name__}")
print(" ----------------------------------------")
print(" subset | # ids | # poses | # cameras")
print(f" train | {num_train_pids:5d} | {len(self.train):8d} | {num_train_cams:9d}")
print(f" query | {num_query_pids:5d} | {len(self.query):8d} | {num_query_cams:9d}")
print(f" gallery | {num_gallery_pids:5d} | {len(self.gallery):8d} | {num_gallery_cams:9d}")
print(" ----------------------------------------")
[docs]
class TorchreidPoseDataManager(TorchreidDM):
"""Custom torchreid DataManager for pose-based data.
Args:
root: Root path to the directory containing all the datasets.
sources: The types of source pose dataset(s).
**kwargs: Additional keyword arguments, see Other Parameters below.
Other Parameters:
combineall (bool):
Combine train, query and gallery in a dataset for training.
Default is False.
targets (Type[TorchreidPoseDataset] | list[Type[TorchreidPoseDataset]]):
The types of target dataset(s).
If not given, it equals to ``sources``.
transforms (list[str | Callable]):
One or multiple transformations applied to model training.
Default is 'random_flip'.
train_sampler (str):
Name of the Sampler during training.
Default "RandomSampler".
use_gpu (bool): Whether to use the gpu.
Default is True.
batch_size_train (int):
The number of images in a training batch.
Default is 32.
batch_size_test (int):
The number of images in a test batch.
Default is 32.
num_instances (int):
The number of instances per identity in a batch.
Default is 4.
num_cams (int):
The number of cameras to sample in a batch (when using ``RandomDomainSampler``).
Default is 1.
num_datasets (int):
The number of datasets to sample in a batch (when using ``RandomDatasetSampler``).
Default is 1.
train_sampler (str):
Name of the torchreid sampler.
Default is RandomSampler.
train_sampler_t (str):
Name of the torchreid sampler for the target train loader.
Default is RandomSampler.
verbose (bool):
Print more debug information.
Default is False.
workers (int):
Number of workers for the torch DataLoader.
As long as no multi-GPU context is available, this value should not be changed.
Default 0.
Notes:
The original image-based transforms are overwritten to support key-points as input.
"""
data_type: str = "pose"
"""Is used within torchreid."""
default_kwargs: dict[str, any] = {
"combineall": False,
"targets": None,
"transforms": ["random_flip"],
"train_sampler": "RandomSampler",
"use_gpu": True,
"batch_size_train": 32,
"batch_size_test": 32,
"num_instances": 4,
"num_cams": 1,
"num_datasets": 1,
"verbose": False,
"workers": 0,
}
"""A dict of default keyword arguments.
This dictionary is used to set default kwargs without passing hundreds of Arguments to `__init__()`.
"""
params: dict[str, any]
"""The parameters of this module."""
[docs]
def __init__(
self, root: FilePath, sources: Type[TorchreidPoseDataset] | list[Type[TorchreidPoseDataset]], **kwargs
) -> None:
# set default kwargs
self.params = self.default_kwargs.copy()
self.params.update(kwargs)
self.root = root
# block printing of transforms
with HidePrint():
super().__init__(sources=sources, targets=self.params["targets"], use_gpu=self.params["use_gpu"])
# the original Dataset transforms are initialized now, but we don't want them
self.train_set, self.train_loader = self.load_train()
self._num_train_pids = self.train_set.num_train_pids
self._num_train_cams = self.train_set.num_train_cams
self.test_loader, self.test_dataset = self.load_test()
if self.params["verbose"]:
self.show_summary()
[docs]
def load_train(self) -> (TorchDataset, TorchDataLoader):
"""Load the train Dataset and DataLoader as torch objects."""
print("=> Loading train (source) dataset")
# sum(list[Dataset]) is implemented via torchreid Dataset
# noinspection PyTypeChecker
train_set: Union[TorchreidPoseDataset, TorchDataset] = sum(
instance(root=self.root, mode="train", transform=self.transform_tr, instance="key_points", **self.params)
for instance in self.sources
)
train_loader = TorchDataLoader(
train_set,
sampler=build_train_sampler(
train_set.train,
self.params["train_sampler"],
batch_size=self.params["batch_size_train"],
num_instances=self.params["num_instances"],
num_cams=self.params["num_cams"],
num_datasets=self.params["num_datasets"],
),
batch_size=self.params["batch_size_train"],
shuffle=False,
num_workers=self.params["workers"], # as long as there is no multi GPU support this has to be zero
pin_memory=self.use_gpu,
drop_last=True,
)
return train_set, train_loader
[docs]
def load_test(self) -> (dict[str, dict[str, any]], dict[str, dict[str, any]]):
"""Load the test Dataset and DataLoader as torch objects."""
print("=> Loading test (target) dataset")
test_loader: dict[str, dict[str, any]] = {name: {"query": None, "gallery": None} for name in self.targets}
test_dataset: dict[str, dict[str, any]] = {name: {"query": None, "gallery": None} for name in self.targets}
for dataset in self.targets:
# test_loader for query
query_set: Union[TorchreidPoseDataset, TorchDataset] = dataset(
root=self.root, mode="query", transform=self.transform_te, **self.params
)
# build query loader
test_loader[dataset]["query"] = TorchDataLoader(
query_set,
batch_size=self.params["batch_size_test"],
shuffle=False,
num_workers=self.params["workers"],
pin_memory=self.use_gpu,
drop_last=self.params.get("drop_last_test", False),
)
# test_loader for gallery
gallery_set: Union[TorchreidDataset, TorchDataset] = dataset(
root=self.root, mode="gallery", transform=self.transform_te, **self.params
)
# build gallery loader
test_loader[dataset]["gallery"] = t.utils.data.DataLoader(
gallery_set,
batch_size=self.params["batch_size_test"],
shuffle=False,
num_workers=self.params["workers"],
pin_memory=self.use_gpu,
drop_last=self.params.get("drop_last_test", False),
)
# modify test_dataset
test_dataset[dataset]["query"] = query_set.query
test_dataset[dataset]["gallery"] = gallery_set.gallery
return test_loader, test_dataset
[docs]
def show_summary(self) -> None:
"""Show a summary describing the DataManager"""
print("\n")
print(" **************** Summary ****************")
print(f" source : {self.sources}")
print(f" # source datasets : {len(self.sources)}")
print(f" # source ids : {self.num_train_pids}")
print(f" # source images : {len(self.train_set)}")
print(f" # source cameras : {self.num_train_cams}")
print(f" target : {self.targets}")
print(" *****************************************")
print("\n")