Welcome to MXFusion’s documentation!

MXFusion is a library for integrating probabilistic modelling with deep learning.

MXFusion helps you rapidly build and test new methods at scale, by focusing on the modularity of probabilistic models and their integration with modern deep learning techniques.

Installation

Dependencies / Prerequisites

MXFusion’s primary dependencies are MXNet >= 1.2 and Networkx >= 2.1. See requirements.

Supported Architectures / Versions

MXFusion is tested on Python 3.5+ on MacOS and Amazon Linux.

pip

If you just want to use MXFusion and not modify the source, you can install through pip:

pip install mxfusion

From source

To install MXFusion from source, after cloning the repository run the following from the top-level directory:

pip install .

Distributed Training

To allow distributed training of MXFusion using Horovod, install through pip (Note that MXFusion only support Horovod version below 0.18):

pip install horovod==0.16.4

More information about Horovod and detailed instructions about its installation can be found on Horovod Site.

API Reference

mxfusion The main module for MXFusion.

Tutorials!

Below is a list of tutorial / example notebooks demonstrating MXFusion’s functionality.

Getting Started

Zhenwen Dai (2018.10.22)

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================

Introduction

MXFusion is a probabilistic programming language. It provides a convenient interface for designing probabilistic models and applying them to real world problems.

Probabilistic models describe the relationships in data through probabilistic distributions of random variables. Probabilistic modeling is typically done by stating your prior belief about the data in terms of a probabilistic model and performing inference with the observations of some of the random variables.

[1]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'

A Simple Example

Let’s start with a toy example about estimating the mean and variance of a set of data. For simplicity, we generate 100 data points with a given mean and variance following a normal distribution.

[2]:
import numpy as np
np.random.seed(0)
mean_groundtruth = 3.
variance_groundtruth = 5.
N = 100
data = np.random.randn(N)*np.sqrt(variance_groundtruth) + mean_groundtruth

Let’s visualize our data by building a histogram.

[3]:
%matplotlib inline
from pylab import *
_=hist(data, 10)
_images/examples_notebooks_getting_started_7_0.png

Now, let’s pretend that we do not know the mean and variance that are used to generate the above data.

We still believe that the data come from a normal distribution, which is our model. It is formulated as

\[ \begin{align}\begin{aligned}y_n \sim \mathcal{N}(\mu, s), \quad Y=(y_1, \ldots, y_{100})\\where :math:`\mu` is the mean, :math:`s` is the variance and :math:`Y` is the vector representing the data.\end{aligned}\end{align} \]

In MXFusion, the above model can be defined as follows:

[4]:
from mxfusion import Variable, Model
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions import Normal
from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'

m = Model()
m.mu = Variable()
m.s = Variable(transformation=PositiveTransformation())
m.Y = Normal.define_variable(mean=m.mu, variance=m.s, shape=(N,))

In the above definition, we start with defining a model by instantiated from the class Model. The variable \(\mu\) and \(s\) are created from the class Variable. Both of them are assigned as members of the model instance m. This is how variables are organized in MXFusion. The variable s is created by passing a PositiveTransformation instance to the transforamtion argument. This constrains the value of the variable s to be positive through a “soft-plus” transformation. The variable Y is created from a normal distribution by specifying the mean and variance and its shape.

Note that, in this example, the mean and variance variable are both scalar, with the shape (1,), while the random variable Y has the shape (100,). This indicates the mean and variance variable are broadcasted into the shape of the random variable, just like the broadcasting rule in numpy array operation. In this case, this means the individual entries of the random variable Y follows a scalar normal distribution with the same mean and variance.

To list the content that is defined in the model instance, just print the model instance as follows:

[5]:
print(m)
Model (637b9)
Variable Y (e2721) ~ Normal(mean=Variable mu (3f2b3), variance=Variable s (cddb2))

After defining the probabilistic model, we want to estimate the mean and variance of the normal distribution in our model conditioned on the data that we generated. In MXFusion, this is done by creating an inference algorithm and passing it into the creation of an Inference instance. An inference algorithm represents a specific algorithm for a probabilistic inference. In this example, we performs a maximum likelihood estimate by using the MAP class. The Inference class takes care of the initialization of parameters and the execution of inference.

In the following code, we created a MAP inference algorithm by specifying the model and the set of observed variable. Then, we created a GradBasedInference instance from the instantiated MAP infernece algorithm.

The execution of inference is done by calling the call function. The call function takes all observed data (specified when creating the inference algorithm) as the keyword arguments, where the keys are the names of the member variables of the model and the values are the corresponding MXNet NDArrays. In this example, we only observed the variable Y, then, we pass “Y” as the key and the generated data as the value. We also specify the configuration parameters for the gradient optimizer such as the learning rate, the maximum number of iterations and whether to print the optimization progress. The default optimizer is adam.

[6]:
from mxfusion.inference import GradBasedInference, MAP
import mxnet as mx

infr = GradBasedInference(inference_algorithm=MAP(model=m, observed=[m.Y]))
infr.run(Y=mx.nd.array(data, dtype='float64'), learning_rate=0.1, max_iter=2000, verbose=True)
Iteration 200 loss: 226.030
Iteration 400 loss: 223.629
Iteration 600 loss: 223.232
Iteration 800 loss: 223.163
Iteration 1000 loss: 223.152
Iteration 1200 loss: 223.151
Iteration 1400 loss: 223.151
Iteration 1600 loss: 223.151
Iteration 1800 loss: 223.151
Iteration 2000 loss: 223.151

After optimization, the estimated parameters are stored in an instance of the class InferenceParameters, which can be access from an Inference instance by infr.params.

We collect the estimated mean and variance and compared with the generating parameters.

[7]:
mean_estimated = infr.params[m.mu].asnumpy()
variance_estimated = infr.params[m.s].asnumpy()

print('The estimated mean and variance: %f, %f.' % (mean_estimated, variance_estimated))
print('The true mean and variance: %f, %f.' % (mean_groundtruth, variance_groundtruth))
The estimated mean and variance: 3.133735, 5.079126.
The true mean and variance: 3.000000, 5.000000.

The estimated parameters are close to the generating parameters, but still off by a small amount. This difference is due to the small size of dataset we used, a problem known as over-fitting.

A Bayesian model

From the above example, we have done a maximum likelihood estimate from the observed data. Due to the limited number of data, the estimated parameters are not the same as the true parameters. An interesting question here is that whether we can have an estimate about how big the difference is. One approach to provide such an estimate is via Bayesian inference.

Following the above example, we need to assume prior distributions for the mean and variance of the normal distribution. We assume the mean to be a normal distribution with a relative big variance, indicating that we do not have much knowledge about the parameter.

[8]:
m = Model()
m.mu = Normal.define_variable(mean=mx.nd.array([0], dtype='float64'),
                              variance=mx.nd.array([100], dtype='float64'), shape=(1,))

Then, we need to specify a prior distribution for the variance. This is a bit more complicated as the variance needs to be positive. In principle, one can use a distribution of positive values such as the Gamma distribution. To enable inference with the reparameterization trick, we, instead, assume a random variable \(\hat{s}\) with a normal distribution and the variance \(s\) is a function of \(\hat{s}\),

\[ \begin{align}\begin{aligned} \hat{s} \sim \mathcal{N}(5, 100), \quad s = \log(1+e^{\hat{s}}).\\The above function is often referred to as the "soft-plus" function, which transforms a real number to a positive number. By applying the transformation, we indirectly specifies the prior distribution for the variance.\end{aligned}\end{align} \]

To implement the above prior in MXFusion, we first create the variable s_hat with a normal distribution. Then, we defines a function in the MXNet Gluon syntax, which is also called a Gluon block, for the “soft-plus” transformation. The MXNet function is brought into the MXFusion environment by applying a wrapper called MXFusionGluonFunction, in which we specify the number of outputs. We pass the variable s_hat as the input to the function and get the variable s as the return value.

[9]:
from mxfusion.components.functions import MXFusionGluonFunction

m.s_hat = Normal.define_variable(mean=mx.nd.array([5], dtype='float64'),
                                 variance=mx.nd.array([100], dtype='float64'),
                                 shape=(1,), dtype='float64')
trans_mxnet = mx.gluon.nn.HybridLambda(lambda F, x: F.Activation(x, act_type='softrelu'))
m.trans = MXFusionGluonFunction(trans_mxnet, num_outputs=1, broadcastable=True)
m.s = m.trans(m.s_hat)

We define the variable Y following a normal distribution with the mean mu and the variance s.

[10]:
m.Y = Normal.define_variable(mean=m.mu, variance=m.s, shape=(N,), dtype='float64')
print(m)
Model (e2f16)
Variable s_hat (459a5) ~ Normal(mean=Variable (f8b0f), variance=Variable (343b3))
Variable s (b43d3) = GluonFunctionEvaluation(hybridlambda0_input_0=Variable s_hat (459a5))
Variable mu (62f33) ~ Normal(mean=Variable (9057c), variance=Variable (03688))
Variable Y (9d22c) ~ Normal(mean=Variable mu (62f33), variance=Variable s (b43d3))

Inference for the above model is more complex, as the exact inference is intractable. We use variational inference with a Gaussian mean field posterior.

We construct the variational posterior by calling the function create_Gaussian_meanfield, which defines a Gaussian distribution for both the mean and the variance as the variational posterior. The content in the generated posterior can be listed by printing the posterior.

[11]:
from mxfusion.inference import create_Gaussian_meanfield

q = create_Gaussian_meanfield(model=m, observed=[m.Y])
print(q)
Posterior (42f50)
Variable s_hat (459a5) ~ Normal(mean=Variable (a44d5), variance=Variable (a74ce))
Variable mu (62f33) ~ Normal(mean=Variable (286b2), variance=Variable (57f6f))

Then, we created an instance of StochasticVariationalInference with both the model and the variational posterior. We also need to specify the number of samples used in inference, as it uses the Monte Carlo method for approximating the integral in the variational lower bound. The execution of inference follows the same interface.

[12]:
from mxfusion.inference import StochasticVariationalInference

infr = GradBasedInference(inference_algorithm=StochasticVariationalInference(
    model=m, posterior=q, num_samples=10, observed=[m.Y]))
infr.run(Y=mx.nd.array(data, dtype='float64'), learning_rate=0.1, verbose=True)
Iteration 200 loss: 235.140
Iteration 400 loss: 231.011
Iteration 600 loss: 229.947
Iteration 800 loss: 229.648
Iteration 1000 loss: 229.859
Iteration 1200 loss: 229.752
Iteration 1400 loss: 229.651
Iteration 1600 loss: 229.705
Iteration 1800 loss: 229.634
Iteration 2000 loss: 229.641

Let’s check the resulting posterior distribution.

[13]:
mu_mean = infr.params[q.mu.factor.mean].asscalar()
mu_std = np.sqrt(infr.params[q.mu.factor.variance].asscalar())
s_hat_mean = infr.params[q.s_hat.factor.mean].asscalar()
s_hat_std = np.sqrt(infr.params[q.s_hat.factor.variance].asscalar())
s_15 = np.log1p(np.exp(s_hat_mean - s_hat_std))
s_50 = np.log1p(np.exp(s_hat_mean))
s_85 = np.log1p(np.exp(s_hat_mean + s_hat_std))
print('The mean and standard deviation of the mean parameter is %f(%f). ' % (mu_mean, mu_std))
print('The 15th, 50th and 85th percentile of the variance parameter is %f, %f and %f.'%(s_15, s_50, s_85))
The mean and standard deviation of the mean parameter is 3.120117(0.221690).
The 15th, 50th and 85th percentile of the variance parameter is 4.604521, 5.309114 and 6.016289.

The true parameter sits within one standard deviation of the estimated posterior distribution for both the mean and variance parameters. The above error gives a good indication about how much we could trust the parameters that we estimate.

Probabilistic PCA Tutorial

This tutorial will demonstrate Probabilistic PCA, a factor analysis technique.

Maths and notation following Machine Learning: A Probabilistic Perspective.

Installation

Follow the instrallation instructions in the README file to get setup.

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================

Probabalistic Modeling Introduction

Probabilistic Models can be categorized into directed graphical models (DGM, Bayes Net) and undirected graphical models (UGM). Most popular probabilistic models are DGMs, so MXFusion will only support the definition of DGMs unless there is a strong customer need of UGMs in future.

