dgs.utils.validation.validate_dimensions

dgs.utils.validation.validate_dimensions(tensor: torch.Tensor, dims: int, *_, length: int | None = None) torch.Tensor[source]

Given a tensor, make sure he has the correct number of dimensions.

Parameters:
  • tensor – Any torch.tensor or other object that can be converted to one.

  • dims – Number of dimensions the tensor should have.

  • length – The number of items or batch-size the tensor should have. Default None does not validate the length.

Returns:

A torch.tensor with the correct number of dimensions.

Raises:
  • TypeError – If the tensor input is not a torch.tensor or cannot be cast to one.

  • ValueError – If the length of the tensor is bigger than dims and cannot be unsqueezed.