dgs.models.alpha.alpha.BaseAlphaModule¶
- class dgs.models.alpha.alpha.BaseAlphaModule(*args: Any, **kwargs: Any)[source]¶
Given a state as input, compute and return the weight of the alpha gate.
Params¶
Optional Params¶
- weight (FilePath):
Local or absolute path to the pretrained weights of the model. Can be left empty.
- __init__(config: dict[str, any], path: list[str])[source]¶
Methods
configure_torch_module
(module[, train])Set compute mode and send model to the device or multiple parallel devices if applicable.
forward
(s)get_data
(s)Given a state, return the data which is input into the model.
Load the weights of the model from the given file path.
sub_forward
(data)Function to call when module is called from within a combined alpha module.
Terminate this module and all of its submodules.
validate_params
(validations[, attrib_name])Given per key validations, validate this module's parameters.
Attributes
Get the device of this module.
Get whether this module is set to training-mode.
Get the name of the module.
Get the name of the module.
Get the escaped name of the module usable in filepaths by replacing spaces and underscores.
Get the (floating point) precision used in multiple parts of this module.