Datasets

Base class

class nussl.datasets.BaseDataset(folder, transform=None, sample_rate=None, stft_params=None, num_channels=None, strict_sample_rate=True, cache_populated=False)[source]

The BaseDataset class is the starting point for all dataset hooks in nussl. To subclass BaseDataset, you only have to implement two functions:

  • get_items: a function that is passed the folder and generates a list of items that will be processed by the next function. The number of items in the list will dictate len(dataset). Must return a list.

  • process_item: this function processes a single item in the list generated by get_items. Must return a dictionary.

Methods

get_items(folder)

This function must be implemented by whatever class inherits BaseDataset.

process_item(item)

Each file returned by get_items is processed by this function.

After process_item is called, a set of Transforms can be applied to the output of process_item. If no transforms are defined (self.transforms = None), then the output of process_item is returned by self[i]. For implemented Transforms, see nussl.datasets.transforms. For example, PhaseSpectrumApproximation will add three new keys to the output dictionary of process_item:

  • mix_magnitude: the magnitude spectrogram of the mixture

  • source_magnitudes: the magnitude spectrogram of each source

  • ideal_binary_mask: the ideal binary mask for each source

The transforms are applied in sequence using transforms.Compose. Not all sequences of transforms will be valid (e.g. if you pop a key in one transform but a later transform operates on that key, you will get an error).

For examples of subclassing, see nussl.datasets.hooks.

Parameters
  • folder (str) – location that should be processed to produce the list of files

  • transform (transforms.* object, optional) – A transforms to apply to the output of self.process_item. If using transforms.Compose, each transform will be applied in sequence. Defaults to None.

  • sample_rate (int, optional) – Sample rate to use for each audio files. If audio file sample rate doesn’t match, it will be resampled on the fly. If None, uses the default sample rate. Defaults to None.

  • stft_params (STFTParams, optional) – STFTParams object defining window_length, hop_length, and window_type that will be set for each AudioSignal object. Defaults to None (32ms window length, 8ms hop, ‘hann’ window).

  • num_channels (int, optional) – Number of channels to make each AudioSignal object conform to. If an audio signal in your dataset has fewer channels than num_channels, a warning is raised, as the behavior in this case is undefined. Defaults to None.

  • strict_sample_rate (bool, optional) – Whether to raise an error if

Raises

DataSetException – Exceptions are raised if the output of the implemented functions by the subclass don’t match the specification.

get_items(folder)[source]

This function must be implemented by whatever class inherits BaseDataset. It should return a list of items in the given folder, each of which is processed by process_items in some way to produce mixes, sources, class labels, etc.

Parameters

folder (str) – location that should be processed to produce the list of files.

Returns

list of items that should be processed

Return type

list

process_item(item)[source]

Each file returned by get_items is processed by this function. For example, if each file is a json file containing the paths to the mixture and sources, then this function should parse the json file and load the mixture and sources and return them.

Exact behavior of this functionality is determined by implementation by subclass.

Parameters

item (object) – the item that will be processed by this function. Input depends on implementation of self.get_items.

Returns

This should return a dictionary that gets processed by the transforms.

MUSDB18

class nussl.datasets.MUSDB18(folder=None, is_wav=False, download=False, subsets=None, split=None, **kwargs)[source]

Hook for MUSDB18. Uses the musdb.DB object to access the dataset. If download=True, then the 7s snippets of each track are downloaded to self.folder. If no folder is given, then the tracks are downloaded to ~/.nussl/musdb18.

Getting an item from this dataset with no transforms returns the following dictionary:

{
    'mix': [AudioSignal object containing mix audio],
    'source': {
        'bass': [AudioSignal object containing vocals],
        'drums': [AudioSignal object containing drums],
        'other': [AudioSignal object containing other],
        'vocals': [AudioSignal object containing vocals],
    }
    'metadata': {
        'labels': ['bass', 'drums', 'other', 'vocals']
    }
}

Methods

get_items(folder)

This function must be implemented by whatever class inherits BaseDataset.

process_item(item)

Each file returned by get_items is processed by this function.

