Writing a new Distribution

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================

To write and and use a new Distribution class in MXFusion, fill out the Distribution interface and either the Univariate or Multivariate interface, depending on the type of distribution you are creating.

There are 4 primary methods to fill out for a Distribution in MXFusion: * __init__ - This is the constructor for the Distribution. It takes in any parameters the distribution needs. It also defines names for the input variable[s] that the distribution takes and the output variable[s] it produces. * log_pdf - This method returns the logarithm of probabilistic density function for the distribution. This is called during Inference time as necessary to perform the Inference algorithm. * draw_samples - This method returns drawn samples from the distribution. This is called during Inference time as necessary to perform the Inference algorithm. * define_variable - This is used to generate random variables drawn from the Distribution used during model definition.

log_pdf and draw_samples are implemented using MXNet functions to compute on the input variables, which at Inference time are MXNet arrays or MXNet symbolic variables.

This notebook will take the Normal distribution as a reference.

File Structure

Code for distributions lives in the mxfusion/components/distributions directory.

If you’re implementing the FancyNew distribution then you should create a file called mxfusion/components/distributions/fancy_new.py for the class to live in.

Interface Implementation

Since this example is for a Univariate Normal distribution, our class extends the UnivatiateDistribution class.

The Normal distribution’s constructor takes in objects for its mean and variance, specifications for data type and context, and a random number generator if not the default.

In addition, a distribution can take in additional parameters used for calculating that aren’t inputs. We refer to these additional parameters as the Distribution’s attributes. The difference between an input and an attribute is primarily that inputs are dynamic at inference time, while attributes are static throughout a given inference run.

In this case, minibatch_ratio is a static attribute, as it doesn’t change for a given minibatch size during inference.

The mean and variance can be either Variables or MXNet arrays if they are constants.

As mentioned above, you define names for the input and output variable[s] for the distribution here. These names are used when printing and generally inspecting the model, so give meaningful names. We prefer names like mean and variance to ones like location and scale or greek letters like mew and sigma.

[ ]:
class Normal(UnivariateDistribution):
    """
    The one-dimensional normal distribution. The normal distribution can be defined over a scalar random variable or an array of random variables. In case of an array of random variables, the mean and variance are broadcasted to the shape of the output random variable (array).

    :param mean: Mean of the normal distribution.
    :type mean: Variable
    :param variance: Variance of the normal distribution.
    :type variance: Variable
    :param rand_gen: the random generator (default: MXNetRandomGenerator)
    :type rand_gen: RandomGenerator
    :param dtype: the data type for float point numbers
    :type dtype: numpy.float32 or numpy.float64
    :param ctx: the mxnet context (default: None/current context)
    :type ctx: None or mxnet.cpu or mxnet.gpu
    """
    def __init__(self, mean, variance, rand_gen=None, minibatch_ratio=1.,
                 dtype=None, ctx=None):
        self.minibatch_ratio = minibatch_ratio
        if not isinstance(mean, Variable):
            mean = Variable(value=mean)
        if not isinstance(variance, Variable):
            variance = Variable(value=variance)

        inputs = [('mean', mean), ('variance', variance)]
        input_names = ['mean', 'variance']
        output_names = ['random_variable']
        super(Normal, self).__init__(inputs=inputs, outputs=None,
                                     input_names=input_names,
                                     output_names=output_names,
                                     rand_gen=rand_gen, dtype=dtype, ctx=ctx)

If your distribution’s __init__ function only takes in parameters that get passed onto its super constructor, you don’t need to implement replicate_self. If it does take additional parameters (as the Normal distribution does for minibatch_ratio), those parameters need to be copied over to the replicant Distribution before returning, as below.