A DGM can be fully defined using 3 basic components: deterministic functions, probabilistic distributions, and random variables. We show the interface for defining a model using each of the three components below.

First lets import the basic libraries we’ll need to train our model and visualize some data.

[1]:
import warnings
warnings.filterwarnings('ignore')
import mxfusion as mf
import mxnet as mx
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Data Generation

We’ll take as our function to learn components of the log spiral function because it’s 2-dimensional and easy to visualize.

[2]:
def log_spiral(a,b,t):
    x = a * np.exp(b*t) * np.cos(t)
    y = a * np.exp(b*t) * np.sin(t)
    return np.vstack([x,y]).T

We parameterize the function with 100 data points and plot the resulting function.

[3]:
N = 100
D = 100
K = 2

a = 1
b = 0.1
t = np.linspace(0,6*np.pi,N)
r = log_spiral(a,b,t)
[4]:
r.shape
[4]:
(100, 2)
[5]:
plt.plot(r[:,0], r[:,1],'.')
[5]:
[<matplotlib.lines.Line2D at 0x1a1cbf4d68>]
_images/examples_notebooks_ppca_tutorial_10_1.png

We now project our \(K\) dimensional r into a high-dimensional \(D\) space using a random matrix of random weights \(W\). Now that r is embedded in a \(D\) dimensional space the goal of PPCA will be to recover r in it’s original low-dimensional \(K\) space.

[6]:
w = np.random.randn(K,N)
x_train = np.dot(r,w) + np.random.randn(N,N) * 1e-3
[7]:
# from sklearn.decomposition import PCA
# pca = PCA(n_components=2)
# new_r = pca.fit_transform(x_train)
# plt.plot(new_r[:,0], new_r[:,1],'.')

You can explore the higher dimensional data manually by changing dim1 and dim2 in the following cell.

[8]:
dim1 = 79
dim2 = 11
plt.scatter(x_train[:,dim1], x_train[:,dim2], color='blue', alpha=0.1)
plt.axis([-10, 10, -10, 10])
plt.title("Simulated data set")
plt.show()
_images/examples_notebooks_ppca_tutorial_15_0.png

MXFusion Model Definition

Import MXFusion and MXNet modelling components

[9]:
from mxfusion.models import Model
import mxnet.gluon.nn as nn
from mxfusion.components import Variable
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.functions.operators import broadcast_to

The primary data structure in MXFusion is the Model. Models hold ModelComponents, such as Variables, Distributions, and Functions which are the what define a probabilistic model.

The model we’ll be defining for PPCA is:

\(p(z)\) ~ \(N(\mathbf{\mu}, \mathbf{\Sigma)}\)

\(p(x | z,\theta)\) ~ \(N(\mathbf{Wz} + \mu, \Psi)\)

where:

\(z \in \mathbb{R}^{N x K}, \mathbf{\mu} \in \mathbb{R}^K, \mathbf{\Sigma} \in \mathbb{R}^{NxKxK}, x \in \mathbb{R}^{NxD}\)

\(\Psi \in \mathbb{R}^{NxDxD}, \Psi = [\Psi_0, \dots, \Psi_N], \Psi_i = \sigma^2\mathbf{I}\)

\(z\) here is our latent variable of interest, \(x\) is the observed data, and all other variables are parameters or constants of the model.

First we create an MXFusion Model object to build our PPCA model on.

[10]:
m = Model()

We attach Variable objects to our model to collect them in a centralized place. Internally, these are organized into a factor graph which is used during Inference.

[11]:
m.w = Variable(shape=(K,D), initial_value=mx.nd.array(np.random.randn(K,D)))

Because the mean of \(x\)’s distribution is composed of the dot product of \(z\) and \(W\), we need to create a dot product function. First we create a dot product function in MXNet and then wrap the function into MXFusion using the MXFusionGluonFunction class. m.dot can then be called like a normal python function and will apply to the variables it is called on.

[12]:
dot = nn.HybridLambda(function='dot')
m.dot = mf.functions.MXFusionGluonFunction(dot, num_outputs=1, broadcastable=False)

Now we define m.z which has an identity matrix covariance cov and zero mean.

m.z and sigma_2 are then used to define m.x.

Note that both sigma_2 and cov will be added implicitly into the Model because they are inputs to m.x.

[13]:
cov = mx.nd.broadcast_to(mx.nd.expand_dims(mx.nd.array(np.eye(K,K)), 0),shape=(N,K,K))
m.z = mf.distributions.MultivariateNormal.define_variable(mean=mx.nd.zeros(shape=(N,K)), covariance=cov, shape=(N,K))
m.sigma_2 = Variable(shape=(1,), transformation=PositiveTransformation())
m.x = mf.distributions.Normal.define_variable(mean=m.dot(m.z, m.w), variance=broadcast_to(m.sigma_2, (N,D)), shape=(N,D))

Posterior Definition

Now that we have our model, we need to define a posterior with parameters for the inference algorithm to optimize. When constructing a Posterior, we pass in the Model it is defined over and ModelComponent’s from the original Model are accessible and visible in the Posterior.

The covariance matrix must continue to be positive definite throughout the optimization process in order to succeed in the Cholesky decomposition when drawing samples or computing the log pdf of q.z. To satisfy this, we pass the covariance matrix parameters through a Gluon function that forces it into a Symmetric matrix for which suitable initialization values should maintain positive definite-ness throughout the optimization procedure.

[14]:
from mxfusion.inference import BatchInferenceLoop, GradBasedInference, StochasticVariationalInference
class SymmetricMatrix(mx.gluon.HybridBlock):
    def hybrid_forward(self, F, x, *args, **kwargs):
        return F.sum((F.expand_dims(x, 3)*F.expand_dims(x, 2)), axis=-3)

While this model has an analytical solution, we will run Variational Inference to find the posterior to demonstrate inference in a setting where the answer is known.

We place a multivariate normal prior over \(z\) because that is \(z\)’s prior in the model and we don’t need to approximate anything in this case. Because the form we’re optimizing over is the true model, the optimization is convex and will always converge to the same answer given by classical PCA given enough iterations.

[15]:
q = mf.models.Posterior(m)
sym = mf.components.functions.MXFusionGluonFunction(SymmetricMatrix(), num_outputs=1, broadcastable=False)
cov = Variable(shape=(N,K,K), initial_value=mx.nd.broadcast_to(mx.nd.expand_dims(mx.nd.array(np.eye(K,K) * 1e-2), 0),shape=(N,K,K)))
q.post_cov = sym(cov)
q.post_mean = Variable(shape=(N,K), initial_value=mx.nd.array(np.random.randn(N,K)))
q.z.set_prior(mf.distributions.MultivariateNormal(mean=q.post_mean, covariance=q.post_cov))

We now take our posterior and model, along with an observation pattern (in our case only m.x is observed) and create an inference algorithm. This inference algorithm is combined with a gradient loop to create the Inference method infr.

[16]:
observed = [m.x]
alg = StochasticVariationalInference(num_samples=3, model=m, posterior=q, observed=observed)
infr = GradBasedInference(inference_algorithm=alg,  grad_loop=BatchInferenceLoop())

The inference method is then initialized with our training data and we run optimiziation for a while until convergence.

[17]:
infr.initialize(x=mx.nd.array(x_train))
[18]:
infr.run(max_iter=1000, learning_rate=1e-2, x=mx.nd.array(x_train))

Once training completes, we retrieve the posterior mean (our trained representation for \(\mathbf{Wz} + \mu\)) from the inference method and plot it. As shown, the plot recovers (up to rotation) the original 2D data quite well.

[19]:
post_z_mean = infr.params[q.z.factor.mean].asnumpy()
[20]:
plt.plot(post_z_mean[:,0], post_z_mean[:,1],'.')
[20]:
[<matplotlib.lines.Line2D at 0x1a1fc3e668>]
_images/examples_notebooks_ppca_tutorial_39_1.png

Bayesian Neural Network (VI) for classification (under Development)

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================
[1]:
import warnings
warnings.filterwarnings('ignore')
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference

Generate Synthetic Data

[2]:
import GPy
%matplotlib inline
from pylab import *

np.random.seed(4)
k = GPy.kern.RBF(1, lengthscale=0.1)
x = np.random.rand(200,1)
y = np.random.multivariate_normal(mean=np.zeros((200,)), cov=k.K(x), size=(1,)).T>0.
plot(x[:,0], y[:,0], '.')
[2]:
[<matplotlib.lines.Line2D at 0x1a2bc4aeb8>]
_images/examples_notebooks_bnn_classification_4_1.png
[3]:
D = 10
net = nn.HybridSequential(prefix='nn_')
with net.name_scope():
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=1))
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=D))
    net.add(nn.Dense(2, flatten=False, in_units=D))
net.initialize(mx.init.Xavier(magnitude=1))
[4]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion.inference import VariationalPosteriorForwardSampling
from mxfusion.components.functions.operators import broadcast_to
from mxfusion.components.distributions import Normal, Categorical
from mxfusion import Variable, Model
from mxfusion.components.functions import MXFusionGluonFunction
[5]:
m = Model()
m.N = Variable()
m.f = MXFusionGluonFunction(net, num_outputs=1, broadcastable=False)
m.x = Variable(shape=(m.N,1))
m.r = m.f(m.x)
for _,v in m.r.factor.parameters.items():
    v.set_prior(Normal(mean=broadcast_to(mx.nd.array([0]), v.shape),
                       variance=broadcast_to(mx.nd.array([1.]), v.shape)))
