Source code for nussl.separation.benchmark.ideal_ratio_mask

from ..base import MaskSeparationBase, SeparationException
from ...datasets import transforms


[docs]class IdealRatioMask(MaskSeparationBase): """ Implements an ideal ratio mask (IRM) that is computed by using the known ground truth performance. This is one of the upper baselines. Args: input_audio_signal (AudioSignal): Signal to separate. sources (list): List of audio signal objects that correspond to the sources. approach (str): Either 'psa' (phase sensitive spectrum approximation) or 'msa' (magnitude spectrum approximation). Generally 'psa' does better. mask_type (str, optional): Mask type. Defaults to 'soft'. mask_threshold (float, optional): Masking threshold. Defaults to 0.5. kwargs (dict): Extra keyword arguments are passed to the transform classes at initialization. """ def __init__(self, input_audio_signal, sources, approach='psa', mask_type='soft', mask_threshold=.5, **kwargs): if isinstance(sources, list): sources = {i: sources[i] for i in range(len(sources))} elif not isinstance(sources, dict): raise SeparationException("sources must be a list or a dict!") self.sources = sources self.approach = approach if self.approach == 'psa': tfm = transforms.PhaseSensitiveSpectrumApproximation(**kwargs) elif self.approach == 'msa': tfm = transforms.MagnitudeSpectrumApproximation(**kwargs) else: raise SeparationException(f'Unknown approach: {self.approach}') self.tfm = tfm super().__init__( input_audio_signal=input_audio_signal, mask_type=mask_type, mask_threshold=mask_threshold) def run(self): # Set up dictionary to pass to the transform. data = { 'mix': self.audio_signal, 'sources': self.sources } output = self.tfm(data) masks = [] mask_data = ( output['source_magnitudes'] / (output['source_magnitudes'].sum(axis=-1, keepdims=True) + 1e-8) ) for i in range(mask_data.shape[-1]): mask = self.mask_type(mask_data[..., i]) masks.append(mask) self.result_masks = masks return self.result_masks