import torch
from ..base import SeparationBase, DeepMixin, SeparationException
[docs]class DeepAudioEstimation(SeparationBase, DeepMixin):
"""
Separates an audio signal using a model that produces separated sources directly
in the waveform domain. It expects that the model outputs a dictionary where one
of the keys is 'audio'. This uses the `DeepMixin` class to load the model
and set the audio signal's parameters to be appropriate for the model.
Args:
input_audio_signal: (AudioSignal`) An AudioSignal object containing the
mixture to be separated.
model_path (str, optional): Path to the model that will be used. Can be None,
so that you can initialize a class and load the model later.
Defaults to None.
device (str, optional): Device to put the model on. Defaults to 'cpu'.
**kwargs (dict): Keyword arguments for MaskSeparationBase.
"""
def __init__(self, input_audio_signal, model_path=None, device='cpu', **kwargs):
if model_path is not None:
self.load_model(model_path, device=device)
super().__init__(input_audio_signal, **kwargs)
self.model_output = None
# audio channel dimension in an audio model
self.channel_dim = 1
def forward(self):
input_data = self._get_input_data_for_model()
with torch.no_grad():
output = self.model(input_data)
if 'audio' not in output:
raise SeparationException(
"This model is not a deep audio estimation model! "
"Did not find 'audio' key in output dictionary.")
audio = output['audio']
# swap back batch and sample dims
if self.metadata['num_channels'] == 1:
audio = audio.transpose(0, self.channel_dim)
audio = audio.squeeze(0)
audio = audio.cpu().data.numpy()
self.model_output = output
return audio
def run(self, audio=None):
if audio is None:
audio = self.forward()
self.audio = audio
return self.audio
def make_audio_signals(self):
estimates = []
for i in range(self.audio.shape[-1]):
_estimate = self.audio_signal.make_copy_with_audio_data(
self.audio[..., i])
estimates.append(_estimate)
return estimates