Parameters
  • folder (str, optional) – Location that should be processed to produce the list of files. Defaults to None.

  • is_wav (bool, optional) – Expect subfolder with wav files for each source instead of stems, defaults to False.

  • download (bool, optional) – Download sample version of MUSDB18 which includes 7s excerpts. Defaults to False.

  • subsets (list, optional) – Select a musdb subset train or test. Defaults to [‘train’, ‘test’] (all tracks).

  • split (str, optional) – When subset train is loaded, split selects the train/validation split. split=’train’ loads the training split, `split=’valid’ loads the validation split. split=None applies no splitting. Defaults to None.

  • **kwargs – Any additional arguments that are passed up to BaseDataset (see nussl.datasets.BaseDataset).

get_items(folder)[source]

This function must be implemented by whatever class inherits BaseDataset. It should return a list of items in the given folder, each of which is processed by process_items in some way to produce mixes, sources, class labels, etc.

Parameters

folder (str) – location that should be processed to produce the list of files.

Returns

list of items that should be processed

Return type

list

process_item(item)[source]

Each file returned by get_items is processed by this function. For example, if each file is a json file containing the paths to the mixture and sources, then this function should parse the json file and load the mixture and sources and return them.

Exact behavior of this functionality is determined by implementation by subclass.

Parameters

item (object) – the item that will be processed by this function. Input depends on implementation of self.get_items.

Returns

This should return a dictionary that gets processed by the transforms.

WHAM

class nussl.datasets.WHAM(root, mix_folder='mix_clean', mode='min', split='tr', sample_rate=8000, **kwargs)[source]

Hook for the WHAM dataset. Essentially subclasses MixSourceFolder but with presets that are helpful for WHAM, which as the following directory structure:

[wav8k, wav16k]/
  [min, max]/
    [tr, cv, tt]/
        mix_both/
        mix_clean/
        mix_single/
        noise/
        s1/
        s2/
wham_noise/
  tr/
  cv/
  tt/
  metadata/
Parameters

MixSourceFolder ([type]) – [description]

Scaper

class nussl.datasets.Scaper(folder, transform=None, sample_rate=None, stft_params=None, num_channels=None, strict_sample_rate=True, cache_populated=False)[source]

Source separation datasets can be generated using Scaper, a library for automatic soundscape generation. Datasets that are generated with Scaper can be fed into this class easily. Scaper generates a large list of JAMS files which specify the parameters of the soundscape. If the soundscape is generated with save_isolated_events=True, then the audio corresponding to each event in the soundscape will be saved as well.

Below is an example of using Scaper to generate a small dataset of 10 mixtures with 2 sources each. The generated dataset can then be immediately loaded into an instance of nussl.datasets.Scaper for integration into a training or evaluation pipeline.

The sources are output in a dictionary that looks like this:

data['sources] = {
    '{label}::{count}': AudioSignal,
    '{label}::{count}': AudioSignal,
    ...
}

Methods

get_items(folder)

This function must be implemented by whatever class inherits BaseDataset.

process_item(item)

Each file returned by get_items is processed by this function.

For example:

data['sources] = {
    'siren::0': AudioSignal,
    'siren::1': AudioSignal,
    'car_horn::0': AudioSignal,
    ...
}

Getting an item from this dataset with no transforms returns the following dictionary:

{
    'mix': [AudioSignal object containing mix audio],
    'source': {
        '[label0::count]': [AudioSignal object containing label0 audio],
        '[label1::count]': [AudioSignal object containing label1 audio],
        '[label2::count]': [AudioSignal object containing label2 audio],
        '[label3::count]': [AudioSignal object containing label3 audio],
        ...
    }
    'metadata': {
        'jams': [the content of the jams file used to generate the soundscape]
        'labels': ['label0', 'label1', 'label2', 'label3']
    }
}

Example of generating a Scaper dataset and then loading it with nussl:

>>> n_sources = 2
>>> n_mixtures = 10
>>> duration = 3
>>> ref_db = -40
>>> fg_path = '/path/to/foreground/'
>>> output_dir = '/output/path'
>>> for i in range(n_mixtures):
>>>     sc = scaper.Scaper(
>>>         duration, fg_path, fg_path, random_state=i)
>>>     sc.ref_db = ref_db
>>>     sc.sr = 16000
>>>     for j in range(n_sources):
>>>         sc.add_event(
>>>             label=('choose', []),
>>>             source_file=('choose', []),
>>>             source_time=('const', 0),
>>>             event_time=('const', 0),
>>>             event_duration=('const', duration),
>>>             snr=('const', 0),
>>>             pitch_shift=None,
>>>             time_stretch=None
>>>         )
>>>     audio_path = os.path.join(output_dir, f'{i}.wav')
>>>     jams_path = os.path.join(output_dir, f'{i}.jams')
>>>     sc.generate(audio_path, jams_path, save_isolated_events=True)
>>> dataset = nussl.datasets.Scaper(output_dir)
>>> dataset[0] # contains mix, sources, and metadata corresponding to 0.jams.
Raises

DataSetException – if Scaper dataset wasn’t saved with isolated event audio.

get_items(folder)[source]

This function must be implemented by whatever class inherits BaseDataset. It should return a list of items in the given folder, each of which is processed by process_items in some way to produce mixes, sources, class labels, etc.

Parameters

folder (str) – location that should be processed to produce the list of files.

Returns

list of items that should be processed

Return type

list

process_item(item)[source]

Each file returned by get_items is processed by this function. For example, if each file is a json file containing the paths to the mixture and sources, then this function should parse the json file and load the mixture and sources and return them.

Exact behavior of this functionality is determined by implementation by subclass.

Parameters

item (object) – the item that will be processed by this function. Input depends on implementation of self.get_items.

Returns

This should return a dictionary that gets processed by the transforms.

MixSourceFolder

class nussl.datasets.MixSourceFolder(folder, mix_folder='mix', source_folders=None, sample_rate=None, ext=None, **kwargs)[source]

This dataset expects your data to be formatted in the following way:

data/
    mix/
        [file0].wav
        [file1].wav
        [file2].wav
        ...
    [label0]/
        [file0].wav
        [file1].wav
        [file2].wav
        ...
    [label1]/
        [file0].wav
        [file1].wav
        [file2].wav
        ...
    [label2]/
        [file0].wav
        [file1].wav
        [file2].wav
        ...
    ...

Methods

get_items(folder)

This function must be implemented by whatever class inherits BaseDataset.

process_item(item)

Each file returned by get_items is processed by this function.

Note that the the filenames match between the mix folder and each source folder. The source folder names can be whatever you want. Given a file in the self.mix_folder folder, this dataset will look up the corresponding files with the same name in the source folders. These are the source audio files. The sum of the sources should equal the mixture. Each source will be labeled according to the folder name it comes from.

Getting an item from this dataset with no transforms returns the following dictionary:

{
    'mix': [AudioSignal object containing mix audio],
    'source': {
        '[label0]': [AudioSignal object containing label0 audio],
        '[label1]': [AudioSignal object containing label1 audio],
        '[label2]': [AudioSignal object containing label2 audio],
        '[label3]': [AudioSignal object containing label3 audio],
        ...
    }
    'metadata': {
        'labels': ['label0', 'label1', 'label2', 'label3']
    }
}
Parameters
  • folder (str, optional) – Location that should be processed to produce the list of files. Defaults to None.

  • mix_folder (str, optional) – Folder to look in for mixtures. Defaults to ‘mix’.

  • source_folders (list, optional) – List of folders to look in for sources. Path is defined relative to folder. If None, all folders other than mix_folder are treated as the source folders. Defaults to None.

  • ext (list, optional) – Audio extensions to look for in mix_folder. Defaults to [‘.wav’, ‘.flac’, ‘.mp3’].

  • **kwargs – Any additional arguments that are passed up to BaseDataset (see nussl.datasets.BaseDataset).

get_items(folder)[source]

This function must be implemented by whatever class inherits BaseDataset. It should return a list of items in the given folder, each of which is processed by process_items in some way to produce mixes, sources, class labels, etc.

Parameters

folder (str) – location that should be processed to produce the list of files.

Returns

list of items that should be processed

Return type

list

process_item(item)[source]

Each file returned by get_items is processed by this function. For example, if each file is a json file containing the paths to the mixture and sources, then this function should parse the json file and load the mixture and sources and return them.

Exact behavior of this functionality is determined by implementation by subclass.

Parameters

item (object) – the item that will be processed by this function. Input depends on implementation of self.get_items.

Returns

This should return a dictionary that gets processed by the transforms.

Data transforms

Classes

Cache(location[, cache_size, overwrite])

The Cache transform can be placed within a Compose transform.

Compose(transforms)

Composes several transforms together.

GetAudio([mix_key, source_key])

Extracts the audio from each signal in mix_key and source_key.

GetExcerpt(excerpt_length[, time_dim, tf_keys])

Takes in a dictionary containing Torch tensors or numpy arrays and extracts an excerpt from each tensor corresponding to a spectral representation of a specified length in frames.

IndexSources(target_key, index)

Takes in a dictionary containing Torch tensors or numpy arrays and extracts the indexed sources from the set key (usually either source_magnitudes or ideal_binary_mask).

LabelsToOneHot([source_key])

Takes a data dictionary with sources and their keys and converts the keys to a one-hot numpy array using the list in data[‘metadata’][‘labels’] to figure out which index goes where.

MagnitudeSpectrumApproximation([mix_key, …])

Takes a dictionary and looks for two special keys, defined by the arguments mix_key and source_key.

MagnitudeWeights([mix_key, mix_magnitude_key])

Applying time-frequency weights to the deep clustering objective results in a huge performance boost.

PhaseSensitiveSpectrumApproximation([…])

Takes a dictionary and looks for two special keys, defined by the arguments mix_key and source_key.

SumSources(groupings[, group_names, source_key])

Sums sources together.

ToSeparationModel([swap_tf_dims])

Takes in a dictionary containing objects and removes any objects that cannot be passed to SeparationModel (e.g.

Exceptions

TransformException

Exception class for errors when working with transforms in nussl.

class nussl.datasets.transforms.Cache(location, cache_size=1, overwrite=False)[source]

The Cache transform can be placed within a Compose transform. The data dictionary coming into this transform will be saved to the specified location using zarr. Then instead of computing all of the transforms before the cache, one can simply read from the cache. The transforms after this will then be applied to the data dictionary that is read from the cache. A typical pipeline might look like this:

dataset = datasets.Scaper('path/to/scaper/folder')
tfm = transforms.Compose([
    transforms.PhaseSensitiveApproximation(),
    transforms.ToSeparationModel(),
    transforms.Cache('~/.nussl/cache/tag', overwrite=True),
    transforms.GetExcerpt()
])
dataset[0] # first time will write to cache then apply GetExcerpt
dataset.cache_populated = True # switches to reading from cache
dataset[0] # second time will read from cache then apply GetExcerpt
dataset[1] # will error out as it wasn't written to the cache!

dataset.cache_populated = False
for i in range(len(dataset)):
    dataset[i] # every item will get written to cache
dataset.cache_populated = True
dataset[1] # now it exists

dataset = datasets.Scaper('path/to/scaper/folder') # next time around
tfm = transforms.Compose([
    transforms.PhaseSensitiveApproximation(),
    transforms.ToSeparationModel(),
    transforms.Cache('~/.nussl/cache/tag', overwrite=False),
    transforms.GetExcerpt()
])
dataset.cache_populated = True
dataset[0] # will read from cache, which still exists from last time
Parameters

object ([type]) – [description]

class nussl.datasets.transforms.Compose(transforms)[source]

Composes several transforms together. Inspired by torchvision implementation.

Parameters

transforms (list of Transform objects) – list of transforms to compose.

Example

>>> transforms.Compose([
>>>     transforms.MagnitudeSpectrumApproximation(),
>>>     transforms.ToSeparationModel(),
>>> ])
class nussl.datasets.transforms.GetAudio(mix_key='mix', source_key='sources')[source]

Extracts the audio from each signal in mix_key and source_key. These will be at new keys, called mix_audio and source_audio. Can be used for training end-to-end models.

Parameters
  • mix_key (str, optional) – The key to look for in data for the mixture AudioSignal. Defaults to ‘mix’.

  • source_key (str, optional) – The key to look for in the data containing the dict of source AudioSignals. Defaults to ‘sources’.

class nussl.datasets.transforms.GetExcerpt(excerpt_length, time_dim=0, tf_keys=None)[source]

Takes in a dictionary containing Torch tensors or numpy arrays and extracts an excerpt from each tensor corresponding to a spectral representation of a specified length in frames. Can be used to get L-length spectrograms from mixture and source spectrograms. If the data is shorter than the specified length, it is padded to the specified length. If it is longer, a random offset between (0, data_length - specified_length) is chosen. This function assumes that it is being passed data AFTER ToSeparationModel. Thus the time dimension is on axis=1.

Parameters
  • excerpt_length (int) – Specified length of transformed data in frames.

  • time_dim (int) – Which dimension time is on (excerpts are taken along this axis). Defaults to 0.

  • time_frequency_keys (list) – Which keys to look at it in the data dictionary to take excerpts from.

class nussl.datasets.transforms.IndexSources(target_key, index)[source]

Takes in a dictionary containing Torch tensors or numpy arrays and extracts the indexed sources from the set key (usually either source_magnitudes or ideal_binary_mask). Can be used to train single-source separation models (e.g. mix goes in, vocals come out).

You need to know which slice of the source magnitudes or ideal binary mask arrays to extract. The order of the sources in the source magnitudes array will be in alphabetical order according to their source labels.

For example, if source magnitudes has shape (257, 400, 1, 4), and the data is from MUSDB, then the four possible source labels are bass, drums, other, and vocals. The data in source magnitudes is in alphabetical order, so:

# source_magnitudes is an array returned by either MagnitudeSpectrumApproximation
# or PhaseSensitiveSpectrumApproximation
source_magnitudes[..., 0] # bass spectrogram
source_magnitudes[..., 1] # drums spectrogram
source_magnitudes[..., 2] # other spectrogram
source_magnitudes[..., 3] # vocals spectrogram

# ideal_binary_mask is an array returned by either MagnitudeSpectrumApproximation
# or PhaseSensitiveSpectrumApproximation
ideal_binary_mask[..., 0] # bass ibm mask
ideal_binary_mask[..., 1] # drums ibm mask
ideal_binary_mask[..., 2] # other ibm mask
ideal_binary_mask[..., 3] # vocals ibm mask

You can apply this transform to either the source_magnitudes or the ideal_binary_mask or both.

Parameters

object ([type]) – [description]

class nussl.datasets.transforms.LabelsToOneHot(source_key='sources')[source]

Takes a data dictionary with sources and their keys and converts the keys to a one-hot numpy array using the list in data[‘metadata’][‘labels’] to figure out which index goes where.

class nussl.datasets.transforms.MagnitudeSpectrumApproximation(mix_key='mix', source_key='sources')[source]

Takes a dictionary and looks for two special keys, defined by the arguments mix_key and source_key. These default to mix and sources. These values of these keys are used to calculate the magnitude spectrum approximation [1]. The input dictionary is modified to have additional keys:

  • mix_magnitude: The magnitude spectrogram of the mixture audio signal.

  • source_magnitudes: The magnitude spectrograms of each source spectrogram.

  • assignments: The ideal binary assignments for each time-frequency bin.

data[self.source_key] points to a dictionary containing the source names in the keys and the corresponding AudioSignal in the values. The keys are sorted in alphabetical order and then appended to the mask. data[self.source_key] then points to an OrderedDict instead, where the keys are in the same order as in data['source_magnitudes'] and data['assignments'].

This transform uses the STFTParams that are attached to the AudioSignal objects contained in data[mix_key] and data[source_key].

[1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux.

“Phase-sensitive and recognition-boosted speech separation using deep recurrent neural networks.” In 2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE, 2015.

Parameters
  • mix_key (str, optional) – The key to look for in data for the mixture AudioSignal. Defaults to ‘mix’.

  • source_key (str, optional) – The key to look for in the data containing the dict of source AudioSignals. Defaults to ‘sources’.

Raises

TransformException – if the expected keys are not in the dictionary, an Exception is raised.

Returns

Modified version of the input dictionary.

Return type

data

class nussl.datasets.transforms.MagnitudeWeights(mix_key='mix', mix_magnitude_key='mix_magnitude')[source]

Applying time-frequency weights to the deep clustering objective results in a huge performance boost. This transform looks for ‘mix_magnitude’, which is output by either MagnitudeSpectrumApproximation or PhaseSensitiveSpectrumApproximation and puts it into the weights.

[1] Wang, Zhong-Qiu, Jonathan Le Roux, and John R. Hershey. “Alternative objective functions for deep clustering.” 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018.

Parameters

mix_magnitude_key (str) – Which key to look for the mix_magnitude data in.

class nussl.datasets.transforms.PhaseSensitiveSpectrumApproximation(mix_key='mix', source_key='sources', range_min=0.0, range_max=1.0)[source]

Takes a dictionary and looks for two special keys, defined by the arguments mix_key and source_key. These default to mix and sources. These values of these keys are used to calculate the phase sensitive spectrum approximation [1]. The input dictionary is modified to have additional keys:

  • mix_magnitude: The magnitude spectrogram of the mixture audio signal.

  • source_magnitudes: The magnitude spectrograms of each source spectrogram.

  • assignments: The ideal binary assignments for each time-frequency bin.

data[self.source_key] points to a dictionary containing the source names in the keys and the corresponding AudioSignal in the values. The keys are sorted in alphabetical order and then appended to the mask. data[self.source_key] then points to an OrderedDict instead, where the keys are in the same order as in data['source_magnitudes'] and data['assignments'].

This transform uses the STFTParams that are attached to the AudioSignal objects contained in data[mix_key] and data[source_key].

[1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux.

“Phase-sensitive and recognition-boosted speech separation using deep recurrent neural networks.” In 2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE, 2015.

Parameters
  • mix_key (str, optional) – The key to look for in data for the mixture AudioSignal. Defaults to ‘mix’.

  • source_key (str, optional) – The key to look for in the data containing the list of source AudioSignals. Defaults to ‘sources’.

  • range_min (float, optional) – The lower end to use when truncating the source magnitudes in the phase sensitive spectrum approximation. Defaults to 0.0 (construct non-negative masks). Use -np.inf for untruncated source magnitudes.

  • range_max (float, optional) – The higher end of the truncated spectrum. This gets multiplied by the magnitude of the mixture. Use 1.0 to truncate the source magnitudes to max(source_magnitudes, mix_magnitude). Use np.inf for untruncated source magnitudes (best performance for an oracle mask but may be beyond what a neural network is capable of masking). Defaults to 1.0.

Raises

TransformException – if the expected keys are not in the dictionary, an Exception is raised.

Returns

Modified version of the input dictionary.

Return type

data

class nussl.datasets.transforms.SumSources(groupings, group_names=None, source_key='sources')[source]

Sums sources together. Looks for sources in data[self.source_key]. If a source belongs to a group, it is popped from the data[self.source_key] and summed with the other sources in the group. If there is a corresponding group_name in group_names, it is named that in data[self.source_key]. If group_names are not given, then the names are constructed using the keys in each group (e.g. drums+bass+other).

If using Scaper datasets, then there may be multiple sources with the same label but different counts. The Scaper dataset hook organizes the source dictionary as follows:

data['sources] = {
    '{label}::{count}': AudioSignal,
    '{label}::{count}': AudioSignal,
    ...
}

SumSources sums by source label, so the ::count will be ignored and only the label part will be used when grouping sources.

Example

>>> tfm = transforms.SumSources(
>>>     groupings=[['drums', 'bass', 'other]],
>>>     group_names=['accompaniment],
>>> )
>>> # data['sources'] is a dict containing keys:
>>> #   ['vocals', 'drums', 'bass', 'other]
>>> data = tfm(data)
>>> # data['sources'] is now a dict containing keys:
>>> #   ['vocals', 'accompaniment']
Parameters
  • groupings (list) – a list of lists telling how to group each sources.

  • group_names (list, optional) – A list containing the names of each group, or None. Defaults to None.

  • source_key (str, optional) – The key to look for in the data containing the list of source AudioSignals. Defaults to ‘sources’.

Raises
Returns

modified dictionary with summed sources

Return type

data

class nussl.datasets.transforms.ToSeparationModel(swap_tf_dims=None)[source]

Takes in a dictionary containing objects and removes any objects that cannot be passed to SeparationModel (e.g. not a numpy array or torch Tensor). If these objects are passed to SeparationModel, then an error will occur. This class should be the last one in your list of transforms, if you’re using this dataset in a DataLoader object for training a network. If the keys correspond to numpy arrays, they are converted to tensors using torch.from_numpy. Finally, the dimensions corresponding to time and frequency are swapped for all the keys in swap_tf_dims, as this is how SeparationModel expects it.

Example:

data = {
    # 2ch spectrogram for mixture
    'mix_magnitude': torch.randn(513, 400, 2),
    # 2ch spectrogram for each source
    'source_magnitudes': torch.randn(513, 400, 2, 4)
    'mix': AudioSignal()
}

tfm = transforms.ToSeparationModel()
data = tfm(data)

data['mix_magnitude'].shape # (400, 513, 2)
data['source_magnitudes].shape # (400, 513, 2, 4)
'mix' in data.keys() # False

If this class isn’t in your transforms list for the dataset, but you are using it in the Trainer class, then it is added automatically as the last transform.

exception nussl.datasets.transforms.TransformException[source]

Exception class for errors when working with transforms in nussl.