Source code for dgs.models.alpha.combined
"""
An alpha module combining other alpha modules.
"""
import torch as t
from dgs.models.alpha.alpha import BaseAlphaModule
from dgs.models.loader import module_loader
from dgs.utils.config import insert_into_config
from dgs.utils.state import get_ds_data_getter, State
from dgs.utils.types import Config, DataGetter, NodePath, Validations
sequential_combined_validations: Validations = {
"paths": [list, ("longer eq", 1), ("forall", ("any", [str, dict, ("all", [list, ("forall", str)])]))],
"name": [str],
# optional
}
[docs]
class SequentialCombinedAlpha(BaseAlphaModule):
"""An alpha module sequentially combining multiple other :class:`BaseAlphaModule` s.
First load the data from the :class:`State` using `name`.
Then insert the resulting :class:`.Tensor` into the forward call of the respective next model.
Params
------
paths (list[str, NodePath]):
A list containing either :class:`NodePath` s pointing to the configuration of a :class:`~BaseAlphaModule`
or the name of a function from `torch.nn` (e.g. 'Flatten', 'ReLU', ...).
All submodules do not need to have the "name" property,
because all other layers will use the result returned by the previous layer.
name (str):
The name of the attribute or getter function used to retrieve the input data from the state.
Optional Params
---------------
"""
model: t.nn.Sequential
[docs]
def __init__(self, config: Config, path: NodePath):
super().__init__(config=config, path=path)
self.validate_params(sequential_combined_validations)
self.data_getter: DataGetter = get_ds_data_getter(self.params["name"])
# get all modules
modules: list[BaseAlphaModule] = []
for sub_path in self.params["paths"]:
if isinstance(sub_path, list):
# set name of all the submodules to empty string
# can be done for the first module too, because the data_getter is already set
new_cfg = insert_into_config(path=sub_path, value={"name": ""}, original=config, copy=True)
modules.append(module_loader(config=new_cfg, module_type="alpha", key=sub_path))
elif isinstance(sub_path, str):
try:
modules.append(getattr(t.nn, sub_path)())
except AttributeError as e:
raise AttributeError(f"Tried to load non-existent torch module '{sub_path}'.") from e
elif isinstance(sub_path, dict):
if len(sub_path) > 1:
raise ValueError(f"Expected submodule config to be a single dict, got: {sub_path}")
k, v = list(sub_path.keys())[0], list(sub_path.values())[0]
if not isinstance(v, dict):
raise NotImplementedError(f"Expected submodule parameters to be a dict, got: {v}")
try:
modules.append(getattr(t.nn, k)(**v))
except AttributeError as e:
raise AttributeError(f"Tried to load non-existent torch module '{sub_path}'.") from e
else:
raise NotImplementedError(f"Expected list or str, got: {sub_path}")
self.register_module(name="model", module=self.configure_torch_module(t.nn.Sequential(*modules)))
self.load_weights()
[docs]
def forward(self, s: State) -> t.Tensor:
"""Forward call for sequential model calls the next layer with the output of the previous layer.
Works for :class:`BaseAlphaModule` s and any arbitrary model from `torch.nn`.
"""
inpt = self.get_data(s)
for sub_models in self.model:
inpt = sub_models(inpt)
return inpt
[docs]
def get_data(self, s: State) -> tuple[any, ...]:
return self.data_getter(s)