dgs.utils.torchtools.load_pretrained_weights

dgs.utils.torchtools.load_pretrained_weights(model: TorchMod, weight_path: str, device: torch.types.Device | str | None = None, verbose: bool = False) None[source]

Loads pretrianed weights to model.

Features:
  • Incompatible layers (unmatched in name or size) will be ignored.

  • Can automatically deal with keys containing ‘module.’.

Parameters:
  • model – A torch module.

  • weight_path – path to pretrained weights.

  • device – Device to load weights to.

  • verbose – Whether to print non-warning messages

Examples

>>> from dgs.utils.torchtools import load_pretrained_weights
>>> weight_path = 'log/my_model/model-best.pth.tar'
>>> load_pretrained_weights(model, weight_path)