dgs.models.alpha.combined.SequentialCombinedAlpha

class dgs.models.alpha.combined.SequentialCombinedAlpha(*args: Any, **kwargs: Any)[source]

An alpha module sequentially combining multiple other BaseAlphaModule’s. First load the data from the State using name. Then insert the resulting Tensor into the forward call of the respective next model.

Params:
  • paths (list[str, NodePath]) – A list containing either NodePath’s pointing to the configuration of a BaseAlphaModule or the name of a function from torch.nn (e.g. ‘Flatten’, ‘ReLU’, …). All submodules do not need to have the “name” property, because all other layers will use the result returned by the previous layer.

  • name (str) – The name of the attribute or getter function used to retrieve the input data from the state.

Important Inherited Params:

weight (FilePath) – Local or absolute path to the pretrained weights of the model. Can be left empty.

Methods

__init__(config: dict[str, any], path: list[str])[source]
configure_torch_module(module: torch.nn.Module, train: bool | None = None) torch.nn.Module

Set compute mode and send model to the device or multiple parallel devices if applicable.

Parameters:
  • module – The torch module instance to configure.

  • train – Whether to train or eval this module, defaults to the value set in the base config.

Returns:

The module on the specified device or in parallel.

forward(s: State) torch.Tensor[source]

Forward call for sequential model calls the next layer with the output of the previous layer. Works for BaseAlphaModule’s and any arbitrary model from torch.nn.

get_data(s: State) tuple[any, ...][source]

Given a state, return the data which is input into the model.

load_weights() None

Load the weights of the model from the given file path. If no weights are given, initialize the model.

sub_forward(data: torch.Tensor) torch.Tensor

Function to call when module is called from within a combined alpha module.

terminate() None

Terminate this module and all of its submodules.

If nothing has to be done, just pass. Is used for terminating parallel execution and threads in specific models.

validate_params(validations: dict[str, list[str | type | tuple[str, any] | Callable[[any, any], bool]]], attrib_name: str = 'params') None

Given per key validations, validate this module’s parameters.

Throws exceptions on invalid or nonexistent params.

Parameters:
  • attrib_name – name of the attribute to validate, should be “params” and only for base class “config”

  • validations

    Dictionary with the name of the parameter as key and a list of validations as value. Every validation in this list has to be true for the validation to be successful.

    The value for the validation can have multiple types:
    • A lambda function or other type of callable

    • A string as reference to a predefined validation function with one argument

    • None for existence

    • A tuple with a string as reference to a predefined validation function with one additional argument

    • It is possible to write nested validations, but then every nested validation has to be a tuple, or a tuple of tuples. For convenience, there are implementations for “any”, “all”, “not”, “eq”, “neq”, and “xor”. Those can have data which is a tuple containing other tuples or validations, or a single validation.

    • Lists and other iterables can be validated using “forall” running the given validations for every item in the input. A single validation or a tuple of (nested) validations is accepted as data.

Example

This example is an excerpt of the validation for the BaseModule-configuration.

>>> validations = {
    "device": [
            str,
            ("any",
                [
                    ("in", ["cuda", "cpu"]),
                    ("instance", torch.device)
                ]
            )
        ],
        "print_prio": [("in", PRINT_PRIORITY)],
        "callable": (lambda value: value == 1),
    }

And within the class __init__() call:

>>> self.validate_params()
Raises:

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

model