dgs.utils.torchtools

Tools for handling recurring torch tasks. Mostly taken from the torchreid package

Module Functions

close_all_layers(model)

Closes / Freezes all layers in this model, e.g., for evaluation.

close_specified_layers(model, close_layers)

Close / Freeze the specified layers in the given model for training while keeping all other layers unchanged.

configure_torch_module(orig_cls, *_args, ...)

Decorator to decorate a class, which has to be a child of torch.nn.Module and the BaseModule! The decorator will then call BaseModule.configure_torch_model on themselves after initializing the original class.

get_model_from_module(module)

Given either a torch module or an instance of BaseModule, return a torch module.

init_instance_params(instance)

Given a module instance, initialize a single instance.

init_model_params(module)

Given a torch module, initialize the model parameters using some default weights.

load_checkpoint(fpath[, device])

Load a given checkpoint.

load_pretrained_weights(model, weight_path)

Loads pretrianed weights to model.

open_all_layers(model)

Opens all layers in this model for training.

open_specified_layers(model, open_layers[, ...])

Opens the specified layers in the given model for training while keeping all other layers unchanged or frozen.

resume_from_checkpoint(fpath, model[, ...])

Resumes training from a checkpoint.

save_checkpoint(state, save_dir, *[, ...])

Save a given checkpoint.

set_bn_to_eval(module)

Sets BatchNorm layers to eval mode.

torch_memory_analysis(f[, file_name, max_events])

A decorator for torch memory analysis using torch.cuda.memory._record_memory_history().

torch_memory_analysis_win(f[, file_name, ...])

A decorator for torch memory analysis using torch.cuda.memory._record_memory_history_legacy() that works on Windows machines.