dgs.models.loader.register_module

dgs.models.loader.register_module(name, new_module, module_type: str) None[source]

Register a new module.

Parameters:
  • name – The name under which to register the new module.

  • new_module – The type of the new module to register.

  • module_type – The type of module instance to register. Has to be in MODULE_TYPES.

Raises:

ValueError – If the instance class name is invalid.

Examples:

from torch import nn
from dgs.models import register_module
class CustomNNLLoss(Loss):
    def __init__(...):
        ...
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        return ...
register_module(name="CustomNNLLoss", new_module=CustomNNLLoss, inst_class_name="loss")