dgs.utils.torchtools.load_checkpoint

dgs.utils.torchtools.load_checkpoint(fpath, device: torch.types.Device | str | None = None) dict[source]

Load a given checkpoint.

UnicodeDecodeError can be well handled, which means python2-saved files can be read from python3.

Parameters:
  • fpath (str) – path to checkpoint.

  • device (torch.device, optional) – If not None, load all tensors to this device. If None tries to load to CUDA.

Returns:

dict

Examples

>>> from dgs.utils.torchtools import load_checkpoint
>>> fpath = 'log/my_model/model.pth.tar-10'
>>> checkpoint = load_checkpoint(fpath)