dgs.utils.state.State¶
- class dgs.utils.state.State(*args, bbox: torchvision.tv_tensors.BoundingBoxes, validate: bool = True, **kwargs)[source]¶
Class for storing one or multiple samples of data as a ‘State’.
Batch Size¶
Even if the batch size of a State is 1, or even zero (!), the dimension containing the batch size should always be present.
Validation¶
By default, this object validates all new inputs. If you validate elsewhere, use an existing dataset, or you don’t want validation for performance reasons, validation can be turned off.
Additional Values¶
The model might be given additional values during initialization, or at any time using the given setters or the get_item call. Additionally, the object can compute / load further values.
All args and keyword args can be accessed using the States’ properties. Additionally, the underlying dict structure (‘self.data’) can be used, but this does not allow validation nor on the fly computation of additional values. So make sure you do so, if needed.
- keypoints (
torch.Tensor
) The key points for this bounding box as torch tensor in global coordinates.
Shape
[B x J x 2|3]
- filepath (
FilePaths
) The respective filepath(s) of every image.
Length
B
.- person_id (
torch.Tensor
) The person id, only required for training and validation.
Shape
[B]
.- class_id (
torch.Tensor
) The class id, only required for training and validation.
Shape
[B]
.- device (
Device
) The torch device to use. If the device is not given, the device of
bbox
is used as the default.- heatmap (
torch.Tensor
) The heatmap of this bounding box. Currently not used.
Shape
[B x J x h x w]
.- image (
Images
) A list containing the original image(s) as
tv_tensors.Image
object.A list of length
B
containing images of shape[1 x C x H x W]
.- image_crop (
Image
) The content of the original image cropped using the bbox.
Shape
[B x C x h x w]
- joint_weight (
torch.Tensor
) Some kind of joint- or key-point confidence. E.g., the joint confidence score (JCS) of AlphaPose or the joint visibility of
PoseTrack21
.Shape
[B x J x 1]
- keypoints_local (
torch.Tensor
) The key points for this bounding box as torch tensor in local coordinates.
Shape
[B x J x 2|3]
- param bbox:
One single bounding box as torchvision bounding box in global coordinates.
Shape
[B x 4]
- type bbox:
tv_tensors.BoundingBoxes
- param kwargs:
Additional keyword arguments as shown in the ‘Additional Values’ section.
Methods
- __init__(*args, bbox: torchvision.tv_tensors.BoundingBoxes, validate: bool = True, **kwargs) None [source]¶
- cast_joint_weight(dtype: torch.dtype = torch.float32, decimals: int = 0, overwrite: bool = False) torch.Tensor [source]¶
Cast and return the joint weight as tensor.
The weight might have an arbitrary tensor type, this function allows getting one specific variant.
E.g., the visibility might be a boolean value or a model certainty.
Note
Keep in mind, that torch.round() is not really differentiable and does not really allow backpropagation. See https://discuss.pytorch.org/t/torch-round-gradient/28628/4 for more information.
- Parameters:
dtype – The new torch dtype of the tensor. Default torch.float32.
decimals – Number of decimals to round floats to, before type casting. Default 0 (round to integer). When the value of decimals is set to -1 (minus one), there will only be type casting and no rounding at all. But keep in mind that when information is compressed, e.g., when casting from float to bool, simply calling float might not be enough to cast 0.9 to True.
overwrite – Whether self.joint_weight will be overwritten or not.
- Returns:
A type-cast version of the tensor.
If overwrite is True, the returned tensor will be the same as self.joint_weight, including the computational graph.
If overwrite is False, the returned tensor will be a detached and cloned instance of self.joint_weight.
- clean(keys: list[str] | str | None = None) State [source]¶
Given a state, remove one or more keys to free up memory.
- Parameters:
keys – The name of the keys to remove. If a key is not present in self.data, the key is ignored. If keys is None, the default keys
["image", "image_crop"]
are removed. If keys is “all”, all keys that contain tensors are removed except for the bounding box.
- clear() None. Remove all items from D. ¶
- copy() State [source]¶
Obtain a copy of this state. No validation, either it was done already or it is not wanted.
- draw(save_path: str, show_kp: bool = True, show_skeleton: bool = True, show_bbox: bool = True, **kwargs) None ¶
Draw the bboxes and key points of this State on the first image.
This method uses torchvision to draw the information of this State on the first image in
self.image
. The drawing of key points, the respective connectivity / skeleton, and the bounding boxes can be disabled. Additionally, many keyword arguments can be set, see the docstring forshow_image_with_additional()
for more information.Notes
In the case that
B
is0
, this method can still draw an empty image if an image or filepath is set. This works iffvalidation
isFalse
. ThePoseTrack21_Image
dataset uses this trick to draw the images that aren’t annotated.
- draw_individually(save_path: str | tuple[str, ...], **kwargs) None ¶
Split the state and draw the detections of the image(s) independently.
- Parameters:
save_path – Directory path to save the images to.
- extract(idx: int) State [source]¶
Extract the i-th State from a batch B of states.
- Parameters:
idx – The index of the State to retrieve. It is expected that \(-B \lte idx \lt B\).
- Returns:
The extracted State.
- classmethod fromkeys(iterable, value=None)¶
- get(k[, d]) D[k] if k in D, else d. d defaults to None. ¶
- items() a set-like object providing a view on D's items ¶
- keypoints_and_weights_from_paths(paths: tuple[str, ...], save_weights: bool = True) torch.Tensor [source]¶
Given a tuple of paths, load the (local) key-points and weights from these paths. Does change
self.joint_weight
, but does not changeself.keypoints
orself.keypoints_local
respectively.- Parameters:
paths – A tuple of paths to the .pt files containing the key-points and weights.
save_weights – Whether to save the weights if they were provided.
- Returns:
The (local) key-points as
Tensor
.- Raises:
ValueError – If the number of paths does not match the batch size.
FileExistsError – If one of the paths does not exist.
- keys() a set-like object providing a view on D's keys ¶
- load_image(store: bool = False) list[torchvision.tv_tensors.Image | torch.Tensor] [source]¶
Load the images using the filepaths of this object. Does nothing if the images are already present.
- load_image_crop(store: bool = False, **kwargs) torchvision.tv_tensors.Image | torch.Tensor [source]¶
Load the images crops using the crop_paths of this object. Does nothing if the crops are already present.
- Keyword Arguments:
crop_size – The size of the image crops. Default
DEF_VAL.images.crop_size
.
- pop(k[, d]) v, remove specified key and return the corresponding value. ¶
If key is not found, d is returned if given, otherwise KeyError is raised.
- popitem() (k, v), remove and return some (key, value) pair ¶
as a 2-tuple; but raise KeyError if D is empty.
- setdefault(k[, d]) D.get(k,d), also set D[k]=d if k not in D ¶
- split() list[State] [source]¶
Given a batched State object, split it into a list of single State objects.
- update([E, ]**F) None. Update D from mapping/iterable E and F. ¶
If E present and has a .keys() method, does: for k in E: D[k] = E[k] If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v In either case, this is followed by: for k, v in F.items(): D[k] = v
- values() an object providing a view on D's values ¶
Attributes
Get the batch size.
Get the number of joints in every skeleton.
Get this States bounding-box.
Get the bounding box coordinates in relation to the width and height of the full image.
Get the class-ID of the bounding-boxes.
Get the path to the image crops.
Get the device of this State.
If data filepath has a single entry, return the filepath as a string, otherwise return the list.
Get the original image(s) of this State.
Get the image crop(s) of this State.
Get the dimensionality of the joints.
Get the weight of the joints.
Get the key-points.
Get the local key-points.
Get the global key points in coordinates relative to the full image size.
Get the ID of the respective person shown on the bounding-box.
Get the ID of the tracks associated with the respective bounding-boxes.
Whether to validate the inputs into this state.
All the data in this state as a dict.
- keypoints (