dgs.models.combine.dynamic.DynamicAlphaCombine

class dgs.models.combine.dynamic.DynamicAlphaCombine(*args: Any, **kwargs: Any)[source]

Use inputs and multiple alpha modules to weight the similarity matrices.

Notes

The models for computing the per-similarity alpha has to be set manually after the initialization.

Given N inputs to the alpha module (e.g. the visual embeddings of N images, or N different sized inputs like the bbox, pose, and visual embedding of a single crop), compute the alpha weights for the similarity matrices. Then use \(\alpha_i\) to compute the weighted sum of all the similarity matrices \(S_i\) as \(S = \sum_N \alpha_i \cdot S_i\).

Every \(\alpha_i\) can either be a single float value in range \([0, 1]\) or a (float-) tensor of the same shape as \(S_i\) again with values in \([0,1]\).

Params:

alpha_modules (list[NodePath]) – A list containing paths to multiple BaseAlphaModule’s.

Methods

__init__(*args, **kwargs)
configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module

Set compute mode and send model to the device or multiple parallel devices if applicable.

Parameters:
  • module – The torch module instance to configure.

  • train – Whether to train or eval this module, defaults to the value set in the base config.

Returns:

The module on the specified device or in parallel.

forward(*tensors: torch.Tensor, s: State | None = None, **_kwargs) torch.Tensor[source]

The forward call of this module combines an arbitrary number of similarity matrices using an importance weight \(\alpha\).

\(\alpha_i\) describes how important the similarity \(s_i\) is. The sum of all \(\alpha_i\) should be 1 by definition given the last layer is a softmax layer. \(\alpha\) is computed using the respective BaseAlphaModule and the given State.

All tensors should be on the same device and should have the same shape.

Parameters:
  • tensors – A tuple of tensors describing similarities between the detections and tracks. All S similarity matrices of this iterable should have values in range [0,1], be of the same shape [D x T], and be on the same device. If tensors is a single tensor, it should have the shape [S x D x T]. S can be any number of similarity matrices greater than 0, even though only values greater than 1 really make sense.

  • s – A State containing the batched input data for the alpha models. The state should be on the same device as tensors.

Returns:

The weighted similarity matrix as tensor of shape [D x T].

Return type:

torch.Tensor

Raises:
  • ValueError – If alpha or the matrices have invalid shapes.

  • RuntimeError – If one of the tensors is not on the correct device.

  • TypeError – If one of the tensors or one of the alpha inputs is not of type class:torch.Tensor.

terminate() None[source]

Terminate this module and all of its submodules.

If nothing has to be done, just pass. Is used for terminating parallel execution and threads in specific models.

validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None

Given per key validations, validate this module’s parameters.

Throws exceptions on invalid or nonexistent params.

Parameters:
  • attrib_name – name of the attribute to validate, should be “params” and only for base class “config”

  • validations

    Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.

    The value for the validation can have multiple types:
    • A lambda function or other type of callable

    • A string as reference to a predefined validation function with one argument

    • None for existence

    • A tuple with a string as reference to a predefined validation function with one additional argument

    • It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.

    • Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.

Example

This example is an excerpt of the validation for the BaseModule-configuration.

>>> validations = {
    "device": [
            str,
            ("any",
                [
                    ("in", ["cuda", "cpu"]),
                    ("instance", torch.device)
                ]
            )
        ],
        "print_prio": [("in", PRINT_PRIORITY)],
        "callable": (lambda value: value == 1),
    }

And within the class __init__() call:

>>> self.validate_params()
Raises:

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

alpha_models

The model that computes the alpha values from given inputs.

softmax