Source code for nussl.core.efz_utils

"""
The *nussl* External File Zoo (EFZ) is a server that houses all files that are too large to
bundle with *nussl* when distributing it through ``pip`` or Github. These types of files include
audio examples, benchmark files for tests, and trained neural network models.

*nussl* has built-in utilities for accessing the EFZ through its API. Here, it is possible to
see what files are available on the EFZ and download desired files. The EFZ utilities allow
for such functionality.
"""

import warnings
import json
import os
import sys
import hashlib

from six.moves.urllib_parse import urljoin
from six.moves.urllib.error import HTTPError
from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlopen, Request
from six.moves.urllib.request import urlretrieve

from . import constants


[docs]def get_available_audio_files(): """ Returns a list of dicts containing metadata of the available audio files on the nussl External File Zoo (EFZ) server (http://nussl.ci.northwestern.edu/). Each entry in the list is in the following format: .. code-block:: python { u'file_length_seconds': 5.00390022675737, u'visible': True, u'file_name': u'K0140.wav', u'date_modified': u'2018-06-01', u'file_hash': u'f0d8d3c8d199d3790b0e42d1e5df50a6801f928d10f533149ed0babe61b5d7b5', u'file_size_bytes': 441388, u'file_description': u'Acoustic piano playing middle C.', u'audio_attributes': u'piano, middle C', u'file_size': u'431.0KiB', u'date_added': u'2018-06-01' } See Also: * :func:`print_available_audio_files`, prints a list of the audio files to the console. * :func:`download_audio_file` to download an audio file from the EFZ. Returns: (list): A list of dicts containing metadata of the available audio files on the nussl External File Zoo (EFZ) server (http://nussl.ci.northwestern.edu/). """ # _download_all_metadata() will throw its own errors, so no try block around it return _download_all_metadata(constants.NUSSL_EFZ_AUDIO_METADATA_URL)
[docs]def get_available_trained_models(): """ Returns a list of dicts containing metadata of the available trained models on the nussl External File Zoo (EFZ) server (http://nussl.ci.northwestern.edu/). Each entry in the list is in the following format: .. code-block:: python { u'for_class': u'DeepClustering', u'visible': True, u'file_name': u'deep_clustering_vocals_44k_long.model', u'date_modified': u'2018-06-01', u'file_hash': u'e09034c2cb43a293ece0b121f113b8e4e1c5a247331c71f40cb9ca38227ccc2c', u'file_size_bytes': 94543355, u'file_description': u'Deep clustering for vocal separation trained on augmented DSD100.', u'file_size': u'90.2MiB', u'date_added': u'2018-06-01' } Notes: Most of the entries in the dictionary are self-explanatory, but note the ``for_class`` entry. The ``for_class`` entry specifies which `nussl` separation class the given model will work with. Usually, `nussl` separation classes that require a model will default so retrieving a model on the EFZ server (if not already found on the user's machine), but sometimes it is desirable to use a model other than the default one provided. In this case, the ``for_class`` entry lets the user know which class it is valid for use with. Additionally, trying to load a model into a class that it is not explicitly labeled for that class will raise an exception. Just don't do it, ok? See Also: * :func:`print_available_trained_models`, prints a list of the trained models to the console. * :func:`download_trained_model` to download a trained model from the EFZ. Returns: (list): A list of dicts containing metadata of the available trained models on the nussl External File Zoo (EFZ) server (http://nussl.ci.northwestern.edu/). """ return _download_all_metadata(constants.NUSSL_EFZ_MODEL_METADATA_URL)
[docs]def get_available_benchmark_files(): """ Returns a list of dicts containing metadata of the available benchmark files for tests on the nussl External File Zoo (EFZ) server (http://nussl.ci.northwestern.edu/). Each entry in the list is in the following format: .. code-block:: python { u'for_class': u'DuetUnitTests', u'visible': True, u'file_name': u'benchmark_atn_bins.npy', u'date_modified': u'2018-06-19', u'file_hash': u'cf7fef6f4ea9af3dbde8b9880602eeaf72507b6c78f04097c5e79d34404a8a1f', u'file_size_bytes': 488, u'file_description': u'Attenuation bins numpy array for DUET benchmark test.', u'file_size': u'488.0B', u'date_added': u'2018-06-19' } Notes: Most of the entries in the dictionary are self-explanatory, but note the `for_class` entry. The `for_class` entry specifies which `nussl` benchmark class will load the corresponding benchmark file. Make sure these match exactly when writing tests! See Also: * :func:`print_available_benchmark_files`, prints a list of the benchmark files to the console. * :func:`download_benchmark_file` to download an benchmark file from the EFZ. Returns: (list): A list of dicts containing metadata of the available audio files on the nussl External File Zoo (EFZ) server (http://nussl.ci.northwestern.edu/). """ return _download_all_metadata(constants.NUSSL_EFZ_BENCHMARK_METADATA_URL)
def _download_all_metadata(url): """ Downloads the json file that contains all of the metadata for a specific file type (read: audio files, benchmark files, or trained models) that is on the EFZ server. This is retrieved from one of following three URLs (which are stored in nussl.constants): NUSSL_EFZ_AUDIO_METADATA_URL, NUSSL_EFZ_BENCHMARK_METADATA_URL, or NUSSL_EFZ_MODEL_METADATA_URL. Args: url (str): URL for the EFZ server that has metadata. One of these three: NUSSL_EFZ_AUDIO_METADATA_URL, NUSSL_EFZ_BENCHMARK_METADATA_URL, or NUSSL_EFZ_MODEL_METADATA_URL. Returns: (list): List of dicts with metadata for the desired file type. """ request = Request(url) # Make sure to get the newest data request.add_header('Pragma', 'no-cache') request.add_header('Cache-Control', 'max-age=0') try: return json.loads(urlopen(request).read()) except: raise NoConnectivityError("Can't connect to internet") def _download_metadata_for_file(file_name, file_type): """ Downloads the metadata entry for a specific file (:param:`file_name`) on the EFZ server. Args: file_name (str): File name as specified on the EFZ server. file_type (str): 'Type' of file, either 'audio', 'model', or 'benchmark'. Returns: (dict) Metadata entry for the specified file, or ``None`` if it could not be located. """ metadata_urls = { 'audio': constants.NUSSL_EFZ_AUDIO_METADATA_URL, 'benchmark': constants.NUSSL_EFZ_BENCHMARK_METADATA_URL, 'model': constants.NUSSL_EFZ_MODEL_METADATA_URL, } if file_type in metadata_urls: metadata_url = metadata_urls[file_type] else: # wrong file type, return raise MetadataError(f'Cannot find metadata of type {file_type}.') metadata = _download_all_metadata(metadata_url) for file_metadata in metadata: if file_metadata['file_name'] == file_name: return file_metadata raise MetadataError( f'No matching metadata for file {file_name}' f' at url {constants.NUSSL_EFZ_AUDIO_METADATA_URL}!' )
[docs]def download_audio_file(audio_file_name, local_folder=None, verbose=True): """ Downloads the specified audio file from the `nussl` External File Zoo (EFZ) server. The downloaded file is stored in :param:`local_folder` if a folder is provided. If a folder is not provided, `nussl` attempts to save the downloaded file in `~/.nussl/` (expanded) or in `tmp/.nussl`. If the requested file is already in :param:`local_folder` (or one of the two aforementioned directories) and the calculated hash matches the precomputed hash from the EFZ server metadata, then the file will not be downloaded. Args: audio_file_name: (str) Name of the audio file to attempt to download. local_folder: (str) Path to local folder in which to download the file. If no folder is provided, `nussl` will store the file in `~/.nussl/` (expanded) or in `/tmp/.nussl`. verbose (bool): If ``True`` prints the status of the download to the console. Returns: (String) Full path to the requested file (whether downloaded or not). Example: >>> import nussl >>> piano_path = nussl.efz_utils.download_audio_file('K0140.wav') >>> piano_signal = nussl.AudioSignal(piano_path) """ file_metadata = _download_metadata_for_file(audio_file_name, 'audio') file_hash = file_metadata['file_hash'] file_url = urljoin(constants.NUSSL_EFZ_AUDIO_URL, audio_file_name) result = _download_file(audio_file_name, file_url, local_folder, 'audio', file_hash=file_hash, verbose=verbose) return result
[docs]def download_trained_model(model_name, local_folder=None, verbose=True): """ Downloads the specified trained model from the `nussl` External File Zoo (EFZ) server. The downloaded file is stored in :param:`local_folder` if a folder is provided. If a folder is not provided, `nussl` attempts to save the downloaded file in `~/.nussl/` (expanded) or in `tmp/.nussl`. If the requested file is already in :param:`local_folder` (or one of the two aforementioned directories) and the calculated hash matches the precomputed hash from the EFZ server metadata, then the file will not be downloaded. Args: model_name: (str) Name of the trained model to attempt to download. local_folder: (str) Path to local folder in which to download the file. If no folder is provided, `nussl` will store the file in `~/.nussl/` (expanded) or in `/tmp/.nussl`. verbose (bool): If ``True`` prints the status of the download to the console. Returns: (String) Full path to the requested file (whether downloaded or not). Example: >>> import nussl >>> model_path = nussl.efz_utils.download_trained_model('deep_clustering_model.h5') >>> signal = nussl.AudioSignal() >>> piano_signal = nussl.DeepClustering(signal, model_path=model_path) """ file_metadata = _download_metadata_for_file(model_name, 'model') file_hash = file_metadata['file_hash'] file_url = urljoin(constants.NUSSL_EFZ_MODELS_URL, model_name) result = _download_file(model_name, file_url, local_folder, 'models', file_hash=file_hash, verbose=verbose) return result
[docs]def download_benchmark_file(benchmark_name, local_folder=None, verbose=True): """ Downloads the specified benchmark file from the `nussl` External File Zoo (EFZ) server. The downloaded file is stored in :param:`local_folder` if a folder is provided. If a folder is not provided, `nussl` attempts to save the downloaded file in `~/.nussl/` (expanded) or in `/tmp/.nussl`. If the requested file is already in :param:`local_folder` (or one of the two aforementioned directories) and the calculated hash matches the precomputed hash from the EFZ server metadata, then the file will not be downloaded. Args: benchmark_name: (str) Name of the trained model to attempt to download. local_folder: (str) Path to local folder in which to download the file. If no folder is provided, `nussl` will store the file in `~/.nussl/` (expanded) or in `tmp/.nussl`. verbose (bool): If ``True`` prints the status of the download to the console. Returns: (String) Full path to the requested file (whether downloaded or not). Example: >>> import nussl >>> import numpy as np >>> stm_atn_path = nussl.efz_utils.download_benchmark_file('benchmark_sym_atn.npy') >>> sym_atm = np.load(stm_atn_path) """ file_metadata = _download_metadata_for_file(benchmark_name, 'benchmark') file_hash = file_metadata['file_hash'] file_url = urljoin(constants.NUSSL_EFZ_BENCHMARKS_URL, benchmark_name) result = _download_file(benchmark_name, file_url, local_folder, 'benchmarks', file_hash=file_hash, verbose=verbose) return result
def _download_file(file_name, url, local_folder, cache_subdir, file_hash=None, cache_dir=None, verbose=True): """ Downloads the specified file from the Heavily inspired by and lovingly adapted from keras' `get_file` function: https://github.com/fchollet/keras/blob/afbd5d34a3bdbb0916d558f96af197af1e92ce70/keras/utils/data_utils.py#L109 Args: file_name: (String) name of the file located on the server url: (String) url of the file local_folder: (String) alternate folder in which to download the file cache_subdir: (String) subdirectory of folder in which to download flie file_hash: (String) expected hash of downloaded file cache_dir: Returns: (String) local path to downloaded file """ if local_folder not in [None, '']: # local folder provided, let's create it if it doesn't exist and use it as datadir os.makedirs(os.path.expanduser(local_folder), exist_ok=True) datadir = os.path.expanduser(local_folder) else: if cache_dir is None: cache_dir = os.path.expanduser(os.path.join('~', '.nussl')) datadir_base = os.path.expanduser(cache_dir) datadir = os.path.join(datadir_base, cache_subdir) os.makedirs(datadir, exist_ok=True) file_path = os.path.join(datadir, file_name) download = False if os.path.exists(file_path): if file_hash is not None: # compare the provided hash with the hash of the file currently at file_path current_hash = _hash_file(file_path) # if the hashes are equal, we already have the file we need, so don't download if file_hash != current_hash: if verbose: warnings.warn( f'Hash for {file_path} does not match known hash. ' f' Downloading {file_name} from servers...' ) download = True elif verbose: print(f'Matching file found at {file_path}, skipping download.') else: download = True else: download = True if download: if verbose: print(f'Saving file at {file_path}\nDownloading {file_name} from {url}') def _dl_progress(count, block_size, total_size): percent = int(count * block_size * 100 / total_size) if percent <= 100: sys.stdout.write(f'\r{file_name}...{percent}%') sys.stdout.flush() try: try: reporthook = _dl_progress if verbose else None urlretrieve(url, file_path, reporthook) if verbose: print() # print a new line after the progress is done. except HTTPError as e: raise FailedDownloadError(f'URL fetch failure on {url}: {e.code} -- {e.msg}') except URLError as e: raise FailedDownloadError(f'URL fetch failure on {url}: {e.errno} -- {e.reason}') except (Exception, KeyboardInterrupt) as e: if os.path.exists(file_path): os.remove(file_path) raise e # check hash of received file to see if it matches the provided hash if file_hash is not None: download_hash = _hash_file(file_path) if file_hash != download_hash: # the downloaded file is not what it should be. Get rid of it. os.remove(file_path) raise MismatchedHashError( f'Deleted downloaded file ({file_path}) because of a hash mismatch.' ) return file_path else: return file_path def _hash_directory(directory, ext=None): """ Calculates the hash of every child file in the given directory using python's built-in SHA256 function (using `os.walk()`, which also searches subdirectories recursively). If :param:`ext` is specified, this will only look at files with extension provided. This function is used to verify the integrity of data sets for use with nussl. Pretty much just makes sure that when we loop through/look at a directory, we understand the structure because the organization of the data set directories for different data sets are all unique and thus need to be hard coded by each generator function (below). If we get a hash mismatch we can throw an error easily. Args: directory (str): Directory within which file hashes get calculated. Searches recursively. ext (str): If provided, this function will only calculate the hash on files with the given extension. Returns: (str): String containing only hexadecimal digits of the has of the contents of the given directory. """ hash_list = [] for path, sub_dirs, files in os.walk(directory): if ext is None: hash_list.extend([_hash_file(os.path.join(path, f)) for f in files if os.path.isfile(os.path.join(path, f))]) else: hash_list.extend([_hash_file(os.path.join(path, f)) for f in files if os.path.isfile(os.path.join(path, f)) if os.path.splitext(f)[1] == ext]) hasher = hashlib.sha256() for hash_val in sorted(hash_list): # Sort this list so we're platform agnostic hasher.update(hash_val.encode('utf-8')) return hasher.hexdigest() def _hash_file(file_path, chunk_size=65535): """ Args: file_path: System path to the file to be hashed chunk_size: size of chunks Returns: file_hash: the SHA256 hashed string in hex """ hasher = hashlib.sha256() with open(file_path, 'rb') as fpath_file: for chunk in iter(lambda: fpath_file.read(chunk_size), b''): hasher.update(chunk) return hasher.hexdigest() ######################################## # Error Classes ########################################
[docs]class NoConnectivityError(Exception): """ Exception class for lack of internet connection. """ pass
[docs]class FailedDownloadError(Exception): """ Exception class for failed file downloads. """ pass
[docs]class MismatchedHashError(Exception): """ Exception class for when a computed hash function does match a pre-computed hash. """ pass
[docs]class MetadataError(Exception): """ Exception class for errors with metadata. """ pass