{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Training deep models in *nussl*\n", "==============================\n", "\n", "*nussl* has a tightly integrated deep learning pipeline for computer audition,\n", "with a focus on source separation. This pipeline includes:\n", "\n", "- Existing source separation architectures (Deep Clustering, Mask Inference, etc),\n", "- Building blocks for creating new architectures (Recurrent Stacks, Embedding spaces, Mask Layers,\n", " Mel Projection Layers, etc),\n", "- Handling data and common data sets (WSJ, MUSDB, etc),\n", "- Training architectures via an easy to use API powered by [PyTorch Ignite](\n", "https://pytorch.org/ignite/index.html),\n", "- Evaluating model performance (SDR, SI-SDR, etc),\n", "- Using the models on new audio signals for inference,\n", "- Storing and distributing trained models via the [External File Zoo](\n", "http://nussl.ci.northwestern.edu/).\n", "\n", "This tutorial will walk you through *nussl*'s model training capabilities on a simple\n", "synthetic dataset for illustration purposes. While *nussl* has support for a broad variety of\n", "models, we will focus on straight-forward mask inference networks." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Do our imports and setup for this tutorial.\n", "import os\n", "import json\n", "import logging\n", "import copy\n", "import tempfile\n", "import glob\n", "import time\n", "import shutil\n", "from concurrent.futures import ThreadPoolExecutor\n", "\n", "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tqdm\n", "\n", "import nussl\n", "\n", "start_time = time.time()\n", "\n", "# seed this notebook\n", "# (this seeds python's random, np.random, and torch.random)\n", "nussl.utils.seed(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SeparationModel\n", "---------------\n", "\n", "At the heart of *nussl*'s deep learning pipeline is the SeparationModel class.\n", "SeparationModel takes in a description of the model architecture and instantiates it.\n", "Model architectures are described via a dictionary. A model architecture has three\n", "parts: the building blocks, or *modules*, how the building blocks are wired together,\n", "and the outputs of the model.\n", "\n", "### Modules ##\n", "\n", "Let's take a look how a simple architecture is described. This model will be a single\n", "linear layer that estimates the spectra for 3 sources for every frame in the STFT." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# define the building blocks\n", "num_features = 129 # number of frequency bins in STFT\n", "num_sources = 3 # how many sources to estimate\n", "mask_activation = 'sigmoid' # activation function for masks\n", "num_audio_channels = 1 # number of audio channels\n", "\n", "modules = {\n", " 'mix_magnitude': {},\n", " 'my_log_spec': {\n", " 'class': 'AmplitudeToDB'\n", " },\n", " 'my_norm': {\n", " 'class': 'BatchNorm',\n", " },\n", " 'my_mask': {\n", " 'class': 'Embedding',\n", " 'args': {\n", " 'num_features': num_features,\n", " 'hidden_size': num_features,\n", " 'embedding_size': num_sources,\n", " 'activation': mask_activation,\n", " 'num_audio_channels': num_audio_channels,\n", " 'dim_to_embed': [2, 3] # embed the frequency dimension (2) for all audio channels (3)\n", " }\n", " },\n", " 'my_estimates': {\n", " 'class': 'Mask',\n", " },\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The lines above define the building blocks, or *modules* of the SeparationModel. \n", "There are four building blocks:\n", "\n", "- `mix_magnitude`, the input to the model (this key is not user-definable),\n", "- `my_log_spec`, a \"layer\" that converts the spectrogram to dB space,\n", "- `my_norm`, a BatchNorm normalization layer, and\n", "- `my_mask`, which outputs the resultant mask. \n", "\n", "Each module in the dictionary has a key and a\n", "value. The key tells SeparationModel the user-definable name of that layer in our architecture.\n", "For example, `my_log_spec` will be the name of a building block. The value is\n", "also a dictionary with two values: `class` and `args`. `class` tells SeparationModel\n", "what the code for this module should be. `args` tells SeparationModel what the \n", "arguments to the class should be when instantiating it. Finally, if the dictionary\n", "that the key points to is empty, then it is assumed to be something that comes from\n", "the input dictionary to the model. Note that we haven't fully defined the model yet! We still\n", "need to determine how these modules are put together.\n", "\n", "So where does the code for each of these classes live? The code for these modules\n", "is in `nussl.ml.modules`. The existing modules in *nussl* are as follows:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "lines_to_end_of_cell_marker": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nussl.ml.modules contents:\n", "--------------------------\n", "AmplitudeToDB\n", "BatchNorm\n", "Concatenate\n", "ConvolutionalStack2D\n", "DualPath\n", "DualPathBlock\n", "Embedding\n", "Expand\n", "FilterBank\n", "GaussianMixtureTorch\n", "InstanceNorm\n", "LayerNorm\n", "LearnedFilterBank\n", "Mask\n", "MelProjection\n", "RecurrentStack\n", "STFT\n", "ShiftAndScale\n", "Split\n", "blocks\n", "filter_bank\n" ] } ], "source": [ "def print_existing_modules():\n", " excluded = ['checkpoint', 'librosa', 'nn', 'np', 'torch', 'warnings']\n", " print('nussl.ml.modules contents:')\n", " print('--------------------------')\n", " existing_modules = [x for x in dir(nussl.ml.modules) if\n", " x not in excluded and not x.startswith('__')]\n", " print('\\n'.join(existing_modules))\n", "\n", "\n", "print_existing_modules()" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 2 }, "source": [ "Descriptions of each of these modules and their arguments can be found in the API docs.\n", "In the model we have described above, we have used: \n", "\n", "1. `AmplitudeToDB` to compute log-magnitude spectrograms from the input `mix_magnitude`.\n", "2. `BatchNorm` to normalize each spectrogram input by the mean and standard\n", " deviation of all the data (one mean/std for the entire spectrogram, not per feature).\n", "3. `Embedding` to embed each 129-dimensional frame into 3*129-dimensional space with a\n", " sigmoid activation.\n", "4. `Mask` to take the output of the embedding and element-wise multiply it by the input\n", " `mix_magnitude` to generate source estimates.\n", " \n", "### Connections ###\n", "\n", "Now we have to define the next part of SeparationModel - how the modules are wired together.\n", "We do this by defining the `connections` of the model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "# define the topology\n", "connections = [\n", " ['my_log_spec', ['mix_magnitude', ]],\n", " ['my_norm', ['my_log_spec', ]],\n", " ['my_mask', ['my_norm', ]],\n", " ['my_estimates', ['my_mask', 'mix_magnitude']]\n", "]" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 2 }, "source": [ "`connections` is a list of lists. Each item of `connections` has two elements. The first\n", "element contains the name of our module (defined in `modules`). The second element\n", "contains the arguments that will go into the module defined in the first element.\n", "\n", "So for example, `my_log_spec`, which corresponded to the `AmplitudeToDB`\n", "class takes in `my_mix_magnitude`. In the forward pass `my_mix_magnitude` corresponds to\n", "the data in the input dictionary. The output of `my_log_spec` (a\n", "log-magnitude spectrogram) is passed to the module named `my_norm`, (a `BatchNorm`\n", "layer). This output is then passed to the `my_mask` module, which\n", "constructs the masks using an `Embedding` class. Finally, the source estimates\n", "are constructed by passing both `mix_magnitude` and `my_mask` to the `my_estimates`\n", "module, which uses a `Mask` class.\n", "\n", "Complex forward passes can be defined via these connections. Connections can be\n", "even more detailed. Modules can take in keyword arguments by making the second\n", "element a dictionary. If modules also output a dictionary, then specific outputs\n", "can be reference in the connections via `module_name:key_in_dictionary`. For\n", "example, `nussl.ml.modules.GaussianMixtureTorch` (which is a differentiable\n", "GMM unfolded on some input data) outputs a dictionary with\n", "the following keys: `resp, log_prob, means, covariance, prior`. If this module\n", "was named `gmm`, then these outputs can be used in the second element via\n", "`gmm:means`, `gmm:resp`, `gmm:covariance`, etc.\n", "\n", "### Output and forward pass ###\n", "\n", "Next, models have to actually output some data to be used later on. Let's have\n", "this model output the keys for `my_estimates` and `my_mask` (as defined in our `modules` dict, above) by doing this:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "# define the outputs\n", "output = ['my_estimates', 'my_mask']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can use these outputs directly or you can use them as a part of a \n", "larger deep learning pipeline. SeparationModel can be, for example, a\n", "first step before you do something more complicated with the output\n", "that doesn't fit cleanly into how SeparationModels are built.\n", "\n", "### Putting it all together ###\n", "\n", "Finally, let's put it all together in one config dictionary. The dictionary\n", "must have the following keys to be valid: `modules`, `connections`, and \n", "`output`. If these keys don't exist, then SeparationModel will throw\n", "an error." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"modules\": {\n", " \"mix_magnitude\": {},\n", " \"my_log_spec\": {\n", " \"class\": \"AmplitudeToDB\"\n", " },\n", " \"my_norm\": {\n", " \"class\": \"BatchNorm\"\n", " },\n", " \"my_mask\": {\n", " \"class\": \"Embedding\",\n", " \"args\": {\n", " \"num_features\": 129,\n", " \"hidden_size\": 129,\n", " \"embedding_size\": 3,\n", " \"activation\": \"sigmoid\",\n", " \"num_audio_channels\": 1,\n", " \"dim_to_embed\": [\n", " 2,\n", " 3\n", " ]\n", " }\n", " },\n", " \"my_estimates\": {\n", " \"class\": \"Mask\"\n", " }\n", " },\n", " \"connections\": [\n", " [\n", " \"my_log_spec\",\n", " [\n", " \"mix_magnitude\"\n", " ]\n", " ],\n", " [\n", " \"my_norm\",\n", " [\n", " \"my_log_spec\"\n", " ]\n", " ],\n", " [\n", " \"my_mask\",\n", " [\n", " \"my_norm\"\n", " ]\n", " ],\n", " [\n", " \"my_estimates\",\n", " [\n", " \"my_mask\",\n", " \"mix_magnitude\"\n", " ]\n", " ]\n", " ],\n", " \"output\": [\n", " \"my_estimates\",\n", " \"my_mask\"\n", " ]\n", "}\n" ] } ], "source": [ "# put it all together\n", "config = {\n", " 'modules': modules,\n", " 'connections': connections,\n", " 'output': output\n", "}\n", "\n", "print(json.dumps(config, indent=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's load this config into SeparationModel and print the model\n", "architecture:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SeparationModel(\n", " (layers): ModuleDict(\n", " (my_estimates): Mask()\n", " (my_log_spec): AmplitudeToDB()\n", " (my_mask): Embedding(\n", " (linear): Linear(in_features=129, out_features=387, bias=True)\n", " )\n", " (my_norm): BatchNorm(\n", " (batch_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", ")\n", "Number of parameters: 50312\n" ] } ], "source": [ "model = nussl.ml.SeparationModel(config)\n", "print(model)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 2 }, "source": [ "Now let's put some random data through it, with the expected size." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { "ename": "ValueError", "evalue": "Not all keys present in data! Needs mix_magnitude", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# so: batch size is 1, 400 frames, 129 frequencies, and 1 audio channel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmix_magnitude\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m129\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmix_magnitude\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/.conda/envs/nussl-refactor/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Dropbox/research/nussl_refactor/nussl/ml/networks/separation_model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m raise ValueError(\n\u001b[0;32m--> 105\u001b[0;31m f'Not all keys present in data! Needs {\", \".join(self.input)}')\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: Not all keys present in data! Needs mix_magnitude" ] } ], "source": [ "# The expected shape is: (batch_size, n_frames, n_frequencies, n_channels)\n", "# so: batch size is 1, 400 frames, 129 frequencies, and 1 audio channel\n", "mix_magnitude = torch.rand(1, 400, 129, 1)\n", "model(mix_magnitude)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 2 }, "source": [ "Uh oh! Putting in the data directly resulted in an error. This is because \n", "SeparationModel expects a *dictionary*. The dictionary must contain all of the\n", "input keys that were defined. Here it was `my_mix_magnitude`. So let's try \n", "again:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "mix_magnitude = torch.rand(1, 400, 129, 1)\n", "data = {'mix_magnitude': mix_magnitude}\n", "output = model(data)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 2 }, "source": [ "Now we have passed the data through the model. Note a few things here:\n", "\n", "1. The tensor passed through the model had the following shape:\n", " `(n_batch, sequence_length, num_frequencies, num_audio_channels)`. This is\n", " different from how STFTs for an AudioSignal are shaped. Those are shaped as:\n", " `(num_frequencies, sequence_length, num_audio_channels)`. We added a batch\n", " dimension here, and the ordering of frequency and audio channel dimensions\n", " were swapped. This is because recurrent networks are a popular way to process\n", " spectrograms, and these expect (and operate more efficiently) when sequence\n", " length is right after the batch dimension.\n", "2. The key in the dictionary had to match what we put in the configuration\n", " before.\n", "3. We embedded *both* the channel dimension (3) as well as the frequency dimension (2)\n", " when building up the configuration.\n", "\n", "Now let's take a look at what's in the output!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { "data": { "text/plain": [ "dict_keys(['my_estimates', 'my_mask'])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.keys()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are two keys as expected: `my_estimates` and `my_mask`. They both have the\n", "same shape as `mix_magnitude` with one addition:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 400, 129, 1, 3]), torch.Size([1, 400, 129, 1, 3]))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output['my_estimates'].shape, output['my_mask'].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last dimension is 3! Which is the number of sources we're trying to\n", "separate. Let's look at the first source." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "i = 0\n", "plt.figure(figsize=(5, 5))\n", "plt.imshow(output['my_estimates'][0, ..., 0, i].T.cpu().data.numpy())\n", "plt.title(\"Source\")\n", "plt.show()\n", "\n", "plt.figure(figsize=(5, 5))\n", "plt.imshow(output['my_mask'][0, ..., 0, i].T.cpu().data.numpy())\n", "plt.title(\"Mask\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Not much to look at! " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving and loading a model ###\n", "\n", "Now let's save this model and load it back up." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dict_keys(['state_dict', 'config', 'nussl_version'])\n", "SeparationModel(\n", " (layers): ModuleDict(\n", " (my_estimates): Mask()\n", " (my_log_spec): AmplitudeToDB()\n", " (my_mask): Embedding(\n", " (linear): Linear(in_features=129, out_features=387, bias=True)\n", " )\n", " (my_norm): BatchNorm(\n", " (batch_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", ")\n", "Number of parameters: 50312\n" ] } ], "source": [ "with tempfile.NamedTemporaryFile(suffix='.pth', delete=True) as f:\n", " loc = model.save(f.name)\n", " reloaded_dict = torch.load(f.name)\n", "\n", " print(reloaded_dict.keys())\n", "\n", " new_model = nussl.ml.SeparationModel(reloaded_dict['config'])\n", " new_model.load_state_dict(reloaded_dict['state_dict'])\n", "\n", " print(new_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When models are saved, both the config AND the weights are saved. Both of these can be easily\n", "loaded back into a new SeparationModel object." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Custom modules\n", "--------------\n", "\n", "There's also straightforward support for *custom* modules that don't \n", "exist in *nussl* but rather exist in the end-user code. These can be\n", "registered with SeparationModel easily. Let's build a custom module\n", "and register it with a copy of our existing model. Let's make this \n", "module a lambda, which takes in some arbitrary function and runs \n", "it on the input. We'll call it LambdaLayer:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape is torch.Size([1, 400, 129, 1])\n" ] } ], "source": [ "class LambdaLayer(torch.nn.Module):\n", " def __init__(self, func):\n", " self.func = func\n", " super().__init__()\n", "\n", " def forward(self, data):\n", " return self.func(data)\n", "\n", "\n", "def print_shape(x):\n", " print(f'Shape is {x.shape}')\n", "\n", "\n", "lamb = LambdaLayer(print_shape)\n", "output = lamb(mix_magnitude)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's put it into a copy of our model and update the connections so that it\n", "prints for every layer." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Copy our previous modules and add our new Lambda class\n", "new_modules = copy.deepcopy(modules)\n", "new_modules['lambda'] = {\n", " 'class': 'LambdaLayer',\n", " 'args': {\n", " 'func': print_shape\n", " }\n", "}\n", "\n", "new_connections = [\n", " ['my_log_spec', ['mix_magnitude', ]],\n", " ['lambda', ['mix_magnitude', ]],\n", " ['lambda', ['my_log_spec', ]],\n", " ['my_norm', ['my_log_spec', ]],\n", " ['lambda', ['my_norm', ]],\n", " ['my_mask', ['my_norm', ]],\n", " ['lambda', ['my_mask', ]],\n", " ['my_estimates', ['my_mask', 'mix_magnitude']],\n", " ['lambda', ['my_estimates', ]]\n", "]\n", "\n", "new_config = {\n", " 'modules': new_modules,\n", " 'connections': new_connections,\n", " 'output': ['my_estimates', 'my_mask']\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But right now, SeparationModel doesn't know about our LambdaLayer class! So,\n", "let's make it aware by registering the module with nussl:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nussl.ml.modules contents:\n", "--------------------------\n", "AmplitudeToDB\n", "BatchNorm\n", "Concatenate\n", "ConvolutionalStack2D\n", "DualPath\n", "DualPathBlock\n", "Embedding\n", "Expand\n", "FilterBank\n", "GaussianMixtureTorch\n", "InstanceNorm\n", "LambdaLayer\n", "LayerNorm\n", "LearnedFilterBank\n", "Mask\n", "MelProjection\n", "RecurrentStack\n", "STFT\n", "ShiftAndScale\n", "Split\n", "blocks\n", "filter_bank\n" ] } ], "source": [ "nussl.ml.register_module(LambdaLayer)\n", "print_existing_modules()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now LambdaLayer is a registered module! Let's build the SeparationModel and\n", "put some data through it:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape is torch.Size([1, 400, 129, 1])\n", "Shape is torch.Size([1, 400, 129, 1])\n", "Shape is torch.Size([1, 400, 129, 1])\n", "Shape is torch.Size([1, 400, 129, 1, 3])\n", "Shape is torch.Size([1, 400, 129, 1, 3])\n" ] } ], "source": [ "verbose_model = nussl.ml.SeparationModel(new_config)\n", "output = verbose_model(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the outputs of the Lambda layer recurring after each connection.\n", "(**Note**: that because we used a non-serializable argument (the function, ``func``)\n", "to the LambdaLayer, this model won't save without special handling!)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alright, now let's see how to use some actual audio data with our model..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Handling data\n", "-------------\n", "\n", "As described in the datasets tutorial, the heart of *nussl* data handling\n", "is BaseDataset and its associated subclasses. We built a simple one in that\n", "tutorial that just produced random sine waves. Let's grab it again:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def make_sine_wave(freq, sample_rate, duration):\n", " dt = 1 / sample_rate\n", " x = np.arange(0.0, duration, dt)\n", " x = np.sin(2 * np.pi * freq * x)\n", " return x\n", "\n", "\n", "class SineWaves(nussl.datasets.BaseDataset):\n", " def __init__(self, *args, num_sources=3, num_frequencies=20, **kwargs):\n", " self.num_sources = num_sources\n", " self.frequencies = np.random.choice(\n", " np.arange(110, 4000, 100), num_frequencies,\n", " replace=False)\n", "\n", " super().__init__(*args, **kwargs)\n", "\n", " def get_items(self, folder):\n", " # ignore folder and return a list\n", " # 100 items in this dataset\n", " items = list(range(100))\n", " return items\n", "\n", " def process_item(self, item):\n", " # we're ignoring ``items`` and making\n", " # sums of random sine waves\n", " sources = {}\n", " freqs = np.random.choice(\n", " self.frequencies, self.num_sources,\n", " replace=False)\n", " for i in range(self.num_sources):\n", " freq = freqs[i]\n", " _data = make_sine_wave(freq, self.sample_rate, 2)\n", " # this is a helper function in BaseDataset for\n", " # making an audio signal from data\n", " signal = self._load_audio_from_array(_data)\n", " signal.path_to_input_file = f'{item}.wav'\n", " sources[f'sine{i}'] = signal * 1 / self.num_sources\n", "\n", " mix = sum(sources.values())\n", "\n", " metadata = {\n", " 'frequencies': freqs\n", " }\n", "\n", " output = {\n", " 'mix': mix,\n", " 'sources': sources,\n", " 'metadata': metadata\n", " }\n", " return output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a reminder, this dataset makes random mixtures of sine waves with fundamental frequencies\n", "between 110 Hz and 4000 Hz. Let's now set it up with appropriate STFT parameters that result\n", "in 129 frequencies in the spectrogram." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'frequencies': array([1610, 310, 1210])}\n" ] } ], "source": [ "nussl.utils.seed(0) # make sure this does the same thing each time\n", "\n", "# We're not reading data, so we can 'ignore' the folder\n", "folder = 'ignored'\n", "\n", "stft_params = nussl.STFTParams(window_length=256, hop_length=64)\n", "\n", "sine_wave_dataset = SineWaves(\n", " folder, sample_rate=8000, stft_params=stft_params\n", ")\n", "\n", "item = sine_wave_dataset[0]\n", "\n", "\n", "def visualize_and_embed(sources, y_axis='mel'):\n", " plt.figure(figsize=(10, 4))\n", " plt.subplot(111)\n", " nussl.utils.visualize_sources_as_masks(\n", " sources, db_cutoff=-60, y_axis=y_axis)\n", " plt.tight_layout()\n", " plt.show()\n", "\n", " nussl.play_utils.multitrack(sources, ext='.wav')\n", "\n", "\n", "visualize_and_embed(item['sources'])\n", "print(item['metadata'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check the shape of the `mix` stft:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(129, 251, 1)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item['mix'].stft().shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Great! There's 129 frequencies and 251 frames and 1 audio channel. To put it into our\n", "model though, we need the STFT in the right shape, and we also need some training data.\n", "Let's use some of *nussl*'s transforms to do this. Specifically, we'll use the\n", "`PhaseSensitiveSpectrumApproximation` and the `ToSeparationModel` transforms. We'll \n", "also use the `MagnitudeWeights` transform in case we want to use deep clustering loss\n", "functions." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['index', 'mix_magnitude', 'ideal_binary_mask', 'source_magnitudes', 'weights'])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "folder = 'ignored'\n", "stft_params = nussl.STFTParams(window_length=256, hop_length=64)\n", "tfm = nussl.datasets.transforms.Compose([\n", " nussl.datasets.transforms.PhaseSensitiveSpectrumApproximation(),\n", " nussl.datasets.transforms.MagnitudeWeights(),\n", " nussl.datasets.transforms.ToSeparationModel()\n", "])\n", "\n", "sine_wave_dataset = SineWaves(\n", " folder, sample_rate=8000, stft_params=stft_params,\n", " transform=tfm\n", ")\n", "\n", "# Let's inspect the 0th item from the dataset\n", "item = sine_wave_dataset[0]\n", "item.keys()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now the item has all the keys that SeparationModel needs. The `ToSeparationModel` transform set everything up for us: it set up the dictionary from `SineWaves.process_item()` exactly as we needed it. It swapped the frequency and sequence length dimension appropriately, and made them all torch Tensors:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([251, 129, 1])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item['mix_magnitude'].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We still need to add a batch dimension and make everything have float type\n", "though. So let's do that for each key, if the key is a torch Tensor:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 251, 129, 1])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "for key in item:\n", " if torch.is_tensor(item[key]):\n", " item[key] = item[key].unsqueeze(0).float()\n", "\n", "item['mix_magnitude'].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can pass this through our model:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = model(item)\n", "\n", "i = 0\n", "plt.figure(figsize=(5, 5))\n", "plt.imshow(\n", " output['my_estimates'][0, ..., 0, i].T.cpu().data.numpy(),\n", " origin='lower')\n", "plt.title(\"Source\")\n", "plt.show()\n", "\n", "plt.figure(figsize=(5, 5))\n", "plt.imshow(\n", " output['my_mask'][0, ..., 0, i].T.cpu().data.numpy(),\n", " origin='lower')\n", "plt.title(\"Mask\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've now seen how to use *nussl* transforms, datasets, and SeparationModel\n", "together to make a forward pass. But so far our model does nothing practical; let's see how to train the model so it actually does something." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Closures and loss functions\n", "---------------------------\n", "\n", "*nussl* trains models via *closures*, which define the forward and backward passes for a\n", "model on a single batch. Closures use *loss functions* within them, which compute the \n", "loss on a single batch. There are a bunch of common loss functions already in *nussl*." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nussl.ml.train.loss contents:\n", "-----------------------------\n", "CombinationInvariantLoss\n", "DeepClusteringLoss\n", "KLDivLoss\n", "L1Loss\n", "MSELoss\n", "PermutationInvariantLoss\n", "SISDRLoss\n", "WhitenedKMeansLoss\n" ] } ], "source": [ "def print_existing_losses():\n", " excluded = ['nn', 'torch', 'combinations', 'permutations']\n", " print('nussl.ml.train.loss contents:')\n", " print('-----------------------------')\n", " existing_losses = [x for x in dir(nussl.ml.train.loss) if\n", " x not in excluded and not x.startswith('__')]\n", " print('\\n'.join(existing_losses))\n", "\n", "\n", "print_existing_losses()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to standard loss functions for spectrograms, like L1 Loss and MSE, there is also an SDR loss for time series audio, as well as permutation invariant versions of\n", "these losses for training things like speaker separation networks. See the API docs for more details on all of these loss functions. A closure uses these loss functions in a simple way. For example, here is the code for training a model with a closure:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from nussl.ml.train.closures import Closure\n", "from nussl.ml.train import BackwardsEvents\n", "\n", "\n", "class TrainClosure(Closure):\n", " \"\"\"\n", " This closure takes an optimization step on a SeparationModel object given a\n", " loss.\n", " \n", " Args:\n", " loss_dictionary (dict): Dictionary containing loss functions and specification.\n", " optimizer (torch Optimizer): Optimizer to use to train the model.\n", " model (SeparationModel): The model to be trained.\n", " \"\"\"\n", "\n", " def __init__(self, loss_dictionary, optimizer, model):\n", " super().__init__(loss_dictionary)\n", " self.optimizer = optimizer\n", " self.model = model\n", "\n", " def __call__(self, engine, data):\n", " self.model.train()\n", " self.optimizer.zero_grad()\n", "\n", " output = self.model(data)\n", "\n", " loss_ = self.compute_loss(output, data)\n", " loss_['loss'].backward()\n", " engine.fire_event(BackwardsEvents.BACKWARDS_COMPLETED)\n", " self.optimizer.step()\n", " loss_ = {key: loss_[key].item() for key in loss_}\n", "\n", " return loss_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, this closure takes some data and puts it through the model, then calls\n", "`self.compute_loss` on the result, fires an event on the ignite `engine`, and then steps the optimizer on the loss. This is a standard PyTorch training loop. The magic here is happening in `self.compute_loss`, which comes from the\n", "parent class `Closure`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loss dictionary ###\n", "\n", "The parent class `Closure` takes a loss dictionary which defines the losses that get \n", "computed on the output of the model. The loss dictionary has the following format:\n", "\n", " loss_dictionary = {\n", " 'LossClassName': {\n", " 'weight': [how much to weight the loss in the sum, defaults to 1],\n", " 'keys': [key mapping items in dictionary to arguments to loss],\n", " 'args': [any positional arguments to the loss class],\n", " 'kwargs': [keyword arguments to the loss class],\n", " }\n", " }\n", " \n", "For example, one possible loss could be:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "loss_dictionary = {\n", " 'DeepClusteringLoss': {\n", " 'weight': .2,\n", " },\n", " 'PermutationInvariantLoss': {\n", " 'weight': .8,\n", " 'args': ['L1Loss']\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will apply the deep clustering and a permutation invariant L1 loss to the output\n", "of the model. So, how does the model know what to compare? Each loss function is a \n", "class in *nussl*, and each class has an attribute called `DEFAULT_KEYS`, This attribute\n", "tells the Closure how to use the forward pass of the loss function. For example, this is\n", "the code for the L1 Loss:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class L1Loss(nn.L1Loss):\n", " DEFAULT_KEYS = {'estimates': 'input', 'source_magnitudes': 'target'}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[L1Loss](https://pytorch.org/docs/stable/nn.html?highlight=l1%20loss#torch.nn.L1Loss) \n", "is defined in PyTorch and has the following example for its forward pass:\n", "\n", " >>> loss = nn.L1Loss()\n", " >>> input = torch.randn(3, 5, requires_grad=True)\n", " >>> target = torch.randn(3, 5)\n", " >>> output = loss(input, target)\n", " >>> output.backward()\n", " \n", "The arguments to the function are `input` and `target`. So the mapping from the dictionary\n", "provided by our dataset and model jointly is to use `my_estimates` (like we defined above) as the input and \n", "`source_magnitudes` (what we are trying to match) as the target. This results in \n", "the `DEFAULT_KEYS` you see above. Alternatively, you can pass the mapping between\n", "the dictionary and the arguments to the loss function directly into the loss dictionary\n", "like so:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "loss_dictionary = {\n", " 'L1Loss': {\n", " 'weight': 1.0,\n", " 'keys': {\n", " 'my_estimates': 'input',\n", " 'source_magnitudes': 'target',\n", " }\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Great, now let's use this loss dictionary in a Closure and see what happens." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(L1Loss(),\n", " 1.0,\n", " {'my_estimates': 'input', 'source_magnitudes': 'target'},\n", " 'L1Loss')]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "closure = nussl.ml.train.closures.Closure(loss_dictionary)\n", "closure.losses" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The closure was instantiated with the losses. Calling `closure.compute_loss` results\n", "in the following:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "L1Loss tensor(0.0037, grad_fn=)\n", "loss tensor(0.0037, grad_fn=)\n" ] } ], "source": [ "output = model(item)\n", "loss_output = closure.compute_loss(output, item)\n", "for key, val in loss_output.items():\n", " print(key, val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output is a dictionary with the `loss` item corresponding to the total\n", "(summed) loss and the other keys corresponding to the individual losses." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom loss functions ###\n", "\n", "Loss functions can be registered with the Closure in the same way that\n", "modules are registered with SeparationModel:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nussl.ml.train.loss contents:\n", "-----------------------------\n", "CombinationInvariantLoss\n", "DeepClusteringLoss\n", "KLDivLoss\n", "L1Loss\n", "MSELoss\n", "MeanDifference\n", "PermutationInvariantLoss\n", "SISDRLoss\n", "WhitenedKMeansLoss\n" ] } ], "source": [ "class MeanDifference(torch.nn.Module):\n", " DEFAULT_KEYS = {'my_estimates': 'input', 'source_magnitudes': 'target'}\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, input, target):\n", " return torch.abs(input.mean() - target.mean())\n", "\n", "\n", "nussl.ml.register_loss(MeanDifference)\n", "print_existing_losses()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now this loss can be used in a closure:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MeanDifference tensor(0.0012, grad_fn=)\n", "loss tensor(0.0012, grad_fn=)\n" ] } ], "source": [ "new_loss_dictionary = {\n", " 'MeanDifference': {}\n", "}\n", "\n", "new_closure = nussl.ml.train.closures.Closure(new_loss_dictionary)\n", "new_closure.losses\n", "\n", "output = model(item)\n", "loss_output = new_closure.compute_loss(output, item)\n", "for key, val in loss_output.items():\n", " print(key, val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optimizing the model ###\n", "\n", "We now have a loss. We can then put it backwards through the model and\n", "take a step forward on the model with an optimizer. Let's define\n", "an optimizer (we'll use Adam), and then use it to take a step on\n", "the model:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'L1Loss': tensor(0.0037, grad_fn=), 'loss': tensor(0.0037, grad_fn=)}\n" ] } ], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=.001)\n", "\n", "optimizer.zero_grad()\n", "output = model(item)\n", "loss_output = closure.compute_loss(output, item)\n", "loss_output['loss'].backward()\n", "optimizer.step()\n", "print(loss_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cool, we did a single step. Instead of manually defining this all above, we can \n", "instead use the TrainClosure from *nussl*." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "train_closure = nussl.ml.train.closures.TrainClosure(\n", " loss_dictionary, optimizer, model\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `__call__` function of the closure takes an `engine` as well as the batch data. \n", "Since we don't currently have an `engine` object (more on that below), let's just pass `None`.\n", "We can run this on a batch:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'L1Loss': 0.003581754630431533, 'loss': 0.003581754630431533}" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_closure(None, item)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can run this a bunch of times and watch the loss go down." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "loss_history = []\n", "n_iter = 100\n", "\n", "for i in range(n_iter):\n", " loss_output = train_closure(None, item)\n", " loss_history.append(loss_output['loss'])" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(loss_history)\n", "plt.title('Train loss')\n", "plt.xlabel('Iteration')\n", "plt.ylabel('Loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that there is also a `ValidationClosure` which does not take\n", "an optimization step but only computes the loss. \n", "\n", "Let's look at the model output now!" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = model(item)\n", "\n", "for i in range(output['my_estimates'].shape[-1]):\n", " plt.figure(figsize=(10, 5))\n", " plt.subplot(121)\n", " plt.imshow(\n", " output['my_estimates'][0, ..., 0, i].T.cpu().data.numpy(),\n", " origin='lower')\n", " plt.title(\"Source\")\n", "\n", " plt.subplot(122)\n", " plt.imshow(\n", " output['my_mask'][0, ..., 0, i].T.cpu().data.numpy(),\n", " origin='lower')\n", " plt.title(\"Mask\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hey! That looks a lot better! We've now overfit the model to a single item in the dataset. Now, let's do it at scale by using a PyTorch Ignite engines with the functionality in `nussl.ml.train`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ignite Engines\n", "--------------\n", "\n", "*nussl* uses PyTorch Ignite to power its training functionality. PyTorch\n", "At the heart of Ingite is the *Engine* object. An Engine contains a lot\n", "of functionality for iterating through a dataset and feeding data to a model.\n", "What makes Ignite so desireable is that we can define all of the things we\n", "need to train a model ahead of time, the the Ignite engine will run the code\n", "to train the model for us. This saves us a lot of time writing boilerplate\n", "code for training. *nussl* also provides a lot of boilerplate code for\n", "training source separation models, specifically.\n", "\n", "To use Ignite with *nussl*, the only thing we need to to define is a *closure*. \n", "A closure defines a pass through the model for a single batch. The rest of\n", "the details, such as queueing up data, are taken care of by\n", "`torch.utils.data.DataLoader` and the engine object. All of the state\n", "regarding a training run, such as the epoch number, the loss history, etc,\n", "is kept in the engine's state at `engine.state`.\n", "\n", "*nussl* provides a helper function to build a standard engine with a lot\n", "of nice functionality like keeping track of \n", "loss history, preparing the batches properly, setting up the \n", "train and validation closures. This function is `create_train_and_validation_engines()`.\n", "\n", "It's also possible to add attach handlers to an Engine for further \n", "functionality. These handlers make use of the engine's state. *nussl* \n", "comes with several of these:\n", "\n", "1. `add_validate_and_checkpoint`: Adds a pass on the validation data and \n", " checkpoints the model based on the validation loss to either `best`\n", " (if this was the lowest validation loss model) or `latest`.\n", "2. `add_stdout_handler`: Prints some handy information after each epoch.\n", "3. `add_tensorboard_handler`: Logs loss data to tensorboard.\n", "\n", "See the API documentation for further details on these handlers.\n", "\n", "### Putting it all together ###\n", "\n", "Let's put this all together. Let's build the dataset, model and\n", "optimizer, train and validation closures, and engines. Let's also\n", "use the GPU if it's available." ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "# define everything as before\n", "modules = {\n", " 'mix_magnitude': {},\n", " 'log_spec': {\n", " 'class': 'AmplitudeToDB'\n", " },\n", " 'norm': {\n", " 'class': 'BatchNorm',\n", " },\n", " 'mask': {\n", " 'class': 'Embedding',\n", " 'args': {\n", " 'num_features': num_features,\n", " 'hidden_size': num_features,\n", " 'embedding_size': num_sources,\n", " 'activation': mask_activation,\n", " 'num_audio_channels': num_audio_channels,\n", " 'dim_to_embed': [2, 3] # embed the frequency dimension (2) for all audio channels (3)\n", " }\n", " },\n", " 'estimates': {\n", " 'class': 'Mask',\n", " },\n", "}\n", "\n", "connections = [\n", " ['log_spec', ['mix_magnitude', ]],\n", " ['norm', ['log_spec', ]],\n", " ['mask', ['norm', ]],\n", " ['estimates', ['mask', 'mix_magnitude']]\n", "]\n", "\n", "# define the outputs\n", "output = ['estimates', 'mask']\n", "\n", "config = {\n", " 'modules': modules,\n", " 'connections': connections,\n", " 'output': output\n", "}" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 5\n", "LEARNING_RATE = 1e-3\n", "OUTPUT_FOLDER = os.path.expanduser('~/.nussl/tutorial/sinewave')\n", "RESULTS_DIR = os.path.join(OUTPUT_FOLDER, 'results')\n", "NUM_WORKERS = 2\n", "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "shutil.rmtree(os.path.join(RESULTS_DIR), ignore_errors=True)\n", "\n", "os.makedirs(RESULTS_DIR, exist_ok=True)\n", "os.makedirs(OUTPUT_FOLDER, exist_ok=True)\n", "\n", "# adjust logging so we see output of the handlers\n", "logger = logging.getLogger()\n", "logger.setLevel(logging.INFO)\n", "\n", "# Put together data\n", "stft_params = nussl.STFTParams(window_length=256, hop_length=64)\n", "tfm = nussl.datasets.transforms.Compose([\n", " nussl.datasets.transforms.PhaseSensitiveSpectrumApproximation(),\n", " nussl.datasets.transforms.MagnitudeWeights(),\n", " nussl.datasets.transforms.ToSeparationModel()\n", "])\n", "sine_wave_dataset = SineWaves(\n", " 'ignored', sample_rate=8000, stft_params=stft_params,\n", " transform=tfm\n", ")\n", "dataloader = torch.utils.data.DataLoader(\n", " sine_wave_dataset, batch_size=BATCH_SIZE\n", ")\n", "\n", "# Build our simple model\n", "model = nussl.ml.SeparationModel(config).to(DEVICE)\n", "\n", "# Build an optimizer\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n", "\n", "# Set up loss functions and closure\n", "# We'll use permutation invariant loss since we don't\n", "# care what order the sine waves get output in, just that\n", "# they are different.\n", "loss_dictionary = {\n", " 'PermutationInvariantLoss': {\n", " 'weight': 1.0,\n", " 'args': ['L1Loss']\n", " }\n", "}\n", "\n", "train_closure = nussl.ml.train.closures.TrainClosure(\n", " loss_dictionary, optimizer, model\n", ")\n", "val_closure = nussl.ml.train.closures.ValidationClosure(\n", " loss_dictionary, model\n", ")\n", "\n", "# Build the engine and add handlers\n", "train_engine, val_engine = nussl.ml.train.create_train_and_validation_engines(\n", " train_closure, val_closure, device=DEVICE\n", ")\n", "nussl.ml.train.add_validate_and_checkpoint(\n", " OUTPUT_FOLDER, model, optimizer, sine_wave_dataset, train_engine,\n", " val_data=dataloader, validator=val_engine\n", ")\n", "nussl.ml.train.add_stdout_handler(train_engine, val_engine)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cool! We built an engine! (Note the distinction between using the original dataset\n", "object and using the dataloader object.) \n", "\n", "Now to train it, all we have to do is `run`\n", "the engine. Since our SineWaves dataset makes mixes \"on the fly\" (i.e., every time\n", "we get an `item`, the dataset will return a mix of random sine waves), it is\n", "impossible to loop through the whole dataset, and therefore there is no concept\n", "of an epoch. In this case, we will instead define an arbitrary `epoch_length`\n", "of 1000 and pass that value to `train_engine`. After one epoch, the validation\n", "will be run and everything will get printed by the `stdout` handler. \n", "\n", "Let's see it run:" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:\n", "\n", "EPOCH SUMMARY \n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n", "- Epoch number: 0001 / 0001 \n", "- Training loss: 0.001197 \n", "- Validation loss: 0.000683 \n", "- Epoch took: 00:02:11 \n", "- Time since start: 00:02:11 \n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n", "Saving to /home/pseetharaman/.nussl/tutorial/sinewave/checkpoints/best.model.pth. \n", "Output @ /home/pseetharaman/.nussl/tutorial/sinewave \n", "\n", "INFO:ignite.engine.engine.Engine:Engine run complete. Time taken 00:02:11\n" ] }, { "data": { "text/plain": [ "State:\n", "\titeration: 1000\n", "\tepoch: 1\n", "\tepoch_length: 1000\n", "\tmax_epochs: 1\n", "\toutput: \n", "\tbatch: \n", "\tmetrics: \n", "\tdataloader: \n", "\tseed: 12\n", "\tepoch_history: \n", "\titer_history: \n", "\tpast_iter_history: \n", "\tsaved_model_path: /home/pseetharaman/.nussl/tutorial/sinewave/checkpoints/best.model.pth\n", "\toutput_folder: /home/pseetharaman/.nussl/tutorial/sinewave" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_engine.run(dataloader, epoch_length=1000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check out the loss over each iteration in the single epoch\n", "by examining the state:" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(train_engine.state.iter_history['loss'])\n", "plt.xlabel('Iteration')\n", "plt.ylabel('Loss')\n", "plt.title('Train Loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's also see what got saved in the output folder:" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[01;34m/home/pseetharaman/.nussl/tutorial/sinewave\u001b[00m\r\n", "├── \u001b[01;34mcheckpoints\u001b[00m\r\n", "│   ├── best.model.pth\r\n", "│   ├── best.optimizer.pth\r\n", "│   ├── latest.model.pth\r\n", "│   └── latest.optimizer.pth\r\n", "└── \u001b[01;34mresults\u001b[00m\r\n", "\r\n", "2 directories, 4 files\r\n" ] } ], "source": [ "!tree {OUTPUT_FOLDER}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So the models and optimizers got saved! Let's load back one of these\n", "models and see what's in it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What's in a model?\n", "------------------\n", "\n", "After we're finished training the model, it will be saved by our \n", "`add_validate_and_checkpoint` handler. What gets saved in our model? Let's see:" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dict_keys(['state_dict', 'config', 'metadata', 'nussl_version'])\n" ] } ], "source": [ "saved_model = torch.load(train_engine.state.saved_model_path)\n", "print(saved_model.keys())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, there's the `state_dict` containing the weights of\n", "the trained model, the `config` containing the configuration of the model.\n", "There also a `metadata` key in the saved model. Let's check out the metadata..." ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dict_keys(['stft_params', 'sample_rate', 'num_channels', 'folder', 'transforms', 'trainer.state_dict', 'trainer.state.epoch_history'])\n" ] } ], "source": [ "print(saved_model['metadata'].keys())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There's a whole bunch of stuff related to training, like the folder \n", "it was trained on, the state dictionary of the engine used to train the \n", "model, the loss history for each epoch (not each iteration - that's too big).\n", "\n", "There are also keys that are related to the parameters of the AudioSignal. \n", "Namely, `stft_params`, `sample_rate`, and `num_channels`. These \n", "are used by *nussl* to prepare an AudioSignal object to be put into a\n", "deep learning based separation algorithm. There's also a `transforms`\n", "key - this is used by *nussl* to construct the input dictionary at\n", "inference time on an AudioSignal so that the data going into the model\n", "matches how it was given during training time. Let's look at each of these:" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "stft_params: STFTParams(window_length=256, hop_length=64, window_type=None)\n", "sample_rate: 8000\n", "num_channels: 1\n", "folder: ignored\n", "transforms: Compose(\n", " PhaseSensitiveSpectrumApproximation(mix_key = mix, source_key = sources)\n", " \n", " ToSeparationModel()\n", ")\n", "trainer.state_dict: {'epoch': 1, 'epoch_length': 1000, 'max_epochs': 1, 'output': {'PermutationInvariantLoss': 0.000967394735198468, 'loss': 0.000967394735198468}, 'metrics': {}, 'seed': 12}\n", "trainer.state.epoch_history: {'validation/PermutationInvariantLoss': [0.000682936332304962], 'validation/loss': [0.000682936332304962], 'train/PermutationInvariantLoss': [0.0011968749410734745], 'train/loss': [0.0011968749410734745]}\n" ] } ], "source": [ "for key in saved_model['metadata']:\n", " print(f\"{key}: {saved_model['metadata'][key]}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "**Importantly**, everything saved with the model makes training it *entirely reproduceable*. We have everything we need to recreate another model exactly like this if we need to.\n", "\n", "Now that we've trained our toy model, let's move on to actually using and evaluating it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using and evaluating a trained model\n", "------------------------------------\n", "\n", "In this tutorial, we built very simple a deep mask estimation network. There is a \n", "corresponding separation algorithm in *nussl* for using \n", "deep mask estimation networks. Let's build our dataset\n", "again, this time *without* transforms, so we have access to\n", "the actual AudioSignal objects. Then let's instantiate the\n", "separation algorithm and use it to separate an item from the \n", "dataset." ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tt_dataset = SineWaves(\n", " 'ignored', sample_rate=8000\n", ")\n", "tt_dataset.frequencies = sine_wave_dataset.frequencies\n", "\n", "item = tt_dataset[0] # <-- This is an AugioSignal obj\n", "\n", "MODEL_PATH = os.path.join(OUTPUT_FOLDER, 'checkpoints/best.model.pth')\n", "\n", "separator = nussl.separation.deep.DeepMaskEstimation(\n", " item['mix'], model_path=MODEL_PATH\n", ")\n", "estimates = separator()\n", "\n", "visualize_and_embed(estimates)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluation in parallel ###\n", "\n", "We'll usually want to run many mixtures through the model, separate,\n", "and get evaluation metrics like SDR, SIR, and SAR. We can do that with\n", "the following bit of code:" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/pseetharaman/Dropbox/research/nussl_refactor/nussl/separation/base/separation_base.py:71: UserWarning: input_audio_signal has no data!\n", " warnings.warn('input_audio_signal has no data!')\n", "/home/pseetharaman/Dropbox/research/nussl_refactor/nussl/core/audio_signal.py:445: UserWarning: Initializing STFT with data that is non-complex. This might lead to weird results!\n", " warnings.warn('Initializing STFT with data that is non-complex. '\n", " 96%|█████████▌| 96/100 [00:04<00:00, 17.45it/s]/home/pseetharaman/Dropbox/research/nussl_refactor/nussl/evaluation/bss_eval.py:33: RuntimeWarning: divide by zero encountered in log10\n", " srr = -10 * np.log10((1 - (1/alpha)) ** 2)\n", "100%|██████████| 100/100 [00:05<00:00, 19.65it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n", " MEAN +/- STD OF METRICS \n", " \n", "┌─────────┬──────────────────┬──────────────────┬──────────────────┬──────────────────┐\n", "│ METRIC │ OVERALL │ SINE1 │ SINE2 │ SINE3 │\n", "╞═════════╪══════════════════╪══════════════════╪══════════════════╪══════════════════╡\n", "│ # │ 300 │ 100 │ 100 │ 100 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SDR │ 14.83 +/- 15.89 │ 15.44 +/- 14.83 │ 14.20 +/- 17.41 │ 14.85 +/- 15.44 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SIR │ 25.39 +/- 19.97 │ 25.79 +/- 18.48 │ 24.89 +/- 21.90 │ 25.50 +/- 19.57 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SAR │ 19.46 +/- 14.75 │ 19.27 +/- 14.17 │ 19.26 +/- 15.35 │ 19.84 +/- 14.86 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SD-SDR │ 8.68 +/- 21.89 │ 8.65 +/- 21.83 │ 8.38 +/- 22.88 │ 9.02 +/- 21.15 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SNR │ 15.13 +/- 10.91 │ 15.23 +/- 10.46 │ 15.23 +/- 11.41 │ 14.92 +/- 10.96 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SRR │ 20.46 +/- 29.52 │ 20.31 +/- 29.74 │ 20.62 +/- 30.97 │ 20.45 +/- 28.06 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SDRi │ 17.84 +/- 15.89 │ 18.45 +/- 14.83 │ 17.21 +/- 17.41 │ 17.86 +/- 15.44 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SD-SDRi │ 11.69 +/- 21.89 │ 11.66 +/- 21.83 │ 11.39 +/- 22.88 │ 12.03 +/- 21.15 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SNRi │ 18.14 +/- 10.91 │ 18.24 +/- 10.46 │ 18.24 +/- 11.41 │ 17.93 +/- 10.96 │\n", "└─────────┴──────────────────┴──────────────────┴──────────────────┴──────────────────┘\n", " \n", " MEDIAN OF METRICS \n", " \n", "┌─────────┬──────────────────┬──────────────────┬──────────────────┬──────────────────┐\n", "│ METRIC │ OVERALL │ SINE1 │ SINE2 │ SINE3 │\n", "╞═════════╪══════════════════╪══════════════════╪══════════════════╪══════════════════╡\n", "│ # │ 300 │ 100 │ 100 │ 100 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SDR │ 19.40 │ 19.73 │ 19.16 │ 19.16 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SIR │ 28.23 │ 28.29 │ 28.53 │ 27.48 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SAR │ 22.07 │ 21.56 │ 22.41 │ 22.11 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SD-SDR │ 16.27 │ 17.13 │ 16.12 │ 15.66 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SNR │ 16.87 │ 17.68 │ 16.57 │ 16.05 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SRR │ 26.03 │ 26.08 │ 26.32 │ 24.06 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SI-SDRi │ 22.42 │ 22.74 │ 22.17 │ 22.17 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SD-SDRi │ 19.28 │ 20.14 │ 19.13 │ 18.67 │\n", "├─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┤\n", "│ SNRi │ 19.88 │ 20.70 │ 19.59 │ 19.06 │\n", "└─────────┴──────────────────┴──────────────────┴──────────────────┴──────────────────┘\n", " \n", " NOTES \n", " \n", "Testing on sine waves\n" ] } ], "source": [ "# make a separator with an empty audio signal initially\n", "# this one will live on gpu (if one exists) and be used in a \n", "# threadpool for speed\n", "dme = nussl.separation.deep.DeepMaskEstimation(\n", " nussl.AudioSignal(), model_path=MODEL_PATH, device='cuda'\n", ")\n", "\n", "\n", "def forward_on_gpu(audio_signal):\n", " # set the audio signal of the object to this item's mix\n", " dme.audio_signal = audio_signal\n", " masks = dme.forward()\n", " return masks\n", "\n", "\n", "def separate_and_evaluate(item, masks):\n", " separator = nussl.separation.deep.DeepMaskEstimation(item['mix'])\n", " estimates = separator(masks)\n", "\n", " evaluator = nussl.evaluation.BSSEvalScale(\n", " list(item['sources'].values()), estimates, \n", " compute_permutation=True,\n", " source_labels=['sine1', 'sine2', 'sine3']\n", " )\n", " scores = evaluator.evaluate()\n", " output_path = os.path.join(\n", " RESULTS_DIR, f\"{item['mix'].file_name}.json\"\n", " )\n", " with open(output_path, 'w') as f:\n", " json.dump(scores, f)\n", " \n", "pool = ThreadPoolExecutor(max_workers=NUM_WORKERS)\n", "for i, item in enumerate(tqdm.tqdm(tt_dataset)):\n", " masks = forward_on_gpu(item['mix'])\n", " if i == 0:\n", " separate_and_evaluate(item, masks)\n", " else:\n", " pool.submit(separate_and_evaluate, item, masks)\n", "pool.shutdown(wait=True)\n", "\n", "json_files = glob.glob(f\"{RESULTS_DIR}/*.json\")\n", "df = nussl.evaluation.aggregate_score_files(json_files)\n", "report_card = nussl.evaluation.report_card(\n", " df, notes=\"Testing on sine waves\", report_each_source=True)\n", "print(report_card)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We parallelized the evaluation across 2 workers, kept two copies of\n", "the separator, one of which lives on the GPU, and the other which\n", "lives on the CPU. The GPU one does a forward pass in its own thread\n", "and then hands it to the other separator which actually computes the\n", "estimates and evaluates the metrics in parallel. After we're done, \n", "we aggregate all the results (each of which was saved to a JSON file)\n", "using `nussl.evaluation.aggregate_score_files` and then use the\n", "nussl report card at `nussl.evaluation.report_card` to view the results.\n", "We also now have the results as a pandas DataFrame:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sourcefileSI-SDRSI-SIRSI-SARSD-SDRSNRSRRSI-SDRiSD-SDRiSNRi
0sine346.wav.json-20.328155-8.004939-20.066033-40.2751100.039984-40.230923-17.317855-37.2648103.050284
1sine395.wav.json28.66778047.82314228.72085625.99148826.28212229.36364131.67808029.00178829.292422
2sine310.wav.json33.81245851.03662433.89554133.17282033.24307841.80721636.82275836.18312036.253378
3sine383.wav.json9.7672979.88216725.6003889.7461699.94098932.88604012.77759712.75646912.951289
4sine322.wav.json6.32336323.4204626.408937-19.1792300.894722-19.1669809.333663-16.1689303.905022
....................................
295sine10.wav.json20.90110336.09248821.03454720.88309420.93342144.71519023.91140323.89339423.943721
296sine163.wav.json18.45784934.46241318.56821517.42933617.94966924.18919221.46814920.43963620.959969
297sine192.wav.json2.0703332.34129314.253886-7.6465932.669284-7.1563965.080632-4.6362935.679584
298sine137.wav.json-16.62066512.170543-16.614925-49.3679830.027190-49.365675-13.610365-46.3576833.037490
299sine143.wav.json19.68763220.53788127.18820719.68072119.71657947.66660322.69793222.69102122.726879
\n", "

300 rows × 11 columns

\n", "
" ], "text/plain": [ " source file SI-SDR SI-SIR SI-SAR SD-SDR \\\n", "0 sine3 46.wav.json -20.328155 -8.004939 -20.066033 -40.275110 \n", "1 sine3 95.wav.json 28.667780 47.823142 28.720856 25.991488 \n", "2 sine3 10.wav.json 33.812458 51.036624 33.895541 33.172820 \n", "3 sine3 83.wav.json 9.767297 9.882167 25.600388 9.746169 \n", "4 sine3 22.wav.json 6.323363 23.420462 6.408937 -19.179230 \n", ".. ... ... ... ... ... ... \n", "295 sine1 0.wav.json 20.901103 36.092488 21.034547 20.883094 \n", "296 sine1 63.wav.json 18.457849 34.462413 18.568215 17.429336 \n", "297 sine1 92.wav.json 2.070333 2.341293 14.253886 -7.646593 \n", "298 sine1 37.wav.json -16.620665 12.170543 -16.614925 -49.367983 \n", "299 sine1 43.wav.json 19.687632 20.537881 27.188207 19.680721 \n", "\n", " SNR SRR SI-SDRi SD-SDRi SNRi \n", "0 0.039984 -40.230923 -17.317855 -37.264810 3.050284 \n", "1 26.282122 29.363641 31.678080 29.001788 29.292422 \n", "2 33.243078 41.807216 36.822758 36.183120 36.253378 \n", "3 9.940989 32.886040 12.777597 12.756469 12.951289 \n", "4 0.894722 -19.166980 9.333663 -16.168930 3.905022 \n", ".. ... ... ... ... ... \n", "295 20.933421 44.715190 23.911403 23.893394 23.943721 \n", "296 17.949669 24.189192 21.468149 20.439636 20.959969 \n", "297 2.669284 -7.156396 5.080632 -4.636293 5.679584 \n", "298 0.027190 -49.365675 -13.610365 -46.357683 3.037490 \n", "299 19.716579 47.666603 22.697932 22.691021 22.726879 \n", "\n", "[300 rows x 11 columns]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can look at the structure of the output folder again,\n", "seeing there are now 100 entries under results corresponding to each\n", "item in `sine_wave_dataset`:" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[01;34m/home/pseetharaman/.nussl/tutorial/sinewave\u001b[00m\r\n", "├── \u001b[01;34mcheckpoints\u001b[00m\r\n", "│   ├── best.model.pth\r\n", "│   ├── best.optimizer.pth\r\n", "│   ├── latest.model.pth\r\n", "│   └── latest.optimizer.pth\r\n", "└── \u001b[01;34mresults\u001b[00m [100 entries exceeds filelimit, not opening dir]\r\n", "\r\n", "2 directories, 4 files\r\n" ] } ], "source": [ "!tree --filelimit 20 {OUTPUT_FOLDER}" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time taken: 169.6379 seconds\n" ] } ], "source": [ "end_time = time.time()\n", "time_taken = end_time - start_time\n", "print(f'Time taken: {time_taken:.4f} seconds')" ] } ], "metadata": { "jupytext": { "formats": "ipynb,py" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }