Stochastic Variational Gaussian Process Regression

Zhenwen Dai (2019-05-29)

Introduction

Gaussian process (GP) is computationally expensive. A popular approach to scale up GP regression on large data is to use stochastic variational inference with mini-batch training (Hensman et al., 2013). SVGP regression with Gaussian noise has been implemented as a module in MXFusion.

[32]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'

Toy data

We generate some synthetic data for our regression example. The data set is generate from a sine function with some additive Gaussian noise.

[2]:
import numpy as np
%matplotlib inline
from pylab import *

np.random.seed(0)
X = np.random.uniform(-3.,3.,(1000,1))
Y = np.sin(X) + np.random.randn(1000,1)*0.05

The generated data are visualized as follows:

[3]:
plot(X, Y, 'rx', label='data points')
_=legend()
../../_images/examples_notebooks_svgp_regression_6_0.png
[4]:
from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'

The SVGP regression model is created as follow. Two SVGP specific parameters are num_inducing which specifies the number of inducing points used in the variational sparse GP approximation and svgp_log_pdf.jitter which the jitter term in the log pdf calculation for numerical robustness.

[33]:
from mxfusion import Model, Variable
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions.gp.kernels import RBF
from mxfusion.modules.gp_modules import SVGPRegression

m = Model()
m.N = Variable()
m.X = Variable(shape=(m.N, 1))
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=0.01)
m.kernel = RBF(input_dim=1, variance=1, lengthscale=1)
m.Y = SVGPRegression.define_variable(X=m.X, kernel=m.kernel, noise_var=m.noise_var, shape=(m.N, 1), num_inducing=20)
m.Y.factor.svgp_log_pdf.jitter = 1e-6

Inference is done by creating the inference instance from the GradBasedInference class, in which we use a MAP inference algorithm as there are no latent variables outside the SVGPRegression module. Additional, we specify grad_loop to be MiniBatchInferenceLoop in which we set the size of mini-batch and the scaling factor for minibatch training.

Then, training is triggered by calling the run method.

[38]:
import mxnet as mx
from mxfusion.inference import GradBasedInference, MAP, MinibatchInferenceLoop

infr = GradBasedInference(inference_algorithm=MAP(model=m, observed=[m.X, m.Y]),
                          grad_loop=MinibatchInferenceLoop(batch_size=10, rv_scaling={m.Y: 1000/10}))
infr.initialize(X=(1000,1), Y=(1000,1))
infr.params[m.Y.factor.inducing_inputs] = mx.nd.array(np.random.randn(20, 1), dtype='float64')
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'),
         max_iter=50, learning_rate=0.1, verbose=True)
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'),
         max_iter=50, learning_rate=0.01, verbose=True)
