Source code for dgs.models.alpha.alpha
"""
Base class for modules that predict alpha values given a :class:`.State`.
"""
from abc import abstractmethod
import torch as t
from dgs.models.modules.named import NamedModule
from dgs.utils.state import State
from dgs.utils.torchtools import init_model_params, load_pretrained_weights
from dgs.utils.types import Config, NodePath, Validations
alpha_validations: Validations = {
    # optional
    "weight": ["optional", ("file exists", "./weights/")],
}
[docs]
class BaseAlphaModule(NamedModule, t.nn.Module):
    """Given a state as input, compute and return the weight of the alpha gate.
    Optional Params
    ---------------
    weight (FilePath):
        Local or absolute path to the pretrained weights of the model.
        Can be left empty.
    """
    model: t.nn.Module
[docs]
    def __init__(self, config: Config, path: NodePath):
        NamedModule.__init__(self, config=config, path=path)
        t.nn.Module.__init__(self)
        self.validate_params(alpha_validations) 
    @property
    def module_type(self) -> str:
        return "alpha"
    def __call__(self, *args, **kwargs) -> any:
        """The call function uses :func:`sub_forward` and not :func:`forward`
        This way, the sequential layers can just be called later on.
        """
        return self.sub_forward(*args, **kwargs)
[docs]
    @abstractmethod
    def forward(self, s: State) -> t.Tensor:
        raise NotImplementedError 
[docs]
    def sub_forward(self, data: t.Tensor) -> t.Tensor:
        """Function to call when module is called from within a combined alpha module."""
        if not hasattr(self, "model") or self.model is None:
            return data
        return self.model(data) 
[docs]
    @abstractmethod
    def get_data(self, s: State) -> any:
        """Given a state, return the data which is input into the model."""
        raise NotImplementedError 
[docs]
    def load_weights(self) -> None:
        """Load the weights of the model from the given file path. If no weights are given, initialize the model."""
        if "weight" in self.params:
            fp = self.params.get("weight")
            load_pretrained_weights(model=self.model, weight_path=fp)
        else:
            init_model_params(self.model)