dgs.models.similarity.pose_similarity.ObjectKeypointSimilarity¶
- class dgs.models.similarity.pose_similarity.ObjectKeypointSimilarity(*args: Any, **kwargs: Any)[source]¶
Compute the object key-point similarity (OKS) between two batches of poses / States.
- Params:
format (str) – The key point format, e.g., ‘coco’, ‘coco-whole’, … has to be in OKS_SIGMAS.keys().
- Optional Params:
keypoint_dim (int, optional) – The dimensionality of the key points. So whether 2D or 3D is expected. Default
DEF_VAL.similarity.oks.kp_dim
.
Methods
- configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module ¶
Set compute mode and send model to the device or multiple parallel devices if applicable.
- Parameters:
module – The torch module instance to configure.
train – Whether to train or eval this module, defaults to the value set in the base config.
- Returns:
The module on the specified device or in parallel.
- forward(data: State, target: State) torch.Tensor [source]¶
Compute the object key-point similarity between a ground truth label and detected key points.
There has to be one key point of the label for any detection. (Batch sizes have to match)
Notes
Compute the key-point similarity \(\mathtt{ks}_i\) for every joint between every detection and the respective ground truth annotation.
\[\mathtt{ks}_i = \exp(-\dfrac{d_i^2}{2s^2k_i^2})\]The key-point similarity \(\mathtt{OKS}\) is then computed as the weighted sum using the key-point visibilities as weights.
\[\mathtt{OKS} = \dfrac{\sum_i \mathtt{ks}_i \cdot \delta (v_i > 0)}{\sum_i \delta (v_i > 0)}\]\(d_i\) the euclidean distance between the ground truth and detected key point
\(k_i\) the constant for the key point, computed as \(k=2\cdot\sigma\)
- \(v_i\) the visibility of the key point, with
0 = unlabeled
1 = labeled but not visible
2 = labeled but visible
\(s\) the scale of the ground truth object, with \(s^2\) becoming the object’s segmented area
- Parameters:
- Returns:
A (Float)Tensor of shape
[N x T]
with values in[0..1]
. If requested, the softmax is computed along the -1 dimension, resulting in probability distributions for each value of the input data.
- get_data(ds: State) torch.Tensor [source]¶
Given a
State
, compute the detected / predicted key points with shape[B1 x J x 2|3]
and the areas of the respective ground-truth bounding-boxes with shape[B1]
.
- get_target(ds: State) tuple[torch.Tensor, torch.Tensor] [source]¶
Given a
State
obtain the ground truth key points and the key-point-visibility. Both are tensors, the key points are a FloatTensor of shape[B2 x J x 2|3]
and the visibility is a BoolTensor of shape[B2 x J]
.
- get_train_data(ds: State) any ¶
A custom function to get special data for training purposes. If “train_key” is not given, uses the regular
get_data()
function of this module.
- terminate() None ¶
Terminate this module and all of its submodules.
If nothing has to be done, just pass. Is used for terminating parallel execution and threads in specific models.
- validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None ¶
Given per key validations, validate this module’s parameters.
Throws exceptions on invalid or nonexistent params.
- Parameters:
attrib_name – name of the attribute to validate, should be “params” and only for base class “config”
validations –
Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.
- The value for the validation can have multiple types:
A lambda function or other type of callable
A string as reference to a predefined validation function with one argument
None for existence
A tuple with a string as reference to a predefined validation function with one additional argument
It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.
Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.
Example
This example is an excerpt of the validation for the BaseModule-configuration.
>>> validations = { "device": [ str, ("any", [ ("in", ["cuda", "cpu"]), ("instance", torch.device) ] ) ], "print_prio": [("in", PRINT_PRIORITY)], "callable": (lambda value: value == 1), }
And within the class
__init__()
call:>>> self.validate_params()
- Raises:
InvalidParameterException – If one of the parameters is invalid.
ValidationException – If the validation list is invalid or contains an unknown validation.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.