[ ]:
    def replicate_self(self, attribute_map=None):
        """
        Replicates this Factor, using new inputs, outputs, and a new uuid.
        Used during model replication to functionally replicate a factor into a new graph.
        :param inputs: new input variables of the factor
        :type inputs: a dict of {'name' : Variable} or None
        :param outputs: new output variables of the factor.
        :type outputs: a dict of {'name' : Variable} or None
        """
        replicant = super(Normal, self).replicate_self(attribute_map)
        replicant.minibatch_ratio = self.minibatch_ratio
        return replicant

log_pdf and draw_samples are relatively straightforward for implementation. These are the meaningful parts of the Distribution that you’re implementing, putting the math into code using MXNet operators for the compute.

If it’s a distribution that isn’t well documented on Wikipedia, please add a link to a paper or other resource that explains what it’s doing and why.

[ ]:
    def log_pdf(self, mean, variance, random_variable, F=None):
        """
        Computes the logarithm of the probability density function (PDF) of the normal distribution.

        :param mean: the mean of the normal distribution
        :type mean: MXNet NDArray or MXNet Symbol
        :param variance: the variance of the normal distributions
        :type variance: MXNet NDArray or MXNet Symbol
        :param random_variable: the random variable of the normal distribution
        :type random_variable: MXNet NDArray or MXNet Symbol
        :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray)
        :returns: log pdf of the distribution
        :rtypes: MXNet NDArray or MXNet Symbol
        """
        F = get_default_MXNet_mode() if F is None else F
        logvar = np.log(2 * np.pi) / -2 + F.log(variance) / -2
        logL = F.broadcast_add(logvar, F.broadcast_div(F.square(
            F.broadcast_minus(random_variable, mean)), -2 * variance))
        return logL
[ ]:
    def draw_samples(self, mean, variance, rv_shape, num_samples=1, F=None):
        """
        Draw samples from the normal distribution.

        :param mean: the mean of the normal distribution
        :type mean: MXNet NDArray or MXNet Symbol
        :param variance: the variance of the normal distributions
        :type variance: MXNet NDArray or MXNet Symbol
        :param rv_shape: the shape of each sample
        :type rv_shape: tuple
        :param num_samples: the number of drawn samples (default: one)
        :int num_samples: int
        :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray)
        :returns: a set samples of the normal distribution
        :rtypes: MXNet NDArray or MXNet Symbol
        """
        F = get_default_MXNet_mode() if F is None else F
        out_shape = (num_samples,) + rv_shape
        return F.broadcast_add(F.broadcast_mul(self._rand_gen.sample_normal(
            shape=out_shape, dtype=self.dtype, ctx=self.ctx),
            F.sqrt(variance)), mean)

define_variable is just a helper function for end users. All it does is take in parameters for the distribution, create a distribution based on those parameters, then return the output variables of that distribution.

[ ]:
    @staticmethod
    def define_variable(mean=0., variance=1., shape=None, rand_gen=None,
                        minibatch_ratio=1., dtype=None, ctx=None):
        """
        Creates and returns a random variable drawn from a normal distribution.

        :param mean: Mean of the distribution.
        :param variance: Variance of the distribution.
        :param shape: the shape of the random variable(s)
        :type shape: tuple or [tuple]
        :param rand_gen: the random generator (default: MXNetRandomGenerator)
        :type rand_gen: RandomGenerator
        :param dtype: the data type for float point numbers
        :type dtype: numpy.float32 or numpy.float64
        :param ctx: the mxnet context (default: None/current context)
        :type ctx: None or mxnet.cpu or mxnet.gpu
        :returns: the random variables drawn from the normal distribution.
        :rtypes: Variable
        """
        normal = Normal(mean=mean, variance=variance, rand_gen=rand_gen,
                        dtype=dtype, ctx=ctx)
        normal._generate_outputs(shape=shape)
        return normal.random_variable

Using your new distribution

At this point, you should be ready to start testing your new Distribution’s functionality by importing it like any other MXFusion component.

Testing

Before submitting your new code as a pull request, please write unit tests that verify it works as expected. This should include numerical checks against edge cases. See the existing test cases for the Normal or Categorical distributions for example tests.

[ ]: