"""
Utilities for validating recurring data types.
"""
import os
from collections.abc import Iterable, Sized
from typing import Union
import torch as t
from torchvision import tv_tensors as tvte
from dgs.utils.constants import PROJECT_ROOT
from dgs.utils.exceptions import InvalidPathException, ValidationException
from dgs.utils.files import is_dir, is_file, is_project_dir, mkdir_if_missing, to_abspath
from dgs.utils.types import FilePath, FilePaths, Heatmap, Image, Images, Validator
VALIDATIONS: dict[str, Validator] = {
"optional": (lambda _x, _d: True),
# types
"None": (lambda x, _: x is None),
"not None": (lambda x, _: x is not None),
"number": (lambda x, _: isinstance(x, (int, float))),
"callable": (lambda x, _: callable(x)),
"iterable": (lambda x, _: isinstance(x, Iterable)),
"sized": (lambda x, _: isinstance(x, Sized)),
"instance": isinstance,
"isinstance": isinstance,
# number
"gt": (lambda x, d: isinstance(d, int | float) and x > d),
"gte": (lambda x, d: isinstance(d, int | float) and x >= d),
"lt": (lambda x, d: isinstance(d, int | float) and x < d),
"lte": (lambda x, d: isinstance(d, int | float) and x <= d),
"between": (lambda x, d: isinstance(x, int | float) and isinstance(d, tuple) and len(d) == 2 and d[0] < x < d[1]),
"within": (lambda x, d: isinstance(x, int | float) and isinstance(d, tuple) and len(d) == 2 and d[0] <= x <= d[1]),
"outside": (
lambda x, d: isinstance(x, int | float) and isinstance(d, tuple) and len(d) == 2 and x < d[0] or x > d[1]
),
# lists and other iterables
"len": (lambda x, d: hasattr(x, "__len__") and len(x) == d),
"shorter": (lambda x, d: hasattr(x, "__len__") and len(x) < d),
"longer": (lambda x, d: hasattr(x, "__len__") and len(x) > d),
"shorter eq": (lambda x, d: hasattr(x, "__len__") and len(x) <= d),
"longer eq": (lambda x, d: hasattr(x, "__len__") and len(x) >= d),
"in": (lambda x, d: hasattr(d, "__contains__") and x in d),
"not in": (lambda x, d: hasattr(x, "__contains__") and x not in d),
"contains": (lambda x, d: hasattr(x, "__contains__") and d in x),
"not contains": (lambda x, d: hasattr(x, "__contains__") and d not in x),
# string
"startswith": (lambda x, d: isinstance(x, str) and (isinstance(d, str) or bool(str(d))) and x.startswith(d)),
"endswith": (lambda x, d: isinstance(x, str) and (isinstance(d, str) or bool(str(d))) and x.endswith(d)),
# file and folder
"file exists": (
lambda x, f: isinstance(x, str)
and (
VALIDATIONS["file exists absolute"](x, None)
or VALIDATIONS["file exists in project"](x, None)
or VALIDATIONS["file exists in folder"](x, f)
)
),
"file exists absolute": (lambda x, _: isinstance(x, str) and os.path.isfile(x)),
"file exists in project": (lambda x, _: isinstance(x, str) and os.path.isfile(os.path.join(PROJECT_ROOT, x))),
"file exists in folder": (
lambda x, f: isinstance(x, str) and isinstance(f, str) and os.path.isfile(os.path.join(f, x))
),
"folder exists": (
lambda x, b: isinstance(x, str)
and (VALIDATIONS["folder exists absolute"](x, b) or VALIDATIONS["folder exists in project"](x, b))
),
"folder exists absolute": (
lambda x, b: isinstance(x, str) and (is_dir(x) if not b else mkdir_if_missing(x) and True)
),
"folder exists in project": (
lambda x, b: isinstance(x, str)
and (is_project_dir(x) if not b else is_project_dir(x) or mkdir_if_missing(x) and True)
),
"folder exists in folder": (
lambda x, f: isinstance(x, str) and isinstance(f, str) and os.path.isdir(os.path.join(f, x))
),
# complicated recurring validations
"NodePath": (
lambda x, d: (isinstance(x, str) and x in d)
or (isinstance(x, list) and x[0] in d and ((len(x) == 1) or (VALIDATIONS["NodePath"](x[1:], d[x[0]]))))
),
"NodePaths": (lambda x, d: isinstance(x, list) and all(VALIDATIONS["NodePath"](x_i, d) for x_i in x)),
# logical operators, including nested validations
"eq": (lambda x, d: x == d),
"neq": (lambda x, d: x != d),
"not": (lambda x, d: not VALIDATIONS["all"](x, d)),
"forall": (
lambda x, data: (
VALIDATIONS["iterable"](x, None)
and (
all(VALIDATIONS[data[0]](x_i, data[1]) for x_i in x)
if isinstance(data, tuple)
else (
all(VALIDATIONS[data](x_i, None) for x_i in x)
if isinstance(data, str)
else (
all(isinstance(x_i, data) for x_i in x)
if isinstance(data, type)
else (
all(VALIDATIONS["all"](x_i, d_i) for d_i in data for x_i in x)
if isinstance(data, list)
else False
)
)
)
)
)
),
"all": (
lambda x, data: (
(len(data) == 2 and VALIDATIONS[data[0]](x, data[1]))
if isinstance(data, tuple)
else (
VALIDATIONS[data](x, None)
if isinstance(data, str)
else (
isinstance(x, data)
if isinstance(data, type)
else (
(len(data) and all(VALIDATIONS["all"](x, sub_item) for sub_item in data))
if isinstance(data, list)
else False
)
)
)
)
),
"any": (
lambda x, data: (
(len(data) == 2 and VALIDATIONS[data[0]](x, data[1]))
if isinstance(data, tuple)
else (
VALIDATIONS[data](x, None)
if isinstance(data, str)
else (
isinstance(x, data)
if isinstance(data, type)
else (
(len(data) and any(VALIDATIONS["any"](x, sub_item) for sub_item in data))
if isinstance(data, list)
else False
)
)
)
)
),
"xor": (
lambda x, d: isinstance(d, list)
and len(d) == 2
and bool(VALIDATIONS["all"](x, d[0])) != bool(VALIDATIONS["all"](x, d[1]))
),
}
[docs]
def validate_bboxes(
bboxes: tvte.BoundingBoxes,
length: int = None,
dims: Union[int, None] = 2,
box_format: Union[tvte.BoundingBoxFormat, None] = None,
) -> tvte.BoundingBoxes:
"""Given a torchvision tensor of bounding boxes,
validate them and return them as a torchvision-tensor of bounding-boxes.
Args:
bboxes: `tv_tensor.BoundingBoxes` object with an arbitrary shape, most likely ``[B x 4]``.
length: The number of items or batch-size the tensor should have.
Default `None` does not validate the length.
dims: Number of dimensions bboxes should have.
Use None to not force any number of dimensions.
Defaults to two dimensions with the bounding box dimensions as ``[B x 4]``.
box_format: If present, validates whether the tv_tensors.BoundingBoxFormat matches the one of bboxes.
Default None, and therefore no validation of the format.
Returns:
Bounding boxes as `tv_tensor.BoundingBoxes` object with exactly `dims` dimensions.
Raises:
TypeError: If the `bboxes` input is not a Tensor.
ValueError: If the `bboxes` have the wrong shape or the `bboxes` have the wrong format.
"""
if not isinstance(bboxes, tvte.BoundingBoxes):
raise TypeError(f"Bounding boxes should be torch tensor or tv_tensor Bounding Boxes but is {type(bboxes)}")
if box_format is not None and box_format != bboxes.format:
raise ValueError(f"Bounding boxes are expected to be in format {box_format} but are in format {bboxes.format}")
saved = bboxes
if dims is not None:
bboxes = validate_dimensions(tensor=bboxes, dims=dims, length=length)
elif length is not None and len(bboxes) != length:
raise ValidationException(f"Bounding box length is expected to be {length} but got {len(bboxes)}")
return tvte.wrap(bboxes, like=saved)
[docs]
def validate_dimensions(tensor: t.Tensor, dims: int, *_, length: int = None) -> t.Tensor:
"""Given a tensor, make sure he has the correct number of dimensions.
Args:
tensor: Any `torch.tensor` or other object that can be converted to one.
dims: Number of dimensions the tensor should have.
length: The number of items or batch-size the tensor should have.
Default `None` does not validate the length.
Returns:
A `torch.tensor` with the correct number of dimensions.
Raises:
TypeError: If the `tensor` input is not a `torch.tensor` or cannot be cast to one.
ValueError: If the length of the `tensor` is bigger than `dims` and cannot be unsqueezed.
"""
if not isinstance(tensor, t.Tensor):
try:
tensor = t.tensor(tensor)
except (TypeError, ValueError) as e:
raise TypeError(
f"The input should be a torch tensor or a type that can be converted to one. "
f"But `tensor` is {type(tensor)}"
) from e
if tensor.ndim > dims:
tensor.squeeze_()
if tensor.ndim > dims:
raise ValueError(
f"The length of tensor.shape should be {dims} but shape is {tensor.shape}. "
f"Unsqueezing did not work."
)
while tensor.ndim < dims:
tensor.unsqueeze_(0)
if length is not None and length != len(tensor):
raise ValidationException(f"length is expected to be {length} but got {len(tensor)}")
return tensor
[docs]
def validate_filepath(file_paths: Union[FilePath, Iterable[FilePath], FilePaths], length: int = None) -> FilePaths:
"""Validate the file path.
Args:
file_paths: Path to the file as a string or a file object.
length: The length a :class:`FilePaths` object should have.
Except for a length of 1 not applicable for :class:`FilePath`.
Returns:
FilePaths: The validated file path.
Raises:
InvalidPathException: If at least one of the paths in `file_paths` does not exist.
"""
if isinstance(file_paths, (list, tuple)):
if length is not None and len(file_paths) != length:
raise ValidationException(f"Expected {length} paths but got {len(file_paths)}.")
return tuple(validate_filepath(file_path)[0] for file_path in file_paths)
if isinstance(file_paths, str) and length is not None and length != 1:
raise ValidationException(f"Expected {length} paths but got a single path {file_paths}.")
file_paths = str(file_paths)
if not is_file(file_paths):
raise InvalidPathException(filepath=file_paths)
return tuple([to_abspath(filepath=file_paths)])
[docs]
def validate_heatmaps(
heatmaps: Union[t.Tensor, Heatmap], length: int = None, dims: Union[int, None] = 4, nof_joints: int = None
) -> Heatmap:
"""Validate a given tensor of heatmaps, whether it has the correct format and shape.
Args:
heatmaps: tensor-like object
length: The number of items or batch-size the tensor should have.
Default `None` does not validate the length.
dims: Number of dimensions heatmaps should have.
Use None to not force any number of dimensions.
Defaults to four dimensions with the heatmap dimensions as ``[B x J x w x h]``.
nof_joints: The number of joints the heatmap should have (``J``).
Default None does not validate the number of joints at all.
Returns:
Heatmap: The validated heatmaps as tensor with the correct type.
Raises:
TypeError: If the `heatmaps` input is not a Tensor or cannot be cast to one.
ValueError: If the `heatmaps` are neither two- nor three-dimensional.
"""
if not isinstance(heatmaps, (Heatmap, t.Tensor)):
raise TypeError(f"heatmaps should be a Heatmap or torch tensor but are {type(heatmaps)}.")
if nof_joints is not None and (heatmaps.ndim < 3 or heatmaps.shape[-3] != nof_joints):
raise ValueError(f"The number of joints should be {nof_joints} but is {heatmaps.shape[-2]}.")
if dims is not None:
heatmaps = validate_dimensions(tensor=heatmaps, dims=dims, length=length)
elif length is not None and len(heatmaps) != length:
raise ValidationException(f"Heatmap length is expected to be {length} but got {len(heatmaps)}")
return tvte.Mask(heatmaps)
[docs]
def validate_ids(ids: Union[int, t.Tensor], length: int = None) -> t.Tensor:
"""Validate a given tensor or single integer value.
Args:
ids: Arbitrary torch tensor to check.
length: The number of items or batch-size the tensor should have.
Default `None` does not validate the length.
Returns:
torch.Tensor: Torch integer tensor with one dimension.
Raises:
TypeError: If `ids` is not a `torch.Tensor`.
"""
if isinstance(ids, int):
ids = t.tensor([ids], dtype=t.int)
if not isinstance(ids, t.Tensor) or ids.is_floating_point() or ids.is_complex():
raise TypeError(f"The input should be an integer or an whole numbered torch.Tensor but is {type(ids)}")
ids.squeeze_()
if ids.ndim == 0:
ids.unsqueeze_(-1)
elif ids.ndim != 1:
raise ValueError(f"IDs should have only one dimension, but shape is {ids.shape}")
if length is not None and ids.size(0) != length:
raise ValidationException(f"IDs length is expected to be {length} but got {ids.size(0)}")
return ids.long()
[docs]
def validate_image(images: Union[Image, t.Tensor], length: int = None, dims: Union[int, None] = 4) -> Image:
"""Given one single image or a stacked batch images, validate them and return a torchvision-tensor image.
Args:
images: torch tensor or tv_tensor.Image object
length: The number of items or batch-size the tensor should have.
Default `None` does not validate the length.
dims: Number of dimensions img should have.
Use None to not force any number of dimensions.
Defaults to four dimensions with the image dimensions as ``[B x C x H x W]``.
Returns:
Image: The images as `tv_tensor.Image` object with exactly `dims` dimensions.
Raises:
TypeError: If `images` is not a Tensor or cannot be cast to one.
ValueError: If the dimension of the `images` channels is wrong.
"""
if not isinstance(images, (t.Tensor, t.Tensor, t.Tensor, tvte.Image)) or not (
isinstance(images, t.Tensor) and images.dtype in [t.float32, t.uint8] # iff tensor, check dtype
):
raise TypeError(f"Image should be torch tensor or tv_tensor Image but is {type(images)}.")
if dims is not None:
images = validate_dimensions(tensor=images, dims=dims, length=length)
elif length is not None and len(images) != length:
raise ValidationException(f"Image length is expected to be {length} but got {len(images)}")
if images.ndim < 3:
raise ValueError(f"Image should have at least 3 dimensions. Shape: {images.shape}.")
if images.shape[-3] not in [1, 3, 4]:
raise ValueError(
f"Image should either be RGB, RGBA or depth. But a dimensionality {images.shape[-3]} is unknown."
)
return tvte.Image(images)
[docs]
def validate_images(images: list[Union[Image, t.Tensor]]) -> Images:
"""Given one single or multiple images, validate them and return a torchvision-tensor image.
Args:
images: A list containing :class:`~torch.Tensor` or :class:`.tv_tensor.Image` objects.
Returns:
The images as a list containing :class:`.tv_tensor.Image` objects, each with exactly 4 dimensions.
Raises:
TypeError: If `images` is not a list.
"""
if not isinstance(images, (list, tuple)):
raise TypeError(f"Expected images to be a list, got {type(images)}.")
return [validate_image(img, length=1, dims=4) for img in images]
[docs]
def validate_key_points(
key_points: t.Tensor,
length: int = None,
dims: Union[int, None] = 3,
nof_joints: int = None,
joint_dim: int = None,
) -> t.Tensor:
"""Given a tensor of key points, validate them and return them as torch tensor of the correct shape.
Args:
key_points: One `torch.tensor` or any similarly structured data.
length: The number of items or batch-size the tensor should have.
Default `None` does not validate the length.
dims: The number of dimensions `key_points` should have.
Use `None` to not force any number of dimensions.
Defaults to three dimensions with the key point dimensions as ``[B x J x 2|3]``.
nof_joints: The number of joints ``key_points`` should have (``J``).
Default `None` does not validate the number of joints at all.
joint_dim: The dimensionality the joint dimension should have (``2|3``).
Default `None` does not validate the dimensionality additionally to being two or three.
Returns:
torch.Tensor: The key points as a single `torch.tensor` with exactly the requested number of dimensions like
``[... x nof_joints x joint_dim]``.
Raises:
TypeError: If the key point input is not a Tensor.
ValueError: If the key points or joints have the wrong dimensionality.
"""
if not isinstance(key_points, t.Tensor):
raise TypeError(f"Key points should be torch tensor but is {type(key_points)}.")
if joint_dim is None and not 2 <= key_points.shape[-1] <= 3:
raise ValueError(
f"By default, the key points should be two- or three-dimensional, "
f"but they have a shape of {key_points.shape[-1]}"
)
if joint_dim is not None and key_points.shape[-1] != joint_dim:
raise ValueError(f"The dimensionality of the joints should be {joint_dim} but is {key_points.shape[-1]}.")
if nof_joints is not None and key_points.shape[-2] != nof_joints:
raise ValueError(f"The number of joints should be {nof_joints} but is {key_points.shape[-2]}.")
if dims is not None:
key_points = validate_dimensions(tensor=key_points, dims=dims, length=length)
elif length is not None and len(key_points) != length:
raise ValidationException(f"Key-point length is expected to be {length} but got {len(key_points)}")
return key_points
[docs]
def validate_value(value: any, data: any, validation: str) -> bool:
"""Check a single value against a given predefined validation, possibly given some additional data.
Args:
value: The value to validate.
data: Possibly additional data needed for validation, is ignored otherwise.
validation: The name of the validation to perform.
Returns:
bool: Whether the given `value` is valid given the `validation` and possibly more `data`.
Raises:
KeyError: If the given `validation` does not exist.
"""
if isinstance(validation, type):
return isinstance(value, validation)
if validation is None:
return value is None
if validation not in VALIDATIONS:
raise KeyError(f"Validation '{validation}' does not exist.")
try:
return VALIDATIONS[validation](value, data)
except Exception as e:
raise ValidationException(
f"Could not validate value {value} with data {data} for validation {validation}"
) from e