Source code for dgs.models.combine.combine
"""
Implementation of modules that combine two or more similarity matrices.
Obtain similarity matrices as a result of one or multiple
:class:``~dgs.models.similarity.similarity.SimilarityModule`` s.
"""
from abc import abstractmethod
import torch as t
from torch import nn
from dgs.models.module import enable_keyboard_interrupt
from dgs.models.modules.named import NamedModule
from dgs.utils.config import DEF_VAL
from dgs.utils.types import Config, NodePath, Validations
combine_validations: Validations = {
# optional
"softmax": ["optional", bool],
}
[docs]
class CombineSimilaritiesModule(NamedModule, nn.Module):
"""Given two or more similarity matrices, combine them into a single similarity matrix.
Params
------
Optional Params
---------------
softmax (bool, optional):
Whether to compute the softmax along the last dimension of the resulting weighted similarity matrix.
Default ``DEF_VAL.combine.softmax``.
"""
softmax: nn.Sequential
[docs]
def __init__(self, config: Config, path: NodePath):
NamedModule.__init__(self, config=config, path=path)
nn.Module.__init__(self)
self.validate_params(combine_validations)
softmax = nn.Sequential()
if self.params.get("softmax", DEF_VAL["combine"]["softmax"]):
softmax.append(nn.Softmax(dim=-1))
self.register_module(name="softmax", module=self.configure_torch_module(softmax))
@property
def module_type(self) -> str:
return "combine"
def __call__(self, *args, **kwargs) -> any: # pragma: no cover
return self.forward(*args, **kwargs)
[docs]
@abstractmethod
@enable_keyboard_interrupt
def forward(self, *args, **kwargs) -> t.Tensor: # pragma: no cover
raise NotImplementedError