Source code for nussl.datasets.transforms

import os
import shutil
import logging
import random
from collections import OrderedDict

import torch
import zarr
import numcodecs
import numpy as np
from sklearn.preprocessing import OneHotEncoder

from .. import utils

# This is for when you're running multiple
# training threads
if hasattr(numcodecs, 'blosc'):
    numcodecs.blosc.use_threads = False

def compute_ideal_binary_mask(source_magnitudes):
    ibm = (
            source_magnitudes == np.max(source_magnitudes, axis=-1, keepdims=True)
    ).astype(float)

    ibm = ibm / np.sum(ibm, axis=-1, keepdims=True)
    ibm[ibm <= .5] = 0
    return ibm


# Keys that correspond to the time-frequency representations after being passed through
# the transforms here.
time_frequency_keys = ['mix_magnitude', 'source_magnitudes', 'ideal_binary_mask', 'weights']

[docs]class SumSources(object): """ Sums sources together. Looks for sources in ``data[self.source_key]``. If a source belongs to a group, it is popped from the ``data[self.source_key]`` and summed with the other sources in the group. If there is a corresponding group_name in group_names, it is named that in ``data[self.source_key]``. If group_names are not given, then the names are constructed using the keys in each group (e.g. `drums+bass+other`). If using Scaper datasets, then there may be multiple sources with the same label but different counts. The Scaper dataset hook organizes the source dictionary as follows: .. code-block:: none data['sources] = { '{label}::{count}': AudioSignal, '{label}::{count}': AudioSignal, ... } SumSources sums by source label, so the ``::count`` will be ignored and only the label part will be used when grouping sources. Example: >>> tfm = transforms.SumSources( >>> groupings=[['drums', 'bass', 'other]], >>> group_names=['accompaniment], >>> ) >>> # data['sources'] is a dict containing keys: >>> # ['vocals', 'drums', 'bass', 'other] >>> data = tfm(data) >>> # data['sources'] is now a dict containing keys: >>> # ['vocals', 'accompaniment'] Args: groupings (list): a list of lists telling how to group each sources. group_names (list, optional): A list containing the names of each group, or None. Defaults to None. source_key (str, optional): The key to look for in the data containing the list of source AudioSignals. Defaults to 'sources'. Raises: TransformException: if groupings is not a list TransformException: if group_names is not None but len(groupings) != len(group_names) Returns: data: modified dictionary with summed sources """ def __init__(self, groupings, group_names=None, source_key='sources'): if not isinstance(groupings, list): raise TransformException( f"groupings must be a list, got {type(groupings)}!") if group_names: if len(group_names) != len(groupings): raise TransformException( f"group_names and groupings must be same length or " f"group_names can be None! Got {len(group_names)} for " f"len(group_names) and {len(groupings)} for len(groupings)." ) self.groupings = groupings self.source_key = source_key if group_names is None: group_names = ['+'.join(groupings[i]) for i in range(len(groupings))] self.group_names = group_names def __call__(self, data): if self.source_key not in data: raise TransformException( f"Expected {self.source_key} in dictionary " f"passed to this Transform!" ) sources = data[self.source_key] source_keys = [(k.split('::')[0], k) for k in list(sources.keys())] for i, group in enumerate(self.groupings): combined = [] group_name = self.group_names[i] for key1 in group: for key2 in source_keys: if key2[0] == key1: combined.append(sources[key2[1]]) sources.pop(key2[1]) sources[group_name] = sum(combined) sources[group_name].path_to_input_file = group_name data[self.source_key] = sources if 'metadata' in data: if 'labels' in data['metadata']: data['metadata']['labels'].extend(self.group_names) return data def __repr__(self): return ( f"{self.__class__.__name__}(" f"groupings = {self.groupings}, " f"group_names = {self.group_names}, " f"source_key = {self.source_key}" f")" )
[docs]class LabelsToOneHot(object): """ Takes a data dictionary with sources and their keys and converts the keys to a one-hot numpy array using the list in data['metadata']['labels'] to figure out which index goes where. """ def __init__(self, source_key='sources'): self.source_key = source_key def __call__(self, data): if 'metadata' not in data: raise TransformException( f"Expected metadata in data, got {list(data.keys())}") if 'labels' not in data['metadata']: raise TransformException( f"Expected labels in data['metadata'], got " f"{list(data['metadata'].keys())}") enc = OneHotEncoder(categories=[data['metadata']['labels']]) sources = data[self.source_key] source_keys = [k.split('::')[0] for k in list(sources.keys())] source_labels = [[l] for l in sorted(source_keys)] one_hot_labels = enc.fit_transform(source_labels) data['one_hot_labels'] = one_hot_labels.toarray() return data
[docs]class MagnitudeSpectrumApproximation(object): """ Takes a dictionary and looks for two special keys, defined by the arguments ``mix_key`` and ``source_key``. These default to `mix` and `sources`. These values of these keys are used to calculate the magnitude spectrum approximation [1]. The input dictionary is modified to have additional keys: - mix_magnitude: The magnitude spectrogram of the mixture audio signal. - source_magnitudes: The magnitude spectrograms of each source spectrogram. - assignments: The ideal binary assignments for each time-frequency bin. ``data[self.source_key]`` points to a dictionary containing the source names in the keys and the corresponding AudioSignal in the values. The keys are sorted in alphabetical order and then appended to the mask. ``data[self.source_key]`` then points to an OrderedDict instead, where the keys are in the same order as in ``data['source_magnitudes']`` and ``data['assignments']``. This transform uses the STFTParams that are attached to the AudioSignal objects contained in ``data[mix_key]`` and ``data[source_key]``. [1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux. "Phase-sensitive and recognition-boosted speech separation using deep recurrent neural networks." In 2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE, 2015. Args: mix_key (str, optional): The key to look for in data for the mixture AudioSignal. Defaults to 'mix'. source_key (str, optional): The key to look for in the data containing the dict of source AudioSignals. Defaults to 'sources'. Raises: TransformException: if the expected keys are not in the dictionary, an Exception is raised. Returns: data: Modified version of the input dictionary. """ def __init__(self, mix_key='mix', source_key='sources'): self.mix_key = mix_key self.source_key = source_key def __call__(self, data): if self.mix_key not in data: raise TransformException( f"Expected {self.mix_key} in dictionary " f"passed to this Transform! Got {list(data.keys())}." ) mixture = data[self.mix_key] mixture.stft() mix_magnitude = mixture.magnitude_spectrogram_data data['mix_magnitude'] = mix_magnitude if self.source_key not in data: return data _sources = data[self.source_key] source_names = sorted(list(_sources.keys())) sources = OrderedDict() for key in source_names: sources[key] = _sources[key] data[self.source_key] = sources source_magnitudes = [] for key in source_names: s = sources[key] s.stft() source_magnitudes.append(s.magnitude_spectrogram_data) source_magnitudes = np.stack(source_magnitudes, axis=-1) data['ideal_binary_mask'] = compute_ideal_binary_mask(source_magnitudes) data['source_magnitudes'] = source_magnitudes return data def __repr__(self): return ( f"{self.__class__.__name__}(" f"mix_key = {self.mix_key}, " f"source_key = {self.source_key}" f")" )
[docs]class MagnitudeWeights(object): """ Applying time-frequency weights to the deep clustering objective results in a huge performance boost. This transform looks for 'mix_magnitude', which is output by either MagnitudeSpectrumApproximation or PhaseSensitiveSpectrumApproximation and puts it into the weights. [1] Wang, Zhong-Qiu, Jonathan Le Roux, and John R. Hershey. "Alternative objective functions for deep clustering." 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018. Args: mix_magnitude_key (str): Which key to look for the mix_magnitude data in. """ def __init__(self, mix_key='mix', mix_magnitude_key='mix_magnitude'): self.mix_magnitude_key = mix_magnitude_key self.mix_key = mix_key def __call__(self, data): if self.mix_magnitude_key not in data and self.mix_key not in data: raise TransformException( f"Expected {self.mix_magnitude_key} or {self.mix_key} in dictionary " f"passed to this Transform! Got {list(data.keys())}. " "Either MagnitudeSpectrumApproximation or " "PhaseSensitiveSpectrumApproximation should be called " "on the data dict prior to this transform. " ) elif self.mix_magnitude_key not in data: data[self.mix_magnitude_key] = np.abs(data[self.mix_key].stft()) magnitude_spectrogram = data[self.mix_magnitude_key] weights = magnitude_spectrogram / (np.sum(magnitude_spectrogram) + 1e-6) weights *= ( magnitude_spectrogram.shape[0] * magnitude_spectrogram.shape[1] ) data['weights'] = np.sqrt(weights) return data
[docs]class PhaseSensitiveSpectrumApproximation(object): """ Takes a dictionary and looks for two special keys, defined by the arguments ``mix_key`` and ``source_key``. These default to `mix` and `sources`. These values of these keys are used to calculate the phase sensitive spectrum approximation [1]. The input dictionary is modified to have additional keys: - mix_magnitude: The magnitude spectrogram of the mixture audio signal. - source_magnitudes: The magnitude spectrograms of each source spectrogram. - assignments: The ideal binary assignments for each time-frequency bin. ``data[self.source_key]`` points to a dictionary containing the source names in the keys and the corresponding AudioSignal in the values. The keys are sorted in alphabetical order and then appended to the mask. ``data[self.source_key]`` then points to an OrderedDict instead, where the keys are in the same order as in ``data['source_magnitudes']`` and ``data['assignments']``. This transform uses the STFTParams that are attached to the AudioSignal objects contained in ``data[mix_key]`` and ``data[source_key]``. [1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux. "Phase-sensitive and recognition-boosted speech separation using deep recurrent neural networks." In 2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE, 2015. Args: mix_key (str, optional): The key to look for in data for the mixture AudioSignal. Defaults to 'mix'. source_key (str, optional): The key to look for in the data containing the list of source AudioSignals. Defaults to 'sources'. range_min (float, optional): The lower end to use when truncating the source magnitudes in the phase sensitive spectrum approximation. Defaults to 0.0 (construct non-negative masks). Use -np.inf for untruncated source magnitudes. range_max (float, optional): The higher end of the truncated spectrum. This gets multiplied by the magnitude of the mixture. Use 1.0 to truncate the source magnitudes to `max(source_magnitudes, mix_magnitude)`. Use np.inf for untruncated source magnitudes (best performance for an oracle mask but may be beyond what a neural network is capable of masking). Defaults to 1.0. Raises: TransformException: if the expected keys are not in the dictionary, an Exception is raised. Returns: data: Modified version of the input dictionary. """ def __init__(self, mix_key='mix', source_key='sources', range_min=0.0, range_max=1.0): self.mix_key = mix_key self.source_key = source_key self.range_min = range_min self.range_max = range_max def __call__(self, data): if self.mix_key not in data: raise TransformException( f"Expected {self.mix_key} in dictionary " f"passed to this Transform! Got {list(data.keys())}." ) mixture = data[self.mix_key] mix_stft = mixture.stft() mix_magnitude = np.abs(mix_stft) mix_angle = np.angle(mix_stft) data['mix_magnitude'] = mix_magnitude if self.source_key not in data: return data _sources = data[self.source_key] source_names = sorted(list(_sources.keys())) sources = OrderedDict() for key in source_names: sources[key] = _sources[key] data[self.source_key] = sources source_angles = [] source_magnitudes = [] for key in source_names: s = sources[key] _stft = s.stft() source_magnitudes.append(np.abs(_stft)) source_angles.append(np.angle(_stft)) source_magnitudes = np.stack(source_magnitudes, axis=-1) source_angles = np.stack(source_angles, axis=-1) range_min = self.range_min range_max = self.range_max * mix_magnitude[..., None] # Section 3.1: https://arxiv.org/pdf/1909.08494.pdf source_magnitudes = np.minimum( np.maximum( source_magnitudes * np.cos(source_angles - mix_angle[..., None]), range_min ), range_max ) data['ideal_binary_mask'] = compute_ideal_binary_mask(source_magnitudes) data['source_magnitudes'] = source_magnitudes return data def __repr__(self): return ( f"{self.__class__.__name__}(" f"mix_key = {self.mix_key}, " f"source_key = {self.source_key}" f")" )
[docs]class IndexSources(object): """ Takes in a dictionary containing Torch tensors or numpy arrays and extracts the indexed sources from the set key (usually either `source_magnitudes` or `ideal_binary_mask`). Can be used to train single-source separation models (e.g. mix goes in, vocals come out). You need to know which slice of the source magnitudes or ideal binary mask arrays to extract. The order of the sources in the source magnitudes array will be in alphabetical order according to their source labels. For example, if source magnitudes has shape `(257, 400, 1, 4)`, and the data is from MUSDB, then the four possible source labels are bass, drums, other, and vocals. The data in source magnitudes is in alphabetical order, so: .. code-block:: python # source_magnitudes is an array returned by either MagnitudeSpectrumApproximation # or PhaseSensitiveSpectrumApproximation source_magnitudes[..., 0] # bass spectrogram source_magnitudes[..., 1] # drums spectrogram source_magnitudes[..., 2] # other spectrogram source_magnitudes[..., 3] # vocals spectrogram # ideal_binary_mask is an array returned by either MagnitudeSpectrumApproximation # or PhaseSensitiveSpectrumApproximation ideal_binary_mask[..., 0] # bass ibm mask ideal_binary_mask[..., 1] # drums ibm mask ideal_binary_mask[..., 2] # other ibm mask ideal_binary_mask[..., 3] # vocals ibm mask You can apply this transform to either the `source_magnitudes` or the `ideal_binary_mask` or both. Args: object ([type]): [description] """ def __init__(self, target_key, index): self.target_key = target_key self.index = index def __call__(self, data): if self.target_key not in data: raise TransformException( f"Expected {self.target_key} in dictionary, got {list(data.keys())}") if self.index >= data[self.target_key].shape[-1]: raise TransformException( f"Shape of data[{self.target_key}] is {data[self.target_key].shape} " f"but index = {self.index} out of bounds bounds of last dim.") data[self.target_key] = data[self.target_key][..., self.index, None] return data
[docs]class GetExcerpt(object): """ Takes in a dictionary containing Torch tensors or numpy arrays and extracts an excerpt from each tensor corresponding to a spectral representation of a specified length in frames. Can be used to get L-length spectrograms from mixture and source spectrograms. If the data is shorter than the specified length, it is padded to the specified length. If it is longer, a random offset between ``(0, data_length - specified_length)`` is chosen. This function assumes that it is being passed data AFTER ToSeparationModel. Thus the time dimension is on axis=1. Args: excerpt_length (int): Specified length of transformed data in frames. time_dim (int): Which dimension time is on (excerpts are taken along this axis). Defaults to 0. time_frequency_keys (list): Which keys to look at it in the data dictionary to take excerpts from. """ def __init__(self, excerpt_length, time_dim=0, tf_keys=None): self.excerpt_length = excerpt_length self.time_dim = time_dim self.time_frequency_keys = tf_keys if tf_keys else time_frequency_keys @staticmethod def _validate(data, key): is_tensor = torch.is_tensor(data[key]) is_array = isinstance(data[key], np.ndarray) if not is_tensor and not is_array: raise TransformException( f"data[{key}] was not a torch Tensor or a numpy array!") return is_tensor, is_array def _get_offset(self, data, key): self._validate(data, key) data_length = data[key].shape[self.time_dim] if data_length >= self.excerpt_length: offset = random.randint(0, data_length - self.excerpt_length) else: offset = 0 pad_amount = max(0, self.excerpt_length - data_length) return offset, pad_amount def _construct_pad_func_tuple(self, shape, pad_amount, is_tensor): if is_tensor: pad_func = torch.nn.functional.pad pad_tuple = [0 for _ in range(2 * len(shape))] pad_tuple[2 * self.time_dim] = pad_amount pad_tuple = pad_tuple[::-1] else: pad_func = np.pad pad_tuple = [(0, 0) for _ in range(len(shape))] pad_tuple[self.time_dim] = (0, pad_amount) return pad_func, pad_tuple def __call__(self, data): offset, pad_amount = self._get_offset( data, self.time_frequency_keys[0]) for key in data: if key in self.time_frequency_keys: is_tensor, is_array = self._validate(data, key) if pad_amount > 0: pad_func, pad_tuple = self._construct_pad_func_tuple( data[key].shape, pad_amount, is_tensor) data[key] = pad_func(data[key], pad_tuple) data[key] = utils._slice_along_dim( data[key], self.time_dim, offset, offset + self.excerpt_length) return data
[docs]class Cache(object): """ The Cache transform can be placed within a Compose transform. The data dictionary coming into this transform will be saved to the specified location using ``zarr``. Then instead of computing all of the transforms before the cache, one can simply read from the cache. The transforms after this will then be applied to the data dictionary that is read from the cache. A typical pipeline might look like this: .. code-block:: python dataset = datasets.Scaper('path/to/scaper/folder') tfm = transforms.Compose([ transforms.PhaseSensitiveApproximation(), transforms.ToSeparationModel(), transforms.Cache('~/.nussl/cache/tag', overwrite=True), transforms.GetExcerpt() ]) dataset[0] # first time will write to cache then apply GetExcerpt dataset.cache_populated = True # switches to reading from cache dataset[0] # second time will read from cache then apply GetExcerpt dataset[1] # will error out as it wasn't written to the cache! dataset.cache_populated = False for i in range(len(dataset)): dataset[i] # every item will get written to cache dataset.cache_populated = True dataset[1] # now it exists dataset = datasets.Scaper('path/to/scaper/folder') # next time around tfm = transforms.Compose([ transforms.PhaseSensitiveApproximation(), transforms.ToSeparationModel(), transforms.Cache('~/.nussl/cache/tag', overwrite=False), transforms.GetExcerpt() ]) dataset.cache_populated = True dataset[0] # will read from cache, which still exists from last time Args: object ([type]): [description] """ def __init__(self, location, cache_size=1, overwrite=False): self.location = location self.cache_size = cache_size self.cache = None self.overwrite = overwrite @property def info(self): return self.cache.info @property def overwrite(self): return self._overwrite @overwrite.setter def overwrite(self, value): self._overwrite = value self._clear_cache(self.location) self._open_cache(self.location) def _clear_cache(self, location): if os.path.exists(location): if self.overwrite: logging.info( f"Cache {location} exists and overwrite = True, clearing cache.") shutil.rmtree(location, ignore_errors=True) def _open_cache(self, location): if self.overwrite: self.cache = zarr.open(location, mode='w', shape=(self.cache_size,), chunks=(1,), dtype=object, object_codec=numcodecs.Pickle(), synchronizer=zarr.ThreadSynchronizer()) else: if os.path.exists(location): self.cache = zarr.open(location, mode='r', object_codec=numcodecs.Pickle(), synchronizer=zarr.ThreadSynchronizer()) def __call__(self, data): if 'index' not in data: raise TransformException( f"Expected 'index' in dictionary, got {list(data.keys())}") index = data['index'] if self.overwrite: self.cache[index] = data data = self.cache[index] if not isinstance(data, dict): raise TransformException( f"Reading from cache resulted in not a dictionary! " f"Maybe you haven't written to index {index} yet in " f"the cache?") return data
[docs]class GetAudio(object): """ Extracts the audio from each signal in `mix_key` and `source_key`. These will be at new keys, called `mix_audio` and `source_audio`. Can be used for training end-to-end models. Args: mix_key (str, optional): The key to look for in data for the mixture AudioSignal. Defaults to 'mix'. source_key (str, optional): The key to look for in the data containing the dict of source AudioSignals. Defaults to 'sources'. """ def __init__(self, mix_key='mix', source_key='sources'): self.mix_key = mix_key self.source_key = source_key def __call__(self, data): if self.mix_key not in data: raise TransformException( f"Expected {self.mix_key} in dictionary " f"passed to this Transform! Got {list(data.keys())}." ) mix = data[self.mix_key] data['mix_audio'] = mix.audio_data if self.source_key not in data: return data _sources = data[self.source_key] source_names = sorted(list(_sources.keys())) source_audio = [] for key in source_names: source_audio.append(_sources[key].audio_data) # sources on last axis source_audio = np.stack(source_audio, axis=-1) data['source_audio'] = source_audio return data
[docs]class ToSeparationModel(object): """ Takes in a dictionary containing objects and removes any objects that cannot be passed to SeparationModel (e.g. not a numpy array or torch Tensor). If these objects are passed to SeparationModel, then an error will occur. This class should be the last one in your list of transforms, if you're using this dataset in a DataLoader object for training a network. If the keys correspond to numpy arrays, they are converted to tensors using ``torch.from_numpy``. Finally, the dimensions corresponding to time and frequency are swapped for all the keys in swap_tf_dims, as this is how SeparationModel expects it. Example: .. code-block:: none data = { # 2ch spectrogram for mixture 'mix_magnitude': torch.randn(513, 400, 2), # 2ch spectrogram for each source 'source_magnitudes': torch.randn(513, 400, 2, 4) 'mix': AudioSignal() } tfm = transforms.ToSeparationModel() data = tfm(data) data['mix_magnitude'].shape # (400, 513, 2) data['source_magnitudes].shape # (400, 513, 2, 4) 'mix' in data.keys() # False If this class isn't in your transforms list for the dataset, but you are using it in the Trainer class, then it is added automatically as the last transform. """ def __init__(self, swap_tf_dims=None): self.swap_tf_dims = swap_tf_dims if swap_tf_dims else time_frequency_keys def __call__(self, data): keys = list(data.keys()) for key in keys: if key != 'index': is_array = isinstance(data[key], np.ndarray) if is_array: data[key] = torch.from_numpy(data[key]) if not torch.is_tensor(data[key]): data.pop(key) if key in self.swap_tf_dims: data[key] = data[key].transpose(1, 0) return data def __repr__(self): return f"{self.__class__.__name__}()"
[docs]class Compose(object): """Composes several transforms together. Inspired by torchvision implementation. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.MagnitudeSpectrumApproximation(), >>> transforms.ToSeparationModel(), >>> ]) """ def __init__(self, transforms): self.transforms = transforms def __call__(self, data): for t in self.transforms: data = t(data) if not isinstance(data, dict): raise TransformException( "The output of every transform must be a dictionary!") return data def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: format_string += '\n' format_string += ' {0}'.format(t) format_string += '\n)' return format_string
[docs]class TransformException(Exception): """ Exception class for errors when working with transforms in nussl. """ pass