Source code for nussl.separation.spatial.projet

import copy

import numpy as np
import torch

from .. import SeparationBase, SeparationException
from ... import AudioSignal


[docs]class Projet(SeparationBase): """ Implements the PROJET algorithm for spatial audio separation using projections. This implementation uses PyTorch to speed up computation considerably. PROJET does the following steps: 1. Project the complex stereo STFT onto multiple angles and delay via projection and delay matrix transformations. 2. Initialize the parameters of the system to "remix" these projections along with PSDs of the sources such that they try to reconstruct the original stereo mixture. 3. Find the optimal parameters via multiplicative update rules for P and for Q. 4. Use the discovered parameters to isolate the sources via spatial cues. This implementation considers BOTH panning and delays when isolating sources. PROJET is not a masking based method, it estimates the sources directly by projecting the complex STFT. Args: input_audio_signal (AudioSignal): Audio signal to separate. num_sources (int): Number of source to separate. estimates (list of AudioSignal): initial estimates for the separated sources if available. These will be used to initialize the update algorithm. So one could (for example), run FT2D on a signal and then refine the estimates using PROJET. Defaults to None (randomly initialize P). num_iterations (int, optional): Number of iterations to do for the update rules for P and Q. Defaults to 50. maximum_delay_in_samples (int, optional): Maximum delay in samples that you are willing to consider in the projection matrices. Defaults to 20. location_set_panning (int, optional): How many locations in panning you are willing to consider. Defaults to 30. location_set_delay (int, optional): How many delays you are willing to consider. Defaults to 17. projection_set_panning (int, optional): How many projections you are willing use in panning-space. Defaults to 10. projection_set_delay (int, optional): How many delays you are willing to project the mixutre onto in panning-space. Defaults to 9. beta (int, optional): Beta in beta divergence. See Table 1 in [1]. Defaults to 1. alpha (int, optional): Power to raise each power spectral density estimate of each source to. Defaults to 1. device (str, optional): Device to use when performing update rules. 'cuda' will be fastest, if available. Defaults to 'cpu'. References: [1] Fitzgerald, Derry, Antoine Liutkus, and Roland Badeau. "Projection-based demixing of spatial audio." IEEE/ACM Transactions on Audio, Speech, and Language Processing 24.9 (2016): 1560-1572. [2] Fitzgerald, Derry, Antoine Liutkus, and Roland Badeau. "Projet—spatial audio separation using projections." 2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2016. """ def __init__(self, input_audio_signal, num_sources, estimates=None, num_iterations=50, maximum_delay_in_samples=20, location_set_panning=30, location_set_delay=17, projection_set_panning=10, projection_set_delay=9, beta=1, alpha=1, device='cpu'): self.num_sources = num_sources self.alpha = alpha self.beta = beta self.num_iterations = num_iterations self.eps = 1e-8 self.device = device self.projection_set_panning = projection_set_panning self.projection_set_delay = projection_set_delay self.location_set_panning = location_set_panning self.location_set_delay = location_set_delay self.maximum_delay_in_samples = maximum_delay_in_samples self.projection_set = None self.inverse_projection_set = None self.reconstructions = None super().__init__(input_audio_signal=input_audio_signal) if estimates is None: self.P = None else: if len(estimates) != self.num_sources: raise SeparationException( "Number of estimates must be equal to num_sources!") psds = [] for e in estimates: _e = copy.deepcopy(e) _e.to_mono() _e.stft_params = self.audio_signal.stft_params psds.append(np.abs(_e.stft())) self.P = np.stack(psds, axis=-1) def _preprocess_audio_signal(self): self.stft = self.audio_signal.stft() F, T, I = self.stft.shape self.device = 'cpu' if not torch.cuda.is_available() else self.device pannings = np.linspace(0, np.pi / 2, self.location_set_panning) delays = np.linspace( -self.maximum_delay_in_samples, self.maximum_delay_in_samples, self.location_set_delay ) self.location_set = self.create_panning_delay_set(pannings, delays, F, I) pannings = np.linspace(-np.pi / 2, 0, self.projection_set_panning) delays = np.linspace( -self.maximum_delay_in_samples, self.maximum_delay_in_samples, self.projection_set_delay ) self.projection_set = self.create_panning_delay_set(pannings, delays, F, I) self.inverse_projection_set = np.linalg.pinv(self.projection_set) self.V, self.complex_projections = self.create_projections() self.K = self.create_k_matrix() @staticmethod def create_panning_delay_set(pannings, delays, F, I): panning_delay_set = np.zeros( (F, len(pannings), len(delays), I), dtype='complex' ) for i, delay in enumerate(delays): phase_change = np.exp( -1j * 2 * np.pi * np.linspace(0, 0.5, F) * delay ) panning_delay_set[:, :, i, 0] = np.outer( np.ones(F), np.cos(pannings) ) panning_delay_set[:, :, i, 1] = np.outer( phase_change, np.sin(pannings) ) return panning_delay_set def create_projections(self): F = self.stft.shape[0] inner_dim = self.projection_set.shape[-1] V = ( self.projection_set.reshape(F, -1, inner_dim) @ self.stft.reshape(F, -1, inner_dim).transpose(0, 2, 1) ) V = V.reshape((F,) + self.projection_set.shape[1:-1] + (-1,)) return self._convert_to_tensor(np.abs(V) ** self.alpha), V def _convert_to_tensor(self, data): tensor = torch.from_numpy(data) return tensor.to(self.device) @staticmethod def _convert_to_numpy(data): array = data.cpu().data.numpy() return array def create_k_matrix(self): F = self.stft.shape[0] inner_dim = self.location_set.shape[-1] K = ( self.projection_set.reshape(F, -1, inner_dim) @ self.location_set.reshape(F, -1, inner_dim).transpose(0, 2, 1) ) K = np.abs(K.reshape( (F,) + self.projection_set.shape[1:-1] + self.location_set.shape[1:-1])) ** self.alpha return self._convert_to_tensor(K) def initialize_parameters(self): F, T, I = self.stft.shape P = np.abs(np.random.randn(F, T, self.num_sources)) Q = np.abs(np.random.randn(*self.location_set.shape[1:3], self.num_sources)) return self._convert_to_tensor(P), self._convert_to_tensor(Q) def _update_sigma(self, P, Q, KQ): F = self.stft.shape[0] inner_dim = KQ.shape[-1] sigma = ( KQ.reshape(F, -1, inner_dim) @ P.transpose(2, 1) ) sigma = sigma.reshape( P.shape[0], KQ.shape[1], KQ.shape[2], P.shape[1]) return sigma def _update_P(self, P, sigma, KQ): F = self.stft.shape[0] temps = [ (sigma ** (self.beta - 2)) * self.V, sigma ** (self.beta - 1) ] inner_dim = KQ.shape[1] * KQ.shape[2] P_num_denom = [ self.eps + ( KQ.reshape(F, inner_dim, -1).transpose(2, 1) @ temp.reshape(F, inner_dim, -1) ) for temp in temps ] P_update = (P_num_denom[0] / P_num_denom[1]).transpose(2, 1) return P * P_update def _update_Q(self, P, sigma, Q): F = self.stft.shape[0] temps = [ (sigma ** (self.beta - 2)) * self.V, sigma ** (self.beta - 1) ] inner_dim = self.K.shape[1] * self.K.shape[2] Q_num_denom = [ self.K.reshape(F, inner_dim, -1).transpose(2, 1) @ (temp.reshape(F, -1, P.shape[1]) @ P) for temp in temps ] Q_num_denom = [ x.reshape(F, *Q.shape).sum(dim=0) for x in Q_num_denom ] Q_update = Q_num_denom[0] / Q_num_denom[1] return Q * Q_update def _get_kq(self, Q): F = self.stft.shape[0] # get KQ inner_dim = Q.shape[0] * Q.shape[1] KQ = ( self.K.reshape(-1, inner_dim) @ Q.reshape(inner_dim, -1) ) KQ = KQ.reshape( F, self.K.shape[1], self.K.shape[2], Q.shape[-1]) return KQ def _update(self, P, Q): KQ = self._get_kq(Q) sigma = self._update_sigma(P, Q, KQ) P = self._update_P(P, sigma, KQ) sigma = self._update_sigma(P, Q, KQ) Q = self._update_Q(P, sigma, Q) return P, Q def run(self): P, Q = self.initialize_parameters() for i in range(self.num_iterations): P, Q = self._update(P, Q) KQ = self._get_kq(Q) KQ = KQ.reshape(KQ.shape[0], -1, 1, KQ.shape[-1]) sigma_j = KQ / P[:, None, ...] sigma_j = sigma_j / (self.eps + sigma_j.sum(dim=-1)[..., None]) sigma_j = self._convert_to_numpy(sigma_j) self.projection_set = self.projection_set.reshape( self.projection_set.shape[0], self.projection_set.shape[1] * self.projection_set.shape[2], self.projection_set.shape[-1] ) self.inverse_projection_set = np.linalg.pinv(self.projection_set) cf_j = ( (self.projection_set @ self.stft.transpose(0, 2, 1))[..., None] * sigma_j ) shape = cf_j.shape reconstructions = ( self.inverse_projection_set @ cf_j.reshape( cf_j.shape[0], cf_j.shape[1], -1 ) ) reconstructions = reconstructions.reshape( shape[0], self.stft.shape[-1], -1, shape[-1] ) self.reconstructions = np.swapaxes(reconstructions, 1, 2) return reconstructions def make_audio_signals(self): estimates = [] for j in range(self.reconstructions.shape[-1]): estimate_stft = self.reconstructions[..., j] estimate = self.audio_signal.make_copy_with_stft_data(estimate_stft) estimate.istft() estimates.append(estimate) return estimates