Source code for gpflow.utilities.model_utils
from typing import Any, Callable
import tensorflow as tf
from check_shapes import check_shapes
from ..base import TensorType
from ..likelihoods import Gaussian
[docs]
def assert_params_false(
called_method: Callable[..., Any],
**kwargs: bool,
) -> None:
"""
Asserts that parameters are ``False``.
:param called_method: The method or function that is calling this. Used for nice error messages.
:param kwargs: Parameters that must be ``False``.
:raises NotImplementedError: If any ``kwargs`` are ``True``.
"""
errors_str = ", ".join(f"{param}={value}" for param, value in kwargs.items() if value)
if errors_str:
raise NotImplementedError(
f"{called_method.__qualname__} does not currently support: {errors_str}"
)
[docs]
@check_shapes(
"K: [batch..., N, N]",
"likelihood_variance: [broadcast batch..., broadcast N]",
"return: [batch..., N, N]",
)
def add_noise_cov(K: tf.Tensor, likelihood_variance: TensorType) -> tf.Tensor:
"""
Returns K + σ², where σ² is the diagonal likelihood noise variance.
"""
k_diag = tf.linalg.diag_part(K)
return tf.linalg.set_diag(K, k_diag + likelihood_variance)
[docs]
@check_shapes(
"K: [batch..., N, N]",
"X: [batch..., N, D]",
"return: [batch..., N, N]",
)
def add_likelihood_noise_cov(K: tf.Tensor, likelihood: Gaussian, X: TensorType) -> tf.Tensor:
"""
Returns K + σ², where σ² is the likelihood noise variance.
"""
return add_noise_cov(K, tf.squeeze(likelihood.variance_at(X), axis=-1))