Source code for gpflow.models.cglb

# Copyright 2021 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, NamedTuple, Optional, Tuple

import tensorflow as tf

from ..base import InputData, MeanAndVariance, Parameter, RegressionData, TensorType
from ..config import default_float, default_int
from ..covariances import Kuf
from ..experimental.check_shapes import check_shapes, inherit_check_shapes
from ..utilities import add_noise_cov, assert_params_false, to_default_float
from .sgpr import SGPR


[docs]class CGLB(SGPR): """ Conjugate Gradient Lower Bound. The key reference is :cite:t:`pmlr-v139-artemev21a`. :param cg_tolerance: Determines accuracy to which conjugate gradient is run when evaluating the elbo. Running more iterations of CG would increase the ELBO by at most `cg_tolerance`. :param max_cg_iters: Maximum number of iterations of CG to run per evaluation of the ELBO (or mean prediction). :param restart_cg_iters: How frequently to restart the CG iteration. Can be useful to avoid build up of numerical errors when many steps of CG are run. :param v_grad_optimization: If False, in every evaluation of the ELBO, CG is run to select a new auxilary vector `v`. If False, no CG is run when evaluating the ELBO but gradients with respect to `v` are tracked so that it can be optimized jointly with other parameters. """ @check_shapes( "data[0]: [N, D]", "data[1]: [N, P]", ) def __init__( self, data: RegressionData, *args: Any, cg_tolerance: float = 1.0, max_cg_iters: int = 100, restart_cg_iters: int = 40, v_grad_optimization: bool = False, **kwargs: Any, ) -> None: super().__init__(data, *args, **kwargs) n, b = self.data[1].shape self._v = Parameter(tf.zeros((b, n), dtype=default_float()), trainable=v_grad_optimization) self._cg_tolerance = cg_tolerance self._max_cg_iters = max_cg_iters self._restart_cg_iters = restart_cg_iters @property # type: ignore[misc] # Mypy doesn't like decorated properties. @check_shapes( "return: [N, P]", ) def aux_vec(self) -> Parameter: return self._v
[docs] @inherit_check_shapes def logdet_term(self, common: SGPR.CommonTensors) -> tf.Tensor: r""" Compute a lower bound on :math:`-0.5 * \log |K + σ²I|` based on a low-rank approximation to K. .. math:: \log |K + σ²I| <= \log |Q + σ²I| + n * \log(1 + \textrm{tr}(K - Q)/(σ²n)). This bound is at least as tight as .. math:: \log |K + σ²I| <= \log |Q + σ²I| + \textrm{tr}(K - Q)/σ², which appears in SGPR. """ LB = common.LB AAT = common.AAT x, y = self.data num_data = to_default_float(tf.shape(y)[0]) output_dim = to_default_float(tf.shape(y)[1]) sigma_sq = self.likelihood.variance kdiag = self.kernel(x, full_cov=False) # t / σ² trace = tf.reduce_sum(kdiag) / sigma_sq - tf.reduce_sum(tf.linalg.diag_part(AAT)) logdet_b = tf.reduce_sum(tf.math.log(tf.linalg.diag_part(LB))) logsigma_sq = num_data * tf.math.log(sigma_sq) # Correction term from Jensen's inequality logtrace = num_data * tf.math.log(1 + trace / num_data) return -output_dim * (logdet_b + 0.5 * logsigma_sq + 0.5 * logtrace)
[docs] @inherit_check_shapes def quad_term(self, common: SGPR.CommonTensors) -> tf.Tensor: """ Computes a lower bound on the quadratic term in the log marginal likelihood of conjugate GPR. The bound is based on an auxiliary vector, v. For :math:`Q ≺ K` and :math:`r=y - Kv` .. math:: -0.5 * (rᵀQ⁻¹r + 2yᵀv - vᵀ K v ) <= -0.5 * yᵀK⁻¹y <= -0.5 * (2yᵀv - vᵀKv). Equality holds if :math:`r=0`, i.e. :math:`v = K⁻¹y`. If `self.aux_vec` is trainable, gradients are computed with respect to :math:`v` as well and :math:`v` can be optimized using gradient based methods. Otherwise, :math:`v` is updated with the method of conjugate gradients (CG). CG is run until :math:`0.5 * rᵀQ⁻¹r <= ϵ`, which ensures that the maximum bias due to this term is not more than :math:`ϵ`. The :math:`ϵ` is the CG tolerance. """ x, y = self.data err = y - self.mean_function(x) sigma_sq = self.likelihood.variance K = add_noise_cov(self.kernel.K(x), sigma_sq) A = common.A LB = common.LB preconditioner = NystromPreconditioner(A, LB, sigma_sq) err_t = tf.transpose(err) v_init = self.aux_vec if not v_init.trainable: v = cglb_conjugate_gradient( K, err_t, v_init, preconditioner, self._cg_tolerance, self._max_cg_iters, self._restart_cg_iters, ) else: v = v_init Kv = v @ K r = err_t - Kv _, error_bound = preconditioner(r) lb = tf.reduce_sum(v * (r + 0.5 * Kv)) ub = lb + 0.5 * error_bound if not v_init.trainable: v_init.assign(v) return -ub
[docs] @inherit_check_shapes def predict_f( self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False, cg_tolerance: Optional[float] = 1e-3, ) -> MeanAndVariance: """ The posterior mean for CGLB model is given by .. :math:: m(xs) = K_{sf}v + Q_{ff}Q⁻¹r where :math:`r = y - K v` is the residual from CG. Note that when :math:`v=0`, this agree with the SGPR mean, while if :math:`v = K⁻¹ y`, then :math:`r=0`, and the exact GP mean is recovered. :param cg_tolerance: float or None: If None, the cached value of :math:`v` is used. If float, conjugate gradient is run until :math:`rᵀQ⁻¹r < ϵ`. """ assert_params_false(self.predict_f, full_output_cov=full_output_cov) x, y = self.data err = y - self.mean_function(x) kxx = self.kernel(x, x) ksf = self.kernel(Xnew, x) sigma_sq = self.likelihood.variance sigma = tf.sqrt(sigma_sq) iv = self.inducing_variable kernel = self.kernel matmul = tf.linalg.matmul trisolve = tf.linalg.triangular_solve kmat = add_noise_cov(kxx, sigma_sq) common = self._common_calculation() A, LB, L = common.A, common.LB, common.L v = self.aux_vec if cg_tolerance is not None: preconditioner = NystromPreconditioner(A, LB, sigma_sq) err_t = tf.transpose(err) v = cglb_conjugate_gradient( kmat, err_t, v, preconditioner, cg_tolerance, self._max_cg_iters, self._restart_cg_iters, ) cg_mean = matmul(ksf, v, transpose_b=True) res = err - matmul(kmat, v, transpose_b=True) Kus = Kuf(iv, kernel, Xnew) Ares = matmul(A, res) # The god of war! c = trisolve(LB, Ares, lower=True) / sigma tmp1 = trisolve(L, Kus, lower=True) tmp2 = trisolve(LB, tmp1, lower=True) sgpr_mean = matmul(tmp2, c, transpose_a=True) if full_cov: var = ( kernel(Xnew) + matmul(tmp2, tmp2, transpose_a=True) - matmul(tmp1, tmp1, transpose_a=True) ) var = tf.tile(var[None, ...], [self.num_latent_gps, 1, 1]) # [P, N, N] else: var = ( kernel(Xnew, full_cov=False) + tf.reduce_sum(tf.square(tmp2), 0) - tf.reduce_sum(tf.square(tmp1), 0) ) var = tf.tile(var[:, None], [1, self.num_latent_gps]) mean = sgpr_mean + cg_mean + self.mean_function(Xnew) return mean, var
[docs] @inherit_check_shapes def predict_y( self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False, cg_tolerance: Optional[float] = 1e-3, ) -> MeanAndVariance: """ Compute the mean and variance of the held-out data at the input points. """ assert_params_false(self.predict_y, full_cov=full_cov, full_output_cov=full_output_cov) f_mean, f_var = self.predict_f( Xnew, full_cov=full_cov, full_output_cov=full_output_cov, cg_tolerance=cg_tolerance ) return self.likelihood.predict_mean_and_var(Xnew, f_mean, f_var)
[docs] @inherit_check_shapes def predict_log_density( self, data: RegressionData, full_cov: bool = False, full_output_cov: bool = False, cg_tolerance: Optional[float] = 1e-3, ) -> tf.Tensor: """ Compute the log density of the data at the new data points. """ assert_params_false( self.predict_log_density, full_cov=full_cov, full_output_cov=full_output_cov ) x, y = data f_mean, f_var = self.predict_f( x, full_cov=full_cov, full_output_cov=full_output_cov, cg_tolerance=cg_tolerance ) return self.likelihood.predict_log_density(x, f_mean, f_var, y)
[docs]class NystromPreconditioner: """ Preconditioner of the form :math:`Q=(Q_ff + σ²I)⁻¹`, where L is lower triangular with :math: `LLᵀ = Kᵤᵤ` :math:`A = σ⁻²L⁻¹Kᵤₓ` and :math:`B = AAᵀ + I = LᵦLᵦᵀ` """ @check_shapes( "A: [M, N]", "LB: [M, M]", ) def __init__(self, A: tf.Tensor, LB: tf.Tensor, sigma_sq: float) -> None: self.A = A self.LB = LB self.sigma_sq = sigma_sq @check_shapes( "v: [B, N]", "return[0]: [B, N]", "return[1]: []", ) def __call__(self, v: TensorType) -> Tuple[tf.Tensor, tf.Tensor]: """ Computes :math:`vᵀQ^{-1}` and `vᵀQ^{-1}v`. Note that this is implemented as multipication of a row vector on the right. :param v: Vector we want to backsolve. """ sigma_sq = self.sigma_sq A = self.A LB = self.LB trans = tf.transpose trisolve = tf.linalg.triangular_solve matmul = tf.linalg.matmul v = trans(v) Av = matmul(A, v) LBinvAv = trisolve(LB, Av) LBinvtLBinvAv = trisolve(trans(LB), LBinvAv, lower=False) rv = v - matmul(A, LBinvtLBinvAv, transpose_a=True) vtrv = tf.reduce_sum(rv * v) return trans(rv) / sigma_sq, vtrv / sigma_sq
[docs]@check_shapes( "K: [N, N]", "b: [B, N]", "initial: [P, N]", "return: [P, N]", ) def cglb_conjugate_gradient( K: TensorType, b: TensorType, initial: TensorType, preconditioner: NystromPreconditioner, cg_tolerance: float, max_steps: int, restart_cg_step: int, ) -> tf.Tensor: """ Conjugate gradient algorithm used in CGLB model. The method of conjugate gradient (Hestenes and Stiefel, 1952) produces a sequence of vectors :math:`v_0, v_1, v_2, ..., v_N` such that :math:`v_0` = initial, and (in exact arithmetic) :math:`Kv_n = b`. In practice, the v_i often converge quickly to approximate :math:`K^{-1}b`, and the algorithm can be stopped without running N iterations. We assume the preconditioner, :math:`Q`, satisfies :math:`Q ≺ K`, and stop the algorithm when :math:`r_i = b - Kv_i` satisfies :math:`||rᵢᵀ||_{Q⁻¹r}^2 = rᵢᵀQ⁻¹rᵢ <= ϵ`. :param K: Matrix we want to backsolve from. Must be PSD. :param b: Vector we want to backsolve. :param initial: Initial vector solution. :param preconditioner: Preconditioner function. :param cg_tolerance: Expected maximum error. This value is used as a decision boundary against stopping criteria. :param max_steps: Maximum number of CG iterations. :param restart_cg_step: Restart step at which the CG resets the internal state to the initial position using the currect solution vector :math:`v`. Can help avoid build up of numerical errors. :return: `v` where `v` approximately satisfies :math:`Kv = b`. """ class CGState(NamedTuple): i: tf.Tensor v: tf.Tensor r: tf.Tensor p: tf.Tensor rz: tf.Tensor def stopping_criterion(state: CGState) -> tf.Tensor: return (0.5 * state.rz > cg_tolerance) and (state.i < max_steps) def cg_step(state: CGState) -> List[CGState]: Ap = state.p @ K denom = tf.reduce_sum(state.p * Ap, axis=-1) gamma = state.rz / denom v = state.v + gamma * state.p i = state.i + 1 r = tf.cond( state.i % restart_cg_step == restart_cg_step - 1, lambda: b - v @ K, lambda: state.r - gamma * Ap, ) z, new_rz = preconditioner(r) p = tf.cond( state.i % restart_cg_step == restart_cg_step - 1, lambda: z, lambda: z + state.p * new_rz / state.rz, ) return [CGState(i, v, r, p, new_rz)] Kv = initial @ K r = b - Kv z, rz = preconditioner(r) p = z i = tf.constant(0, dtype=default_int()) initial_state = CGState(i, initial, r, p, rz) final_state = tf.while_loop(stopping_criterion, cg_step, [initial_state]) final_state = tf.nest.map_structure(tf.stop_gradient, final_state) return final_state[0].v