import numpy as np
import norbert
from ..base import MaskSeparationBase, SeparationException
[docs]class WienerFilter(MaskSeparationBase):
    """
    Implements a multichannel Wiener filter that is computed by using some
    source estimates. When using the estimates produced by IdealRatioMask or 
    IdealBinaryMask, this is one of the upper baselines.
    
    Args:
        input_audio_signal (AudioSignal): Signal to separate.
        estimates (list): List of audio signal objects that correspond to the estimates.
        iterations (int): Number of iterations for expectation-maximization in Wiener 
          filter.
        mask_type (str, optional): Mask type. Defaults to 'soft'.
        mask_threshold (float, optional): Threshold for masking binary. Defaults to 0.5.
        kwargs (dict): Additional keyword arguments to `norbert.wiener`.
    """
    def __init__(self, input_audio_signal, estimates, iterations=1, mask_type='soft',  
                 mask_threshold=.5, **kwargs):
        if not isinstance(estimates, list):
            raise SeparationException("estimates must be a list!")
        self.estimates = estimates
        self.iterations = iterations
        self.kwargs = kwargs
        super().__init__(
            input_audio_signal=input_audio_signal, 
            mask_type=mask_type,
            mask_threshold=mask_threshold)
    def run(self):
        source_magnitudes = np.stack([
            np.abs(e.stft()) for e in self.estimates], axis=-1)
        source_magnitudes = np.transpose(source_magnitudes, (1, 0, 2, 3))
        mix_stft = np.transpose(self.audio_signal.stft(), (1, 0, 2))
        enhanced = norbert.wiener(
            source_magnitudes, mix_stft, 
            iterations=self.iterations, **self.kwargs)
        _masks = np.abs(enhanced) / np.maximum(1e-7, np.abs(mix_stft[..., None]))
        _masks = np.transpose(_masks, (1, 0, 2, 3))
        self.result_masks = []
        for i in range(_masks.shape[-1]):
            mask_data = _masks[..., i]
            if self.mask_type == self.MASKS['binary']:
                mask_data = _masks[..., i] == np.max(_masks, axis=-1)            
            mask = self.mask_type(mask_data)
            self.result_masks.append(mask)
        
        return self.result_masks