dgs.utils.validation.validate_heatmaps

dgs.utils.validation.validate_heatmaps(heatmaps: torch.Tensor | torchvision.tv_tensors.Mask | torch.Tensor, length: int | None = None, dims: int | None = 4, nof_joints: int | None = None) torchvision.tv_tensors.Mask | torch.Tensor[source]

Validate a given tensor of heatmaps, whether it has the correct format and shape.

Parameters:
  • heatmaps – tensor-like object

  • length – The number of items or batch-size the tensor should have. Default None does not validate the length.

  • dims – Number of dimensions heatmaps should have. Use None to not force any number of dimensions. Defaults to four dimensions with the heatmap dimensions as [B x J x w x h].

  • nof_joints – The number of joints the heatmap should have (J). Default None does not validate the number of joints at all.

Returns:

The validated heatmaps as tensor with the correct type.

Return type:

Heatmap

Raises:
  • TypeError – If the heatmaps input is not a Tensor or cannot be cast to one.

  • ValueError – If the heatmaps are neither two- nor three-dimensional.