Source code for dgs.models.alpha.fully_connected

"""
A class for alpha modules with one or multiple fully connected layers.
"""

import torch as t

from dgs.models.alpha.alpha import BaseAlphaModule
from dgs.utils.config import DEF_VAL
from dgs.utils.nn import fc_linear
from dgs.utils.state import get_ds_data_getter, State
from dgs.utils.types import Config, DataGetter, NodePath, Validations

fc_validations: Validations = {
    "name": [str],
    "hidden_layers": [list, ("forall", int)],
    "bias": [("any", [bool, ("all", [list, ("forall", bool)])])],
    # optional
    "act_func": [
        "optional",
        ("any", [("isinstance", str), "None", ("isinstance", t.nn.Module), ("all", [list, ("forall", str)])]),
    ],
}


[docs] class FullyConnectedAlpha(BaseAlphaModule): """An alpha module consisting of ``L - 1`` fully connected layers. Each layer can have a custom bias and activation function. Params ------ name (str): The name of the attribute or getter function used to retrieve the input data from the state. hidden_layers (list[int]): The sizes of each of the hidden layers, including the size of the data. Has length ``L``. bias (Union[bool, list[bool]): Whether each of the respective layers should have values for the bias. Has length ``L - 1``. Optional Params --------------- act_func (list[Union[str, None, nn.Module]]): A list containing the activation functions placed after each of the layers. Has length ``L - 1``. Default ``DEF_VAL.alpha.act_func``. """
[docs] def __init__(self, config: Config, path: NodePath): super().__init__(config=config, path=path) self.validate_params(fc_validations) self.data_getter: DataGetter = get_ds_data_getter(self.params["name"]) model = fc_linear( hidden_layers=self.params["hidden_layers"], bias=self.params["bias"], act_func=self.params.get("act_func", DEF_VAL["alpha"]["act_func"]), ) self.register_module(name="model", module=self.configure_torch_module(model)) self.load_weights()
[docs] def forward(self, s: State) -> t.Tensor: return self.model(self.get_data(s))
[docs] def get_data(self, s: State) -> t.Tensor: return self.data_getter(s)