dgs.utils.state.State.cast_joint_weight¶
- State.cast_joint_weight(dtype: torch.dtype = torch.float32, decimals: int = 0, overwrite: bool = False) torch.Tensor [source]¶
Cast and return the joint weight as tensor.
The weight might have an arbitrary tensor type, this function allows getting one specific variant.
E.g., the visibility might be a boolean value or a model certainty.
Note
Keep in mind, that torch.round() is not really differentiable and does not really allow backpropagation. See https://discuss.pytorch.org/t/torch-round-gradient/28628/4 for more information.
- Parameters:
dtype – The new torch dtype of the tensor. Default torch.float32.
decimals – Number of decimals to round floats to, before type casting. Default 0 (round to integer). When the value of decimals is set to -1 (minus one), there will only be type casting and no rounding at all. But keep in mind that when information is compressed, e.g., when casting from float to bool, simply calling float might not be enough to cast 0.9 to True.
overwrite – Whether self.joint_weight will be overwritten or not.
- Returns:
A type-cast version of the tensor.
If overwrite is True, the returned tensor will be the same as self.joint_weight, including the computational graph.
If overwrite is False, the returned tensor will be a detached and cloned instance of self.joint_weight.