dgs.models.combine.dynamic.AlphaCombine

class dgs.models.combine.dynamic.AlphaCombine(*args: Any, **kwargs: Any)[source]

Compute a weighted sum of multiple given similarity matrices and given alpha weights.

More precisely, given a similarity matrix / tensor with shape [N x T], and one alpha value per similarity, compute the weighted sum of all the similarity matrices. The module will make sure, that \(\sum_N \alpha_i = 1\).

Params

Optional Params

__init__(*args, **kwargs)

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

forward(*tensors[, alpha])

The forward call of this module combines an arbitrary number of similarity matrices using an importance weight \(\alpha\).

terminate()

Terminate this module and all of its submodules.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

softmax