m.y = Categorical.define_variable(log_prob=m.r, shape=(m.N,1), num_classes=2)
print(m)
Model (cf188)
Variable(ca45c) = BroadcastToOperator(data=Variable(8dec8))
Variable(383ef) = BroadcastToOperator(data=Variable(faf0a))
Variable(44754) ~ Normal(mean=Variable(383ef), variance=Variable(ca45c))
Variable(03371) = BroadcastToOperator(data=Variable(ad532))
Variable(88468) = BroadcastToOperator(data=Variable(2110c))
Variable(84fc2) ~ Normal(mean=Variable(88468), variance=Variable(03371))
Variable(1dc39) = BroadcastToOperator(data=Variable(f3d0f))
Variable(77d1c) = BroadcastToOperator(data=Variable(121a5))
Variable(4e7d9) ~ Normal(mean=Variable(77d1c), variance=Variable(1dc39))
Variable(dbd5d) = BroadcastToOperator(data=Variable(68ad8))
Variable(51f11) = BroadcastToOperator(data=Variable(1ac45))
Variable(dccd6) ~ Normal(mean=Variable(51f11), variance=Variable(dbd5d))
Variable(47fa5) = BroadcastToOperator(data=Variable(e966f))
Variable(90359) = BroadcastToOperator(data=Variable(f19f8))
Variable(daaa7) ~ Normal(mean=Variable(90359), variance=Variable(47fa5))
Variable(c310f) = BroadcastToOperator(data=Variable(1378f))
Variable(a21a3) = BroadcastToOperator(data=Variable(91998))
Variable(44c15) ~ Normal(mean=Variable(a21a3), variance=Variable(c310f))
r = GluonFunctionEvaluation(nn_input_0=x, nn_dense0_weight=Variable(44c15), nn_dense0_bias=Variable(daaa7), nn_dense1_weight=Variable(dccd6), nn_dense1_bias=Variable(4e7d9), nn_dense2_weight=Variable(84fc2), nn_dense2_bias=Variable(44754))
y ~ Categorical(log_prob=r)
[6]:
from mxfusion.inference import BatchInferenceLoop, create_Gaussian_meanfield, GradBasedInference, StochasticVariationalInference, MAP
[7]:
observed = [m.y, m.x]
q = create_Gaussian_meanfield(model=m, observed=observed)
alg = StochasticVariationalInference(num_samples=5, model=m, posterior=q, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
[8]:
infr.initialize(y=mx.nd.array(y), x=mx.nd.array(x))
[9]:
for v_name, v in m.r.factor.parameters.items():
    infr.params[q[v].factor.mean] = net.collect_params()[v_name].data()
    infr.params[q[v].factor.variance] = mx.nd.ones_like(infr.params[q[v].factor.variance])*1e-6
[10]:
infr.run(max_iter=500, learning_rate=1e-1, y=mx.nd.array(y), x=mx.nd.array(x), verbose=True)
Iteration 51 loss: 661.41601562593755
Iteration 101 loss: 301.83178710937525
Iteration 151 loss: 166.81152343757812
Iteration 201 loss: 159.75297546386725
Iteration 251 loss: 154.61776733398438
Iteration 301 loss: 147.10989379882812
Iteration 351 loss: 153.09896850585938
Iteration 401 loss: 131.58213806152344
Iteration 451 loss: 147.08862304687556
Iteration 500 loss: 136.80494689941406
[11]:
# for uuid, v in infr.inference_algorithm.posterior.variables.items():
#     if uuid in infr.params.param_dict:
#         print(v.name, infr.params[v])
[12]:
xt = np.linspace(0,1,100)[:,None]
[13]:
infr2 = VariationalPosteriorForwardSampling(10, [m.x], infr, [m.r])
res = infr2.run(x=mx.nd.array(xt))
[14]:
yt = res[0].asnumpy()
[15]:
yt_mean = yt.mean(0)
yt_std = yt.std(0)
for i in range(yt.shape[0]):
    plot(xt[:,0],1./(1+np.exp(yt[i,:,0]-yt[i,:,1])),'k',alpha=0.2)
plot(x[:,0],y[:,0],'.')
[15]:
[<matplotlib.lines.Line2D at 0x1a2d368ac8>]
_images/examples_notebooks_bnn_classification_17_1.png

Bayesian Neural Network (VI) for regression

Zhenwen Dai (2018-8-21)

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================
[1]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference

Generate Synthetic Data

[2]:
import GPy
%matplotlib inline
from pylab import *

np.random.seed(0)
k = GPy.kern.RBF(1, lengthscale=0.1)
x = np.random.rand(1000,1)
y = np.random.multivariate_normal(mean=np.zeros((1000,)), cov=k.K(x), size=(1,)).T
plot(x[:,0], y[:,0], '.')
[2]:
[<matplotlib.lines.Line2D at 0x10e9cf6d8>]
_images/examples_notebooks_bnn_regression_4_1.png
Model definition
[3]:
D = 50
net = nn.HybridSequential(prefix='nn_')
with net.name_scope():
    net.add(nn.Dense(D, activation="tanh", in_units=1))
    net.add(nn.Dense(D, activation="tanh", in_units=D))
    net.add(nn.Dense(1, flatten=True, in_units=D))
net.initialize(mx.init.Xavier(magnitude=3))
[4]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion.inference import VariationalPosteriorForwardSampling
from mxfusion.components.functions.operators import broadcast_to
from mxfusion.components.distributions import Normal
from mxfusion import Variable, Model
from mxfusion.components.functions import MXFusionGluonFunction
[5]:
m = Model()
m.N = Variable()
m.f = MXFusionGluonFunction(net, num_outputs=1,broadcastable=False)
m.x = Variable(shape=(m.N,1))
m.v = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=mx.nd.array([0.01]))
m.r = m.f(m.x)
for v in m.r.factor.parameters.values():
    v.set_prior(Normal(mean=broadcast_to(mx.nd.array([0]), v.shape),
                       variance=broadcast_to(mx.nd.array([1.]), v.shape)))
m.y = Normal.define_variable(mean=m.r, variance=broadcast_to(m.v, (m.N,1)), shape=(m.N,1))
print(m)
Model (a0f1e)
Variable(b08ec) = BroadcastToOperator(data=Variable(6c09d))
Variable(1cdae) = BroadcastToOperator(data=Variable(20ed6))
Variable(f6567) ~ Normal(mean=Variable(1cdae), variance=Variable(b08ec))
Variable(11427) = BroadcastToOperator(data=Variable(cb51c))
Variable(6068d) = BroadcastToOperator(data=Variable(a397c))
Variable(0d566) ~ Normal(mean=Variable(6068d), variance=Variable(11427))
Variable(a2806) = BroadcastToOperator(data=Variable(37171))
Variable(64e44) = BroadcastToOperator(data=Variable(58b81))
Variable(591da) ~ Normal(mean=Variable(64e44), variance=Variable(a2806))
Variable(04dac) = BroadcastToOperator(data=Variable(56e87))
Variable(a1d30) = BroadcastToOperator(data=Variable(e500b))
Variable(1caf4) ~ Normal(mean=Variable(a1d30), variance=Variable(04dac))
Variable(7a6fd) = BroadcastToOperator(data=Variable(39c80))
Variable(2bf77) = BroadcastToOperator(data=Variable(2d483))
Variable(c555f) ~ Normal(mean=Variable(2bf77), variance=Variable(7a6fd))
Variable(9c33c) = BroadcastToOperator(data=Variable(19481))
Variable(507a4) = BroadcastToOperator(data=Variable(56583))
Variable(5c091) ~ Normal(mean=Variable(507a4), variance=Variable(9c33c))
Variable(120d7) = BroadcastToOperator(data=v)
r = GluonFunctionEvaluation(nn_input_0=x, nn_dense0_weight=Variable(5c091), nn_dense0_bias=Variable(c555f), nn_dense1_weight=Variable(1caf4), nn_dense1_bias=Variable(591da), nn_dense2_weight=Variable(0d566), nn_dense2_bias=Variable(f6567))
y ~ Normal(mean=r, variance=Variable(120d7))
Inference with Meanfield
[6]:
from mxfusion.inference import BatchInferenceLoop, create_Gaussian_meanfield, GradBasedInference, StochasticVariationalInference
[7]:
observed = [m.y, m.x]
q = create_Gaussian_meanfield(model=m, observed=observed)
alg = StochasticVariationalInference(num_samples=3, model=m, posterior=q, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
[8]:
infr.initialize(y=mx.nd.array(y), x=mx.nd.array(x))
[9]:
for v_name, v in m.r.factor.parameters.items():
    infr.params[q[v].factor.mean] = net.collect_params()[v_name].data()
    infr.params[q[v].factor.variance] = mx.nd.ones_like(infr.params[q[v].factor.variance])*1e-6
[10]:
infr.run(max_iter=2000, learning_rate=1e-2, y=mx.nd.array(y), x=mx.nd.array(x), verbose=True)
Iteration 201 loss: 15813.8652343755
Iteration 401 loss: 11816.2539062575
Iteration 601 loss: 8878.53613281255
Iteration 801 loss: 6882.62353515625
Iteration 1001 loss: 4587.8847656255
Iteration 1201 loss: 3141.453613281255
Iteration 1401 loss: 2384.0412597656255
Iteration 1601 loss: 1506.3929443359375
Iteration 1801 loss: 1371.0905761718755
Iteration 2000 loss: 1076.2847900390625
Use prediction to visualize the resulting BNN
[11]:
xt = np.linspace(0,1,100)[:,None]
[12]:
infr2 = VariationalPosteriorForwardSampling(10, [m.x], infr, [m.r])
res = infr2.run(x=mx.nd.array(xt))
[13]:
yt = res[0].asnumpy()
[14]:
yt_mean = yt.mean(0)
yt_std = yt.std(0)

for i in range(yt.shape[0]):
    plot(xt[:,0],yt[i,:,0],'k',alpha=0.2)
plot(x[:,0],y[:,0],'.')
[14]:
[<matplotlib.lines.Line2D at 0x10f6b75c0>]
_images/examples_notebooks_bnn_regression_19_1.png

Variational Auto-Encoder (VAE)

Zhenwen Dai (2019-05-29)

Variational auto-encoder (VAE) is a latent variable model that uses a latent variable to generate data represented in vector form. Consider a latent variable \(x\) and an observed variable \(y\). The plain VAE is defined as

\begin{align} p(x) =& \mathcal{N}(0, I) \\ p(y|x) =& \mathcal{N}(f(x), \sigma^2I) \end{align}

where \(f\) is the deep neural network (DNN), often referred to as the decoder network.

The variational posterior of VAE is defined as

\begin{align} q(x) = \mathcal{N}\left(g_{\mu}(y), \sigma^2_x I)\right) \end{align}

where \(g_{\mu}\) is the encoder networks that generate the mean of the variational posterior of \(x\). For simplicity, we assume that all the data points share the same variance in the variational posteior. This can be extended by generating the variance also from the encoder network.

[1]:
import warnings
warnings.filterwarnings('ignore')
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference
%matplotlib inline
from pylab import *
Load a toy dataset
[2]:
import GPy
data = GPy.util.datasets.oil_100()
Y = data['X']
label = data['Y'].argmax(1)
[3]:
N, D = Y.shape
Model Defintion

We first define that the encoder and decoder DNN with MXNet Gluon blocks. Both DNNs have two hidden layers with tanh non-linearity.

[4]:
Q = 2
[5]:
H = 50
encoder = nn.HybridSequential(prefix='encoder_')
with encoder.name_scope():
    encoder.add(nn.Dense(H, in_units=D, activation="tanh", flatten=False))
    encoder.add(nn.Dense(H, in_units=H, activation="tanh", flatten=False))
    encoder.add(nn.Dense(Q, in_units=H, flatten=False))
encoder.initialize(mx.init.Xavier(magnitude=3))
[6]:
H = 50
decoder = nn.HybridSequential(prefix='decoder_')
with decoder.name_scope():
    decoder.add(nn.Dense(H, in_units=Q, activation="tanh", flatten=False))
    decoder.add(nn.Dense(H, in_units=H, activation="tanh", flatten=False))
    decoder.add(nn.Dense(D, in_units=H, flatten=False))
decoder.initialize(mx.init.Xavier(magnitude=3))

Then, we define the model of VAE in MXFusion. Note that for simplicity in implementation, we use scalar normal distributions defined for individual entries of a Matrix instead of multivariate normal distributions with diagonal covariance matrices.

[8]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion import Variable, Model, Posterior
from mxfusion.components.functions import MXFusionGluonFunction
from mxfusion.components.distributions import Normal
from mxfusion.components.functions.operators import broadcast_to

m = Model()
m.N = Variable()
m.decoder = MXFusionGluonFunction(decoder, num_outputs=1,broadcastable=True)
m.x = Normal.define_variable(mean=broadcast_to(mx.nd.array([0]), (m.N, Q)),
                             variance=broadcast_to(mx.nd.array([1]), (m.N, Q)), shape=(m.N, Q))
m.f = m.decoder(m.x)
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=mx.nd.array([0.01]))
m.y = Normal.define_variable(mean=m.f, variance=broadcast_to(m.noise_var, (m.N, D)),
                             shape=(m.N, D))
print(m)
Model (37a04)
Variable (b92c2) = BroadcastToOperator(data=Variable noise_var (a50d4))
Variable (39c2c) = BroadcastToOperator(data=Variable (e1aad))
Variable (b7150) = BroadcastToOperator(data=Variable (a57d4))
Variable x (53056) ~ Normal(mean=Variable (b7150), variance=Variable (39c2c))
Variable f (ad606) = GluonFunctionEvaluation(decoder_input_0=Variable x (53056), decoder_dense0_weight=Variable (b9b70), decoder_dense0_bias=Variable (d95aa), decoder_dense1_weight=Variable (73dc2), decoder_dense1_bias=Variable (b85dd), decoder_dense2_weight=Variable (7a61c), decoder_dense2_bias=Variable (eba91))
Variable y (23bca) ~ Normal(mean=Variable f (ad606), variance=Variable (b92c2))

We also define the variational posterior following the equation above.

[9]:
q = Posterior(m)
q.x_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=mx.nd.array([1e-6]))
q.encoder = MXFusionGluonFunction(encoder, num_outputs=1, broadcastable=True)
q.x_mean = q.encoder(q.y)
q.x.set_prior(Normal(mean=q.x_mean, variance=broadcast_to(q.x_var, q.x.shape)))
print(q)
Posterior (4ec05)
Variable x_mean (86d22) = GluonFunctionEvaluation(encoder_input_0=Variable y (23bca), encoder_dense0_weight=Variable (51b3d), encoder_dense0_bias=Variable (c0092), encoder_dense1_weight=Variable (ad9ef), encoder_dense1_bias=Variable (83db0), encoder_dense2_weight=Variable (78b82), encoder_dense2_bias=Variable (b856d))
Variable (6dc84) = BroadcastToOperator(data=Variable x_var (19d07))
Variable x (53056) ~ Normal(mean=Variable x_mean (86d22), variance=Variable (6dc84))
Variational Inference

