dgs.models.similarity.pose_similarity.ObjectKeypointSimilarity

class dgs.models.similarity.pose_similarity.ObjectKeypointSimilarity(*args: Any, **kwargs: Any)[source]

Compute the object key-point similarity (OKS) between two batches of poses / States.

Params

format (str):

The key point format, e.g., ‘coco’, ‘coco-whole’, … has to be in OKS_SIGMAS.keys().

Optional Params

keypoint_dim (int, optional):

The dimensionality of the key points. So whether 2D or 3D is expected. Default DEF_VAL.similarity.oks.kp_dim.

__init__(config: dict[str, any], path: list[str])[source]

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

forward(data, target)

Compute the object key-point similarity between a ground truth label and detected key points.

get_area(ds)

Given a State, compute the area of the bounding box.

get_data(ds)

Given a State, compute the detected / predicted key points with shape [B1 x J x 2|3] and the areas of the respective ground-truth bounding-boxes with shape [B1].

get_target(ds)

Given a State obtain the ground truth key points and the key-point-visibility.

get_train_data(ds)

A custom function to get special data for training purposes.

terminate()

Terminate this module and all of its submodules.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

softmax