Source code for nussl.separation.base.deep_mixin

import torch

from ...ml import SeparationModel
from ...datasets import transforms as tfm

OMITTED_TRANSFORMS = (
    tfm.GetExcerpt,
    tfm.MagnitudeWeights,
    tfm.SumSources,
    tfm.Cache
)


[docs]class DeepMixin:
[docs] def load_model(self, model_path, device='cpu'): """ Loads the model at specified path `model_path`. Uses GPU if available. Args: model_path (str): path to model saved as SeparatonModel. device (str or torch.Device): loads model on CPU or GPU. Defaults to 'cuda'. Returns: model (SeparationModel): Loaded model, nn.Module metadata (dict): metadata associated with model, used for making the input data into the model. """ model_dict = torch.load(model_path, map_location='cpu') model = SeparationModel(model_dict['config']) model.load_state_dict(model_dict['state_dict']) device = device if torch.cuda.is_available() else 'cpu' self.device = device model = model.to(device).eval() metadata = model_dict['metadata'] if 'metadata' in model_dict else {} self.model = model self.config = model_dict['config'] self.metadata = metadata self.transform = self._get_transforms( self.metadata['transforms'])
@staticmethod def _get_transforms(loaded_tfm): """ Look through the loaded transforms and omits any that are in `OMITTED_TRANSFORMS`. Args: loaded_tfm (Transform): A Transform from `nussl.datasets.transforms`. Returns: Transform: If the transform was a Compose, this returns a new Compose that omits the transforms listed in `OMITTED_TRANSFORMS`. """ if isinstance(loaded_tfm, tfm.Compose): transform = [] for _tfm in loaded_tfm.transforms: if not isinstance(_tfm, OMITTED_TRANSFORMS): transform.append(_tfm) transform = tfm.Compose(transform) else: if not isinstance(loaded_tfm, OMITTED_TRANSFORMS): transform = loaded_tfm else: transform = None return transform def _get_input_data_for_model(self): """ Sets up the audio signal with the appropriate STFT parameters and runs it through the transform found in the metadata. Returns: dict: Data dictionary to pass into the model. """ if self.metadata['sample_rate'] is not None: if self.audio_signal.sample_rate != self.metadata['sample_rate']: self.audio_signal.resample(self.metadata['sample_rate']) self.audio_signal.stft_params = self.metadata['stft_params'] self.audio_signal.stft() data = {'mix': self.audio_signal} data = self.transform(data) for key in data: if torch.is_tensor(data[key]): data[key] = data[key].unsqueeze(0).to(self.device).float() if self.metadata['num_channels'] == 1: # then each channel is processed indep data[key] = data[key].transpose(0, self.channel_dim) self.input_data = data return self.input_data