"""
Tools for handling recurring torch tasks. Mostly taken from the `torchreid package
<https://kaiyangzhou.github.io/deep-person-reid/_modules/torchreid/utils/torchtools.html#load_pretrained_weights>`_
"""
import os
import pickle
import shutil
import warnings
from collections import OrderedDict
from copy import deepcopy
from functools import partial
from typing import TypeVar, Union
import torch as t
from torch import nn, optim
from torch.nn import Module as TorchModule
from dgs.models.module import BaseModule
from dgs.utils.files import mkdir_if_missing
from dgs.utils.types import Device, FilePath
BaseMod = TypeVar("BaseMod", bound=BaseModule)
TorchMod = TypeVar("TorchMod", bound=TorchModule)
[docs]
def get_model_from_module(module: Union[TorchMod, BaseMod]) -> TorchMod:
"""Given either a torch module or an instance of BaseModule, return a torch module.
Within a BaseModule, this function searches for a 'module' attribute.
Args:
module: The module containing or being a torch module.
Returns:
An instance of a torch module.
Raises:
ValueError if a torch module cannot be found.
"""
if isinstance(module, nn.DataParallel):
module = module.module
if isinstance(module, BaseModule):
if hasattr(module, "model"):
module = module.model
elif hasattr(module, "module"):
module = module.module
elif not isinstance(module, nn.Module):
raise ValueError(
f"model {module.__class__.__name__} is a BaseModule but there is no 'model' attribute "
f"and the model is not a subclass of nn.Module."
)
return module
[docs]
def save_checkpoint(
state: dict[str, any],
save_dir: FilePath,
*,
is_best: bool = False,
remove_module_from_keys: bool = False,
prepend: str = "",
verbose: bool = True,
) -> None:
"""Save a given checkpoint.
Args:
state: State dictionary. See examples.
save_dir: directory to save checkpoint.
is_best (bool, optional): if True, this checkpoint will be copied and named
``model-best.pth.tar``. Default is False.
remove_module_from_keys: Whether to remove the 'module.' prepend in the state dict of the module.
prepend: A string to prepend to the filename.
verbose (bool, optional): whether to print a confirmation when the checkpoint has been created. Default is True.
Examples:
>>> state = {
>>> 'model': model.state_dict(),
>>> 'epoch': 10,
>>> 'rank1': 0.5,
>>> 'optimizer': optimizer.state_dict()
>>> }
>>> save_checkpoint(state, 'log/my_model')
"""
mkdir_if_missing(save_dir)
# all the module keys start with 'module.' remove that
if remove_module_from_keys:
# remove 'module.' in state_dict's keys
state_dict = deepcopy(state["module"])
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:]
elif k.startswith("model."):
k = k[6:]
new_state_dict[k] = v
state["module"] = new_state_dict
# save
epoch = int(state["epoch"])
if len(prepend) > 0 and not prepend.endswith("_"):
prepend += "_"
fpath = os.path.normpath(os.path.join(save_dir, f"./{prepend}epoch{epoch:0>3}.pth"))
t.save(obj=state, f=fpath)
if verbose:
print(f"Checkpoint saved to '{fpath}'")
if is_best:
shutil.copy(fpath, os.path.join(os.path.dirname(fpath), "model-best.pth.tar"))
if verbose:
print("Saved best model as model-best.pth.tar")
[docs]
def load_checkpoint(fpath, device: Device = None) -> dict:
"""Load a given checkpoint.
``UnicodeDecodeError`` can be well handled, which means
python2-saved files can be read from python3.
Args:
fpath (str): path to checkpoint.
device (torch.device, optional): If not None, load all tensors to this device. If None tries to load to CUDA.
Returns:
dict
Examples:
>>> from dgs.utils.torchtools import load_checkpoint
>>> fpath = 'log/my_model/model.pth.tar-10'
>>> checkpoint = load_checkpoint(fpath)
"""
if fpath is None:
raise ValueError("File path is None")
fpath = os.path.abspath(os.path.expanduser(fpath))
if not os.path.exists(fpath):
raise FileNotFoundError(f"File is not found at '{fpath}'")
map_location = device if device is not None else "cuda" if t.cuda.is_available() else "cpu"
try:
checkpoint = t.load(fpath, map_location=map_location)
except UnicodeDecodeError:
pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
checkpoint = t.load(fpath, pickle_module=pickle, map_location=map_location)
except Exception:
print(f"Unable to load checkpoint from '{fpath}'")
raise
return checkpoint
[docs]
def load_pretrained_weights(
model: TorchMod, weight_path: FilePath, device: Device = None, verbose: bool = False
) -> None:
"""Loads pretrianed weights to model.
Features:
- Incompatible layers (unmatched in name or size) will be ignored.
- Can automatically deal with keys containing 'module.'.
Args:
model: A torch module.
weight_path: path to pretrained weights.
device: Device to load weights to.
verbose: Whether to print non-warning messages
Examples:
>>> from dgs.utils.torchtools import load_pretrained_weights
>>> weight_path = 'log/my_model/model-best.pth.tar'
>>> load_pretrained_weights(model, weight_path)
"""
checkpoint = load_checkpoint(weight_path, device=device)
if "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
elif k.startswith("module.") and k[7:] in model_dict and model_dict[k[7:]].size() == v.size():
new_state_dict[k[7:]] = v
matched_layers.append(k[7:])
elif k.startswith("model.") and k[6:] in model_dict and model_dict[k[6:]].size() == v.size():
new_state_dict[k[6:]] = v
matched_layers.append(k[6:])
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
f"The pretrained weights '{weight_path}' cannot be loaded, "
f"please check the key names manually "
f"(** ignored and continue **)"
)
else:
if verbose:
print(f"Successfully loaded pretrained weights from '{weight_path}'")
if len(discarded_layers) > 0:
print(f"** The following layers are discarded due to unmatched keys or layer size: {discarded_layers}")
[docs]
def resume_from_checkpoint(
fpath: FilePath,
model: Union[TorchMod, BaseMod],
optimizer: optim.Optimizer = None,
scheduler: optim.lr_scheduler.LRScheduler = None,
verbose: bool = False,
) -> int:
"""Resumes training from a checkpoint.
This will load (1) model weights and (2) ``state_dict``
of optimizer if ``optimizer`` is not None.
Args:
fpath: The path to checkpoint. Can be a local or absolute path.
model: The model that is currently trained.
optimizer: An Optimizer.
scheduler: A single LRScheduler.
verbose: Whether to print additional debug information.
Returns:
int: start_epoch.
Examples:
>>> from dgs.utils.torchtools import resume_from_checkpoint
>>> fpath = 'log/my_model/model.pth.tar-10'
>>> start_epoch = resume_from_checkpoint(
>>> fpath, model, optimizer, scheduler
>>> )
"""
model = get_model_from_module(module=model)
if verbose:
print(f"Loading checkpoint from '{fpath}'")
load_pretrained_weights(model=model, weight_path=fpath, verbose=verbose)
if verbose:
print("Loaded model weights")
checkpoint = load_checkpoint(fpath)
if optimizer is not None and "optimizer" in checkpoint.keys():
optimizer.load_state_dict(checkpoint["optimizer"])
if verbose:
print("Loaded optimizer")
if scheduler is not None and "scheduler" in checkpoint.keys():
scheduler.load_state_dict(checkpoint["scheduler"])
if verbose:
print("Loaded scheduler")
return checkpoint["epoch"]
[docs]
def set_bn_to_eval(module: Union[TorchMod, BaseMod]) -> None:
"""Sets BatchNorm layers to eval mode.
Args:
module: A torch module.
"""
# 1. no update for running mean and var
# 2. scale and shift parameters are still trainable
module = get_model_from_module(module=module)
classname = module.__class__.__name__
if classname.find("BatchNorm") != -1:
module.eval()
[docs]
def open_specified_layers(
model: Union[TorchMod, BaseMod], open_layers: str | list[str], freeze_others: bool = True, verbose: bool = False
) -> None:
"""Opens the specified layers in the given model for training while keeping all other layers unchanged or frozen.
Args:
model: A torch module or a BaseModule containing a torch module as attribute 'module'.
open_layers: Name or names of the layers to open for training.
freeze_others: Whether to freeze all the other modules that are not present in ``open_layers``.
verbose: Whether to print some debugging information.
Examples:
In the first example, open only the classifier-layer and freeze the rest of the model.
Then, in the second example using the same model,
open the two fc-layers while keeping the previously opened classifier open.
In the third one open the fc- and classifier-layers and freeze everything else.
>>> from dgs.utils.torchtools import open_specified_layers
>>> open_specified_layers(model, open_layers='classifier')
>>> open_specified_layers(model, open_layers=['fc1', 'fc2'], freeze_others=False)
>>> open_specified_layers(other_model, open_layers=['fc', 'classifier'])
Raises:
ValueError if a value in open_layers is not an attribute of the model.
"""
# pylint: disable=too-many-branches
model = get_model_from_module(module=model)
if isinstance(open_layers, str):
open_layers = [open_layers]
for layer in open_layers:
if not hasattr(model, layer):
raise ValueError(
f"{layer} is not an attribute of the model {model.__class__.__name__}, "
f"please provide the correct name or model."
)
nof_opened, nof_freezed, still_open, still_closed = 0, 0, 0, 0
sub_module: TorchMod
for name, sub_module in model.named_children():
if name in open_layers:
sub_module.train()
sub_module.requires_grad_()
for p in sub_module.parameters():
p.requires_grad = True
nof_opened += 1
elif freeze_others:
sub_module.eval()
sub_module.requires_grad_(False)
for p in sub_module.parameters():
p.requires_grad = False
nof_freezed += 1
elif sub_module.training:
still_open += 1
else:
still_closed += 1
if verbose:
if freeze_others:
print(f"Opened {nof_opened} layers. Froze {nof_freezed}.")
else:
print(f"Opened {nof_opened} layers. Layers still open: {still_open}. Layers still closed: {still_closed}")
[docs]
def open_all_layers(model: Union[TorchMod, BaseMod]) -> None:
"""Opens all layers in this model for training.
Args:
model: A torch module.
Examples:
>>> from dgs.utils.torchtools import open_all_layers
>>> open_all_layers(model)
"""
def open_module(m: TorchMod) -> None:
if hasattr(m, "requires_grad"):
m.requires_grad = True
if hasattr(m, "train"):
m.train()
model: TorchMod = get_model_from_module(module=model)
model.train()
model.requires_grad_()
model.apply(open_module)
for p in model.parameters():
p.requires_grad = True
[docs]
def close_specified_layers(
model: Union[TorchMod, BaseMod], close_layers: str | list[str], open_others: bool = False, verbose: bool = False
) -> None:
"""Close / Freeze the specified layers in the given model for training while keeping all other layers unchanged.
Args:
model: A torch module or a BaseModule containing a torch module as attribute 'module'.
close_layers: Name or names of the layers to close for evaluation.
open_others: Whether to open all layers not present in ``close_layers``.
verbose: Whether to print some debugging information.
Raises:
ValueError if a value in close_layers is not an attribute of the model.
"""
# pylint: disable=too-many-branches
model = get_model_from_module(module=model)
if isinstance(close_layers, str):
close_layers = [close_layers]
for layer in close_layers:
if not hasattr(model, layer):
raise ValueError(
f"{layer} is not an attribute of the model {model.__class__.__name__}, "
f"please provide the correct name or model."
)
nof_closed, nof_opened, still_closed, still_open = 0, 0, 0, 0
sub_module: TorchMod
for name, sub_module in model.named_children():
if name in close_layers:
sub_module.eval()
sub_module.requires_grad_(False)
for p in sub_module.parameters():
p.requires_grad = False
nof_closed += 1
elif open_others:
sub_module.train()
sub_module.requires_grad_()
for p in sub_module.parameters():
p.requires_grad = True
nof_opened += 1
elif sub_module.training:
still_open += 1
else:
still_closed += 1
if verbose:
if open_others:
print(f"Closed {nof_closed} layers. Opened {nof_opened} layers.")
else:
print(f"Closed {nof_closed} layers. Still open: {still_open}, kept closed: {still_closed}")
[docs]
def close_all_layers(model: Union[TorchMod, BaseMod]) -> None:
"""Closes / Freezes all layers in this model, e.g., for evaluation.
Args:
model: A torch module.
"""
def freeze_module(m: TorchMod) -> None:
if hasattr(m, "requires_grad"):
m.requires_grad = False
if hasattr(m, "eval"):
m.eval()
model: TorchMod = get_model_from_module(module=model)
model.eval()
model.requires_grad_(False)
model.apply(freeze_module)
for p in model.parameters():
p.requires_grad = False
[docs]
def init_model_params(module: TorchMod) -> None:
"""Given a torch module, initialize the model parameters using some default weights."""
model: TorchMod = get_model_from_module(module)
for instance in model.modules():
init_instance_params(instance=instance)
[docs]
def init_instance_params(instance: nn.Module) -> None:
"""Given a module instance, initialize a single instance."""
if isinstance(instance, nn.Conv2d):
nn.init.kaiming_normal_(instance.weight, mode="fan_out", nonlinearity="relu")
if instance.bias is not None:
nn.init.constant_(instance.bias, 0)
elif isinstance(instance, nn.BatchNorm2d):
nn.init.constant_(instance.weight, 1)
nn.init.constant_(instance.bias, 0)
elif isinstance(instance, nn.BatchNorm1d):
nn.init.constant_(instance.weight, 1)
nn.init.constant_(instance.bias, 0)
elif isinstance(instance, nn.InstanceNorm2d):
nn.init.constant_(instance.weight, 1)
nn.init.constant_(instance.bias, 0)
elif isinstance(instance, nn.Linear):
nn.init.normal_(instance.weight, 0, 0.01)
if instance.bias is not None:
nn.init.constant_(instance.bias, 0)
elif isinstance(instance, nn.ConvTranspose2d):
nn.init.normal_(instance.weight, std=0.001)
for name, _ in instance.named_parameters():
if name in ["bias"]:
nn.init.constant_(instance.bias, 0)
[docs]
def torch_memory_analysis(
f: callable, file_name: FilePath = "./memory_snapshot.pickle", max_events: int = 100_000
) -> callable: # pragma: no cover
"""A decorator for torch memory analysis using :func:`torch.cuda.memory._record_memory_history`."""
# pylint: disable=protected-access
def decorator(*args, **kwargs):
"""The decorator."""
try:
# start memory recording
t.cuda.memory._record_memory_history(max_entries=max_events)
# call original function
f(*args, **kwargs)
finally:
t.cuda.memory._dump_snapshot(file_name)
# stop recording memory
t.cuda.memory._record_memory_history(enabled=None)
print(f"saved torch memory snapshot to: '{file_name}'")
return decorator
[docs]
def torch_memory_analysis_win(
f: callable, file_name: FilePath = "./memory_snapshot.pickle", max_events: int = 100_000
) -> callable: # pragma: no cover
"""A decorator for torch memory analysis using :func:`torch.cuda.memory._record_memory_history_legacy`
that works on Windows machines."""
# pylint: disable=protected-access
def decorator(*args, **kwargs):
"""The decorator."""
try:
# start memory recording
t.cuda.memory._record_memory_history_legacy(
enabled=True, trace_alloc_max_entries=max_events, trace_alloc_record_context=True
)
# call original function
f(*args, **kwargs)
finally:
t.cuda.memory._dump_snapshot(file_name)
# stop recording memory
t.cuda.memory._record_memory_history_legacy(enabled=False)
print(f"saved torch memory snapshot to: '{file_name}'")
return decorator