Variational inference is done via creating an inference object and passing in the stochastic variational inference algorithm.

[11]:
from mxfusion.inference import BatchInferenceLoop, StochasticVariationalInference, GradBasedInference

observed = [m.y]
alg = StochasticVariationalInference(num_samples=3, model=m, posterior=q, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())

SVI is a gradient-based algorithm. We can run the algorithm by providing the data and specifying the parameters for the gradient optimizer (the default gradient optimizer is Adam).

[13]:
infr.run(max_iter=2000, learning_rate=1e-2, y=mx.nd.array(Y), verbose=True)
Iteration 200 loss: 1720.556396484375
Iteration 400 loss: 601.11962890625
Iteration 600 loss: 168.620849609375
Iteration 800 loss: -48.67474365234375
Iteration 1000 loss: -207.34835815429688
Iteration 1200 loss: -354.17742919921875
Iteration 1400 loss: -356.26409912109375
Iteration 1600 loss: -561.263427734375
Iteration 1800 loss: -697.8665161132812
Iteration 2000 loss: -753.83203125                              8
Plot the training data in the latent space

Finally, we may be interested in visualizing the latent space of our dataset. We can do that by calling encoder network.

[17]:
from mxfusion.inference import TransferInference

q_x_mean = q.encoder.gluon_block(mx.nd.array(Y)).asnumpy()
[18]:
for i in range(3):
    plot(q_x_mean[label==i,0], q_x_mean[label==i,1], '.')
_images/examples_notebooks_variational_auto_encoder_20_0.png

Gaussian Processes

Gaussian Process Regression

Zhenwen Dai (2019-05-29)

Introduction

Gaussian process (GP) is a Bayesian non-parametric model used for various machine learning problems such as regression, classification. This notebook shows about how to use a Gaussian process regression model in MXFusion.

[1]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
Toy data

We generate some synthetic data for our regression example. The data set is generate from a sine function with some additive Gaussian noise.

[2]:
import numpy as np
%matplotlib inline
from pylab import *

np.random.seed(0)
X = np.random.uniform(-3.,3.,(20,1))
Y = np.sin(X) + np.random.randn(20,1)*0.05

The generated data are visualized as follows:

[3]:
plot(X, Y, 'rx', label='data points')
_=legend()
_images/examples_notebooks_gp_regression_6_0.png
Gaussian process regression with Gaussian likelihood

Denote a set of input points \(X \in \mathbb{R}^{N \times Q}\). A Gaussian process is often formulated as a multi-variate normal distribution conditioned on the inputs:

\[ \begin{align}\begin{aligned} p(F|X) = \mathcal{N}(F; 0, K),\\where :math:`F \in \mathbb{R}^{N \times 1}` is the corresponding output points of the Gaussian process and :math:`K` is the covariance matrix computed on the set of inputs according to a chosen kernel function :math:`k(\cdot, \cdot)`.\end{aligned}\end{align} \]

For a regression problem, \(F\) is often referred to as the noise-free output and we usually assume an additional probability distribution as the observation noise. In this case, we assume the noise distribution to be Gaussian:

\[ \begin{align}\begin{aligned} p(Y|F) = \mathcal{N}(Y; F, \sigma^2 \mathcal{I}),\\where :math:`Y \in \mathbb{R}^{N \times 1}` is the observed output and :math:`\sigma^2` is the variance of the Gaussian distribution.\end{aligned}\end{align} \]

The following code defines the above GP regression in MXFusion. First, we change the default data dtype to double precision to avoid any potential numerical issues.

[4]:
from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'

In the code below, the variable Y is defined following the probabilistic module GPRegression. A probabilistic module in MXFusion is a pre-built probabilistic model with dedicated inference algorithms for computing log-pdf and drawing samples. In this case, GPRegression defines the above GP regression model with a Gaussian likelihood. It understands that the log-likelihood after marginalizing \(F\) is closed-form and exploits this property when computing log-pdf.

The model is defined by the input variable X with the shape (m.N, 1), where the value of m.N is discovered when data is given during inference. A positive noise variance variable m.noise_var is defined with the initial value to be 0.01. For GP, we define a RBF kernel with input dimensionality being one and initial value of variance and lengthscale to be one. We define the variable m.Y following the GP regression distribution with the above specified kernel, input variable and noise_variance.

[5]:
from mxfusion import Model, Variable
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions.gp.kernels import RBF
from mxfusion.modules.gp_modules import GPRegression

m = Model()
m.N = Variable()
m.X = Variable(shape=(m.N, 1))
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=0.01)
m.kernel = RBF(input_dim=1, variance=1, lengthscale=1)
m.Y = GPRegression.define_variable(X=m.X, kernel=m.kernel, noise_var=m.noise_var, shape=(m.N, 1))

In the above model, we have not defined any prior distributions for any hyper-parameters. To use the model for regrssion, we typically do a maximum likelihood estimate for all the hyper-parameters conditioned on the input and output variable. In MXFusion, this is done by first creating an inference algorithm, which is MAP in this case, by specifying the observed variables. Then, we create an inference body for gradient optimization inference methods, which is called GradBasedInference. The inference method is triggered by calling the run method, in which all the observed data are given as keyword arguments and any necessary configruation parameters are specified.

[6]:
import mxnet as mx
from mxfusion.inference import GradBasedInference, MAP

infr = GradBasedInference(inference_algorithm=MAP(model=m, observed=[m.X, m.Y]))
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'),
         max_iter=100, learning_rate=0.05, verbose=True)
Iteration 10 loss: -13.09287954321266
Iteration 20 loss: -15.971970034359586
Iteration 30 loss: -16.725359053995163
Iteration 40 loss: -16.835084442759314
Iteration 50 loss: -16.850332113428053
Iteration 60 loss: -16.893812683762203
Iteration 70 loss: -16.900137667771077
Iteration 80 loss: -16.901158761459012
Iteration 90 loss: -16.903085976668137
Iteration 100 loss: -16.903135093930537

All the inference outcomes are in the attribute params of the inference body. The inferred value of a parameter can be access by passing the reference of the queried parameter to the params attribute. For example, to get the value m.noise_var, we can call inference.params[m.noise_var]. The estimated parameters from the above experiment are as follows:

[7]:
print('The estimated variance of the RBF kernel is %f.' % infr.params[m.kernel.variance].asscalar())
print('The estimated length scale of the RBF kernel is %f.' % infr.params[m.kernel.lengthscale].asscalar())
print('The estimated variance of the Gaussian likelihood is %f.' % infr.params[m.noise_var].asscalar())
The estimated variance of the RBF kernel is 0.616992.
The estimated length scale of the RBF kernel is 1.649073.
The estimated variance of the Gaussian likelihood is 0.002251.

We can compare the estimated values with the same model implemented in GPy. The estimated values from GPy are very close to the ones from MXFusion.

[8]:
import GPy

m_gpy = GPy.models.GPRegression(X, Y, kernel=GPy.kern.RBF(1))
m_gpy.optimize()
print(m_gpy)

Name : GP regression
Objective : -16.903456670910902
Number of Parameters : 3
Number of Optimization Parameters : 3
Updates : True
Parameters:
  GP_regression.           |                 value  |  constraints  |  priors
  rbf.variance             |    0.6148038604494702  |      +ve      |
  rbf.lengthscale          |    1.6500299722611123  |      +ve      |
  Gaussian_noise.variance  |  0.002270049772204339  |      +ve      |
Prediction

The above section shows how to estimate the model hyper-parameters of a GP regression model. This is often referred to as training. After training, we are often interested in using the inferred model to predict on unseen inputs. The GP modules offers two types of predictions: predicting the mean and variance of the output variable or drawing samples from the predictive posterior distributions.

Mean and variance of the posterior distribution

To estimate the mean and variance of the predictive posterior distribution, we use the inference algorithm ModulePredictionAlgorithm, which takes the model, the observed variables and the target variables of prediction as input arguments. We use TransferInference as the inference body, which allows us to take the inference outcome from the previous inference. This is done by passing the inference parameters infr.params into the infr_params argument.

[9]:
from mxfusion.inference import TransferInference, ModulePredictionAlgorithm
infr_pred = TransferInference(ModulePredictionAlgorithm(model=m, observed=[m.X], target_variables=[m.Y]),
                              infr_params=infr.params)

To visualize the fitted model, we make predictions on 100 points evenly spanned from -5 to 5. We estimate the mean and variance of the noise-free output \(F\).

[10]:
xt = np.linspace(-5,5,100)[:, None]
res = infr_pred.run(X=mx.nd.array(xt, dtype='float64'))[0]
f_mean, f_var = res[0].asnumpy()[0], res[1].asnumpy()[0]

The resulting figure is shown as follows:

[11]:
plot(xt, f_mean[:,0], 'b-', label='mean')
plot(xt, f_mean[:,0]-2*np.sqrt(f_var), 'b--', label='2 x std')
plot(xt, f_mean[:,0]+2*np.sqrt(f_var), 'b--')
plot(X, Y, 'rx', label='data points')
ylabel('F')
xlabel('X')
_=legend()
_images/examples_notebooks_gp_regression_22_0.png
Posterior samples of Gaussian process

Apart from getting the mean and variance at every location, we may need to draw samples from the posterior GP. As the output variables at different locations are correlated with each other, each sample gives us some idea of a potential function from the posterior GP distribution.

To draw samples from the posterior distribution, we need to change the prediction inference algorithm attached to the GP module. The default prediction function estimate the mean and variance of the output variable as shown above. We can attach another inference algorithm as the prediction algorithm. In the following code, we attach the GPRegressionSamplingPrediction algorithm as the prediction algorithm. The targets and conditionals arguments specify the target variables of the algorithm and the conditional variables of the algorithm. After spcifying a name in the alg_name argument such as gp_predict, we can access this inference algorithm with the specified name like gp.gp_predict. In following code, we set the diagonal_variance attribute to be False in order to draw samples from a full covariace matrix. To avoid numerical issue, we set a small jitter to help matrix inversion. Then, we create the inference body in the same way as the above example.

[12]:
from mxfusion.inference import TransferInference, ModulePredictionAlgorithm
from mxfusion.modules.gp_modules.gp_regression import GPRegressionSamplingPrediction

gp = m.Y.factor
gp.attach_prediction_algorithms(targets=gp.output_names, conditionals=gp.input_names,
            algorithm=GPRegressionSamplingPrediction(
                gp._module_graph, gp._extra_graphs[0], [gp._module_graph.X]),
            alg_name='gp_predict')
gp.gp_predict.diagonal_variance = False
gp.gp_predict.jitter = 1e-8
infr_pred = TransferInference(ModulePredictionAlgorithm(model=m, observed=[m.X], target_variables=[m.Y], num_samples=5),
                              infr_params=infr.params)

We draw five samples on the 100 evenly spanned input locations.

[13]:
xt = np.linspace(-5,5,100)[:, None]
y_samples = infr_pred.run(X=mx.nd.array(xt, dtype='float64'))[0].asnumpy()

We visualize the individual samples each with a different color.

[14]:
for i in range(y_samples.shape[0]):
    plot(xt, y_samples[i,:,0])
_images/examples_notebooks_gp_regression_28_0.png
Gaussian process with a mean function

In the previous example, we created an GP regression model without a mean function (the mean of GP is zero). It is very easy to extend a GP model with a mean field. First, we create a mean function in MXNet (a neural network). For simplicity, we create a 1D linear function as the mean function.

[21]:
mean_func = mx.gluon.nn.Dense(1, in_units=1, flatten=False)
mean_func.initialize(mx.init.Xavier(magnitude=3))

We create the GP regression model in a similar way as above. The difference is 1. We create a wrapper of the mean function in model definition m.mean_func. 2. We evaluate the mean function with the input of our GP model, which results into the mean of the GP. 3. We pass the resulting mean into the mean argument of the GP module.

