dgs.models.loader.register_module¶
- dgs.models.loader.register_module(name, new_module, module_type: str) None [source]¶
Register a new module.
- Parameters:
name – The name under which to register the new module.
new_module – The type of the new module to register.
module_type – The type of module instance to register. Has to be in
MODULE_TYPES
.
- Raises:
ValueError – If the instance class name is invalid.
Examples:
from torch import nn from dgs.models import register_module class CustomNNLLoss(Loss): def __init__(...): ... def forward(self, input: torch.Tensor, target: torch.Tensor): return ... register_module(name="CustomNNLLoss", new_module=CustomNNLLoss, inst_class_name="loss")