dgs.utils.state.State

class dgs.utils.state.State(*args, bbox: torchvision.tv_tensors.BoundingBoxes, validate: bool = True, **kwargs)[source]

Class for storing one or multiple samples of data as a ‘State’.

Batch Size

Even if the batch size of a State is 1, or even zero (!), the dimension containing the batch size should always be present.

Validation

By default, this object validates all new inputs. If you validate elsewhere, use an existing dataset, or you don’t want validation for performance reasons, validation can be turned off.

Additional Values

The model might be given additional values during initialization, or at any time using the given setters or the get_item call. Additionally, the object can compute / load further values.

All args and keyword args can be accessed using the States’ properties. Additionally, the underlying dict structure (‘self.data’) can be used, but this does not allow validation nor on the fly computation of additional values. So make sure you do so, if needed.

keypoints (torch.Tensor)

The key points for this bounding box as torch tensor in global coordinates.

Shape [B x J x 2|3]

filepath (FilePaths)

The respective filepath(s) of every image.

Length B.

person_id (torch.Tensor)

The person id, only required for training and validation.

Shape [B].

class_id (torch.Tensor)

The class id, only required for training and validation.

Shape [B].

device (Device)

The torch device to use. If the device is not given, the device of bbox is used as the default.

heatmap (torch.Tensor)

The heatmap of this bounding box. Currently not used.

Shape [B x J x h x w].

image (Images)

A list containing the original image(s) as tv_tensors.Image object.

A list of length B containing images of shape [1 x C x H x W].

image_crop (Image)

The content of the original image cropped using the bbox.

Shape [B x C x h x w]

joint_weight (torch.Tensor)

Some kind of joint- or key-point confidence. E.g., the joint confidence score (JCS) of AlphaPose or the joint visibility of PoseTrack21.

Shape [B x J x 1]

keypoints_local (torch.Tensor)

The key points for this bounding box as torch tensor in local coordinates.

Shape [B x J x 2|3]

param bbox:

One single bounding box as torchvision bounding box in global coordinates.

Shape [B x 4]

type bbox:

tv_tensors.BoundingBoxes

param kwargs:

Additional keyword arguments as shown in the ‘Additional Values’ section.

Methods

__init__(*args, bbox: torchvision.tv_tensors.BoundingBoxes, validate: bool = True, **kwargs) None[source]
cast_joint_weight(dtype: torch.dtype = torch.float32, decimals: int = 0, overwrite: bool = False) torch.Tensor[source]

Cast and return the joint weight as tensor.

The weight might have an arbitrary tensor type, this function allows getting one specific variant.

E.g., the visibility might be a boolean value or a model certainty.

Note

Keep in mind, that torch.round() is not really differentiable and does not really allow backpropagation. See https://discuss.pytorch.org/t/torch-round-gradient/28628/4 for more information.

Parameters:
  • dtype – The new torch dtype of the tensor. Default torch.float32.

  • decimals – Number of decimals to round floats to, before type casting. Default 0 (round to integer). When the value of decimals is set to -1 (minus one), there will only be type casting and no rounding at all. But keep in mind that when information is compressed, e.g., when casting from float to bool, simply calling float might not be enough to cast 0.9 to True.

  • overwrite – Whether self.joint_weight will be overwritten or not.

Returns:

A type-cast version of the tensor.

If overwrite is True, the returned tensor will be the same as self.joint_weight, including the computational graph.

If overwrite is False, the returned tensor will be a detached and cloned instance of self.joint_weight.

clean(keys: list[str] | str | None = None) State[source]

Given a state, remove one or more keys to free up memory.

Parameters:

keys – The name of the keys to remove. If a key is not present in self.data, the key is ignored. If keys is None, the default keys ["image", "image_crop"] are removed. If keys is “all”, all keys that contain tensors are removed except for the bounding box.

clear() None.  Remove all items from D.
copy() State[source]

Obtain a copy of this state. No validation, either it was done already or it is not wanted.