epoch 1 Iteration 100 loss: 933115.0603707978                   epoch-loss: 10413624.614005275
epoch 2 Iteration 100 loss: 524948.7079326594                   epoch-loss: 686034.5295730559
epoch 3 Iteration 100 loss: 345602.4022749258                   epoch-loss: 427065.8343717841
epoch 4 Iteration 100 loss: 277011.3760208657                   epoch-loss: 297071.493696023
epoch 5 Iteration 100 loss: 183347.13021907964                  epoch-loss: 219808.0871498559
epoch 6 Iteration 100 loss: 143763.11007552472                  epoch-loss: 169486.20729875282
epoch 7 Iteration 100 loss: 132031.47695326462                  epoch-loss: 134765.1471133905
epoch 8 Iteration 100 loss: 95632.60561449913                   epoch-loss: 109798.66321648406
epoch 9 Iteration 100 loss: 73957.6220462552                    epoch-loss: 91257.8705670977
epoch 10 Iteration 100 loss: 64840.07207031624                  epoch-loss: 77084.06942481917
epoch 11 Iteration 100 loss: 60780.27278575914                  epoch-loss: 65962.38163622493
epoch 12 Iteration 100 loss: 48546.66342698521                  epoch-loss: 57037.39009905885
epoch 13 Iteration 100 loss: 42676.907263579335                 epoch-loss: 49725.50869601666
epoch 14 Iteration 100 loss: 43266.74759690139                  epoch-loss: 43635.70855486856
epoch 15 Iteration 100 loss: 33139.32033870425                  epoch-loss: 38501.415430223606
epoch 16 Iteration 100 loss: 35129.68003531527                  epoch-loss: 34139.30892930683
epoch 17 Iteration 100 loss: 33309.08869286892                  epoch-loss: 30414.713307491817
epoch 18 Iteration 100 loss: 31058.180286752693                 epoch-loss: 27222.957705478882
epoch 19 Iteration 100 loss: 22781.668494776342                 epoch-loss: 24466.753696665117
epoch 20 Iteration 100 loss: 16921.53875526696                  epoch-loss: 22063.866203795988
epoch 21 Iteration 100 loss: 16866.27172281184                  epoch-loss: 19959.435781693166
epoch 22 Iteration 100 loss: 18001.39866328793                  epoch-loss: 18093.70564938978
epoch 23 Iteration 100 loss: 19268.435700542395                 epoch-loss: 16435.61461383947
epoch 24 Iteration 100 loss: 13586.70681551015                  epoch-loss: 14947.197437326102
epoch 25 Iteration 100 loss: 11842.634017398044                 epoch-loss: 13605.954880888436
epoch 26 Iteration 100 loss: 12304.581180033452                 epoch-loss: 12393.880316263208
epoch 27 Iteration 100 loss: 12712.095456995734                 epoch-loss: 11293.27810727986
epoch 28 Iteration 100 loss: 12662.540317512301                 epoch-loss: 10292.698091923068
epoch 29 Iteration 100 loss: 9789.253683769626                  epoch-loss: 9379.934609293405
epoch 30 Iteration 100 loss: 10336.484081253366                 epoch-loss: 8542.778732654882
epoch 31 Iteration 100 loss: 8427.615871046397                  epoch-loss: 7780.101399774407
epoch 32 Iteration 100 loss: 6243.338653452632                  epoch-loss: 7083.3906599663305
epoch 33 Iteration 100 loss: 5633.910939630758                  epoch-loss: 6442.360608787293
epoch 34 Iteration 100 loss: 6128.494674105952                  epoch-loss: 5856.924952855579
epoch 35 Iteration 100 loss: 5561.132651568278                  epoch-loss: 5319.662670742758
epoch 36 Iteration 100 loss: 5007.633559342303                  epoch-loss: 4827.494923733251
epoch 37 Iteration 100 loss: 4570.798941667555                  epoch-loss: 4375.152951451802
epoch 38 Iteration 100 loss: 3427.8776815125993                 epoch-loss: 3958.746627662967
epoch 39 Iteration 100 loss: 3145.271868648371                  epoch-loss: 3574.2727718396727
epoch 40 Iteration 100 loss: 3252.388844355417                  epoch-loss: 3216.389789008766
epoch 41 Iteration 100 loss: 2682.992323506939                  epoch-loss: 2880.8040627817663
epoch 42 Iteration 100 loss: 2776.54316335849                   epoch-loss: 2563.2893900928902
epoch 43 Iteration 100 loss: 2052.181117489573                  epoch-loss: 2259.124250598867
epoch 44 Iteration 100 loss: 1789.3450917418618                 epoch-loss: 1963.4009524512699
epoch 45 Iteration 100 loss: 1637.0460616480382                 epoch-loss: 1683.301052960261
epoch 46 Iteration 100 loss: 1250.3190196168575                 epoch-loss: 1421.08925599032
epoch 47 Iteration 100 loss: 1056.4280170128945                 epoch-loss: 1181.4882875552755
epoch 48 Iteration 100 loss: 934.1323712121834                  epoch-loss: 972.8920812023131
epoch 49 Iteration 100 loss: 743.6854774208032                  epoch-loss: 794.3919410633861
epoch 50 Iteration 100 loss: 592.0162492873271                  epoch-loss: 643.6129305537779
epoch 1 Iteration 100 loss: -617.7115390031664                  epoch-loss: 122.02590714978953
epoch 2 Iteration 100 loss: -1042.9322804366407                 epoch-loss: -861.8691127743712
epoch 3 Iteration 100 loss: -1246.1061590298375                 epoch-loss: -1142.8551043268158
epoch 4 Iteration 100 loss: -1422.4364206976472                 epoch-loss: -1248.3343954963652
epoch 5 Iteration 100 loss: -1364.319275718058                  epoch-loss: -1319.0632400945233
epoch 6 Iteration 100 loss: -1138.6014678286117                 epoch-loss: -1375.485088640635
epoch 7 Iteration 100 loss: -1468.2449906521865                 epoch-loss: -1415.3387799226973
epoch 8 Iteration 100 loss: -1331.0742440765116                 epoch-loss: -1398.7259993571608
epoch 9 Iteration 100 loss: -1023.1218294411456                 epoch-loss: -1406.2506096944428
epoch 10 Iteration 100 loss: -1491.0721525479291                        epoch-loss: -1425.3786072098467
epoch 11 Iteration 100 loss: -1487.9902441406107                        epoch-loss: -1385.4821177117121
epoch 12 Iteration 100 loss: -963.575720938497                  epoch-loss: -1148.7904243974
epoch 13 Iteration 100 loss: -1496.8723348964538                        epoch-loss: -1248.4710558849933
epoch 14 Iteration 100 loss: -1189.2469453417261                        epoch-loss: -1302.58240646708
epoch 15 Iteration 100 loss: -1354.0129933002445                        epoch-loss: -1422.9290660653176
epoch 16 Iteration 100 loss: -1375.0688655561046                        epoch-loss: -1296.0532055882159
epoch 17 Iteration 100 loss: -1601.7368685439442                        epoch-loss: -1432.8777691683824
epoch 18 Iteration 100 loss: -1140.428056593764                 epoch-loss: -1443.8657069101057
epoch 19 Iteration 100 loss: -1396.6869254783921                        epoch-loss: -1421.0467725977735
epoch 20 Iteration 100 loss: -1313.511818206805                 epoch-loss: -1411.19388568273
epoch 21 Iteration 100 loss: -1508.1672406497062                        epoch-loss: -1427.8889874691674
epoch 22 Iteration 100 loss: -1249.1642813846483                        epoch-loss: -1379.492333117903
epoch 23 Iteration 100 loss: -1214.1394062603918                        epoch-loss: -1356.5797617962307
epoch 24 Iteration 100 loss: -1554.6263005956837                        epoch-loss: -1358.5256191991677
epoch 25 Iteration 100 loss: -1419.5889498936215                        epoch-loss: -1405.5467914984783
epoch 26 Iteration 100 loss: -1262.3682620336267                        epoch-loss: -1409.6484860247688
epoch 27 Iteration 100 loss: -1327.4752015434606                        epoch-loss: -1368.1521038967614
epoch 28 Iteration 100 loss: -1256.4414309051297                        epoch-loss: -1351.3528504368003
epoch 29 Iteration 100 loss: -1178.4788588168844                        epoch-loss: -1413.816013007459
epoch 30 Iteration 100 loss: -1605.1239164704423                        epoch-loss: -1426.6550440932342
epoch 31 Iteration 100 loss: -1617.1795697144926                        epoch-loss: -1356.5267725452202
epoch 32 Iteration 100 loss: -1590.7237287842681                        epoch-loss: -1425.2884165221458
epoch 33 Iteration 100 loss: -1594.3448025229204                        epoch-loss: -1420.5483351285052
epoch 34 Iteration 100 loss: -1576.9397677615486                        epoch-loss: -1430.2946033617723
epoch 35 Iteration 100 loss: -1303.3497394593587                        epoch-loss: -1380.3330104443605
epoch 36 Iteration 100 loss: -1478.0145396344049                        epoch-loss: -1399.0665992260174
epoch 37 Iteration 100 loss: -1555.4119456067176                        epoch-loss: -1360.9939473244767
epoch 38 Iteration 100 loss: -1553.031887368961                 epoch-loss: -1419.1421503464217
epoch 39 Iteration 100 loss: -1427.3431059260865                        epoch-loss: -1415.0248356293594
epoch 40 Iteration 100 loss: -1137.8470272897798                        epoch-loss: -1398.6618957762776
epoch 41 Iteration 100 loss: -1551.999240061582                 epoch-loss: -1402.3061839927834
epoch 42 Iteration 100 loss: -1458.4434735943848                        epoch-loss: -1425.2654433536431
epoch 43 Iteration 100 loss: -1585.6542548185487                        epoch-loss: -1384.815978968837
epoch 44 Iteration 100 loss: -1410.7384899311965                        epoch-loss: -1400.3690408109871
epoch 45 Iteration 100 loss: -1343.7557878846794                        epoch-loss: -1402.4205821010662
epoch 46 Iteration 100 loss: -1309.0681838828461                        epoch-loss: -1412.2783526889364
epoch 47 Iteration 100 loss: -1125.0585501913108                        epoch-loss: -1391.0496208478644
epoch 48 Iteration 100 loss: -1470.087468755146                 epoch-loss: -1390.1175558545679
epoch 49 Iteration 100 loss: -1572.597159674086                 epoch-loss: -1389.4460105298315
epoch 50 Iteration 100 loss: -1113.9894360784558                        epoch-loss: -1403.3841449208112

