Source code for nussl.evaluation.report_card

import pandas as pd
import json
import termtables
import numpy as np
import os
import textwrap

def truncate(values, decs=2):
    return np.trunc(values*10**decs)/(10**decs)

[docs]def aggregate_score_files(json_files, aggregator=np.nanmedian): """ Takes a list of json files output by an Evaluation method in nussl and aggregates all the metrics into a Pandas dataframe. Sample output: .. code-block:: none SDR SIR SAR drums oracle0.json 9.086025 15.025801 10.362709 random0.json -6.539877 -6.087538 3.508338 oracle1.json 9.591432 14.335700 11.365882 random1.json -1.358840 -0.993666 9.577297 bass oracle0.json 7.936720 12.843092 9.631929 random0.json -4.190299 -3.730649 5.802003 oracle1.json 8.581090 12.513445 10.831370 random1.json 0.365171 0.697621 11.693103 other oracle0.json 2.024207 6.133359 4.158805 random0.json -9.857085 -9.481909 0.965199 oracle1.json 3.961383 6.861785 7.085745 random1.json -4.042277 -3.707997 7.260934 vocals oracle0.json 12.169686 16.650161 14.085037 random0.json -2.440166 -1.884026 6.760966 oracle1.json 12.409913 16.248470 14.725983 random1.json 1.609577 1.958037 12.738970 Args: json_files (list): List of JSON files that will be parsed for metrics. aggregator ([type], optional): How to aggregate results within a single track. Defaults to np.median. Returns: pd.DataFrame: Pandas dataframe containing the aggregated metrics. """ metrics = {} for json_file in json_files: with open(json_file, 'r') as f: data = json.load(f) json_key = os.path.basename(json_file) for name in data: if name not in ['combination', 'permutation']: if name not in metrics: metrics[name] = {} if json_key not in metrics[name]: metrics[name][json_key] = {} for key in data[name]: _data = aggregator(data[name][key]) metrics[name][json_key][key] = _data df = pd.concat({ k: pd.DataFrame(v).T for k, v in metrics.items() }, axis=0, names=['source', 'file']) df.reset_index(inplace=True) return df
def _get_mean_and_std(df, decs=2): """ Gets the mean and standard deviation of each metric in the pandas DataFrame and returns it as a list of strings. """ excluded_columns = ['source', 'file'] metrics = [x for x in list(df.columns) if x not in excluded_columns] metrics.insert(0, '#') means = [ f'{truncate(m, decs=decs):{4+decs}.{decs}f}' for m in np.array(df.mean()).T ] stds = [ f'{truncate(s, decs=decs):{3+decs}.{decs}f}' for s in np.array(df.std()).T ] data = [f'{m} +/- {s}' for m, s in zip(means, stds)] data.insert(0, df.shape[0]) return metrics, data def _get_medians(df, decs=2): """ Gets the median of each metric in the pandas DataFrame and returns it as a list of strings. """ excluded_columns = ['source', 'file'] metrics = [x for x in list(df.columns) if x not in excluded_columns] metrics.insert(0, '#') data = [ f'{truncate(m, decs=decs):{4+decs}.{decs}f}' for m in np.array(df.median()).T ] data.insert(0, df.shape[0]) return metrics, data def _format_title(title, length, marker=" "): pad = (length - len(title)) // 2 pad = ''.join([marker for _ in range(pad)]) border = pad + title + pad if len(title) % 2: border = border + marker return border def _get_report_card(df, func, report_each_source=True, decs=2): """ Gets a report card for a DataFrame using a specific function. """ labels, data = func(df, decs=decs) data.insert(0, 'OVERALL') data = [data] if report_each_source: for name in np.unique(df['source']): _df = df[df['source'] == name] _, _data = func(_df, decs=decs) _data.insert(0, name.upper()) data.append(_data) # transposing data so each column has the source and its metrics data = list(map(list, zip(*data))) header = data.pop(0) header.insert(0, 'METRIC') for i in range(1, len(header)): header[i] = _format_title(header[i], 16) for l, d in zip(labels, data): d.insert(0, l) alignment = ["c" for _ in header] alignment[0] = "l" alignment = ''.join(alignment) report_card = termtables.to_string( data, header=header, padding=(0, 1), alignment=alignment) return report_card def report_card(df, notes=None, report_each_source=True, decs=2): """ Given a Pandas dataframe, usually the output of ``aggregate_score_files``, returns a string that looks like this:: .. code-block:: none MEAN +/- STD OF METRICS ┌─────────┬──────────────────┬──────────────────┬──────────────────┐ │ METRIC │ OVERALL │ S1 │ S2 │ ╞═════════╪══════════════════╪══════════════════╪══════════════════╡ │ # │ 6000 │ 3000 │ 3000 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SDR │ 11.2 +/- 3.8 │ 12.5 +/- 3.5 │ 9.8 +/- 3.5 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SIR │ 22.7 +/- 5.4 │ 22.9 +/- 5.0 │ 22.6 +/- 5.7 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SAR │ 11.6 +/- 3.6 │ 13.0 +/- 3.3 │ 10.1 +/- 3.3 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SD-SDR │ 10.8 +/- 4.0 │ 12.2 +/- 3.8 │ 9.3 +/- 3.7 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SNR │ 11.6 +/- 3.3 │ 12.9 +/- 3.1 │ 10.3 +/- 3.0 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SRR │ 22.8 +/- 6.5 │ 25.5 +/- 6.3 │ 20.0 +/- 5.6 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SDRi │ 11.2 +/- 3.7 │ 10.0 +/- 3.4 │ 12.3 +/- 3.6 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SD-SDRi │ 10.8 +/- 3.8 │ 9.7 +/- 3.6 │ 11.8 +/- 3.7 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SNRi │ 11.6 +/- 3.3 │ 10.3 +/- 3.0 │ 12.9 +/- 3.1 │ └─────────┴──────────────────┴──────────────────┴──────────────────┘ MEDIAN OF METRICS ┌─────────┬──────────────────┬──────────────────┬──────────────────┐ │ METRIC │ OVERALL │ S1 │ S2 │ ╞═════════╪══════════════════╪══════════════════╪══════════════════╡ │ # │ 6000 │ 3000 │ 3000 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SDR │ 11.6 │ 13.1 │ 10.4 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SIR │ 23.6 │ 23.6 │ 23.6 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SAR │ 12.0 │ 13.5 │ 10.6 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SD-SDR │ 11.3 │ 12.9 │ 10.0 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SNR │ 11.9 │ 13.3 │ 10.7 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SRR │ 23.4 │ 26.5 │ 20.6 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SI-SDRi │ 11.6 │ 10.5 │ 12.9 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SD-SDRi │ 11.3 │ 10.3 │ 12.5 │ ├─────────┼──────────────────┼──────────────────┼──────────────────┤ │ SNRi │ 11.9 │ 10.7 │ 13.3 │ └─────────┴──────────────────┴──────────────────┴──────────────────┘ NOTES Uses scale-invariant BSSEval metrics. Evaluated on WSJ0-2Mix at 8000 Hz sample rate. Args: df (pandas.DataFrame): DataFrame containing the metrics computed during evaluation. notes (str, optional): Any additional notes you want to be printed at the bottom of the report card. Defaults to None. report_each_source (bool, optional): Whether or not to report the metrics for each individual source type. Defaults to True. Returns: str: A report card for your experiment. """ mean_report_card = _get_report_card( df, _get_mean_and_std, report_each_source=report_each_source, decs=decs) median_report_card = _get_report_card( df, _get_medians, report_each_source=report_each_source, decs=decs) line_break = mean_report_card.index('\n') report_card = ( f"{_format_title('', line_break)}\n" f"{_format_title(' MEAN +/- STD OF METRICS ', line_break)}\n" f"{_format_title('', line_break)}\n" f"{mean_report_card}\n" f"{_format_title('', line_break)}\n" f"{_format_title(' MEDIAN OF METRICS ', line_break)}\n" f"{_format_title('', line_break)}\n" f"{median_report_card}\n" ) if notes is not None: notes = '\n'.join(textwrap.wrap(notes, line_break)) report_card += ( f"{_format_title('', line_break)}\n" f"{_format_title(' NOTES ', line_break)}\n" f"{_format_title('', line_break)}\n" f"{notes}" ) return report_card