import atexit
import fnmatch
import json
import os
import numpy as np
import sys
import re
import random
import shutil
import string
import time
import traceback
import warnings
import torch
from trixi.experiment.experiment import Experiment
from trixi.logger import CombinedLogger, PytorchExperimentLogger, PytorchVisdomLogger
from trixi.logger.tensorboard.pytorchtensorboardlogger import PytorchTensorboardLogger
from trixi.util import Config, ResultElement, ResultLogDict, SourcePacker, name_and_iter_to_filename
from trixi.util.config import update_from_sys_argv
from trixi.util.pytorchutils import set_seed
from trixi.util.util import is_picklable
logger_lookup_dict = dict(
visdom=PytorchVisdomLogger,
tensorboard=PytorchTensorboardLogger,
)
try:
from trixi.logger.message.slackmessagelogger import SlackMessageLogger
logger_lookup_dict["slack"] = SlackMessageLogger
except:
pass
try:
from trixi.logger import TelegramMessageLogger
logger_lookup_dict["telegram"] = TelegramMessageLogger
except:
pass
[docs]class PytorchExperiment(Experiment):
"""
A PytorchExperiment extends the basic
functionality of the :class:`.Experiment` class with
convenience features for PyTorch (and general logging) such as creating a folder structure,
saving, plotting results and checkpointing your experiment.
The basic life cycle of a PytorchExperiment is the same as
:class:`.Experiment`::
setup()
prepare()
for epoch in n_epochs:
train()
validate()
end()
where the distinction between the first two is that between them
PytorchExperiment will automatically restore checkpoints and save the
:attr:`_config_raw` in :meth:`._setup_internal`. Please see below for more
information on this.
To get your own experiment simply inherit from the PytorchExperiment and
overwrite the :meth:`.setup`, :meth:`.prepare`, :meth:`.train`,
:meth:`.validate` method (or you can use the `very` experimental decorator
:func:`.experimentify` to convert your class into a experiment).
Then you can run your own experiment by calling the :meth:`.run` method.
Internally PytorchExperiment will provide a number of member variables which
you can access.
- n_epochs
Number of epochs.
- exp_name
Name of your experiment.
- config
The (initialized) :class:`.Config` of your experiment. You can
access the uninitialized one via :attr:`_config_raw`.
- result
A dict in which you can store your result values. If a
:class:`.PytorchExperimentLogger` is used, results will be a
:class:`.ResultLogDict` that directly automatically writes to a file
and also stores the N last entries for each key for quick access
(e.g. to quickly get the running mean).
- elog (if base_dir is given)
A :class:`.PytorchExperimentLogger` instance which can log your
results to a given folder. Will automatically be created if a base_dir is available.
- loggers
Contains all loggers you provide, including the experiment logger, accessible by the
names you provide.
- clog
A :class:`.CombinedLogger` instance which logs to all loggers with
different frequencies (specified with the last entry in the tuple you provide for each
logger where 1 means every time and N means every Nth time,
e.g. if you only want to send stuff to Visdom every 10th time).
The most important attribute is certainly :attr:`.config`, which is the
initialized :class:`.Config` for the experiment. To understand how it needs
to be structured to allow for automatic instantiation of types, please refer
to its documentation. If you decide not to use this functionality,
:attr:`config` and :attr:`_config_raw` are identical. **Beware however that
by default the Pytorchexperiment only saves the raw config** after
:meth:`.setup`. If you modify :attr:`config` during setup, make sure
to implement :meth:`._setup_internal` yourself should you want the modified
config to be saved::
def _setup_internal(self):
super(YourExperiment, self)._setup_internal() # calls .prepare_resume()
self.elog.save_config(self.config, "config")
Args:
config (dict or Config): Configures your experiment. If :attr:`name`,
:attr:`n_epochs`, :attr:`seed`, :attr:`base_dir` are given in the
config, it will automatically
overwrite the other args/kwargs with the values from the config.
In addition (defined by :attr:`parse_config_sys_argv`) the config
automatically parses the argv arguments and updates its values if a
key matches a console argument.
name (str):
The name of the PytorchExperiment.
n_epochs (int): The number of epochs (number of times the training
cycle will be executed).
seed (int): A random seed (which will set the random, numpy and
torch seed).
base_dir (str): A base directory in which the experiment result folder
will be created. A :class:`.PytorchExperimentLogger` instance will be created if
this is given.
globs: The :func:`globals` of the script which is run. This is necessary
to get and save the executed files in the experiment folder.
resume (str or PytorchExperiment): Another PytorchExperiment or path to
the result dir from another PytorchExperiment from which it will
load the PyTorch modules and other member variables and resume
the experiment.
ignore_resume_config (bool): If :obj:`True` it will not resume with the
config from the resume experiment but take the current/own config.
resume_save_types (list or tuple): A list which can define which values
to restore when resuming. Choices are:
- "model" <-- Pytorch models
- "optimizer" <-- Optimizers
- "simple" <-- Simple python variables (basic types and lists/tuples
- "th_vars" <-- torch tensors/variables
- "results" <-- The result dict
resume_reset_epochs (bool): Set epoch to zero if you resume an existing experiment.
parse_sys_argv (bool): Parsing the console arguments (argv) to get a
:attr:`config path` and/or :attr:`resume_path`.
parse_config_sys_argv (bool): Parse argv to update the config
(if the keys match).
checkpoint_to_cpu (bool): When checkpointing, transfer all tensors to
the CPU beforehand.
save_checkpoint_every_epoch (int): Determines after how many epochs a
checkpoint is stored.
explogger_kwargs (dict): Keyword arguments for :attr:`elog`
instantiation.
explogger_freq (int): The frequency x (meaning one in x) with which
the :attr:`clog` will call the :attr:`elog`.
loggers (dict): Specify additional loggers.
Entries should have one of these formats::
"name": "identifier" (will default to a frequency of 10)
"name": ("identifier"(, kwargs, frequency)) (last two are optional)
"identifier" is one of "telegram", "tensorboard", "visdom", "slack".
append_rnd_to_name (bool): If :obj:`True`, will append a random six
digit string to the experiment name.
save_checkpoints_default (bool): By default save the current and the last checkpoint or not.
"""
def __init__(self,
config=None,
name=None,
n_epochs=None,
seed=None,
base_dir=None,
globs=None,
resume=None,
ignore_resume_config=False,
resume_save_types=("model", "optimizer", "simple", "th_vars", "results"),
resume_reset_epochs=True,
parse_sys_argv=False,
checkpoint_to_cpu=True,
save_checkpoint_every_epoch=1,
explogger_kwargs=None,
explogger_freq=1,
loggers=None,
append_rnd_to_name=False,
default_save_types=("model", "optimizer", "simple", "th_vars", "results"),
save_checkpoints_default=True):
# super(PytorchExperiment, self).__init__()
Experiment.__init__(self)
# check for command line inputs for config_path and resume_path,
# will be prioritized over config and resume!
config_path_from_argv = None
if parse_sys_argv:
config_path_from_argv, resume_path_from_argv = get_vars_from_sys_argv()
if resume_path_from_argv:
resume = resume_path_from_argv
# construct _config_raw
if config_path_from_argv is None:
self._config_raw = self._config_raw_from_input(config, name, n_epochs, seed, append_rnd_to_name)
else:
self._config_raw = Config(file_=config_path_from_argv)
update_from_sys_argv(self._config_raw)
# set a few experiment attributes
self.n_epochs = self._config_raw["n_epochs"]
self._seed = self._config_raw['seed']
set_seed(self._seed)
self.exp_name = self._config_raw["name"]
self._checkpoint_to_cpu = checkpoint_to_cpu
self._save_checkpoint_every_epoch = save_checkpoint_every_epoch
self._default_save_types = default_save_types
self._save_checkpoint_default = save_checkpoints_default
self.results = dict()
# get base_dir from _config_raw or store there
if base_dir is not None:
self._config_raw["base_dir"] = base_dir
base_dir = self._config_raw["base_dir"]
# Construct experiment logger (automatically activated if base_dir is there)
self.loggers = {}
logger_list = []
if base_dir is not None:
if explogger_kwargs is None:
explogger_kwargs = {}
self.elog = PytorchExperimentLogger(base_dir=base_dir,
exp_name=self.exp_name,
**explogger_kwargs)
if explogger_freq is not None and explogger_freq > 0:
logger_list.append((self.elog, explogger_freq))
self.results = ResultLogDict("results-log.json", base_dir=self.elog.result_dir)
else:
self.elog = None
# Construct other loggers
if loggers is not None:
for logger_name, logger_cfg in loggers.items():
_logger, log_freq = self._make_logger(logger_name, logger_cfg)
self.loggers[logger_name] = _logger
if log_freq is not None and log_freq > 0:
logger_list.append((_logger, log_freq))
self.clog = CombinedLogger(*logger_list)
# Set resume attributes and update _config_raw,
# actual resuming is done automatically after setup in _setup_internal
self._resume_path = None
self._resume_save_types = resume_save_types
self._ignore_resume_config = ignore_resume_config
self._resume_reset_epochs = resume_reset_epochs
if resume is not None:
if isinstance(resume, str):
if resume == "last":
if base_dir is None:
raise ValueError("resume='last' requires base_dir.")
self._resume_path = os.path.join(base_dir, sorted(os.listdir(base_dir))[-1])
else:
self._resume_path = resume
elif isinstance(resume, PytorchExperiment):
self._resume_path = resume.elog.base_dir
if self._resume_path is not None and not self._ignore_resume_config:
self._config_raw.update(Config(file_=os.path.join(self._resume_path, "config", "config.json")),
ignore=list(map(lambda x: re.sub("^-+", "", x), sys.argv)))
# Save everything we need to reproduce experiment
if globs is not None and self.elog is not None:
zip_name = os.path.join(self.elog.save_dir, "sources.zip")
SourcePacker.zip_sources(globs, zip_name)
# Init objects in config
self.config = Config.init_objects(self._config_raw)
atexit.register(self.at_exit_func)
def _config_raw_from_input(self,
config=None,
name=None,
n_epochs=None,
seed=None,
append_rnd_to_name=False):
"""Construct _config_raw from input."""
_config_raw = None
if isinstance(config, str):
_config_raw = Config(file_=config)
elif isinstance(config, (Config, dict)):
_config_raw = Config(config=config)
else:
_config_raw = Config()
if n_epochs is None and _config_raw.get("n_epochs") is not None:
n_epochs = _config_raw["n_epochs"]
elif n_epochs is None and _config_raw.get("n_epochs") is None:
n_epochs = 0
_config_raw["n_epochs"] = n_epochs
if seed is None and _config_raw.get('seed') is not None:
seed = _config_raw['seed']
elif seed is None and _config_raw.get('seed') is None:
random_data = os.urandom(4)
seed = int.from_bytes(random_data, byteorder="big")
_config_raw['seed'] = seed
if name is None and _config_raw.get("name") is not None:
name = _config_raw["name"]
elif name is None and _config_raw.get("name") is None:
name = "experiment"
if append_rnd_to_name:
rnd_str = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(5))
name += "_" + rnd_str
_config_raw["name"] = name
return _config_raw
def _make_logger(self, logger_name, logger_cfg):
if isinstance(logger_cfg, (list, tuple)):
log_type = logger_cfg[0]
log_params = logger_cfg[1] if len(logger_cfg) > 1 else {}
log_freq = logger_cfg[2] if len(logger_cfg) > 2 else 10
else:
assert isinstance(logger_cfg, str), "The specified logger has to either be a string or a list with " \
"name, parameters, clog_frequency"
log_type = logger_cfg
log_params = {}
log_freq = 10
if "exp_name" not in log_params:
log_params["exp_name"] = self.exp_name
if log_type == "tensorboard":
if "target_dir" not in log_params or log_params["target_dir"] is None:
if self.elog is not None:
log_params["target_dir"] = os.path.join(self.elog.save_dir, "tensorboard")
else:
raise AttributeError("TensorboardLogger requires a target_dir or an ExperimentLogger instance.")
elif self.elog is not None:
log_params["target_dir"] = os.path.join(log_params["target_dir"], self.elog.folder_name)
log_type = logger_lookup_dict[log_type]
_logger = log_type(**log_params)
return _logger, log_freq
@property
def vlog(self):
if "visdom" in self.loggers:
return self.loggers["visdom"]
elif "v" in self.loggers:
return self.loggers["v"]
else:
return None
@property
def tlog(self):
if "telegram" in self.loggers:
return self.loggers["telegram"]
elif "t" in self.loggers:
return self.loggers["t"]
else:
return None
@property
def txlog(self):
return self.tblog
@property
def tblog(self):
if "tensorboard" in self.loggers:
return self.loggers["tensorboard"]
if "tensorboardx" in self.loggers:
return self.loggers["tensorboardx"]
elif "tx" in self.loggers:
return self.loggers["tx"]
else:
return None
@property
def slog(self):
if "slack" in self.loggers:
return self.loggers["slack"]
elif "s" in self.loggers:
return self.loggers["s"]
else:
return None
[docs] def process_err(self, e):
if self.elog is not None:
self.elog.text_logger.log_to("\n".join(traceback.format_tb(e.__traceback__)), "err")
self.elog.text_logger.log_to(repr(e), "err")
raise e
[docs] def update_attributes(self, var_dict, ignore=()):
"""
Updates the member attributes with the attributes given in the var_dict
Args:
var_dict (dict): dict in which the update values stored. If a key matches a member attribute name
the member attribute will be updated
ignore (list or tuple): iterable of keys to ignore
"""
for key, val in var_dict.items():
if key == "results":
self.results.load(val)
continue
if key in ignore:
continue
if hasattr(self, key):
setattr(self, key, val)
[docs] def get_pytorch_modules(self, from_config=True):
"""
Returns all torch.nn.Modules stored in the experiment in a dict (even child dicts are stored).
Args:
from_config (bool): Also get modules that are stored in the :attr:`.config` attribute.
Returns:
dict: Dictionary of PyTorch modules
"""
def parse_torchmodules_recursive(input, output):
if isinstance(input, dict):
for key, value in input.items():
if isinstance(value, dict):
parse_torchmodules_recursive(value, output)
elif isinstance(value, torch.nn.Module):
output[key] = value
pyth_modules = dict()
parse_torchmodules_recursive(self.__dict__, pyth_modules)
# for key, val in self.__dict__.items():
# if isinstance(val, torch.nn.Module):
# pyth_modules[key] = val
if from_config:
for key, val in self.config.items():
if isinstance(val, torch.nn.Module):
if type(key) == str:
key = "config." + key
pyth_modules[key] = val
return pyth_modules
[docs] def get_pytorch_optimizers(self, from_config=True):
"""
Returns all torch.optim.Optimizers stored in the experiment in a dict.
Args:
from_config (bool): Also get optimizers that are stored in the :attr:`.config`
attribute.
Returns:
dict: Dictionary of PyTorch optimizers
"""
pyth_optimizers = dict()
for key, val in self.__dict__.items():
if isinstance(val, torch.optim.Optimizer):
pyth_optimizers[key] = val
if from_config:
for key, val in self.config.items():
if isinstance(val, torch.optim.Optimizer):
if type(key) == str:
key = "config." + key
pyth_optimizers[key] = val
return pyth_optimizers
[docs] def get_simple_variables(self, ignore=()):
"""
Returns all standard variables in the experiment in a dict.
Specifically, this looks for types :class:`int`, :class:`float`, :class:`bytes`,
:class:`bool`, :class:`str`, :class:`set`, :class:`list`, :class:`tuple`.
Args:
ignore (list or tuple): Iterable of names which will be ignored
Returns:
dict: Dictionary of variables
"""
simple_vars = dict()
for key, val in self.__dict__.items():
if key in ignore:
continue
if isinstance(val, (int, float, bytes, bool, str, set, list, tuple)):
if is_picklable(val):
simple_vars[key] = val
return simple_vars
[docs] def get_pytorch_tensors(self, ignore=()):
"""
Returns all torch.tensors in the experiment in a dict.
Args:
ignore (list or tuple): Iterable of names which will be ignored
Returns:
dict: Dictionary of PyTorch tensor
"""
pytorch_vars = dict()
for key, val in self.__dict__.items():
if key in ignore:
continue
if torch.is_tensor(val):
pytorch_vars[key] = val
return pytorch_vars
[docs] def get_pytorch_variables(self, ignore=()):
"""Same as :meth:`.get_pytorch_tensors`."""
return self.get_pytorch_tensors(ignore)
[docs] def save_results(self, name="results.json"):
"""
Saves the result dict as a json file in the result dir of the experiment logger.
Args:
name (str): The name of the json file in which the results are written.
"""
if self.elog is None:
return
with open(os.path.join(self.elog.result_dir, name), "w") as file_:
json.dump(self.results, file_, indent=4)
[docs] def save_pytorch_models(self):
"""Saves all torch.nn.Modules as model files in the experiment checkpoint folder."""
if self.elog is None:
return
pyth_modules = self.get_pytorch_modules()
for key, val in pyth_modules.items():
self.elog.save_model(val, key)
[docs] def load_pytorch_models(self):
"""Loads all model files from the experiment checkpoint folder."""
if self.elog is None:
return
pyth_modules = self.get_pytorch_modules()
for key, val in pyth_modules.items():
self.elog.load_model(val, key)
[docs] def log_simple_vars(self):
"""
Logs all simple python member variables as a json file in the experiment log folder.
The file will be names 'simple_vars.json'.
"""
if self.elog is None:
return
simple_vars = self.get_simple_variables()
with open(os.path.join(self.elog.log_dir, "simple_vars.json"), "w") as file_:
json.dump(simple_vars, file_)
[docs] def load_simple_vars(self):
"""
Restores all simple python member variables from the 'simple_vars.json' file in the log
folder.
"""
if self.elog is None:
return
simple_vars = {}
with open(os.path.join(self.elog.log_dir, "simple_vars.json"), "r") as file_:
simple_vars = json.load(file_)
self.update_attributes(simple_vars)
[docs] def save_checkpoint(self,
name="checkpoint",
save_types=("model", "optimizer", "simple", "th_vars", "results"),
n_iter=None,
iter_format="{:05d}",
prefix=False):
"""
Saves a current model checkpoint from the experiment.
Args:
name (str): The name of the checkpoint file
save_types (list or tuple): What kind of member variables should be stored? Choices are:
"model" <-- Pytorch models,
"optimizer" <-- Optimizers,
"simple" <-- Simple python variables (basic types and lists/tuples),
"th_vars" <-- torch tensors,
"results" <-- The result dict
n_iter (int): Number of iterations. Together with the name, defined by the iter_format,
a file name will be created.
iter_format (str): Defines how the name and the n_iter will be combined.
prefix (bool): If True, the formatted n_iter will be prepended, otherwise appended.
"""
if self.elog is None:
return
model_dict = {}
optimizer_dict = {}
simple_dict = {}
th_vars_dict = {}
results_dict = {}
if "model" in save_types:
model_dict = self.get_pytorch_modules()
if "optimizer" in save_types:
optimizer_dict = self.get_pytorch_optimizers()
if "simple" in save_types:
simple_dict = self.get_simple_variables()
if "th_vars" in save_types:
th_vars_dict = self.get_pytorch_variables()
if "results" in save_types:
results_dict = {"results": self.results}
checkpoint_dict = {**model_dict, **optimizer_dict, **simple_dict, **th_vars_dict, **results_dict}
self.elog.save_checkpoint(name=name, n_iter=n_iter, iter_format=iter_format, prefix=prefix,
move_to_cpu=self._checkpoint_to_cpu, **checkpoint_dict)
[docs] def load_checkpoint(self,
name="checkpoint",
save_types=("model", "optimizer", "simple", "th_vars", "results"),
n_iter=None,
iter_format="{:05d}",
prefix=False,
path=None):
"""
Loads a checkpoint and restores the experiment.
Make sure you have your torch stuff already on the right devices beforehand,
otherwise this could lead to errors e.g. when making a optimizer step
(and for some reason the Adam states are not already on the GPU:
https://discuss.pytorch.org/t/loading-a-saved-model-for-continue-training/17244/3 )
Args:
name (str): The name of the checkpoint file
save_types (list or tuple): What kind of member variables should be loaded? Choices are:
"model" <-- Pytorch models,
"optimizer" <-- Optimizers,
"simple" <-- Simple python variables (basic types and lists/tuples),
"th_vars" <-- torch tensors,
"results" <-- The result dict
n_iter (int): Number of iterations. Together with the name, defined by the iter_format,
a file name will be created and searched for.
iter_format (str): Defines how the name and the n_iter will be combined.
prefix (bool): If True, the formatted n_iter will be prepended, otherwise appended.
path (str): If no path is given then it will take the current experiment dir and formatted
name, otherwise it will simply use the path and the formatted name to define the
checkpoint file.
"""
if self.elog is None:
return
model_dict = {}
optimizer_dict = {}
simple_dict = {}
th_vars_dict = {}
results_dict = {}
if "model" in save_types:
model_dict = self.get_pytorch_modules()
if "optimizer" in save_types:
optimizer_dict = self.get_pytorch_optimizers()
if "simple" in save_types:
simple_dict = self.get_simple_variables()
if "th_vars" in save_types:
th_vars_dict = self.get_pytorch_variables()
if "results" in save_types:
results_dict = {"results": self.results}
checkpoint_dict = {**model_dict, **optimizer_dict, **simple_dict, **th_vars_dict, **results_dict}
if n_iter is not None:
name = name_and_iter_to_filename(name,
n_iter,
".pth.tar",
iter_format=iter_format,
prefix=prefix)
if path is None:
restore_dict = self.elog.load_checkpoint(name=name, **checkpoint_dict)
else:
checkpoint_path = os.path.join(path, name)
if checkpoint_path.endswith(os.sep):
checkpoint_path = os.path.dirname(checkpoint_path)
restore_dict = self.elog.load_checkpoint_static(checkpoint_file=checkpoint_path, **checkpoint_dict)
self.update_attributes(restore_dict)
def _end_internal(self):
"""Ends the experiment and stores the final results/checkpoint"""
if isinstance(self.results, ResultLogDict):
self.results.close()
self.save_results()
self.save_end_checkpoint()
self._save_exp_config()
self.print("Experiment ended. Checkpoints stored =)")
def _end_test_internal(self):
"""Ends the experiment after test and stores the final results and config"""
self.save_results()
self._save_exp_config()
self.print("Testing ended. Results stored =)")
[docs] def at_exit_func(self):
"""
Stores the results and checkpoint at the end (if not already stored).
This method is also called if an error occurs.
"""
if self._exp_state not in ("Ended", "Tested"):
if isinstance(self.results, ResultLogDict):
self.results.print_to_file("]")
self.save_checkpoint(name="checkpoint_exit-" + self._exp_state, save_types=self._default_save_types)
self.save_results()
self._save_exp_config()
self.print("Experiment exited. Checkpoints stored =)")
time.sleep(2) # allow checkpoint saving to finish. We need a better solution for this :D
def _setup_internal(self):
self.prepare_resume()
if self.elog is not None:
self.elog.save_config(self._config_raw, "config")
self._save_exp_config()
def _start_internal(self):
self._save_exp_config()
[docs] def prepare_resume(self):
"""Tries to resume the experiment by using the defined resume path or PytorchExperiment."""
checkpoint_file = ""
base_dir = ""
reset_epochs = self._resume_reset_epochs
if self._resume_path is not None:
if isinstance(self._resume_path, str):
if self._resume_path.endswith(".pth.tar"):
checkpoint_file = self._resume_path
base_dir = os.path.dirname(os.path.dirname(checkpoint_file))
elif self._resume_path.endswith("checkpoint") or self._resume_path.endswith("checkpoint/"):
checkpoint_file = get_last_file(self._resume_path)
base_dir = os.path.dirname(os.path.dirname(checkpoint_file))
elif "checkpoint" in os.listdir(self._resume_path) and "config" in os.listdir(self._resume_path):
checkpoint_file = get_last_file(self._resume_path)
base_dir = self._resume_path
else:
warnings.warn("You have not selected a valid experiment folder, will search all sub folders",
UserWarning)
if self.elog is not None:
self.elog.text_logger.log_to("You have not selected a valid experiment folder, will search all "
"sub folders", "warnings")
checkpoint_file = get_last_file(self._resume_path)
base_dir = os.path.dirname(os.path.dirname(checkpoint_file))
# if base_dir:
# if not self._ignore_resume_config:
# load_config = Config()
# load_config.load(os.path.join(base_dir, "config/config.json"))
# self._config_raw = load_config
# self.config = Config.init_objects(self._config_raw)
# self.print("Loaded existing config from:", base_dir)
# if self.n_epochs is None:
# self.n_epochs = self._config_raw.get("n_epochs")
if checkpoint_file:
self.load_checkpoint(name="", path=checkpoint_file, save_types=self._resume_save_types)
self._resume_path = checkpoint_file
shutil.copyfile(checkpoint_file, os.path.join(self.elog.checkpoint_dir, "0_checkpoint.pth.tar"))
self.print("Loaded existing checkpoint from:", checkpoint_file)
self._resume_reset_epochs = reset_epochs
if self._resume_reset_epochs:
self._epoch_idx = 0
def _end_epoch_internal(self, epoch):
self.save_results()
if self._save_checkpoint_every_epoch is not None and self._save_checkpoint_every_epoch > 0 and epoch % \
self._save_checkpoint_every_epoch == 0:
self.save_temp_checkpoint()
self._save_exp_config()
def _save_exp_config(self):
if self.elog is not None:
cur_time = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time()))
self.elog.save_config(Config(**{'name': self.exp_name,
'time': self._time_start,
'state': self._exp_state,
'current_time': cur_time,
'epoch': self._epoch_idx
}),
"exp")
[docs] def save_temp_checkpoint(self):
"""Saves the current checkpoint as checkpoint_current."""
if self._save_checkpoint_default:
self.save_checkpoint(name="checkpoint_current", save_types=self._default_save_types)
[docs] def save_end_checkpoint(self):
"""Saves the current checkpoint as checkpoint_last."""
if self._save_checkpoint_default:
self.save_checkpoint(name="checkpoint_last", save_types=self._default_save_types)
[docs] def add_result(self, value, name, counter=None, tag=None, label=None, plot_result=True, plot_running_mean=False):
"""
Saves a results and add it to the result dict, this is similar to results[key] = val,
but in addition also logs the value to the combined logger
(it also stores in the results-logs file).
**This should be your preferred method to log your numeric values**
Args:
value: The value of your variable
name (str): The name/key of your variable
counter (int or float): A counter which can be seen as the x-axis of your value.
Normally you would just use the current epoch for this.
tag (str): A label/tag which can group similar values and will plot values with the same
label in the same plot
label: deprecated label
plot_result (bool): By default True, will also log all your values to the combined
logger (with show_value).
"""
if label is not None:
warnings.warn("label in add_result is deprecated, please use tag instead")
if tag is None:
tag = label
tag_name = tag
if tag_name is None:
tag_name = name
r_elem = ResultElement(data=value, label=tag_name, epoch=self._epoch_idx, counter=counter)
self.results[name] = r_elem
if plot_result:
if tag is None:
legend = False
else:
legend = True
if plot_running_mean:
value = np.mean(self.results.running_mean_dict[name])
self.clog.show_value(value=value, name=name, tag=tag, counter=counter, show_legend=legend)
[docs] def get_result(self, name):
"""
Similar to result[key] this will return the values in the results dictionary with the given
name/key.
Args:
name (str): the name/key for which a value is stored.
Returns:
The value with the key 'name' in the results dict.
"""
return self.results.get(name)
[docs] def add_result_without_epoch(self, val, name):
"""
A faster method to store your results, has less overhead and does not call the combined
logger. Will only store to the results dictionary.
Args:
val: the value you want to add.
name (str): the name/key of your value.
"""
self.results[name] = val
[docs] def get_result_without_epoch(self, name):
"""
Similar to result[key] this will return the values in result with the given name/key.
Args:
name (str): the name/ key for which a value is stores.
Returns:
The value with the key 'name' in the results dict.
"""
return self.results.get(name)
[docs] def print(self, *args):
"""
Calls 'print' on the experiment logger or uses builtin 'print' if former is not
available.
"""
if self.elog is None:
print(*args)
else:
self.elog.print(*args)
[docs]def get_last_file(dir_, name=None):
"""
Returns the most recently created file in the folder which matches the name supplied
Args:
dir_: The base directory to start the search in
name: The name pattern to match with the files
Returns:
str: the path to the most recent file
"""
if name is None:
name = "*checkpoint*.pth.tar"
dir_files = []
for root, dirs, files in os.walk(dir_):
for filename in fnmatch.filter(files, name):
# if 'last' in filename:
# return os.path.join(root, filename)
checkpoint_file = os.path.join(root, filename)
dir_files.append(checkpoint_file)
if len(dir_files) == 0:
return ""
last_file = max(dir_files, key=os.path.getctime)
return last_file
[docs]def get_vars_from_sys_argv():
"""
Parses the command line args (argv) and looks for --config_path and --resume_path and returns them if found.
Returns:
tuple: a tuple of (config_path, resume_path ) , None if it is not found
"""
import sys
import argparse
if len(sys.argv) > 1:
parser = argparse.ArgumentParser()
# parse just config keys
parser.add_argument("config_path", type=str)
parser.add_argument("resume_path", type=str)
# parse args
param, unknown = parser.parse_known_args()
if len(unknown) > 0:
warnings.warn("Called with unknown arguments: %s" % unknown, RuntimeWarning)
# update dict
return param.config_path, param.resume_path
[docs]def experimentify(setup_fn="setup", train_fn="train", validate_fn="validate", end_fn="end", test_fn="test", **decoargs):
"""
Experimental decorator which monkey patches your class into a PytorchExperiment.
You can then call run on your new :class:`.PytorchExperiment` class.
Args:
setup_fn: The name of your setup() function
train_fn: The name of your train() function
validate_fn: The name of your validate() function
end_fn: The name of your end() function
test_fn: The name of your test() function
"""
def wrap(cls):
### Initilaize both Classes (as original class)
prev_init = cls.__init__
def new_init(*args, **kwargs):
prev_init(*args, **kwargs)
kwargs.update(decoargs)
PytorchExperiment.__init__(*args, **kwargs)
cls.__init__ = new_init
### Set new Experiment methods
if not hasattr(cls, "setup") and hasattr(cls, setup_fn):
setattr(cls, "setup", getattr(cls, setup_fn))
elif hasattr(cls, "setup") and setup_fn != "setup":
warnings.warn("Found already exisiting setup function in class, so will use the exisiting one")
if not hasattr(cls, "train") and hasattr(cls, train_fn):
setattr(cls, "train", getattr(cls, train_fn))
elif hasattr(cls, "train") and setup_fn != "train":
warnings.warn("Found already exisiting train function in class, so will use the exisiting one")
if not hasattr(cls, "validate") and hasattr(cls, validate_fn):
setattr(cls, "validate", getattr(cls, validate_fn))
elif hasattr(cls, "validate") and setup_fn != "validate":
warnings.warn("Found already exisiting validate function in class, so will use the exisiting one")
if not hasattr(cls, "end") and hasattr(cls, end_fn):
setattr(cls, "end", getattr(cls, end_fn))
elif hasattr(cls, "end") and end_fn != "end":
warnings.warn("Found already exisiting end function in class, so will use the exisiting one")
if not hasattr(cls, "test") and hasattr(cls, test_fn):
setattr(cls, "test", getattr(cls, test_fn))
elif hasattr(cls, "test") and test_fn != "test":
warnings.warn("Found already exisiting test function in class, so will use the exisiting one")
### Copy methods from PytorchExperiment into the original class
for elem in dir(PytorchExperiment):
if not hasattr(cls, elem):
trans_fn = getattr(PytorchExperiment, elem)
setattr(cls, elem, trans_fn)
return cls
return wrap