[23]:
m = Model()
m.N = Variable()
m.X = Variable(shape=(m.N, 1))
m.mean_func = MXFusionGluonFunction(mean_func, num_outputs=1, broadcastable=True)
m.mean = m.mean_func(m.X)
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=0.01)
m.kernel = RBF(input_dim=1, variance=1, lengthscale=1)
m.Y = GPRegression.define_variable(X=m.X, kernel=m.kernel, noise_var=m.noise_var, mean=m.mean, shape=(m.N, 1))
[24]:
import mxnet as mx
from mxfusion.inference import GradBasedInference, MAP

infr = GradBasedInference(inference_algorithm=MAP(model=m, observed=[m.X, m.Y]))
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'),
         max_iter=100, learning_rate=0.05, verbose=True)
Iteration 10 loss: -6.288699675985622
Iteration 20 loss: -13.938366520031717
Iteration 30 loss: -16.238146742572965
Iteration 40 loss: -16.214515784955303
Iteration 50 loss: -16.302410205174386
Iteration 60 loss: -16.423765889507315
Iteration 70 loss: -16.512277794947106
Iteration 80 loss: -16.5757306621185
Iteration 90 loss: -16.6410597628529
Iteration 100 loss: -16.702913078848557
[25]:
from mxfusion.inference import TransferInference, ModulePredictionAlgorithm
infr_pred = TransferInference(ModulePredictionAlgorithm(model=m, observed=[m.X], target_variables=[m.Y]),
                              infr_params=infr.params)
[26]:
xt = np.linspace(-5,5,100)[:, None]
res = infr_pred.run(X=mx.nd.array(xt, dtype='float64'))[0]
f_mean, f_var = res[0].asnumpy()[0], res[1].asnumpy()[0]
[27]:
plot(xt, f_mean[:,0], 'b-', label='mean')
plot(xt, f_mean[:,0]-2*np.sqrt(f_var), 'b--', label='2 x std')
plot(xt, f_mean[:,0]+2*np.sqrt(f_var), 'b--')
plot(X, Y, 'rx', label='data points')
ylabel('F')
xlabel('X')
_=legend()
_images/examples_notebooks_gp_regression_36_0.png

The effect of the mean function is not noticable, because there is no linear trend in our data. We can plot the values of the estimated parameters of the linear mean function.

[36]:
print("The weight is %f and the bias is %f." %(infr.params[m.mean_func.parameters['dense1_weight']].asnumpy(),
                                               infr.params[m.mean_func.parameters['dense1_bias']].asscalar()))
The weight is 0.021969 and the bias is 0.079038.
Variational sparse Gaussian process regression

In MXFusion, we also have variational sparse GP implemented as a module. A sparse GP model can be created in a similar way as the plain GP model.

[39]:
from mxfusion import Model, Variable
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions.gp.kernels import RBF
from mxfusion.modules.gp_modules import SparseGPRegression

m = Model()
m.N = Variable()
m.X = Variable(shape=(m.N, 1))
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=0.01)
m.kernel = RBF(input_dim=1, variance=1, lengthscale=1)
m.Y = SparseGPRegression.define_variable(X=m.X, kernel=m.kernel, noise_var=m.noise_var, shape=(m.N, 1), num_inducing=50)

Stochastic Variational Gaussian Process Regression

Zhenwen Dai (2019-05-29)

Introduction

Gaussian process (GP) is computationally expensive. A popular approach to scale up GP regression on large data is to use stochastic variational inference with mini-batch training (Hensman et al., 2013). SVGP regression with Gaussian noise has been implemented as a module in MXFusion.

[32]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
Toy data

We generate some synthetic data for our regression example. The data set is generate from a sine function with some additive Gaussian noise.

[2]:
import numpy as np
%matplotlib inline
from pylab import *

np.random.seed(0)
X = np.random.uniform(-3.,3.,(1000,1))
Y = np.sin(X) + np.random.randn(1000,1)*0.05

The generated data are visualized as follows:

[3]:
plot(X, Y, 'rx', label='data points')
_=legend()
_images/examples_notebooks_svgp_regression_6_0.png
[4]:
from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'

The SVGP regression model is created as follow. Two SVGP specific parameters are num_inducing which specifies the number of inducing points used in the variational sparse GP approximation and svgp_log_pdf.jitter which the jitter term in the log pdf calculation for numerical robustness.

[33]:
from mxfusion import Model, Variable
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions.gp.kernels import RBF
from mxfusion.modules.gp_modules import SVGPRegression

m = Model()
m.N = Variable()
m.X = Variable(shape=(m.N, 1))
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=0.01)
m.kernel = RBF(input_dim=1, variance=1, lengthscale=1)
m.Y = SVGPRegression.define_variable(X=m.X, kernel=m.kernel, noise_var=m.noise_var, shape=(m.N, 1), num_inducing=20)
m.Y.factor.svgp_log_pdf.jitter = 1e-6

Inference is done by creating the inference instance from the GradBasedInference class, in which we use a MAP inference algorithm as there are no latent variables outside the SVGPRegression module. Additional, we specify grad_loop to be MiniBatchInferenceLoop in which we set the size of mini-batch and the scaling factor for minibatch training.

Then, training is triggered by calling the run method.

[38]:
import mxnet as mx
from mxfusion.inference import GradBasedInference, MAP, MinibatchInferenceLoop

infr = GradBasedInference(inference_algorithm=MAP(model=m, observed=[m.X, m.Y]),
                          grad_loop=MinibatchInferenceLoop(batch_size=10, rv_scaling={m.Y: 1000/10}))
infr.initialize(X=(1000,1), Y=(1000,1))
infr.params[m.Y.factor.inducing_inputs] = mx.nd.array(np.random.randn(20, 1), dtype='float64')
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'),
         max_iter=50, learning_rate=0.1, verbose=True)
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'),
         max_iter=50, learning_rate=0.01, verbose=True)
epoch 1 Iteration 100 loss: 933115.0603707978                   epoch-loss: 10413624.614005275
epoch 2 Iteration 100 loss: 524948.7079326594                   epoch-loss: 686034.5295730559
epoch 3 Iteration 100 loss: 345602.4022749258                   epoch-loss: 427065.8343717841
epoch 4 Iteration 100 loss: 277011.3760208657                   epoch-loss: 297071.493696023
epoch 5 Iteration 100 loss: 183347.13021907964                  epoch-loss: 219808.0871498559
epoch 6 Iteration 100 loss: 143763.11007552472                  epoch-loss: 169486.20729875282
epoch 7 Iteration 100 loss: 132031.47695326462                  epoch-loss: 134765.1471133905
epoch 8 Iteration 100 loss: 95632.60561449913                   epoch-loss: 109798.66321648406
epoch 9 Iteration 100 loss: 73957.6220462552                    epoch-loss: 91257.8705670977
epoch 10 Iteration 100 loss: 64840.07207031624                  epoch-loss: 77084.06942481917
epoch 11 Iteration 100 loss: 60780.27278575914                  epoch-loss: 65962.38163622493
epoch 12 Iteration 100 loss: 48546.66342698521                  epoch-loss: 57037.39009905885
epoch 13 Iteration 100 loss: 42676.907263579335                 epoch-loss: 49725.50869601666
epoch 14 Iteration 100 loss: 43266.74759690139                  epoch-loss: 43635.70855486856
epoch 15 Iteration 100 loss: 33139.32033870425                  epoch-loss: 38501.415430223606
epoch 16 Iteration 100 loss: 35129.68003531527                  epoch-loss: 34139.30892930683
epoch 17 Iteration 100 loss: 33309.08869286892                  epoch-loss: 30414.713307491817
epoch 18 Iteration 100 loss: 31058.180286752693                 epoch-loss: 27222.957705478882
epoch 19 Iteration 100 loss: 22781.668494776342                 epoch-loss: 24466.753696665117
epoch 20 Iteration 100 loss: 16921.53875526696                  epoch-loss: 22063.866203795988
epoch 21 Iteration 100 loss: 16866.27172281184                  epoch-loss: 19959.435781693166
epoch 22 Iteration 100 loss: 18001.39866328793                  epoch-loss: 18093.70564938978
epoch 23 Iteration 100 loss: 19268.435700542395                 epoch-loss: 16435.61461383947
epoch 24 Iteration 100 loss: 13586.70681551015                  epoch-loss: 14947.197437326102
epoch 25 Iteration 100 loss: 11842.634017398044                 epoch-loss: 13605.954880888436
epoch 26 Iteration 100 loss: 12304.581180033452                 epoch-loss: 12393.880316263208
epoch 27 Iteration 100 loss: 12712.095456995734                 epoch-loss: 11293.27810727986
epoch 28 Iteration 100 loss: 12662.540317512301                 epoch-loss: 10292.698091923068
epoch 29 Iteration 100 loss: 9789.253683769626                  epoch-loss: 9379.934609293405
epoch 30 Iteration 100 loss: 10336.484081253366                 epoch-loss: 8542.778732654882
epoch 31 Iteration 100 loss: 8427.615871046397                  epoch-loss: 7780.101399774407
epoch 32 Iteration 100 loss: 6243.338653452632                  epoch-loss: 7083.3906599663305
epoch 33 Iteration 100 loss: 5633.910939630758                  epoch-loss: 6442.360608787293
epoch 34 Iteration 100 loss: 6128.494674105952                  epoch-loss: 5856.924952855579
epoch 35 Iteration 100 loss: 5561.132651568278                  epoch-loss: 5319.662670742758
epoch 36 Iteration 100 loss: 5007.633559342303                  epoch-loss: 4827.494923733251
epoch 37 Iteration 100 loss: 4570.798941667555                  epoch-loss: 4375.152951451802
epoch 38 Iteration 100 loss: 3427.8776815125993                 epoch-loss: 3958.746627662967
epoch 39 Iteration 100 loss: 3145.271868648371                  epoch-loss: 3574.2727718396727
epoch 40 Iteration 100 loss: 3252.388844355417                  epoch-loss: 3216.389789008766
epoch 41 Iteration 100 loss: 2682.992323506939                  epoch-loss: 2880.8040627817663
epoch 42 Iteration 100 loss: 2776.54316335849                   epoch-loss: 2563.2893900928902
epoch 43 Iteration 100 loss: 2052.181117489573                  epoch-loss: 2259.124250598867
epoch 44 Iteration 100 loss: 1789.3450917418618                 epoch-loss: 1963.4009524512699
epoch 45 Iteration 100 loss: 1637.0460616480382                 epoch-loss: 1683.301052960261
epoch 46 Iteration 100 loss: 1250.3190196168575                 epoch-loss: 1421.08925599032
epoch 47 Iteration 100 loss: 1056.4280170128945                 epoch-loss: 1181.4882875552755
epoch 48 Iteration 100 loss: 934.1323712121834                  epoch-loss: 972.8920812023131
epoch 49 Iteration 100 loss: 743.6854774208032                  epoch-loss: 794.3919410633861
epoch 50 Iteration 100 loss: 592.0162492873271                  epoch-loss: 643.6129305537779
epoch 1 Iteration 100 loss: -617.7115390031664                  epoch-loss: 122.02590714978953
epoch 2 Iteration 100 loss: -1042.9322804366407                 epoch-loss: -861.8691127743712
epoch 3 Iteration 100 loss: -1246.1061590298375                 epoch-loss: -1142.8551043268158
epoch 4 Iteration 100 loss: -1422.4364206976472                 epoch-loss: -1248.3343954963652
epoch 5 Iteration 100 loss: -1364.319275718058                  epoch-loss: -1319.0632400945233
epoch 6 Iteration 100 loss: -1138.6014678286117                 epoch-loss: -1375.485088640635
epoch 7 Iteration 100 loss: -1468.2449906521865                 epoch-loss: -1415.3387799226973
epoch 8 Iteration 100 loss: -1331.0742440765116                 epoch-loss: -1398.7259993571608
epoch 9 Iteration 100 loss: -1023.1218294411456                 epoch-loss: -1406.2506096944428
epoch 10 Iteration 100 loss: -1491.0721525479291                        epoch-loss: -1425.3786072098467
epoch 11 Iteration 100 loss: -1487.9902441406107                        epoch-loss: -1385.4821177117121
epoch 12 Iteration 100 loss: -963.575720938497                  epoch-loss: -1148.7904243974
epoch 13 Iteration 100 loss: -1496.8723348964538                        epoch-loss: -1248.4710558849933
epoch 14 Iteration 100 loss: -1189.2469453417261                        epoch-loss: -1302.58240646708
epoch 15 Iteration 100 loss: -1354.0129933002445                        epoch-loss: -1422.9290660653176
epoch 16 Iteration 100 loss: -1375.0688655561046                        epoch-loss: -1296.0532055882159
epoch 17 Iteration 100 loss: -1601.7368685439442                        epoch-loss: -1432.8777691683824
epoch 18 Iteration 100 loss: -1140.428056593764                 epoch-loss: -1443.8657069101057
epoch 19 Iteration 100 loss: -1396.6869254783921                        epoch-loss: -1421.0467725977735
epoch 20 Iteration 100 loss: -1313.511818206805                 epoch-loss: -1411.19388568273
epoch 21 Iteration 100 loss: -1508.1672406497062                        epoch-loss: -1427.8889874691674
epoch 22 Iteration 100 loss: -1249.1642813846483                        epoch-loss: -1379.492333117903
epoch 23 Iteration 100 loss: -1214.1394062603918                        epoch-loss: -1356.5797617962307
epoch 24 Iteration 100 loss: -1554.6263005956837                        epoch-loss: -1358.5256191991677
epoch 25 Iteration 100 loss: -1419.5889498936215                        epoch-loss: -1405.5467914984783
epoch 26 Iteration 100 loss: -1262.3682620336267                        epoch-loss: -1409.6484860247688
epoch 27 Iteration 100 loss: -1327.4752015434606                        epoch-loss: -1368.1521038967614
epoch 28 Iteration 100 loss: -1256.4414309051297                        epoch-loss: -1351.3528504368003
epoch 29 Iteration 100 loss: -1178.4788588168844                        epoch-loss: -1413.816013007459
epoch 30 Iteration 100 loss: -1605.1239164704423                        epoch-loss: -1426.6550440932342
epoch 31 Iteration 100 loss: -1617.1795697144926                        epoch-loss: -1356.5267725452202
epoch 32 Iteration 100 loss: -1590.7237287842681                        epoch-loss: -1425.2884165221458
epoch 33 Iteration 100 loss: -1594.3448025229204                        epoch-loss: -1420.5483351285052
epoch 34 Iteration 100 loss: -1576.9397677615486                        epoch-loss: -1430.2946033617723
epoch 35 Iteration 100 loss: -1303.3497394593587                        epoch-loss: -1380.3330104443605
epoch 36 Iteration 100 loss: -1478.0145396344049                        epoch-loss: -1399.0665992260174
epoch 37 Iteration 100 loss: -1555.4119456067176                        epoch-loss: -1360.9939473244767
epoch 38 Iteration 100 loss: -1553.031887368961                 epoch-loss: -1419.1421503464217
epoch 39 Iteration 100 loss: -1427.3431059260865                        epoch-loss: -1415.0248356293594
epoch 40 Iteration 100 loss: -1137.8470272897798                        epoch-loss: -1398.6618957762776
epoch 41 Iteration 100 loss: -1551.999240061582                 epoch-loss: -1402.3061839927834
epoch 42 Iteration 100 loss: -1458.4434735943848                        epoch-loss: -1425.2654433536431
epoch 43 Iteration 100 loss: -1585.6542548185487                        epoch-loss: -1384.815978968837
epoch 44 Iteration 100 loss: -1410.7384899311965                        epoch-loss: -1400.3690408109871
epoch 45 Iteration 100 loss: -1343.7557878846794                        epoch-loss: -1402.4205821010662
epoch 46 Iteration 100 loss: -1309.0681838828461                        epoch-loss: -1412.2783526889364
epoch 47 Iteration 100 loss: -1125.0585501913108                        epoch-loss: -1391.0496208478644
epoch 48 Iteration 100 loss: -1470.087468755146                 epoch-loss: -1390.1175558545679
epoch 49 Iteration 100 loss: -1572.597159674086                 epoch-loss: -1389.4460105298315
epoch 50 Iteration 100 loss: -1113.9894360784558                        epoch-loss: -1403.3841449208112

