import os
import logging
from enum import Enum
import copy
import time
from ignite.engine import Events, Engine
from ignite.handlers import Timer
from ignite.contrib.handlers import ProgressBar
from ignite.metrics import RunningAverage
from torch.utils.tensorboard import SummaryWriter
import torch
from torch import nn
import numpy as np
from nussl import datasets
[docs]class ValidationEvents(Enum):
"""
Events based on validation running
"""
VALIDATION_STARTED = 'validation_started'
VALIDATION_COMPLETED = 'validation_completed'
[docs]class BackwardsEvents(Enum):
"""
Events based on validation running
"""
BACKWARDS_COMPLETED = 'backwards_completed'
[docs]def cache_dataset(dataset):
"""
Runs through an entire dataset and caches it if there nussl.datasets.transforms.Cache
is in dataset.transform. If there is no caching, or dataset.cache_populated = True,
then this function just iterates through the dataset and does nothing.
This function can also take a `torch.util.data.DataLoader` object wrapped around
a `nussl.datasets.BaseDataset` object.
Args:
dataset (nussl.datasets.BaseDataset): Must be a subclass of
`nussl.datasets.BaseDataset`.
"""
def dummy_process(engine, data):
pass
cache = Engine(dummy_process)
ProgressBar().attach(cache)
cache.run(dataset)
dataset.cache_populated = True
[docs]def create_train_and_validation_engines(train_func, val_func=None, device='cpu'):
"""
Helper function for creating an ignite Engine object with helpful defaults.
This sets up an Engine that has four handlers attached to it:
- prepare_batch: before a batch is passed to train_func or val_func, this
function runs, moving every item in the batch (which is a dictionary) to
the appropriate device ('cpu' or 'cuda').
- book_keeping: sets up some dictionaries that are used for bookkeeping so one
can easily track the epoch and iteration losses for both training and
validation.
- add_to_iter_history: records the iteration, epoch, and past iteration losses
into the dictionaries set up by book_keeping.
- clear_iter_history: resets the current iteration history of losses after moving
the current iteration history into past iteration history.
Args:
train_func (func): Function that provides the closure for training for
a single batch.
val_func (func, optional): Function that provides the closure for
validating a single batch. Defaults to None.
device (str, optional): Device to move tensors to. Defaults to 'cpu'.
"""
# Set up engines for training and validation
trainer = Engine(train_func)
trainer.register_events(*ValidationEvents)
trainer.register_events(*BackwardsEvents)
validator = None if val_func is None else Engine(val_func)
# Before a batch starts, the items should be float and moved to the
# correct device, for both training and validation. Checks to make
# sure "cuda" is available if user requested cuda.
device = device if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
def prepare_batch(engine):
batch = engine.state.batch
for key in batch:
if torch.is_tensor(batch[key]):
batch[key] = batch[key].float().to(device)
engine.state.batch = batch
# Set up stuff for bookkeeping as training progresses.
def book_keeping(engine):
engine.state.epoch_history = {}
engine.state.iter_history = {}
engine.state.past_iter_history = {}
def add_to_iter_history(engine):
for key in engine.state.output:
if key not in engine.state.iter_history:
engine.state.iter_history[key] = []
if key not in engine.state.past_iter_history:
engine.state.past_iter_history[key] = []
engine.state.iter_history[key].append(
engine.state.output[key]
)
engine.state.past_iter_history[key].append(
engine.state.iter_history[key]
)
def clear_iter_history(engine):
engine.state.iter_history = {}
trainer.add_event_handler(
Events.ITERATION_STARTED, prepare_batch)
trainer.add_event_handler(
Events.STARTED, book_keeping)
trainer.add_event_handler(
Events.ITERATION_COMPLETED, add_to_iter_history)
trainer.add_event_handler(
Events.EPOCH_STARTED, clear_iter_history)
if validator is not None:
validator.add_event_handler(
Events.ITERATION_STARTED, prepare_batch)
validator.add_event_handler(
Events.STARTED, book_keeping)
validator.add_event_handler(
Events.ITERATION_COMPLETED, add_to_iter_history)
validator.add_event_handler(
Events.EPOCH_STARTED, clear_iter_history)
return trainer, validator
[docs]def add_validate_and_checkpoint(output_folder, model, optimizer, train_data, trainer,
val_data=None, validator=None):
"""
This adds the following handler to the trainer:
- validate_and_checkpoint: this runs the validator on the validation dataset
(``val_data``) using a defined validation process function ``val_func``.
These are optional. If these are not provided, then no validator is run
and the model is simply checkpointed. The model is always saved to
``{output_folder}/checkpoints/latest.model.pth``. If the model is also the
one with the lowest validation loss, then it is *also* saved to
``{output_folder}/checkpoints/best.model.pth. This is attached to
``Events.EPOCH_COMPLETED`` on the trainer. After completion, it fires a
``ValidationEvents.VALIDATION_COMPLETED`` event.
Args:
model (torch.nn.Module): Model that is being trained (typically a SeparationModel).
optimizer (torch.optim.Optimizer): Optimizer being used to train.
train_data (BaseDataset): dataset that is being used to train the model. This is to
save additional metadata information alongside the model checkpoint such as the
STFTParams, dataset folder, length, list of transforms, etc.
trainer (ignite.Engine): Engine for trainer
validator (ignite.Engine, optional): Engine for validation.
Defaults to None.
val_data (torch.utils.data.Dataset, optional): The validation data.
Defaults to None.
"""
# When the trainer finishes an epoch, it should validate and save
# the model.
@trainer.on(Events.EPOCH_COMPLETED)
def validate_and_checkpoint(trainer):
trainer.fire_event(ValidationEvents.VALIDATION_STARTED)
is_best = True
if validator is not None:
validator.run(val_data)
for key in validator.state.iter_history:
_key = f"validation/{key}"
if _key not in trainer.state.epoch_history:
trainer.state.epoch_history[_key] = []
trainer.state.epoch_history[_key].append(np.mean(
validator.state.iter_history[key]
))
if 'validation/loss' in trainer.state.epoch_history:
cur = trainer.state.epoch_history['validation/loss'][-1]
is_best = cur == min(trainer.state.epoch_history['validation/loss'])
for key in trainer.state.iter_history:
_key = f"train/{key}"
if _key not in trainer.state.epoch_history:
trainer.state.epoch_history[_key] = []
trainer.state.epoch_history[_key].append(np.mean(
trainer.state.iter_history[key]
))
output_paths = [os.path.join(
output_folder, 'checkpoints', 'latest.model.pth')]
if is_best:
output_paths.append(os.path.join(
output_folder, 'checkpoints', 'best.model.pth'
))
_transform = copy.deepcopy(train_data.transform)
if isinstance(_transform, datasets.transforms.Compose):
for t in _transform.transforms:
if isinstance(t, datasets.transforms.Cache):
_transform.transforms.remove(t)
metadata = {
'stft_params': train_data.stft_params,
'sample_rate': train_data.sample_rate,
'num_channels': train_data.num_channels,
'folder': train_data.folder,
'transforms': _transform,
'trainer.state_dict': {
'epoch': trainer.state.epoch,
'epoch_length': trainer.state.epoch_length,
'max_epochs': trainer.state.max_epochs,
'output': trainer.state.output,
'metrics': trainer.state.metrics,
'seed': trainer.state.seed,
},
'trainer.state.epoch_history': trainer.state.epoch_history,
}
for _path in output_paths:
os.makedirs(os.path.join(
output_folder, 'checkpoints'), exist_ok=True)
if isinstance(model, nn.DataParallel):
_model = model.module
else:
_model = model
_model.save(_path, {'metadata': metadata})
torch.save(optimizer.state_dict(),
_path.replace('model.pth', 'optimizer.pth'))
trainer.state.saved_model_path = output_paths[-1]
trainer.state.output_folder = output_folder
trainer.fire_event(ValidationEvents.VALIDATION_COMPLETED)
[docs]def add_stdout_handler(trainer, validator=None):
"""
This adds the following handler to the trainer engine, and also sets up
Timers:
- log_epoch_to_stdout: This logs the results of a model after it has trained
for a single epoch on both the training and validation set. The output typically
looks like this:
.. code-block:: none
EPOCH SUMMARY
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Epoch number: 0010 / 0010
- Training loss: 0.583591
- Validation loss: 0.137209
- Epoch took: 00:00:03
- Time since start: 00:00:32
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Saving to test.
Output @ tests/local/trainer
Args:
trainer (ignite.Engine): Engine for trainer
validator (ignite.Engine, optional): Engine for validation.
Defaults to None.
"""
# Set up timers for overall time taken and each epoch
overall_timer = Timer(average=False)
overall_timer.attach(trainer,
start=Events.STARTED, pause=Events.COMPLETED)
epoch_timer = Timer(average=False)
epoch_timer.attach(
trainer, start=Events.EPOCH_STARTED,
pause=ValidationEvents.VALIDATION_COMPLETED
)
@trainer.on(ValidationEvents.VALIDATION_COMPLETED)
def log_epoch_to_stdout(trainer):
epoch_time = epoch_timer.value()
epoch_time = time.strftime(
"%H:%M:%S", time.gmtime(epoch_time))
overall_time = overall_timer.value()
overall_time = time.strftime(
"%H:%M:%S", time.gmtime(overall_time))
epoch_number = trainer.state.epoch
total_epochs = trainer.state.max_epochs
try:
validation_loss = (
f"{trainer.state.epoch_history['validation/loss'][-1]:04f}")
except:
validation_loss = 'N/A'
train_loss = trainer.state.epoch_history['train/loss'][-1]
saved_model_path = trainer.state.saved_model_path
logging_str = (
f"\n\n"
f"EPOCH SUMMARY \n"
f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n"
f"- Epoch number: {epoch_number:04d} / {total_epochs:04d} \n"
f"- Training loss: {train_loss:04f} \n"
f"- Validation loss: {validation_loss} \n"
f"- Epoch took: {epoch_time} \n"
f"- Time since start: {overall_time} \n"
f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n"
f"Saving to {saved_model_path}. \n"
f"Output @ {trainer.state.output_folder} \n"
)
logging.info(logging_str)
[docs]def add_progress_bar_handler(*engines):
"""
Adds a progress bar to each engine. Keeps track of a running
average of the loss as well.
Usage::
.. code-block:: python
tr_engine, val_engine = ...
add_progress_bar_handler(tr_engine, val_engine)
"""
for engine in engines:
output_transform = lambda x: x['loss']
RunningAverage(output_transform=output_transform).attach(engine, 'avg_loss')
ProgressBar().attach(engine, ['avg_loss'])
[docs]def add_tensorboard_handler(tensorboard_folder, engine, every_iteration=False):
"""
Every key in engine.state.epoch_history[-1] is logged to TensorBoard.
Args:
tensorboard_folder (str): Where the tensorboard logs should go.
trainer (ignite.Engine): The engine to log.
every_iteration (bool, optional): Whether to also log the values at every
iteration.
"""
@engine.on(ValidationEvents.VALIDATION_COMPLETED)
def log_to_tensorboard(engine):
writer = SummaryWriter(tensorboard_folder)
for key in engine.state.epoch_history:
writer.add_scalar(
key, engine.state.epoch_history[key][-1], engine.state.epoch)
if every_iteration:
@engine.on(Events.ITERATION_COMPLETED)
def log_iteration_to_tensorboard(engine):
writer = SummaryWriter(tensorboard_folder)
for key in engine.state.iter_history:
writer.add_scalar(
key, engine.state.iter_history[key][-1], engine.state.iteration)