draw(save_path: str, show_kp: bool = True, show_skeleton: bool = True, show_bbox: bool = True, **kwargs) None

Draw the bboxes and key points of this State on the first image.

This method uses torchvision to draw the information of this State on the first image in self.image. The drawing of key points, the respective connectivity / skeleton, and the bounding boxes can be disabled. Additionally, many keyword arguments can be set, see the docstring for show_image_with_additional() for more information.

Notes

In the case that B is 0, this method can still draw an empty image if an image or filepath is set. This works iff validation is False. The PoseTrack21_Image dataset uses this trick to draw the images that aren’t annotated.

draw_individually(save_path: str | tuple[str, ...], **kwargs) None

Split the state and draw the detections of the image(s) independently.

Parameters:

save_path – Directory path to save the images to.

extract(idx: int) State[source]

Extract the i-th State from a batch B of states.

Parameters:

idx – The index of the State to retrieve. It is expected that \(-B \lte idx \lt B\).

Returns:

The extracted State.

classmethod fromkeys(iterable, value=None)
get(k[, d]) D[k] if k in D, else d.  d defaults to None.
items() a set-like object providing a view on D's items
keypoints_and_weights_from_paths(paths: tuple[str, ...], save_weights: bool = True) torch.Tensor[source]

Given a tuple of paths, load the (local) key-points and weights from these paths. Does change self.joint_weight, but does not change self.keypoints or self.keypoints_local respectively.

Parameters:
  • paths – A tuple of paths to the .pt files containing the key-points and weights.

  • save_weights – Whether to save the weights if they were provided.

Returns:

The (local) key-points as Tensor.

Raises:
  • ValueError – If the number of paths does not match the batch size.

  • FileExistsError – If one of the paths does not exist.

keys() a set-like object providing a view on D's keys
load_image(store: bool = False) list[torchvision.tv_tensors.Image | torch.Tensor][source]

Load the images using the filepaths of this object. Does nothing if the images are already present.

load_image_crop(store: bool = False, **kwargs) torchvision.tv_tensors.Image | torch.Tensor[source]

Load the images crops using the crop_paths of this object. Does nothing if the crops are already present.

Keyword Arguments:

crop_size – The size of the image crops. Default DEF_VAL.images.crop_size.

pop(k[, d]) v, remove specified key and return the corresponding value.

If key is not found, d is returned if given, otherwise KeyError is raised.

popitem() (k, v), remove and return some (key, value) pair

as a 2-tuple; but raise KeyError if D is empty.

setdefault(k[, d]) D.get(k,d), also set D[k]=d if k not in D
split() list[State][source]

Given a batched State object, split it into a list of single State objects.

to(*args, **kwargs) State[source]

Override torch.Tensor.to() for the whole object.

update([E, ]**F) None.  Update D from mapping/iterable E and F.

If E present and has a .keys() method, does: for k in E: D[k] = E[k] If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v In either case, this is followed by: for k, v in F.items(): D[k] = v

values() an object providing a view on D's values

Attributes

B

Get the batch size.

J

Get the number of joints in every skeleton.

bbox

Get this States bounding-box.

bbox_relative

Get the bounding box coordinates in relation to the width and height of the full image.

class_id

Get the class-ID of the bounding-boxes.

crop_path

Get the path to the image crops.

device

Get the device of this State.

filepath

If data filepath has a single entry, return the filepath as a string, otherwise return the list.

image

Get the original image(s) of this State.

image_crop

Get the image crop(s) of this State.

joint_dim

Get the dimensionality of the joints.

joint_weight

Get the weight of the joints.

keypoints

Get the key-points.

keypoints_local

Get the local key-points.

keypoints_relative

Get the global key points in coordinates relative to the full image size.

person_id

Get the ID of the respective person shown on the bounding-box.

track_id

Get the ID of the tracks associated with the respective bounding-boxes.

validate

Whether to validate the inputs into this state.

data

All the data in this state as a dict.