[docs]class BaseDataset(Dataset):
"""
The BaseDataset class is the starting point for all dataset hooks
in nussl. To subclass BaseDataset, you only have to implement two
functions:
- ``get_items``: a function that is passed the folder and generates a
list of items that will be processed by the next function. The
number of items in the list will dictate len(dataset). Must return
a list.
- ``process_item``: this function processes a single item in the list
generated by get_items. Must return a dictionary.
After process_item is called, a set of Transforms can be applied to the
output of process_item. If no transforms are defined (``self.transforms = None``),
then the output of process_item is returned by self[i]. For implemented
Transforms, see nussl.datasets.transforms. For example,
PhaseSpectrumApproximation will add three new keys to the output dictionary
of process_item:
- mix_magnitude: the magnitude spectrogram of the mixture
- source_magnitudes: the magnitude spectrogram of each source
- ideal_binary_mask: the ideal binary mask for each source
The transforms are applied in sequence using transforms.Compose.
Not all sequences of transforms will be valid (e.g. if you pop a key in
one transform but a later transform operates on that key, you will get
an error).
For examples of subclassing, see ``nussl.datasets.hooks``.
Args:
folder (str): location that should be processed to produce the list of files
transform (transforms.* object, optional): A transforms to apply to the output of
``self.process_item``. If using transforms.Compose, each transform will be
applied in sequence. Defaults to None.
sample_rate (int, optional): Sample rate to use for each audio files. If
audio file sample rate doesn't match, it will be resampled on the fly.
If None, uses the default sample rate. Defaults to None.
stft_params (STFTParams, optional): STFTParams object defining window_length,
hop_length, and window_type that will be set for each AudioSignal object.
Defaults to None (32ms window length, 8ms hop, 'hann' window).
num_channels (int, optional): Number of channels to make each AudioSignal
object conform to. If an audio signal in your dataset has fewer channels
than ``num_channels``, a warning is raised, as the behavior in this case
is undefined. Defaults to None.
strict_sample_rate (bool, optional): Whether to raise an error if
Raises:
DataSetException: Exceptions are raised if the output of the implemented
functions by the subclass don't match the specification.
"""
def __init__(self, folder, transform=None, sample_rate=None, stft_params=None,
num_channels=None, strict_sample_rate=True, cache_populated=False):
self.folder = folder
self.items = self.get_items(self.folder)
self.transform = transform
self.cache_populated = cache_populated
self.stft_params = stft_params
self.sample_rate = sample_rate
self.num_channels = num_channels
self.strict_sample_rate = strict_sample_rate
if not isinstance(self.items, list):
raise DataSetException("Output of self.get_items must be a list!")
# getting one item in order to set up parameters for audio
# signals if necessary, if there are any items
if self.items:
self.process_item(self.items[0])
@property
def cache_populated(self):
return self._cache_populated
@cache_populated.setter
def cache_populated(self, value):
self.post_cache_transforms = []
cache_transform = None
transforms = (
self.transform.transforms
if isinstance(self.transform, tfm.Compose)
else [self.transform])
found_cache_transform = False
for t in transforms:
if isinstance(t, tfm.Cache):
found_cache_transform = True
cache_transform = t
if found_cache_transform:
self.post_cache_transforms.append(t)
if not found_cache_transform:
# there is no cache transform
self._cache_populated = False
else:
self._cache_populated = value
cache_transform.cache_size = len(self)
cache_transform.overwrite = not value
self.post_cache_transforms = tfm.Compose(
self.post_cache_transforms)
[docs] def get_items(self, folder):
"""
This function must be implemented by whatever class inherits BaseDataset.
It should return a list of items in the given folder, each of which is
processed by process_items in some way to produce mixes, sources, class
labels, etc.
Args:
folder (str): location that should be processed to produce the list of files.
Returns:
list: list of items that should be processed
"""
raise NotImplementedError()
def __len__(self):
"""
Gets the length of the dataset (the number of items that will be processed).
Returns:
int: Length of the dataset (``len(self.items)``).
"""
return len(self.items)
def __getitem__(self, i):
"""
Processes a single item in ``self.items`` using ``self.process_item``.
The output of ``self.process_item`` is further passed through bunch of
of transforms if they are defined in parallel. If you want to have
a set of transforms that depend on each other, then you should compose them
into a single transforms and then pass it into here. The output of each
transform is added to an output dictionary which is returned by this
function.
Args:
i (int): Index of the dataset to return. Indexes ``self.items``.
Returns:
dict: Dictionary with keys and values corresponding to the processed
item after being put through the set of transforms (if any are
defined).
"""
if self.cache_populated:
data = {'index': i}
data = self.post_cache_transforms(data)
else:
data = self.process_item(self.items[i])
if not isinstance(data, dict):
raise DataSetException(
"The output of process_item must be a dictionary!")
if self.transform:
data['index'] = i
data = self.transform(data)
if not isinstance(data, dict):
raise tfm.TransformException(
"The output of transform must be a dictionary!")
return data
[docs] def process_item(self, item):
"""Each file returned by get_items is processed by this function. For example,
if each file is a json file containing the paths to the mixture and sources,
then this function should parse the json file and load the mixture and sources
and return them.
Exact behavior of this functionality is determined by implementation by subclass.
Args:
item (object): the item that will be processed by this function. Input depends
on implementation of ``self.get_items``.
Returns:
This should return a dictionary that gets processed by the transforms.
"""
raise NotImplementedError()
def _load_audio_file(self, path_to_audio_file):
"""
Loads audio file at given path. Uses AudioSignal to load the audio data
from disk.
Args:
path_to_audio_file: relative or absolute path to file to load
Returns:
AudioSignal: loaded AudioSignal object of path_to_audio_file
"""
audio_signal = AudioSignal(path_to_audio_file)
self._setup_audio_signal(audio_signal)
return audio_signal
def _load_audio_from_array(self, audio_data, sample_rate=None):
"""
Loads the audio data into an AudioSignal object with the appropriate
sample rate.
Args:
audio_data (np.ndarray): numpy array containing the samples containing
the audio data.
sample_rate (int): the sample rate at which to load the audio file.
If None, self.sample_rate or the sample rate of the actual file is used.
Defaults to None.
Returns:
AudioSignal: loaded AudioSignal object of audio_data
"""
sample_rate = sample_rate if sample_rate else self.sample_rate
audio_signal = AudioSignal(
audio_data_array=audio_data, sample_rate=sample_rate)
self._setup_audio_signal(audio_signal)
return audio_signal
def _setup_audio_signal(self, audio_signal):
"""
You will want every item from a dataset to be uniform in sample rate, STFT
parameters, and number of channels. This function takes an audio signal
object loaded by the dataset and uses it to set the sample rate, STFT parameters,
and the number of channels. If ``self.sample_rate``, ``self.stft_params``, and
``self.num_channels`` are set at construction time of the dataset, then the
opposite happens - attributes of the AudioSignal object are set to the desired
values.
Args:
audio_signal (AudioSignal): AudioSignal object to query to set the parameters
of this dataset or to set the parameters of, according to what is in the
dataset.
"""
if self.sample_rate and self.sample_rate != audio_signal.sample_rate:
if self.strict_sample_rate:
raise DataSetException(
f"All audio files should have been the same sample rate already "
f"because self.strict_sample_rate = True. Please resample or "
f"turn set self.strict_sample_rate = False"
)
audio_signal.resample(self.sample_rate)
else:
self.sample_rate = audio_signal.sample_rate
# set audio signal attributes to requested values, if they exist
if self.stft_params:
audio_signal.stft_params = self.stft_params
else:
self.stft_params = audio_signal.stft_params
if self.num_channels:
if audio_signal.num_channels > self.num_channels:
# pick the first ``self.num_channels`` channels
audio_signal.audio_data = audio_signal.audio_data[:self.num_channels]
elif audio_signal.num_channels < self.num_channels:
warnings.warn(
f"AudioSignal had {audio_signal.num_channels} channels "
f"but self.num_channels = {self.num_channels}. Unsure "
f"of what to do, so warning. You might want to make sure "
f"your dataset is uniform!"
)
else:
self.num_channels = audio_signal.num_channels
class DataSetException(Exception):
"""
Exception class for errors when working with data sets in nussl.
"""
pass