{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluating ideal ratio mask on WHAM!\n", "=====================================\n", "\n", "This recipe evaluates an oracle ideal ratio mask on the `mix_clean`\n", "and `min` subset in the WHAM dataset. This recipe is annotated \n", "as a notebook for documentation but can be run directly\n", "as a script in `docs/recipes/ideal_ratio_mask.py`.\n", "\n", "We evaluate three approaches to constructing the ideal ratio mask:\n", "\n", "- Magnitude spectrum approximation\n", "- Phase sensitive spectrum approximation\n", "- Truncated phase sensitive spectrum approximation\n", "\n", "Imports\n", "----------" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from nussl import datasets, separation, evaluation\n", "import os\n", "import multiprocessing\n", "from concurrent.futures import ThreadPoolExecutor\n", "import logging\n", "import json\n", "import tqdm\n", "import glob\n", "import numpy as np\n", "import termtables\n", "\n", "# set up logging\n", "logger = logging.getLogger()\n", "logger.setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Setting up\n", "----------\n", "\n", "Make sure to point `WHAM_ROOT` where you've actually\n", "built and saved the WHAM dataset. There's a few different\n", "ways to use ideal ratio masks, so we're going to set those\n", "up in a dictionary." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "WHAM_ROOT = '/home/data/wham/'\n", "NUM_WORKERS = multiprocessing.cpu_count() // 4\n", "OUTPUT_DIR = os.path.expanduser('~/.nussl/recipes/ideal_ratio_mask/')\n", "APPROACHES = {\n", " 'Phase-sensitive spectrum approx.': {\n", " 'kwargs': {\n", " 'range_min': -np.inf, 'range_max':np.inf\n", " },\n", " 'approach': 'psa',\n", " 'dir': 'psa' \n", " },\n", " 'Truncated phase-sensitive approx.': {\n", " 'kwargs': {\n", " 'range_min': 0.0, 'range_max': 1.0\n", " },\n", " 'approach': 'psa',\n", " 'dir': 'tpsa' \n", " },\n", " 'Magnitude spectrum approximation': {\n", " 'kwargs': {},\n", " 'approach': 'msa',\n", " 'dir': 'msa'\n", " }\n", "}\n", "\n", "RESULTS_DIR = os.path.join(OUTPUT_DIR, 'results')\n", "for key, val in APPROACHES.items():\n", " _dir = os.path.join(RESULTS_DIR, val['dir'])\n", " os.makedirs(_dir, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluation\n", "----------" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3000/3000 [02:02<00:00, 24.54it/s]\n", " 0%| | 11/3000 [00:00<00:29, 102.56it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------\n", "PHASE-SENSITIVE SPECTRUM APPROX.\n", "--------------------------------\n", "┌────────────────────┬────────────────────┬────────────────────┐\n", "│ │ OVERALL (N = 6000) │ │\n", "╞════════════════════╪════════════════════╪════════════════════╡\n", "│ SAR │ SDR │ SIR │\n", "├────────────────────┼────────────────────┼────────────────────┤\n", "│ 16.757149470647175 │ 16.433001412848633 │ 28.393683596452078 │\n", "└────────────────────┴────────────────────┴────────────────────┘\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3000/3000 [02:04<00:00, 24.15it/s]\n", " 0%| | 11/3000 [00:00<00:28, 105.40it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "---------------------------------\n", "TRUNCATED PHASE-SENSITIVE APPROX.\n", "---------------------------------\n", "┌────────────────────┬────────────────────┬────────────────────┐\n", "│ │ OVERALL (N = 6000) │ │\n", "╞════════════════════╪════════════════════╪════════════════════╡\n", "│ SAR │ SDR │ SIR │\n", "├────────────────────┼────────────────────┼────────────────────┤\n", "│ 15.243919243117174 │ 14.648945592443148 │ 23.894627310434977 │\n", "└────────────────────┴────────────────────┴────────────────────┘\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3000/3000 [02:00<00:00, 24.80it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------\n", "MAGNITUDE SPECTRUM APPROXIMATION\n", "--------------------------------\n", "┌────────────────────┬────────────────────┬────────────────────┐\n", "│ │ OVERALL (N = 6000) │ │\n", "╞════════════════════╪════════════════════╪════════════════════╡\n", "│ SAR │ SDR │ SIR │\n", "├────────────────────┼────────────────────┼────────────────────┤\n", "│ 13.677899166842302 │ 12.694045978486537 │ 19.854502292474113 │\n", "└────────────────────┴────────────────────┴────────────────────┘\n" ] } ], "source": [ "test_dataset = datasets.WHAM(WHAM_ROOT, sample_rate=8000, split='tt')\n", "\n", "for key, val in APPROACHES.items():\n", " def separate_and_evaluate(item):\n", " output_path = os.path.join(\n", " RESULTS_DIR, val['dir'], f\"{item['mix'].file_name}.json\")\n", " separator = separation.benchmark.IdealRatioMask(\n", " item['mix'], item['sources'], approach=val['approach'],\n", " mask_type='soft', **val['kwargs'])\n", " estimates = separator()\n", "\n", " evaluator = evaluation.BSSEvalScale(\n", " list(item['sources'].values()), estimates, compute_permutation=True)\n", " scores = evaluator.evaluate()\n", " with open(output_path, 'w') as f:\n", " json.dump(scores, f)\n", "\n", " pool = ThreadPoolExecutor(max_workers=NUM_WORKERS)\n", " for i, item in enumerate(tqdm.tqdm(test_dataset)):\n", " if i == 0:\n", " separate_and_evaluate(item)\n", " else:\n", " pool.submit(separate_and_evaluate, item)\n", " pool.shutdown(wait=True)\n", "\n", " json_files = glob.glob(f\"{RESULTS_DIR}/{val['dir']}/*.json\")\n", " df = evaluation.aggregate_score_files(json_files)\n", "\n", " overall = df.mean()\n", " print(''.join(['-' for i in range(len(key))]))\n", " print(key.upper())\n", " print(''.join(['-' for i in range(len(key))]))\n", " headers = [\"\", f\"OVERALL (N = {df.shape[0]})\", \"\"]\n", " metrics = [\"SAR\", \"SDR\", \"SIR\"]\n", " data = np.array(df.mean()).T\n", "\n", " data = [metrics, data]\n", " termtables.print(data, header=headers, padding=(0, 1), alignment=\"ccc\")" ] } ], "metadata": { "jupytext": { "formats": "ipynb,py:light", "notebook_metadata_filter": "nbsphinx" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "nbsphinx": { "execute": "never" } }, "nbformat": 4, "nbformat_minor": 4 }