Source code for nussl.separation.benchmark.wiener_filter

import numpy as np
import norbert

from ..base import MaskSeparationBase, SeparationException


[docs]class WienerFilter(MaskSeparationBase): """ Implements a multichannel Wiener filter that is computed by using some source estimates. When using the estimates produced by IdealRatioMask or IdealBinaryMask, this is one of the upper baselines. Args: input_audio_signal (AudioSignal): Signal to separate. estimates (list): List of audio signal objects that correspond to the estimates. iterations (int): Number of iterations for expectation-maximization in Wiener filter. mask_type (str, optional): Mask type. Defaults to 'soft'. mask_threshold (float, optional): Threshold for masking binary. Defaults to 0.5. kwargs (dict): Additional keyword arguments to `norbert.wiener`. """ def __init__(self, input_audio_signal, estimates, iterations=1, mask_type='soft', mask_threshold=.5, **kwargs): if not isinstance(estimates, list): raise SeparationException("estimates must be a list!") self.estimates = estimates self.iterations = iterations self.kwargs = kwargs super().__init__( input_audio_signal=input_audio_signal, mask_type=mask_type, mask_threshold=mask_threshold) def run(self): source_magnitudes = np.stack([ np.abs(e.stft()) for e in self.estimates], axis=-1) source_magnitudes = np.transpose(source_magnitudes, (1, 0, 2, 3)) mix_stft = np.transpose(self.audio_signal.stft(), (1, 0, 2)) enhanced = norbert.wiener( source_magnitudes, mix_stft, iterations=self.iterations, **self.kwargs) _masks = np.abs(enhanced) / np.maximum(1e-7, np.abs(mix_stft[..., None])) _masks = np.transpose(_masks, (1, 0, 2, 3)) self.result_masks = [] for i in range(_masks.shape[-1]): mask_data = _masks[..., i] if self.mask_type == self.MASKS['binary']: mask_data = _masks[..., i] == np.max(_masks, axis=-1) mask = self.mask_type(mask_data) self.result_masks.append(mask) return self.result_masks