{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluating ideal binary mask on WHAM!\n", "=====================================\n", "\n", "This recipe evaluates an oracle ideal binary mask on the `mix_clean`\n", "and `min` subset in the WHAM dataset. This recipe is annotated \n", "as a notebook for documentation but can be run directly\n", "as a script in `docs/recipes/ideal_binary_mask.py`.\n", "\n", "Imports\n", "-------" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from nussl import datasets, separation, evaluation\n", "import os\n", "import multiprocessing\n", "from concurrent.futures import ThreadPoolExecutor\n", "import logging\n", "import json\n", "import tqdm\n", "import glob\n", "import numpy as np\n", "import termtables\n", "\n", "# set up logging\n", "logger = logging.getLogger()\n", "logger.setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Setting up\n", "----------\n", "\n", "Make sure to point `WHAM_ROOT` where you've actually\n", "built and saved the WHAM dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "WHAM_ROOT = '/home/data/wham/'\n", "NUM_WORKERS = multiprocessing.cpu_count() // 4\n", "OUTPUT_DIR = os.path.expanduser('~/.nussl/recipes/ideal_binary_mask/')\n", "RESULTS_DIR = os.path.join(OUTPUT_DIR, 'results')\n", "os.makedirs(RESULTS_DIR, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluation\n", "----------" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3000/3000 [02:03<00:00, 24.34it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "┌────────────────────┬────────────────────┬────────────────────┐\n", "│ │ OVERALL (N = 6000) │ │\n", "╞════════════════════╪════════════════════╪════════════════════╡\n", "│ SAR │ SDR │ SIR │\n", "├────────────────────┼────────────────────┼────────────────────┤\n", "│ 13.663734911779562 │ 13.477634987731774 │ 28.693360844254492 │\n", "└────────────────────┴────────────────────┴────────────────────┘\n" ] } ], "source": [ "test_dataset = datasets.WHAM(WHAM_ROOT, sample_rate=8000, split='tt')\n", "\n", "def separate_and_evaluate(item):\n", " separator = separation.benchmark.IdealBinaryMask(\n", " item['mix'], item['sources'], mask_type='binary')\n", " estimates = separator()\n", "\n", " evaluator = evaluation.BSSEvalScale(\n", " list(item['sources'].values()), estimates, compute_permutation=True)\n", " scores = evaluator.evaluate()\n", " output_path = os.path.join(RESULTS_DIR, f\"{item['mix'].file_name}.json\")\n", " with open(output_path, 'w') as f:\n", " json.dump(scores, f)\n", "\n", "\n", "pool = ThreadPoolExecutor(max_workers=NUM_WORKERS)\n", "for i, item in enumerate(tqdm.tqdm(test_dataset)):\n", " if i == 0:\n", " separate_and_evaluate(item)\n", " else:\n", " pool.submit(separate_and_evaluate, item)\n", "pool.shutdown(wait=True)\n", "\n", "json_files = glob.glob(f\"{RESULTS_DIR}/*.json\")\n", "df = evaluation.aggregate_score_files(json_files)\n", "\n", "overall = df.mean()\n", "headers = [\"\", f\"OVERALL (N = {df.shape[0]})\", \"\"]\n", "metrics = [\"SAR\", \"SDR\", \"SIR\"]\n", "data = np.array(df.mean()).T\n", "\n", "data = [metrics, data]\n", "termtables.print(data, header=headers, padding=(0, 1), alignment=\"ccc\")" ] } ], "metadata": { "jupytext": { "formats": "ipynb,py:light", "notebook_metadata_filter": "nbsphinx" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "nbsphinx": { "execute": "never" } }, "nbformat": 4, "nbformat_minor": 4 }