Faster predictions by caching#

The default behaviour of predict_f in GPflow models is to compute the predictions from scratch on each call. This is convenient when predicting and training are interleaved, and simplifies the use of these models. There are some use cases, such as Bayesian optimisation, where prediction (at different test points) happens much more frequently than training. In these cases it is convenient to cache parts of the calculation which do not depend upon the test points, and reuse those parts between predictions.

There are three models to which we want to add this caching capability: GPR, (S)VGP and SGPR. The VGP and SVGP can be considered together; the difference between the models is whether to condition on the full training data set (VGP) or on the inducing variables (SVGP).

Posterior predictive distribution#

The posterior predictive distribution evaluated at a set of test points \(\mathbf{x}_*\) for a Gaussian process model is given by: \begin{equation*} p(\mathbf{f}_*|X, Y) = \mathcal{N}(\mu, \Sigma) \end{equation*}

In the case of the GPR model, the parameters \(\mu\) and \(\Sigma\) are given by: \begin{equation*} \mu = K_{nm}[K_{mm} + \sigma^2I]^{-1}\mathbf{y} \end{equation*} and \begin{equation*} \Sigma = K_{nn} - K_{nm}[K_{mm} + \sigma^2I]^{-1}K_{mn} \end{equation*}

The posterior predictive distribution for the VGP and SVGP model is parameterised as follows: \begin{equation*} \mu = K_{nu}K_{uu}^{-1}\mathbf{u} \end{equation*} and \begin{equation*} \Sigma = K_{nn} - K_{nu}K_{uu}^{-1}K_{un} \end{equation*}

Finally, the parameters for the SGPR model are: \begin{equation*} \mu = K_{nu}L^{-T}L_B^{-T}\mathbf{c} \end{equation*} and \begin{equation*} \Sigma = K_{nn} - K_{nu}L^{-T}(I - B^{-1})L^{-1}K_{un} \end{equation*}

Where the mean function is not the zero function, the predictive mean should have the mean function evaluated at the test points added to it.

What can be cached?#

We cache two separate values: \(\alpha\) and \(Q^{-1}\). These correspond to the parts of the mean and covariance functions respectively which do not depend upon the test points. In the case of the GPR these are the same value: \begin{equation*} \alpha = Q^{-1} = [K_{mm} + \sigma^2I]^{-1} \end{equation*} in the case of the VGP and SVGP model these are: \begin{equation*} \alpha = K_{uu}^{-1}\mathbf{u}\\ Q^{-1} = K_{uu}^{-1} \end{equation*} and in the case of the SGPR model these are: \begin{equation*} \alpha = L^{-T}L_B^{-T}\mathbf{c}\\ Q^{-1} = L^{-T}(I - B^{-1})L^{-1} \end{equation*}

Note that in the (S)VGP case, \(\alpha\) is the parameter as proposed by Opper and Archambeau for the mean of the predictive distribution.

[1]:
import numpy as np

import gpflow
from gpflow.ci_utils import reduce_in_tests

# Create some data
n_data = reduce_in_tests(1000)
X = np.linspace(-1.1, 1.1, n_data)[:, None]
Y = np.sin(X)
Xnew = np.linspace(-1.1, 1.1, n_data)[:, None]
inducing_points = Xnew
2024-02-07 11:46:00.035176: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-07 11:46:00.075451: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-07 11:46:00.075492: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-07 11:46:00.076827: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-07 11:46:00.083617: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-07 11:46:00.084106: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-07 11:46:01.179731: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

GPR Example#

We will construct a GPR model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes (subclasses of gpflow.posteriors.AbstractPosterior).

[2]:
model = gpflow.models.GPR(
    (X, Y),
    gpflow.kernels.SquaredExponential(),
)

The predict_f method on the GPModel class performs no caching.

[3]:
%%timeit
model.predict_f(Xnew)
165 ms ± 5.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[4]:
# To make use of the caching, first retrieve the posterior class from the model. The posterior class has methods to predict the parameters of marginal distributions at test points, in the same way as the `predict_f` method of the `GPModel`.
posterior = model.posterior()
[5]:
%%timeit
posterior.predict_f(Xnew)
120 ms ± 4.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

SVGP Example#

Likewise, we will construct an SVGP model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes.

[6]:
model = gpflow.models.SVGP(
    gpflow.kernels.SquaredExponential(),
    gpflow.likelihoods.Gaussian(),
    inducing_points,
)

The predict_f method on the GPModel class performs no caching.

[7]:
%%timeit
model.predict_f(Xnew)
214 ms ± 33.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

And again using the posterior object and caching

[8]:
posterior = model.posterior()
[9]:
%%timeit
posterior.predict_f(Xnew)
113 ms ± 3.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

SGPR Example#

And finally, we follow the same approach this time for the SGPR case.

[10]:
model = gpflow.models.SGPR(
    (X, Y), gpflow.kernels.SquaredExponential(), inducing_points
)

The predict_f method on the instance performs no caching.

[11]:
%%timeit
model.predict_f(Xnew)
303 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Using the posterior object instead:

[12]:
posterior = model.posterior()
[13]:
%%timeit
posterior.predict_f(Xnew)
112 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)