Source code for dgs.models.combine.static
"""
Implementation of modules that use static alpha values to combine the similarities.
"""
import torch as t
from dgs.models.combine.combine import CombineSimilaritiesModule
from dgs.utils.torchtools import configure_torch_module
from dgs.utils.types import Config, NodePath, Validations
static_alpha_validation: Validations = {
"alpha": [
list,
("longer eq", 1), # there is actually no need for combining a single model
("forall", [float, ("within", (0.0, 1.0))]),
lambda x: abs(sum(x_i for x_i in x) - 1.0) < 1e-6, # has to sum to 1
],
}
[docs]
@configure_torch_module
class StaticAlphaCombine(CombineSimilaritiesModule):
"""
Weight two or more similarity matrices using constant (float) values for alpha.
Params
------
alpha (list[float]):
A list containing the constant weights for the different similarities.
The weights should be probabilities and therefore sum to one and lie within [0..1].
"""
def __init__(self, config: Config, path: NodePath):
super().__init__(config, path)
self.validate_params(static_alpha_validation)
alpha = t.tensor(self.params["alpha"], dtype=self.precision).reshape(-1)
self.register_buffer("alpha_const", alpha)
self.len_alpha: int = len(alpha)
if not t.allclose(a_sum := t.sum(t.abs(alpha)), t.tensor(1.0)): # pragma: no cover # redundant
raise ValueError(f"alpha should sum to 1.0, but got {a_sum:.8f}")
[docs]
def forward(self, *tensors, **_kwargs) -> t.Tensor:
"""Given alpha from the configuration file and args of the same length,
multiply each alpha with each matrix and compute the sum.
Args:
tensors (tuple[torch.Tensor, ...]): A number of similarity tensors.
Should have the same length as `alpha`.
All the tensors should have the same size.
Returns:
The weighted similarity matrix as FloatTensor.
Raises:
ValueError: If the ``tensors`` argument has the wrong shape
TypeError: If the ``tensors`` argument contains an object that is not a `torch.tensor`.
"""
if not isinstance(tensors, tuple):
raise NotImplementedError(
f"Unknown type for tensors, expected tuple of torch.Tensor but got {type(tensors)}"
)
if any(not isinstance(tensor, t.Tensor) for tensor in tensors):
raise TypeError("All the values in args should be tensors.")
if len(tensors) > 1 and any(tensor.shape != tensors[0].shape for tensor in tensors):
raise ValueError("The shapes of every tensor should match.")
if len(tensors) == 1 and self.len_alpha != 1:
# given a single already stacked tensor or a single valued alpha
tensors = tensors[0]
else:
tensors = t.stack(tensors)
if self.len_alpha != 1 and len(tensors) != self.len_alpha:
raise ValueError(
f"The length of the tensors {len(tensors)} should equal the length of alpha {self.len_alpha}"
)
return self.softmax(t.tensordot(self.alpha_const, tensors.float(), dims=1))
[docs]
def terminate(self) -> None: # pragma: no cover
del self.alpha, self.alpha_const, self.len_alpha