Source code for trixi.experiment.experiment

import time


[docs]class Experiment(object): """ An abstract Experiment which can be run for a number of epochs. The basic life cycle of an experiment is:: setup() prepare() while epoch < n_epochs: train() validate() epoch += 1 end() If you want to use another criterion than number of epochs, e.g. stopping based on validation loss, you can implement that in your validation method and just call .stop() at some point to break the loop. Just set your n_epochs to a high number or np.inf. The reason there is both :meth:`.setup` and :meth:`.prepare` is that internally there is also a :meth:`._setup_internal` method for hidden magic in classes that inherit from this. For example, the :class:`trixi.experiment.pytorchexperiment.PytorchExperiment` uses this to restore checkpoints. Think of :meth:`.setup` as an :meth:`.__init__` that is only called when the Experiment is actually asked to do anything. Then use :meth:`.prepare` to modify the fully instantiated Experiment if you need to. To write a new Experiment simply inherit the Experiment class and overwrite the methods. You can then start your Experiment calling :meth:`.run` In Addition the Experiment also has a test function. If you call the :meth:`.run_test` method it will call the :meth:`.test` and :meth:`.end_test` method internally (and if you give the parameter setup = True in run_test is will again call :meth:`.setup` and :meth:`.prepare` ). Each Experiment also has its current state in :attr:`_exp_state`, its start time in :attr:`_time_start`, its end time in :attr:`_time_end` and the current epoch index in :attr:`_epoch_idx` Args: n_epochs (int): The number of epochs in the Experiment (how often the train and validate method will be called) """ def __init__(self, n_epochs=0): self.n_epochs = n_epochs self._exp_state = "Preparing" self._time_start = "" self._time_end = "" self._epoch_idx = 0 self.__stop = False
[docs] def run(self, setup=True): """ This method runs the Experiment. It runs through the basic lifecycle of an Experiment:: setup() prepare() while epoch < n_epochs: train() validate() epoch += 1 end() """ try: self._time_start = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time())) self._time_end = "" if setup: self.setup() self._setup_internal() self.prepare() self._exp_state = "Started" self._start_internal() print("Experiment started.") self.__stop = False while self._epoch_idx < self.n_epochs and not self.__stop: self.train(epoch=self._epoch_idx) self.validate(epoch=self._epoch_idx) self._end_epoch_internal(epoch=self._epoch_idx) self._epoch_idx += 1 self._exp_state = "Trained" print("Training complete.") self.end() self._end_internal() self._exp_state = "Ended" print("Experiment ended.") self._time_end = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time())) except Exception as e: self._exp_state = "Error" self._time_end = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time())) self.process_err(e)
[docs] def run_test(self, setup=True): """ This method runs the Experiment. The test consist of an optional setup and then calls the :meth:`.test` and :meth:`.end_test`. Args: setup: If True it will execute the :meth:`.setup` and :meth:`.prepare` function similar to the run method before calling :meth:`.test`. """ try: if setup: self.setup() self._setup_internal() self.prepare() self._exp_state = "Testing" print("Start test.") self.test() self.end_test() self._end_test_internal() self._exp_state = "Tested" print("Testing complete.") except Exception as e: self._exp_state = "Error" self.process_err(e)
@property def epoch(self): """Convenience access property for self._epoch_idx""" return self._epoch_idx
[docs] def setup(self): """Is called at the beginning of each Experiment run to setup the basic components needed for a run.""" pass
[docs] def train(self, epoch): """ The training part of the Experiment, it is called once for each epoch Args: epoch (int): The current epoch the train method is called in """ pass
[docs] def validate(self, epoch): """ The evaluation/validation part of the Experiment, it is called once for each epoch (after the training part) Args: epoch (int): The current epoch the validate method is called in """ pass
[docs] def test(self): """The testing part of the Experiment""" pass
[docs] def stop(self): """If called the Experiment will stop after that epoch and not continue training""" self.__stop = True
[docs] def process_err(self, e): """ This method is called if an error occurs during the execution of an experiment. Will just raise by default. Args: e (Exception): The exception which was raised during the experiment life cycle """ raise e
def _setup_internal(self): pass def _end_epoch_internal(self, epoch): pass
[docs] def end(self): """Is called at the end of each experiment""" pass
[docs] def end_test(self): """Is called at the end of each experiment test""" pass
[docs] def prepare(self): """This method is called directly before the experiment training starts""" pass
def _start_internal(self): pass def _end_internal(self): pass def _end_test_internal(self): pass