Source code for dgs.models.alpha.alpha

"""
Base class for modules that predict alpha values given a :class:`State`.
"""

from abc import abstractmethod

import torch as t

from dgs.models.modules.named import NamedModule
from dgs.utils.state import State
from dgs.utils.torchtools import init_model_params, load_pretrained_weights
from dgs.utils.types import Config, NodePath, Validations

alpha_validations: Validations = {
    # optional
    "weight": ["optional", ("file exists", "./weights/")],
}


[docs] class BaseAlphaModule(NamedModule, t.nn.Module): """Given a state as input, compute and return the weight of the alpha gate. Params ------ Optional Params --------------- weight (FilePath): Local or absolute path to the pretrained weights of the model. Can be left empty. """ model: t.nn.Module
[docs] def __init__(self, config: Config, path: NodePath): NamedModule.__init__(self, config=config, path=path) t.nn.Module.__init__(self) self.validate_params(alpha_validations)
@property def module_type(self) -> str: return "alpha" def __call__(self, *args, **kwargs) -> any: """The call function uses :func:`sub_forward` and not :func:`forward` This way, the sequential layers can just be called later on. """ return self.sub_forward(*args, **kwargs)
[docs] @abstractmethod def forward(self, s: State) -> t.Tensor: raise NotImplementedError
[docs] def sub_forward(self, data: t.Tensor) -> t.Tensor: """Function to call when module is called from within a combined alpha module.""" if not hasattr(self, "model") or self.model is None: return data return self.model(data)
[docs] @abstractmethod def get_data(self, s: State) -> any: """Given a state, return the data which is input into the model.""" raise NotImplementedError
[docs] def load_weights(self) -> None: """Load the weights of the model from the given file path. If no weights are given, initialize the model.""" if "weight" in self.params: fp = self.params.get("weight") load_pretrained_weights(model=self.model, weight_path=fp) else: init_model_params(self.model)