dgs.models.alpha.alpha.BaseAlphaModule

class dgs.models.alpha.alpha.BaseAlphaModule(*args: Any, **kwargs: Any)[source]

Given a state as input, compute and return the weight of the alpha gate.

Params

Optional Params

weight (FilePath):

Local or absolute path to the pretrained weights of the model. Can be left empty.

__init__(config: dict[str, any], path: list[str])[source]

Methods

configure_torch_module(module[, train])

Set compute mode and send model to the device or multiple parallel devices if applicable.

forward(s)

get_data(s)

Given a state, return the data which is input into the model.

load_weights()

Load the weights of the model from the given file path.

sub_forward(data)

Function to call when module is called from within a combined alpha module.

terminate()

Terminate this module and all of its submodules.

validate_params(validations[, attrib_name])

Given per key validations, validate this module's parameters.

Attributes

device

Get the device of this module.

is_training

Get whether this module is set to training-mode.

module_name

Get the name of the module.

module_type

name

Get the name of the module.

name_safe

Get the escaped name of the module usable in filepaths by replacing spaces and underscores.

precision

Get the (floating point) precision used in multiple parts of this module.

model