The learned kernel parameters are as follows:

[44]:
print('The estimated variance of the RBF kernel is %f.' % infr.params[m.kernel.variance].asscalar())
print('The estimated length scale of the RBF kernel is %f.' % infr.params[m.kernel.lengthscale].asscalar())
print('The estimated variance of the Gaussian likelihood is %f.' % infr.params[m.noise_var].asscalar())
The estimated variance of the RBF kernel is 0.220715.
The estimated length scale of the RBF kernel is 0.498507.
The estimated variance of the Gaussian likelihood is 0.003107.
Prediction

The prediction of a SVGP model can be done by creating a TransferInference instance.

[45]:
from mxfusion.inference import TransferInference, ModulePredictionAlgorithm
infr_pred = TransferInference(ModulePredictionAlgorithm(model=m, observed=[m.X], target_variables=[m.Y]),
                              infr_params=infr.params)
m.Y.factor.svgp_predict.jitter = 1e-6

To visualize the fitted model, we make predictions on 100 points evenly spanned from -5 to 5. We estimate the mean and variance of the noise-free output \(F\).

[46]:
xt = np.linspace(-5,5,100)[:, None]
res = infr_pred.run(X=mx.nd.array(xt, dtype='float64'))[0]
f_mean, f_var = res[0].asnumpy()[0], res[1].asnumpy()[0]

The resulting figure is shown as follows:

[47]:
plot(xt, f_mean[:,0], 'b-', label='mean')
plot(xt, f_mean[:,0]-2*np.sqrt(f_var[:, 0]), 'b--', label='2 x std')
plot(xt, f_mean[:,0]+2*np.sqrt(f_var[:, 0]), 'b--')
plot(X, Y, 'rx', label='data points')
ylabel('F')
xlabel('X')
_=legend()
_images/examples_notebooks_svgp_regression_19_0.png

Distributed Training

Getting Started - Distributed Training

# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================
Introduction

The increase of training Bayesian probabilistic models will results in increase in size and data consumption, which could not fit in a single processor. The training time of the model will increase significantly with the size of consumption. Hence, MXFusion implemented Horovod to carry out distributed training on Bayesian probabilistic models, which could significantly decrease consumption from GPUs and training times.

We provide an easy interface to perform distributed training in MXFusion.

[ ]:
import warnings
warnings.filterwarnings('ignore')
Simple Example from Getting Started

We can start with the same toy example from the Getting Started about estimating the mean and variance of a set of data. Again, we generate the same 100 data points with a given mean and variance following a normal distribution.

First of all, initialize Horovod with hvd.init(). We also want to set the global context to GPU or CPU depends where the code is executed.

[ ]:
import horovod.mxnet as hvd
import mxnet as mx
hvd.init()
mx.context.Context.default_ctx = mx.gpu(hvd.local_rank()) if mx.test_utils.list_gpus() else mx.cpu()

The following code below is the same data and model defined from Getting Started.

[ ]:
import numpy as np
np.random.seed(0)
mean_groundtruth = 3.
variance_groundtruth = 5.
N = 100
data = np.random.randn(N)*np.sqrt(variance_groundtruth) + mean_groundtruth
[ ]:
from mxfusion import Variable, Model
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions import Normal
from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'

m = Model()
m.mu = Variable()
m.s = Variable(transformation=PositiveTransformation())
m.Y = Normal.define_variable(mean=m.mu, variance=m.s, shape=(N,))

To allow distributed training instead of single processor training, the inference class used would be DistributedGradBasedInference. Note that currently the code is not running distributed training in Horovod as we are still not running horovodrun command from our system.

[ ]:
from mxfusion.inference import DistributedGradBasedInference, MAP

infr = DistributedGradBasedInference(inference_algorithm=MAP(model=m, observed=[m.Y]))
infr.run(Y=mx.nd.array(data, dtype='float64'), learning_rate=0.1, max_iter=2000, verbose=True)

After optimization, the estimated parameters are stored in an instance of the class InferenceParameters, which can be access from an Inference instance by infr.params.

We collect the estimated mean and variance and compared with the generating parameters.

[ ]:
mean_estimated = infr.params[m.mu].asnumpy()
variance_estimated = infr.params[m.s].asnumpy()

print('The estimated mean and variance: %f, %f.' % (mean_estimated, variance_estimated))
print('The true mean and variance: %f, %f.' % (mean_groundtruth, variance_groundtruth))
Distributed Traning on Bayesian model

From the above example, we have done a maximum likelihood estimate from the observed data with distributed training. As our distributed training supports Bayesian model, now we can follow the second example of Getting Started, which uses Bayesian inference to estimate how much our estimated parameters differs from the true parameters.

[ ]:
m = Model()
m.mu = Normal.define_variable(mean=mx.nd.array([0], dtype='float64'),
                              variance=mx.nd.array([100], dtype='float64'), shape=(1,))
[ ]:
from mxfusion.components.functions import MXFusionGluonFunction

m.s_hat = Normal.define_variable(mean=mx.nd.array([5], dtype='float64'),
                                 variance=mx.nd.array([100], dtype='float64'),
                                 shape=(1,), dtype='float64')
trans_mxnet = mx.gluon.nn.HybridLambda(lambda F, x: F.Activation(x, act_type='softrelu'))
m.trans = MXFusionGluonFunction(trans_mxnet, num_outputs=1, broadcastable=True)
m.s = m.trans(m.s_hat)
m.Y = Normal.define_variable(mean=m.mu, variance=m.s, shape=(N,), dtype='float64')
[ ]:
from mxfusion.inference import create_Gaussian_meanfield

q = create_Gaussian_meanfield(model=m, observed=[m.Y])

To allow distributed training instead of single processor training, the inference class used would be DistributedGradBasedInference. The default grad_loop of DistributedGradBasedInference is DistributedBatchInferenceLoop, as opposed to GradBasedInference, which is BatchInferenceLoop.

Note that currently the code is not running distributed training in Horovod as we are still not running horovodrun or mpirun command from our system.

[ ]:
from mxfusion.inference import StochasticVariationalInference

infr = DistributedGradBasedInference(inference_algorithm=StochasticVariationalInference(
    model=m, posterior=q, num_samples=10, observed=[m.Y]))
infr.run(Y=mx.nd.array(data, dtype='float64'), learning_rate=0.1, verbose=True)

Let’s check the resulting posterior distribution.

[ ]:
mu_mean = infr.params[q.mu.factor.mean].asscalar()
mu_std = np.sqrt(infr.params[q.mu.factor.variance].asscalar())
s_hat_mean = infr.params[q.s_hat.factor.mean].asscalar()
s_hat_std = np.sqrt(infr.params[q.s_hat.factor.variance].asscalar())
s_15 = np.log1p(np.exp(s_hat_mean - s_hat_std))
s_50 = np.log1p(np.exp(s_hat_mean))
s_85 = np.log1p(np.exp(s_hat_mean + s_hat_std))
print('The mean and standard deviation of the mean parameter is %f(%f). ' % (mu_mean, mu_std))
print('The 15th, 50th and 85th percentile of the variance parameter is %f, %f and %f.'%(s_15, s_50, s_85))
Running Horovod

Currently, the only way to execute Horovod in MXFusion is via horovodrun or mpirun command from the system. Hence, we can first convert this notebook into Python file then execute the Python file with command line.

[ ]:
!jupyter nbconvert --to script getting_started-distributed.ipynb

To run it on Horovod and allow distributed training, we should run horovodrun or mpirun from our system while specifying the number of processors. More details about running Horovod can be found here. A simple way to run it is with the format: horovodrun -np {number of processors} -H localhost:4 python {python file}

NOTE : Please restart this notebook before executing the code below.

[ ]:
!mpirun -np 4 -H localhost:4 python getting_started-distributed.py

Bayesian Neural Network (VI) for classification - Distributed Training

# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================

The following example follows the same example from Bayesian Neural Network (VI) for classification, with implementation of Horovod’s distributed training.

[ ]:
import warnings
warnings.filterwarnings('ignore')
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference

First of all, initialize Horovod with hvd.init(). We also want to set the global context to GPU or CPU depends where the code is executed.

