[docs]class SeparationModel(nn.Module):
"""
SeparationModel takes a configuration file or dictionary that describes the model
structure, which is some combination of MelProjection, Embedding, RecurrentStack,
ConvolutionalStack, and other modules found in ``nussl.ml.networks.modules``.
References:
Hershey, J. R., Chen, Z., Le Roux, J., & Watanabe, S. (2016, March).
Deep clustering: Discriminative embeddings for segmentation and separation.
In Acoustics, Speech and Signal Processing (ICASSP),
2016 IEEE International Conference on (pp. 31-35). IEEE.
Luo, Y., Chen, Z., Hershey, J. R., Le Roux, J., & Mesgarani, N. (2017, March).
Deep clustering and conventional networks for music separation: Stronger together.
In Acoustics, Speech and Signal Processing (ICASSP),
2017 IEEE International Conference on (pp. 61-65). IEEE.
Args:
config: (str, dict) Either a config dictionary that defines the model and its
connections, or the path to a json file containing the dictionary. If the
latter, the path will be loaded and used.
See also:
ml.register_module to register your custom modules with SeparationModel.
Examples:
>>> config = nussl.ml.networks.builders.build_recurrent_dpcl(
>>> num_features=512, hidden_size=300, num_layers=3, bidirectional=True,
>>> dropout=0.3, embedding_size=20,
>>> embedding_activation=['sigmoid', 'unit_norm'])
>>>
>>> model = SeparationModel(config)
"""
def __init__(self, config, verbose=False):
super(SeparationModel, self).__init__()
if type(config) is str:
if os.path.exists(config):
with open(config, 'r') as f:
config = json.load(f)
else:
config = json.loads(config)
self._validate_config(config)
module_dict = {}
self.input = {}
for module_key in config['modules']:
module = config['modules'][module_key]
if 'class' in module:
if module['class'] in dir(modules):
class_func = getattr(modules, module['class'])
else:
class_func = getattr(nn, module['class'])
if 'args' not in module:
module['args'] = {}
module_dict[module_key] = class_func(**module['args'])
else:
self.input[module_key] = module_key
self.layers = nn.ModuleDict(module_dict)
self.connections = config['connections']
self.output_keys = config['output']
self.config = config
self.verbose = verbose
@staticmethod
def _validate_config(config):
expected_keys = ['connections', 'modules', 'output']
got_keys = sorted(list(config.keys()))
if got_keys != expected_keys:
raise ValueError(
f"Expected keys {expected_keys}, got {got_keys}")
if not isinstance(config['modules'], dict):
raise ValueError("config['modules'] must be a dict!")
if not isinstance(config['connections'], list):
raise ValueError("config['connections'] must be a list!")
if not isinstance(config['output'], list):
raise ValueError("config['output'] must be a list!")
[docs] def forward(self, data):
"""
Args:
data: (dict) a dictionary containing the input data for the model.
Should match the input_keys in self.input.
Returns:
"""
if not all(name in list(data) for name in list(self.input)):
raise ValueError(
f'Not all keys present in data! Needs {", ".join(self.input)}')
output = {}
for connection in self.connections:
layer = self.layers[connection[0]]
input_data = []
kwargs = {}
if len(connection) == 2:
for c in connection[1]:
if isinstance(c, dict):
for key, val in c.items():
if val in output:
kwargs[key] = output[val]
elif val in data:
kwargs[key] = data[val]
else:
kwargs[key] = val
else:
input_data.append(output[c] if c in output else data[c])
_output = layer(*input_data, **kwargs)
added_keys = []
if isinstance(_output, dict):
for k in _output:
_key = f'{connection[0]}:{k}'
output[_key] = _output[k]
added_keys.append(_key)
elif isinstance(_output, tuple):
for i, val in enumerate(_output):
_key = f'{connection[0]}:{i}'
output[_key] = val
added_keys.append(_key)
else:
_key = connection[0]
output[_key] = _output
added_keys.append(_key)
if self.verbose:
input_shapes = []
for d in input_data:
if torch.is_tensor(d):
input_shapes.append(d.shape)
input_desc = ", ".join(map(str, input_shapes))
output_desc = ", ".join(
[f'{k}: {output[k].shape}' for k in added_keys])
print(
f"{connection[1]} -> {connection[0]} \n"
f"\tTook inputs: {input_desc} \n"
f"\tProduced {output_desc} "
)
return {o: output[o] for o in self.output_keys}
[docs] def save(self, location, metadata=None):
"""
Saves a SeparationModel into a location into a dictionary with the
weights and model configuration.
Args:
location: (str) Where you want the model saved, as a path.
Returns:
(str): where the model was saved.
"""
save_dict = {
'state_dict': self.state_dict(),
'config': json.dumps(self.config)
}
metadata = metadata if metadata else {}
metadata['nussl_version'] = __version__
save_dict = {**save_dict, **metadata}
torch.save(save_dict, location)
return location
def __repr__(self):
output = super().__repr__()
num_parameters = 0
for p in self.parameters():
if p.requires_grad:
num_parameters += np.cumprod(p.size())[-1]
output += '\nNumber of parameters: %d' % num_parameters
return output