mxfusion.inference.variational

Members

class mxfusion.inference.variational.VariationalInference(model, posterior, observed)

Bases: mxfusion.inference.inference_alg.InferenceAlgorithm

The base class for Variational Inference (VI) algorithms.

Parameters:
  • model (Model) – the definition of the probabilistic model
  • posterior – the definition of the variational posterior of the probabilistic model
  • posterior – Posterior
  • observed ([Variable]) – A list of observed variables
posterior

return the variational posterior.

class mxfusion.inference.variational.VariationalSamplingAlgorithm(model, posterior, observed, num_samples=1, target_variables=None)

Bases: mxfusion.inference.inference_alg.SamplingAlgorithm

The base class for the sampling algorithms that are applied to the models with variational approximation.

Parameters:
  • model (Model) – the definition of the probabilistic model
  • posterior – the definition of the variational posterior of the probabilistic model
  • posterior – Posterior
  • observed ([Variable]) – A list of observed variables
  • num_samples (int) – the number of samples used in estimating the variational lower bound
  • target_variables ([UUID]) – (optional) the target variables to sample
posterior

return the variational posterior.

class mxfusion.inference.variational.StochasticVariationalInference(num_samples, model, posterior, observed)

Bases: mxfusion.inference.variational.VariationalInference

The class of the Stochastic Variational Inference (SVI) algorithm.

Parameters:
  • num_samples (int) – the number of samples used in estimating the variational lower bound
  • model (Model) – the definition of the probabilistic model
  • posterior – the definition of the variational posterior of the probabilistic model
  • posterior – Posterior
  • observed ([Variable]) – A list of observed variables
compute(F, variables)

Compute the inference algorithm

Parameters:
  • F (Python module) – the execution context (mxnet.ndarray or mxnet.symbol)
  • variables – the set of MXNet arrays that holds the values of

variables at runtime. :type variables: {str(UUID): MXNet NDArray or MXNet Symbol} :returns: the outcome of the inference algorithm :rtype: mxnet.ndarray.ndarray.NDArray or mxnet.symbol.symbol.Symbol