dgs.models.combine.dynamic.DynamicAlphaCombine¶
- class dgs.models.combine.dynamic.DynamicAlphaCombine(*args: Any, **kwargs: Any)[source]¶
Use inputs and multiple alpha modules to weight the similarity matrices.
Notes
The models for computing the per-similarity alpha has to be set manually after the initialization.
Given
N
inputs to the alpha module (e.g. the visual embeddings ofN
images, orN
different sized inputs like the bbox, pose, and visual embedding of a single crop), compute the alpha weights for the similarity matrices. Then use \(\alpha_i\) to compute the weighted sum of all the similarity matrices \(S_i\) as \(S = \sum_N \alpha_i \cdot S_i\).Every \(\alpha_i\) can either be a single float value in range \([0, 1]\) or a (float-) tensor of the same shape as \(S_i\) again with values in \([0,1]\).
- Params:
alpha_modules (list[NodePath]) – A list containing paths to multiple
BaseAlphaModule
’s.
Methods
- __init__(*args, **kwargs)¶
- configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module ¶
Set compute mode and send model to the device or multiple parallel devices if applicable.
- Parameters:
module – The torch module instance to configure.
train – Whether to train or eval this module, defaults to the value set in the base config.
- Returns:
The module on the specified device or in parallel.
- forward(*tensors: torch.Tensor, s: State | None = None, **_kwargs) torch.Tensor [source]¶
The forward call of this module combines an arbitrary number of similarity matrices using an importance weight \(\alpha\).
\(\alpha_i\) describes how important the similarity \(s_i\) is. The sum of all \(\alpha_i\) should be 1 by definition given the last layer is a softmax layer. \(\alpha\) is computed using the respective
BaseAlphaModule
and the givenState
.All tensors should be on the same device and should have the same shape.
- Parameters:
tensors – A tuple of tensors describing similarities between the detections and tracks. All
S
similarity matrices of this iterable should have values in range[0,1]
, be of the same shape[D x T]
, and be on the same device. Iftensors
is a single tensor, it should have the shape[S x D x T]
.S
can be any number of similarity matrices greater than 0, even though only values greater than 1 really make sense.s – A
State
containing the batched input data for the alpha models. The state should be on the same device astensors
.
- Returns:
The weighted similarity matrix as tensor of shape
[D x T]
.- Return type:
torch.Tensor
- Raises:
ValueError – If alpha or the matrices have invalid shapes.
RuntimeError – If one of the tensors is not on the correct device.
TypeError – If one of the tensors or one of the alpha inputs is not of type class:torch.Tensor.
- terminate() None [source]¶
Terminate this module and all of its submodules.
If nothing has to be done, just pass. Is used for terminating parallel execution and threads in specific models.
- validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None ¶
Given per key validations, validate this module’s parameters.
Throws exceptions on invalid or nonexistent params.
- Parameters:
attrib_name – name of the attribute to validate, should be “params” and only for base class “config”
validations –
Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.
- The value for the validation can have multiple types:
A lambda function or other type of callable
A string as reference to a predefined validation function with one argument
None for existence
A tuple with a string as reference to a predefined validation function with one additional argument
It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.
Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.
Example
This example is an excerpt of the validation for the BaseModule-configuration.
>>> validations = { "device": [ str, ("any", [ ("in", ["cuda", "cpu"]), ("instance", torch.device) ] ) ], "print_prio": [("in", PRINT_PRIORITY)], "callable": (lambda value: value == 1), }
And within the class
__init__()
call:>>> self.validate_params()
- Raises:
InvalidParameterException – If one of the parameters is invalid.
ValidationException – If the validation list is invalid or contains an unknown validation.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.
The model that computes the alpha values from given inputs.