[ ]:
import horovod.mxnet as hvd
import mxnet as mx
hvd.init()
mx.context.Context.default_ctx = mx.gpu(hvd.local_rank()) if mx.test_utils.list_gpus() else mx.cpu()
Generate Synthetic Data
[ ]:
import GPy
from pylab import *
import matplotlib.pyplot as plt

np.random.seed(4)
k = GPy.kern.RBF(1, lengthscale=0.1)
x = np.random.rand(200,1)
y = np.random.multivariate_normal(mean=np.zeros((200,)), cov=k.K(x), size=(1,)).T>0.
plt.plot(x[:,0], y[:,0], '.')
[ ]:
D = 10
net = nn.HybridSequential(prefix='nn_')
with net.name_scope():
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=1))
    net.add(nn.Dense(D, activation="tanh", flatten=False, in_units=D))
    net.add(nn.Dense(2, flatten=False, in_units=D))
net.initialize(mx.init.Xavier(magnitude=1))
[ ]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion.inference import VariationalPosteriorForwardSampling
from mxfusion.components.functions.operators import broadcast_to
from mxfusion.components.distributions import Normal, Categorical
from mxfusion import Variable, Model
from mxfusion.components.functions import MXFusionGluonFunction
[ ]:
m = Model()
m.N = Variable()
m.f = MXFusionGluonFunction(net, num_outputs=1, broadcastable=False)
m.x = Variable(shape=(m.N,1))
m.r = m.f(m.x)
for _,v in m.r.factor.parameters.items():
    v.set_prior(Normal(mean=broadcast_to(mx.nd.array([0]), v.shape),
                       variance=broadcast_to(mx.nd.array([1.]), v.shape)))
m.y = Categorical.define_variable(log_prob=m.r, shape=(m.N,1), num_classes=2)
[ ]:
from mxfusion.inference import DistributedBatchInferenceLoop, create_Gaussian_meanfield, DistributedGradBasedInference, StochasticVariationalInference, MAP

To allow distributed training instead of single processor training, the inference class used would be DistributedGradBasedInference. The default grad_loop of DistributedGradBasedInference is DistributedBatchInferenceLoop, as opposed to GradBasedInference, which is BatchInferenceLoop.

Note that currently the code is not running distributed training in Horovod as we are still not running horovodrun or mpirun command from our system.

[ ]:
observed = [m.y, m.x]
q = create_Gaussian_meanfield(model=m, observed=observed)
alg = StochasticVariationalInference(num_samples=5, model=m, posterior=q, observed=observed)
infr = DistributedGradBasedInference(inference_algorithm=alg, grad_loop=DistributedBatchInferenceLoop())
[ ]:
infr.initialize(y=mx.nd.array(y), x=mx.nd.array(x))
[ ]:
for v_name, v in m.r.factor.parameters.items():
    infr.params[q[v].factor.mean] = net.collect_params()[v_name].data()
    infr.params[q[v].factor.variance] = mx.nd.ones_like(infr.params[q[v].factor.variance])*1e-6
[ ]:
infr.run(max_iter=500, learning_rate=1e-1, y=mx.nd.array(y), x=mx.nd.array(x), verbose=True)
[ ]:
# for uuid, v in infr.inference_algorithm.posterior.variables.items():
#     if uuid in infr.params.param_dict:
#         print(v.name, infr.params[v])
[ ]:
xt = np.linspace(0,1,100)[:,None]
[ ]:
infr2 = VariationalPosteriorForwardSampling(10, [m.x], infr, [m.r])
res = infr2.run(x=mx.nd.array(xt))
[ ]:
yt = res[0].asnumpy()
[ ]:
yt_mean = yt.mean(0)
yt_std = yt.std(0)
for i in range(yt.shape[0]):
    plt.plot(xt[:,0],1./(1+np.exp(yt[i,:,0]-yt[i,:,1])),'k',alpha=0.2)
plt.plot(x[:,0],y[:,0],'.')
plt.show()
Running Horovod

Currently, the only way to execute Horovod in MXFusion is via horovodrun or mpirun command from the system. Hence, we can first convert this notebook into Python file then execute the Python file with command line.

[ ]:
!jupyter nbconvert --to script bnn_classification-distributed.ipynb

To run it on Horovod and allow distributed training, we should run horovodrun or mpirun from our system while specifying the number of processors. More details about running Horovod can be found here. A simple way to run it is with the format: horovodrun -np {number of processors} -H localhost:4 python {python file}

NOTE : Please restart this notebook before executing the code below.

[ ]:
!mpirun -np 4 -H localhost:4 python bnn_classification-distributed.py

Bayesian Neural Network (VI) for regression - Distributed Training

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================

The following example follows the same example from Bayesian Neural Network (VI) for regression, with implementation of Horovod’s distributed training.

[ ]:
import warnings
warnings.filterwarnings('ignore')
import mxfusion as mf
import mxnet as mx
import numpy as np
import mxnet.gluon.nn as nn
import mxfusion.components
import mxfusion.inference

First of all, initialize Horovod with hvd.init(). We also want to set the global context to GPU or CPU depends where the code is executed.

[ ]:
import horovod.mxnet as hvd
import mxnet as mx
hvd.init()
mx.context.Context.default_ctx = mx.gpu(hvd.local_rank()) if mx.test_utils.list_gpus() else mx.cpu()
Generate Synthetic Data
[ ]:
import GPy
from pylab import *
import matplotlib.pyplot as plt

np.random.seed(0)
k = GPy.kern.RBF(1, lengthscale=0.1)
x = np.random.rand(1000,1)
y = np.random.multivariate_normal(mean=np.zeros((1000,)), cov=k.K(x), size=(1,)).T
plt.plot(x[:,0], y[:,0], '.')
Model definition
[ ]:
D = 50
net = nn.HybridSequential(prefix='nn_')
with net.name_scope():
    net.add(nn.Dense(D, activation="tanh", in_units=1))
    net.add(nn.Dense(D, activation="tanh", in_units=D))
    net.add(nn.Dense(1, flatten=True, in_units=D))
net.initialize(mx.init.Xavier(magnitude=3))
[ ]:
from mxfusion.components.variables.var_trans import PositiveTransformation
from mxfusion.inference import VariationalPosteriorForwardSampling
from mxfusion.components.functions.operators import broadcast_to
from mxfusion.components.distributions import Normal
from mxfusion import Variable, Model
from mxfusion.components.functions import MXFusionGluonFunction
[ ]:
m = Model()
m.N = Variable()
m.f = MXFusionGluonFunction(net, num_outputs=1,broadcastable=False)
m.x = Variable(shape=(m.N,1))
m.v = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=mx.nd.array([0.01]))
m.r = m.f(m.x)
for v in m.r.factor.parameters.values():
    v.set_prior(Normal(mean=broadcast_to(mx.nd.array([0]), v.shape),
                       variance=broadcast_to(mx.nd.array([1.]), v.shape)))
m.y = Normal.define_variable(mean=m.r, variance=broadcast_to(m.v, (m.N,1)), shape=(m.N,1))
Inference with Meanfield
[ ]:
from mxfusion.inference import DistributedBatchInferenceLoop, create_Gaussian_meanfield, DistributedGradBasedInference, StochasticVariationalInference

To allow distributed training instead of single processor training, the inference class used would be DistributedGradBasedInference. The default grad_loop of DistributedGradBasedInference is DistributedBatchInferenceLoop, as opposed to GradBasedInference, which is BatchInferenceLoop.

Note that currently the code is not running distributed training in Horovod as we are still not running horovodrun or mpirun command from our system.

[ ]:
observed = [m.y, m.x]
q = create_Gaussian_meanfield(model=m, observed=observed)
alg = StochasticVariationalInference(num_samples=3, model=m, posterior=q, observed=observed)
infr = DistributedGradBasedInference(inference_algorithm=alg, grad_loop=DistributedBatchInferenceLoop())

We also need to specify the correct shape of data when initializing the inference. In this case if we are using 4 processors, and the shape of the data is (1000,1), we have to divide by 4. The line below will produce error in the notebook since it is still not running in Horovod.

[ ]:
infr.initialize(y=(250,1), x=(250,1))
[ ]:
for v_name, v in m.r.factor.parameters.items():
    infr.params[q[v].factor.mean] = net.collect_params()[v_name].data()
    infr.params[q[v].factor.variance] = mx.nd.ones_like(infr.params[q[v].factor.variance])*1e-6
[ ]:
infr.run(max_iter=2000, learning_rate=1e-2, y=mx.nd.array(y), x=mx.nd.array(x), verbose=True)
Use prediction to visualize the resulting BNN
[ ]:
xt = np.linspace(0,1,100)[:,None]
[ ]:
infr2 = VariationalPosteriorForwardSampling(10, [m.x], infr, [m.r])
res = infr2.run(x=mx.nd.array(xt))
[ ]:
yt = res[0].asnumpy()
[ ]:
yt_mean = yt.mean(0)
yt_std = yt.std(0)

for i in range(yt.shape[0]):
    plt.plot(xt[:,0],yt[i,:,0],'k',alpha=0.2)
plt.plot(x[:,0],y[:,0],'.')
plt.show()
Running Horovod

Currently, the only way to execute Horovod in MXFusion is via horovodrun or mpirun command from the system. Hence, we can first convert this notebook into Python file then execute the Python file with command line.

[ ]:
!jupyter nbconvert --to script bnn_regression-distributed.ipynb

To run it on Horovod and allow distributed training, we should run horovodrun or mpirun from our system while specifying the number of processors. More details about running Horovod can be found here. A simple way to run it is with the format: horovodrun -np {number of processors} -H localhost:4 python {python file}

NOTE : Please restart this notebook before executing the code below.

[ ]:
!mpirun -np 4 -H localhost:4 python bnn_regression-distributed.py

F.A.Q / Troubleshooting - Distributed Training

# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================

The following lists are the frequent problems and troubleshoot in regarding to running distributed training with Horovod and executing MXFusion’s code in GPU.

ValueError while executing horovodrun
Problem

After recently installed Horovod in the machine, the following error may occur when executing the code with horovodrun on terminal:

ValueError: Neither MPI nor Gloo support has been built. Try reinstalling Horovod ensuring that either MPI is installed (MPI) or CMake is installed (Gloo).

Steps to Reproduce

After installing Horovod with pip install horovod==0.16.4, execute a MXFusion distributed training script with horovodrun -np {number_of_processors} -H localhost:4 python {python_script}

Solution

Use mpirun instead of horovodrun. For example on terminal, type :

mpirun -np {number_of_processors} -H localhost:4 python {python_script}

Warning of CMA Support Not Available
Problem

When first executing MXFusion with Horovod every time Ubuntu boots, the ptrace protection from Ubuntu blocks CMA support from being enabled, which then does not allow shared memory between processors. A warning will be shown in the terminal :

Linux kernel CMA support was requested via the btl_vader_single_copy_mechanism MCA variable, but CMA support is not available due to restrictive ptrace settings.
Steps to Reproduce

After Ubuntu boots, execute a MXFusion distributed training script with mpirun -np {number_of_processors} -H localhost:4 python {python_script}

Solution

Temporarily disable ptrace protection by typing the line below on the terminal. Note that you may need to reenable it back with echo 1 after stopped using Horovod for security measures. Also note that ptrace_scope will be resetted to 1 every time Ubuntu boots. To disable ptrace protection, on terminal type :

echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope

Segmentation fault : 11 with MXNet-cu100
Problem

When executing MXFusion on GPU, error of Segmentation fault : 11 will be thrown if MXNet-cu100 is installed.

Steps to Reproduce

Install MXNet-cu100 with pip install mxnet-cu100 on Deep Learning AMI (Ubuntu) Version 24.1 (ami-06f483a626f873983). Run a MXFusion distributed training script with mpirun -np {number_of_processors} -H localhost:4 python {python_script}.

Solution

Uninstall MXNet-cu100 with and install MXNet-cu100mkl. On terminal, type :

pip uninstall mxnet-cu100 pip install mxnet-cu100mkl
Segmentation fault : 11 with latest version of Horovod
Problem

MXFusion currently does not support Horovod version 18 and above. With latest version of Horovod, when running MXFusion distributed training on CPU, the loss function and output will be inaccurate and inconsistent between processors. When running MXFusion distributed training on GPU, Segmentation fault : 11 error will be thrown.

