Handling data in nussl¶
nussl comes with a bunch of different useful dataset hooks, along with a handy base class for datasets. Let’s examine what the base class looks like first.
BaseDataset¶
The BaseDataset is an abstract class that has a few useful functions for organizing your data. If you call it, directly, however, it will error out:
[1]:
import nussl
import numpy as np
import matplotlib.pyplot as plt
import time
start_time = time.time()
folder = 'ignored'
base_dataset = nussl.datasets.BaseDataset(folder)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-1-ee86740f50fb> in <module>
7
8 folder = 'ignored'
----> 9 base_dataset = nussl.datasets.BaseDataset(folder)
~/Dropbox/research/nussl_refactor/nussl/datasets/base_dataset.py in __init__(self, folder, transform, sample_rate, stft_params, num_channels, strict_sample_rate, cache_populated)
67 num_channels=None, strict_sample_rate=True, cache_populated=False):
68 self.folder = folder
---> 69 self.items = self.get_items(self.folder)
70 self.transform = transform
71
~/Dropbox/research/nussl_refactor/nussl/datasets/base_dataset.py in get_items(self, folder)
131 list: list of items that should be processed
132 """
--> 133 raise NotImplementedError()
134
135 def __len__(self):
NotImplementedError:
For the dataset to work, two functions must be implemented:
self.get_items
: A function that grabs all the items that you will need to process.self.process_item
: A function that processes a single item.
Let’s build a dataset that returns sums of sine wavs at random frequencies.
[2]:
def make_sine_wave(freq, sample_rate, duration):
dt = 1 / sample_rate
x = np.arange(0.0, duration, dt)
x = np.sin(2 * np.pi * freq * x)
return x
class SineWaves(nussl.datasets.BaseDataset):
def get_items(self, folder):
# ignore folder and return a list
# 100 items in this dataset
items = list(range(100))
return items
def process_item(self, item):
# we're ignoring items and making
# sums of random sine waves
sources = {}
freqs = []
for i in range(3):
freq = np.random.randint(110, 1000)
freqs.append(freq)
_data = make_sine_wave(freq, self.sample_rate, 2)
# this is a helper function in BaseDataset for
# making an audio signal from data
signal = self._load_audio_from_array(_data)
sources[f'sine{i}'] = signal * 1/3
mix = sum(sources.values())
metadata = {
'frequencies': freqs
}
output = {
'mix': mix,
'sources': sources,
'metadata': metadata
}
return output
The primary thing to note here is the format of what is output by the process_item
function. It is a dictionary and must always be a dictionary. The dictionary contains three keys: mix
, sources
, and metadata
. sources
is similarly not a list but a dictionary. The sum of the values of sources
adds up to mix
.
Great, now let’s use this dataset.
[3]:
folder = 'ignored'
sine_wave_dataset = SineWaves(folder, sample_rate=44100)
item = sine_wave_dataset[0]
item
[3]:
{'mix': <nussl.core.audio_signal.AudioSignal at 0x7fe6d592c690>,
'sources': {'sine0': <nussl.core.audio_signal.AudioSignal at 0x7fe6d592c2d0>,
'sine1': <nussl.core.audio_signal.AudioSignal at 0x7fe6d592c5d0>,
'sine2': <nussl.core.audio_signal.AudioSignal at 0x7fe6d592c310>},
'metadata': {'frequencies': [161, 536, 693]}}
We can see that getting an item from the dataset resulted in a dictionary containing AudioSignal objects! And the exact frequencies for each sine tone were saved in the metadata. Now, let’s listen and visualize:
[4]:
def visualize_and_embed(sources, y_axis='mel'):
plt.figure(figsize=(10, 4))
plt.subplot(111)
nussl.utils.visualize_sources_as_masks(
sources, db_cutoff=-60, y_axis=y_axis)
plt.tight_layout()
plt.show()
nussl.play_utils.multitrack(sources, ext='.wav')
visualize_and_embed(item['sources'])

The STFT parameters were inferred from the first time we used the dataset based on the audio signal’s sample rate and the defaults in nussl. To enforce a specific STFT parameter, we can do the following:
[5]:
folder = 'ignored'
stft_params = nussl.STFTParams(window_length=256, hop_length=64)
sine_wave_dataset = SineWaves(folder, sample_rate=44100, stft_params=stft_params)
item = sine_wave_dataset[0]
visualize_and_embed(item['sources'])
print('STFT shape:', item['mix'].stft().shape)

STFT shape: (129, 1380, 1)
Cool! Now let’s look at some of the built-in dataset hooks that ship with nussl.
MUSDB18¶
MUSDB18 is a dataset for music source separation research. The full dataset is available here, but there is a useful functionality where if you don’t have, 7-second clips of each track will be downloaded automatically. In nussl, these get downloaded to ~/.nussl/musdb18
. Let’s set up a MUSDB18 dataset object and visualize/listen to an item from the dataset:
[6]:
musdb = nussl.datasets.MUSDB18(download=True)
i = 40 #or get a random track like this: np.random.randint(len(musdb))
item = musdb[i]
mix = item['mix']
sources = item['sources']
visualize_and_embed(sources)
