Source code for muda.base
#!/usr/bin/env python
'''Base module components.'''
import numpy as np
import copy
from collections import OrderedDict
import itertools
import six
import inspect
__all__ = ['BaseTransformer', 'Pipeline', 'Union']
[docs]class BaseTransformer(object):
'''The base class for all transformation objects.
This class implements a single transformation (history)
and some various niceties.'''
# This bit gleefully stolen from sklearn.base
@classmethod
def _get_param_names(cls):
'''Get the list of parameter names for the object'''
init = cls.__init__
args, varargs = inspect.getargspec(init)[:2]
if varargs is not None:
raise RuntimeError('BaseTransformer objects cannot have varargs')
args.pop(0)
args.sort()
return args
[docs] def get_params(self, deep=True):
'''Get the parameters for this object. Returns as a dict.
Parameters
----------
deep : bool
Recurse on nested objects
Returns
-------
params : dict
A dictionary containing all parameters for this object
'''
out = dict(__class__=self.__class__,
params=dict())
for key in self._get_param_names():
value = getattr(self, key, None)
if deep and hasattr(value, 'get_params'):
deep_items = value.get_params().items()
out['params'][key] = dict(__class__=value.__class__)
out['params'][key].update((k, val) for k, val in deep_items)
else:
out['params'][key] = value
return out
def __repr__(self):
'''Pretty-print this object'''
class_name = self.__class__.__name__
return '{:s}({:s})'.format(class_name,
_pprint(self.get_params(deep=False)['params'],
offset=len(class_name),),)
def __init__(self):
self.dispatch = OrderedDict()
def states(self, jam):
raise NotImplementedError
def _register(self, pattern, function):
self.dispatch[pattern] = function.__name__
def _transform(self, jam, state):
'''Apply the transformation to audio and annotations.
The input jam is copied and modified, and returned
contained in a list.
Parameters
----------
jam : jams.JAMS
A single jam object to modify
Returns
-------
jam_list : list
A length-1 list containing `jam` after transformation
See also
--------
core.load_jam_audio
'''
if not hasattr(jam.sandbox, 'muda'):
raise RuntimeError('No muda state found in jams sandbox.')
# We'll need a working copy of this object for modification purposes
jam_w = copy.deepcopy(jam)
# Push our reconstructor onto the history stack
jam_w.sandbox.muda['history'].append({'transformer': self.__serialize__,
'state': state})
if hasattr(self, 'audio'):
self.audio(jam_w.sandbox.muda, state)
if hasattr(self, 'metadata'):
self.metadata(jam_w.file_metadata, state)
# Walk over the list of deformers
for query, function_name in six.iteritems(self.dispatch):
function = getattr(self, function_name)
for matched_annotation in jam_w.search(namespace=query):
function(matched_annotation, state)
return jam_w
[docs] def transform(self, jam):
'''Iterative transformation generator
Applies the deformation to an input jams object.
This generates a sequence of deformed output JAMS.
Parameters
----------
jam : jams.JAMS
The jam to transform
Examples
--------
>>> for jam_out in deformer.transform(jam_in):
... process(jam_out)
'''
for state in self.states(jam):
yield self._transform(jam, state)
@property
def __serialize__(self):
'''Serializer'''
data = self.get_params()
data['__class__'] = data['__class__'].__name__
return data
[docs]class Pipeline(object):
'''Wrapper which allows multiple BaseDeformer objects to be chained together
A given JAMS object will be transformed sequentially by
each stage of the pipeline.
The pipeline induces a graph over transformers
Attributes
----------
steps : argument array
steps[i] is a tuple of `(name, Transformer)`
Examples
--------
>>> P = muda.deformers.PitchShift(semitones=5)
>>> T = muda.deformers.TimeStretch(speed=1.25)
>>> Pipe = muda.Pipeline(steps=[('Pitch:maj3', P), ('Speed:1.25x', T)])
>>> output_jams = list(Pipe.transform(jam_in))
See Also
--------
Union
'''
def __init__(self, steps=None):
names, transformers = zip(*steps)
if len(set(names)) != len(steps):
raise ValueError("Names provided are not unique: "
" {}".format(names,))
# shallow copy of steps
self.steps = list(zip(names, transformers))
for t in transformers:
if not isinstance(t, BaseTransformer):
raise TypeError('{:s} is not a BaseTransformer'.format(t))
[docs] def get_params(self):
'''Get the parameters for this object. Returns as a dict.'''
out = {}
out['__class__'] = self.__class__
out['params'] = dict(steps=[])
for name, step in self.steps:
out['params']['steps'].append([name, step.get_params(deep=True)])
return out
def __repr__(self):
'''Pretty-print the object'''
class_name = self.__class__.__name__
return '{:s}({:s})'.format(class_name,
_pprint(self.get_params(),
offset=len(class_name),),)
def __recursive_transform(self, jam, steps):
'''A recursive transformation pipeline'''
if len(steps) > 0:
head_transformer = steps[0][1]
for t_jam in head_transformer.transform(jam):
for q in self.__recursive_transform(t_jam, steps[1:]):
yield q
else:
yield jam
[docs] def transform(self, jam):
'''Apply the sequence of transformations to a single jam object.
Parameters
----------
jam : jams.JAMS
The jam object to transform
Yields
------
jam_out : jams.JAMS
The jam objects produced by the transformation sequence
'''
for output in self.__recursive_transform(jam, self.steps):
yield output
[docs]class Union(object):
'''Wrapper which allows multiple BaseDeformer objects to be combined
for round-robin sampling.
A given JAMS object will be transformed sequentially by
each element of the union, in round-robin fashion.
This is similar to `Pipeline`, except the deformers are independent
of one another in a Union, rather than applied sequentially.
Attributes
----------
steps : argument array
steps[i] is a tuple of `(name, Transformer)`
Examples
--------
>>> P = muda.deformers.PitchShift(semitones=5)
>>> T = muda.deformers.TimeStretch(speed=1.25)
>>> union = muda.Union(steps=[('Pitch:maj3', P), ('Speed:1.25x', T)])
>>> output_jams = list(union.transform(jam_in))
See Also
--------
Pipeline
'''
def __init__(self, steps=None):
names, transformers = zip(*steps)
if len(set(names)) != len(steps):
raise ValueError("Names provided are not unique: "
" {}".format(names,))
# shallow copy of steps
self.steps = list(zip(names, transformers))
for t in transformers:
if not isinstance(t, BaseTransformer):
raise TypeError('{:s} is not a BaseTransformer'.format(t))
[docs] def get_params(self):
'''Get the parameters for this object. Returns as a dict.'''
out = {}
out['__class__'] = self.__class__
out['params'] = dict(steps=[])
for name, step in self.steps:
out['params']['steps'].append([name, step.get_params(deep=True)])
return out
def __repr__(self):
'''Pretty-print the object'''
class_name = self.__class__.__name__
return '{:s}({:s})'.format(class_name,
_pprint(self.get_params(),
offset=len(class_name),),)
def __serial_transform(self, jam, steps):
'''A serial transformation union'''
# This uses the round-robin itertools recipe
if six.PY2:
attr = 'next'
else:
attr = '__next__'
pending = len(steps)
nexts = itertools.cycle(getattr(iter(D.transform(jam)), attr)
for (name, D) in steps)
while pending:
try:
for next_jam in nexts:
yield next_jam()
except StopIteration:
pending -= 1
nexts = itertools.cycle(itertools.islice(nexts, pending))
[docs] def transform(self, jam):
'''Apply the sequence of transformations to a single jam object.
Parameters
----------
jam : jams.JAMS
The jam object to transform
Yields
------
jam_out : jams.JAMS
The jam objects produced by each member of the union
'''
for output in self.__serial_transform(jam, self.steps):
yield output
###
# Borrowed from scikit-learn 0.18
def _pprint(params, offset=0, printer=repr):
"""Pretty print the dictionary 'params'
Parameters
----------
params: dict
The dictionary to pretty print
offset: int
The offset in characters to add at the begin of each line.
printer:
The function to convert entries to strings, typically
the builtin str or repr
"""
# Do a multi-line justified repr:
options = np.get_printoptions()
np.set_printoptions(precision=5, threshold=64, edgeitems=2)
params_list = list()
this_line_length = offset
line_sep = ',\n' + (1 + offset // 2) * ' '
for i, (k, v) in enumerate(sorted(six.iteritems(params))):
if type(v) is float:
# use str for representing floating point numbers
# this way we get consistent representation across
# architectures and versions.
this_repr = '%s=%s' % (k, str(v))
else:
# use repr of the rest
this_repr = '%s=%s' % (k, printer(v))
if len(this_repr) > 500:
this_repr = this_repr[:300] + '...' + this_repr[-100:]
if i > 0:
if (this_line_length + len(this_repr) >= 75 or '\n' in this_repr):
params_list.append(line_sep)
this_line_length = len(line_sep)
else:
params_list.append(', ')
this_line_length += 2
params_list.append(this_repr)
this_line_length += len(this_repr)
np.set_printoptions(**options)
lines = ''.join(params_list)
# Strip trailing space to avoid nightmare in doctests
lines = '\n'.join(l.rstrip(' ') for l in lines.split('\n'))
return lines