Variational Inference#

ADVI(*args, **kwargs)

Automatic Differentiation Variational Inference (ADVI).

ASVGD([approx, estimator, kernel])

Amortized Stein Variational Gradient Descent.

SVGD([n_particles, jitter, model, start, ...])

Stein Variational Gradient Descent.

FullRankADVI(*args, **kwargs)

Full Rank Automatic Differentiation Variational Inference (ADVI).

ImplicitGradient(approx[, estimator, kernel])

Implicit Gradient for Variational Inference.

Inference(op, approx, tf, **kwargs)

Base class for Variational Inference.

KLqp(approx[, beta])

Kullback Leibler Divergence Inference.

fit([n, method, model, random_seed, start, ...])

Handy shortcut for using inference methods in functional way.

Approximations#

Empirical([trace, size])

Single Group Full Rank Approximation

FullRank(*args, **kwargs)

Single Group Full Rank Approximation

MeanField(*args, **kwargs)

Single Group Mean Field Approximation

sample_approx(approx[, draws, ...])

Draw samples from variational posterior.

OPVI#

Approximation(groups[, model])

Wrapper for grouped approximations.

Group([group, vfam, params])

Base class for grouping variables in VI.

Operators#

KL(approx[, beta])

Operator based on Kullback Leibler Divergence.

KSD(approx[, temperature])

Operator based on Kernelized Stein Discrepancy.

Special#

Stein(approx[, kernel, use_histogram, ...])

adadelta([loss_or_grads, params, ...])

Adadelta updates.

adagrad([loss_or_grads, params, ...])

Adagrad updates.

adagrad_window([loss_or_grads, params, ...])

Return a function that returns parameter updates.

adam([loss_or_grads, params, learning_rate, ...])

Adam updates.

adamax([loss_or_grads, params, ...])

Adamax updates.

apply_momentum(updates[, params, momentum])

Return a modified update dictionary including momentum.

apply_nesterov_momentum(updates[, params, ...])

Return a modified update dictionary including Nesterov momentum.

momentum([loss_or_grads, params, ...])

Stochastic Gradient Descent (SGD) updates with momentum.

nesterov_momentum([loss_or_grads, params, ...])

Stochastic Gradient Descent (SGD) updates with Nesterov momentum.

norm_constraint(tensor_var, max_norm[, ...])

Max weight norm constraints and gradient clipping.

rmsprop([loss_or_grads, params, ...])

RMSProp updates.

sgd([loss_or_grads, params, learning_rate])

Stochastic Gradient Descent (SGD) updates.

total_norm_constraint(tensor_vars, max_norm)

Rescales a list of tensors based on their combined norm.