The learned kernel parameters are as follows:

[44]:
print('The estimated variance of the RBF kernel is %f.' % infr.params[m.kernel.variance].asscalar())
print('The estimated length scale of the RBF kernel is %f.' % infr.params[m.kernel.lengthscale].asscalar())
print('The estimated variance of the Gaussian likelihood is %f.' % infr.params[m.noise_var].asscalar())
The estimated variance of the RBF kernel is 0.220715.
The estimated length scale of the RBF kernel is 0.498507.
The estimated variance of the Gaussian likelihood is 0.003107.

Prediction

The prediction of a SVGP model can be done by creating a TransferInference instance.

[45]:
from mxfusion.inference import TransferInference, ModulePredictionAlgorithm
infr_pred = TransferInference(ModulePredictionAlgorithm(model=m, observed=[m.X], target_variables=[m.Y]),
                              infr_params=infr.params)
m.Y.factor.svgp_predict.jitter = 1e-6

To visualize the fitted model, we make predictions on 100 points evenly spanned from -5 to 5. We estimate the mean and variance of the noise-free output \(F\).

[46]:
xt = np.linspace(-5,5,100)[:, None]
res = infr_pred.run(X=mx.nd.array(xt, dtype='float64'))[0]
f_mean, f_var = res[0].asnumpy()[0], res[1].asnumpy()[0]

The resulting figure is shown as follows:

[47]:
plot(xt, f_mean[:,0], 'b-', label='mean')
plot(xt, f_mean[:,0]-2*np.sqrt(f_var[:, 0]), 'b--', label='2 x std')
plot(xt, f_mean[:,0]+2*np.sqrt(f_var[:, 0]), 'b--')
plot(X, Y, 'rx', label='data points')
ylabel('F')
xlabel('X')
_=legend()
../../_images/examples_notebooks_svgp_regression_19_0.png