Steps to Reproduce

Install Horovod with pip install horovod. Run a distributed training script with mpirun -np {number_of_processors} -H localhost:4 python {python_script}.

Solution

Currently MXFusion supports Horovod below version 18. Install the latest version of MXFusion before version 18 with :

pip install horovod==0.16.4

Error with dtype=’float64’ on GPU
Problem

When setting float64 as the data type and run the script on GPU, this error may occur :

mxnet.base.MXNetError: src/ndarray/ndarray_function.cu:58: Check failed: to->type_flag_ == from.type_flag_ (1 vs. 0) : Source and target must have the same data type when copying across devices.

Steps to Reproduce

In a GPU, change the value of config.DEFAULT_DTYPE and dtype of NDArray to ‘float64’ in distributed_bnn_test.py. Run the test. The error will occur in test_BNN_regression and test_BNN_regression_minibatch. In the terminal, from MXFusion source root folder, type :

cd testing/inference mpirun -np 4 -H localhost:4 pytest -s distributed_bnn_test.py
Solution

Set float32 as the data type. GPU also supports float32 at better speed than float64.

Developer Tutorials

Writing a new Distribution

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#   Licensed under the Apache License, Version 2.0 (the "License").
#   You may not use this file except in compliance with the License.
#   A copy of the License is located at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   or in the "license" file accompanying this file. This file is distributed
#   on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#   express or implied. See the License for the specific language governing
#   permissions and limitations under the License.
# ==============================================================================

To write and and use a new Distribution class in MXFusion, fill out the Distribution interface and either the Univariate or Multivariate interface, depending on the type of distribution you are creating.

There are 4 primary methods to fill out for a Distribution in MXFusion: * __init__ - This is the constructor for the Distribution. It takes in any parameters the distribution needs. It also defines names for the input variable[s] that the distribution takes and the output variable[s] it produces. * log_pdf - This method returns the logarithm of probabilistic density function for the distribution. This is called during Inference time as necessary to perform the Inference algorithm. * draw_samples - This method returns drawn samples from the distribution. This is called during Inference time as necessary to perform the Inference algorithm. * define_variable - This is used to generate random variables drawn from the Distribution used during model definition.

log_pdf and draw_samples are implemented using MXNet functions to compute on the input variables, which at Inference time are MXNet arrays or MXNet symbolic variables.

This notebook will take the Normal distribution as a reference.

File Structure

Code for distributions lives in the mxfusion/components/distributions directory.

If you’re implementing the FancyNew distribution then you should create a file called mxfusion/components/distributions/fancy_new.py for the class to live in.

Interface Implementation

Since this example is for a Univariate Normal distribution, our class extends the UnivatiateDistribution class.

The Normal distribution’s constructor takes in objects for its mean and variance, specifications for data type and context, and a random number generator if not the default.

In addition, a distribution can take in additional parameters used for calculating that aren’t inputs. We refer to these additional parameters as the Distribution’s attributes. The difference between an input and an attribute is primarily that inputs are dynamic at inference time, while attributes are static throughout a given inference run.

In this case, minibatch_ratio is a static attribute, as it doesn’t change for a given minibatch size during inference.

The mean and variance can be either Variables or MXNet arrays if they are constants.

As mentioned above, you define names for the input and output variable[s] for the distribution here. These names are used when printing and generally inspecting the model, so give meaningful names. We prefer names like mean and variance to ones like location and scale or greek letters like mew and sigma.

[ ]:
class Normal(UnivariateDistribution):
    """
    The one-dimensional normal distribution. The normal distribution can be defined over a scalar random variable or an array of random variables. In case of an array of random variables, the mean and variance are broadcasted to the shape of the output random variable (array).

    :param mean: Mean of the normal distribution.
    :type mean: Variable
    :param variance: Variance of the normal distribution.
    :type variance: Variable
    :param rand_gen: the random generator (default: MXNetRandomGenerator)
    :type rand_gen: RandomGenerator
    :param dtype: the data type for float point numbers
    :type dtype: numpy.float32 or numpy.float64
    :param ctx: the mxnet context (default: None/current context)
    :type ctx: None or mxnet.cpu or mxnet.gpu
    """
    def __init__(self, mean, variance, rand_gen=None, minibatch_ratio=1.,
                 dtype=None, ctx=None):
        self.minibatch_ratio = minibatch_ratio
        if not isinstance(mean, Variable):
            mean = Variable(value=mean)
        if not isinstance(variance, Variable):
            variance = Variable(value=variance)

        inputs = [('mean', mean), ('variance', variance)]
        input_names = ['mean', 'variance']
        output_names = ['random_variable']
        super(Normal, self).__init__(inputs=inputs, outputs=None,
                                     input_names=input_names,
                                     output_names=output_names,
                                     rand_gen=rand_gen, dtype=dtype, ctx=ctx)

If your distribution’s __init__ function only takes in parameters that get passed onto its super constructor, you don’t need to implement replicate_self. If it does take additional parameters (as the Normal distribution does for minibatch_ratio), those parameters need to be copied over to the replicant Distribution before returning, as below.

[ ]:
    def replicate_self(self, attribute_map=None):
        """
        Replicates this Factor, using new inputs, outputs, and a new uuid.
        Used during model replication to functionally replicate a factor into a new graph.
        :param inputs: new input variables of the factor
        :type inputs: a dict of {'name' : Variable} or None
        :param outputs: new output variables of the factor.
        :type outputs: a dict of {'name' : Variable} or None
        """
        replicant = super(Normal, self).replicate_self(attribute_map)
        replicant.minibatch_ratio = self.minibatch_ratio
        return replicant

log_pdf and draw_samples are relatively straightforward for implementation. These are the meaningful parts of the Distribution that you’re implementing, putting the math into code using MXNet operators for the compute.

If it’s a distribution that isn’t well documented on Wikipedia, please add a link to a paper or other resource that explains what it’s doing and why.

[ ]:
    def log_pdf(self, mean, variance, random_variable, F=None):
        """
        Computes the logarithm of the probability density function (PDF) of the normal distribution.

        :param mean: the mean of the normal distribution
        :type mean: MXNet NDArray or MXNet Symbol
        :param variance: the variance of the normal distributions
        :type variance: MXNet NDArray or MXNet Symbol
        :param random_variable: the random variable of the normal distribution
        :type random_variable: MXNet NDArray or MXNet Symbol
        :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray)
        :returns: log pdf of the distribution
        :rtypes: MXNet NDArray or MXNet Symbol
        """
        F = get_default_MXNet_mode() if F is None else F
        logvar = np.log(2 * np.pi) / -2 + F.log(variance) / -2
        logL = F.broadcast_add(logvar, F.broadcast_div(F.square(
            F.broadcast_minus(random_variable, mean)), -2 * variance))
        return logL
[ ]:
    def draw_samples(self, mean, variance, rv_shape, num_samples=1, F=None):
        """
        Draw samples from the normal distribution.

        :param mean: the mean of the normal distribution
        :type mean: MXNet NDArray or MXNet Symbol
        :param variance: the variance of the normal distributions
        :type variance: MXNet NDArray or MXNet Symbol
        :param rv_shape: the shape of each sample
        :type rv_shape: tuple
        :param num_samples: the number of drawn samples (default: one)
        :int num_samples: int
        :param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray)
        :returns: a set samples of the normal distribution
        :rtypes: MXNet NDArray or MXNet Symbol
        """
        F = get_default_MXNet_mode() if F is None else F
        out_shape = (num_samples,) + rv_shape
        return F.broadcast_add(F.broadcast_mul(self._rand_gen.sample_normal(
            shape=out_shape, dtype=self.dtype, ctx=self.ctx),
            F.sqrt(variance)), mean)

define_variable is just a helper function for end users. All it does is take in parameters for the distribution, create a distribution based on those parameters, then return the output variables of that distribution.

[ ]:
    @staticmethod
    def define_variable(mean=0., variance=1., shape=None, rand_gen=None,
                        minibatch_ratio=1., dtype=None, ctx=None):
        """
        Creates and returns a random variable drawn from a normal distribution.

        :param mean: Mean of the distribution.
        :param variance: Variance of the distribution.
        :param shape: the shape of the random variable(s)
        :type shape: tuple or [tuple]
        :param rand_gen: the random generator (default: MXNetRandomGenerator)
        :type rand_gen: RandomGenerator
        :param dtype: the data type for float point numbers
        :type dtype: numpy.float32 or numpy.float64
        :param ctx: the mxnet context (default: None/current context)
        :type ctx: None or mxnet.cpu or mxnet.gpu
        :returns: the random variables drawn from the normal distribution.
        :rtypes: Variable
        """
        normal = Normal(mean=mean, variance=variance, rand_gen=rand_gen,
                        dtype=dtype, ctx=ctx)
        normal._generate_outputs(shape=shape)
        return normal.random_variable
Using your new distribution

At this point, you should be ready to start testing your new Distribution’s functionality by importing it like any other MXFusion component.

Testing

Before submitting your new code as a pull request, please write unit tests that verify it works as expected. This should include numerical checks against edge cases. See the existing test cases for the Normal or Categorical distributions for example tests.

[ ]:

FAQ

FAQ for MXFusion APIs

Zhenwen Dai (2019-05-30)

1. How to access a variable or a factor in my model?

There are a few ways to access a variable or a factor in a model:

  1. If a variable is named such as m.x = Variable(), we use x as the name of the variable and this variable can be accessed later by calling m.x.
  2. A factor can also be named in the way as a variable, e.g., m.f = MXFusionGluonFunction(func, 1), in which we name the wrapper of a MXNet function func as f. This function can be accessed by calling m.f.
  3. If a variable is the random variable following a distribution or the output of a function, e.g., m.x = Normal.define_variable(mx.nd.array([0]), mx.nd.array([1]), shape=(1,)), the distribution or the function can be accessed by calling m.x.factor.
3. How to access the parameters after inference?

The inference in MXFusion is done by creating an Inference object, which takes an inference algorithm as the input argument. After the execution of the inference algorithm, all estimated parameters are stored in a InferenceParameters object. If we have an Inference instance infr, the InferenceParameters can be access by infr.params. The individual parameters in the model and posterior can be obtained by passing in the reference of the corresponding variables, e.g., infr.params[m.x] returns the estimated value of the parameter x in the model m.

4. How to serialize the inference results?

Serialization can be conveniently in MXFusion by simply calling the save method of a Inference instance, which takes a filename as the input argument. An example is shown below:

m = Model()
...
infr = ...
infr.save('inference_file.zip')

To load back the inference result of a model, one need recreate the model and posterior instance and the corresponding inference instance with exactly the same configurations. Then, the estimated parameters can be loaded by calling the load method of the Inference instance. See the example below:

m = Model()
...
infr = ...
infr.load('inference_file.zip')
5. How to run the computation in single/double (float32/float64) precision?

When creating random variables from probabilistic distributions and the Inference instance, the argument dtype specifies the precision of the corresponding objects. At the moment, we only support the single and double precision by taking the value “float32” or “float64”.

Alternatively, the computation precision can be set globally by changing the default precision type:

from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'
6. How to run the computation on GPU?

When creating random variables from probabilistic distributions and the Inference instance, the argument ctx or context specifies the device in which the variables are expected to be stored. One can pass in the MXNet device reference such as mxnet.gpu() to switch the computation to be run on GPU.

Alternatively, the computational device can also be set globally by changing the default device of MXNet:

import mxnet as mx
mx.context.Context.device_ctx = mx.gpu()
7. How to view TensorBoard logs?

To use TensorBoard to inspect inference logs you must have TensorBoard and MXBoard installed. Instructions for installing these packages can be found here.

To produce the logs required for TensorBoard, pass a Logger with a log_dir (and an optional log_name) to your inference object instantiation.

infr = Inference(logger=Logger(log_dir='logs'))

To run the TensorBoard server to view the results, run the following command (for more details see here):

$ tensorboard --logdir=path/to/log-directory

Now you can open the server in a browser and view the logs.

[ ]:

Design Overview

Topical Guides

Working in MXFusion breaks up into two primary phases. Model definition involves defining the variables, distributions, and functions that make up your model. Inference then takes in real values and learns parameters for your model or gives predictions over the data.

Design Choices

Indices and tables