"""
The :class:`BinaryMask` class is for creating a time-frequency mask with binary values. Like all
:class:`separation.masks.mask_base.MaskBase` objects, :class:`BinaryMask` is initialized with a 2D or 3D numpy array
containing the mask data. The data type (numpy.dtype) of the initial mask can be either bool, int, or float.
The mask is stored as a 3-dimensional boolean-valued numpy array.
The best case scenario for the input mask np array is when the data type is bool. If the data type of the input mask
upon init is int it is expected that all values are either 0 or 1. If the data type
of the mask is float, all values must be within 1e-2 of either 1 or 0. If the array is not set as one of these,
:class:`BinaryMask` will raise an exception.
:class:`BinaryMask` (like :class:`separation.masks.soft_mask.SoftMask`) is one of the return types for the :func:`run()`
methods of :class:`separation.mask_separation_base.MaskSeparationBase`-derived objects (this is most of the
separation methods in `nussl`.
See Also:
* :class:`separation.masks.mask_base.MaskBase`: The base class for BinaryMask and SoftMask
* :class:`separation.masks.soft_mask.SoftMask`: Similar to BinaryMask, but instead of taking boolean values,
takes floats between ``[0.0 and 1.0]``.
* :class:`separation.mask_separation_base.MaskSeparationBase`: Base class for all mask-based separation methods
in `nussl`.
Examples:
Initializing a mask from a numpy array...
.. code-block:: python
:linenos:
import nussl
import numpy as np
# load a file
signal = nussl.AudioSignal('path/to/file.wav')
stft = signal.stft()
# Make a random binary mask with the same shape as the stft with dtype == bool
rand_bool_mask = np.random.randint(2, size=stft.shape).astype('bool')
bin_mask_bool = nussl.BinaryMask(rand_bool_mask)
# Make a random binary mask with the same shape as the stft with dtype == int
rand_int_mask = np.random.randint(2, size=stft.shape)
bin_mask_int = nussl.BinaryMask(rand_int_mask)
# Make a random binary mask with the same shape as the stft with dtype == float
rand_float_mask = np.random.randint(2, size=stft.shape).astype('float')
bin_mask_int = nussl.BinaryMask(rand_float_mask)
:class:`separation.mask_separation_base.MaskSeparationBase`-derived methods return
:class:`separation.masks.mask_base.MaskBase` masks, like so...
.. code-block:: python
:linenos:
import nussl
# load a file
signal = nussl.AudioSignal('path/to/file.wav')
repet = nussl.Repet(signal, mask_type=nussl.BinaryMask) # You have to specify that you want Binary Masks back
assert isinstance(repet, nussl.MaskSeparationBase) # Repet is a MaskSeparationBase-derived class
[background_mask, foreground_mask] = repet.run() # MaskSeparationBase-derived classes return MaskBase objects
assert isinstance(foreground_mask, nussl.BinaryMask) # this is True
assert isinstance(background_mask, nussl.BinaryMask) # this is True
"""
import numpy as np
from . import MaskBase
[docs]class BinaryMask(MaskBase):
"""
Class for creating a Binary Mask to apply to a time-frequency representation of
the audio.
Args:
input_mask (:obj:`np.ndarray`): 2- or 3-D :obj:`np.array` that represents the mask.
"""
def __init__(self, input_mask=None, mask_shape=None):
super(BinaryMask, self).__init__(input_mask, mask_shape)
@staticmethod
def _validate_mask(mask_):
assert isinstance(mask_, np.ndarray), 'Mask must be a numpy array!'
if mask_.dtype == np.bool:
# This is perfect, do nothing here
return mask_
elif mask_.dtype.kind in np.typecodes['AllInteger']:
if np.max(mask_) > 1 or np.min(mask_) < 0:
raise ValueError('Found values in mask that are not 0 or 1. Mask must be binary!')
elif mask_.dtype.kind in np.typecodes['AllFloat']:
tol = 1e-2
# If we have a float array, ensure that all values are close to 1 or 0
if not np.all(np.logical_or(np.isclose(mask_, [0], atol=tol), np.isclose(mask_, [1], atol=tol))):
raise ValueError('All mask values must be close to 0 or 1!')
return mask_.astype('bool')
[docs] def mask_as_ints(self, channel=None):
"""
Returns this :class:`BinaryMask` as a numpy array of ints of 0's and 1's.
Returns:
numpy :obj:`ndarray` of this :obj:`BinaryMask` represented as ints instead of bools.
"""
if channel is None:
return self.mask.astype(int)
else:
return self.get_channel(channel).astype(int)
[docs] def invert_mask(self):
"""
Makes a new :class:`BinaryMask` object with a logical not applied to flip the values in this :class:`BinaryMask`
object.
Returns:
A new :class:`BinaryMask` object that has all of the boolean values flipped.
"""
return BinaryMask(np.logical_not(self.mask))