# Copyright 2017-2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import tensorflow as tf
from ..base import SamplesMeanAndVariance
from ..inducing_variables import InducingVariables
from ..kernels import Kernel
from .dispatch import conditional, sample_conditional
from .util import sample_mvn
[docs]@sample_conditional.register(object, object, Kernel, object)
@sample_conditional.register(object, InducingVariables, Kernel, object)
def _sample_conditional(
    Xnew: tf.Tensor,
    inducing_variable: InducingVariables,
    kernel: Kernel,
    f: tf.Tensor,
    *,
    full_cov: bool = False,
    full_output_cov: bool = False,
    q_sqrt: Optional[tf.Tensor] = None,
    white: bool = False,
    num_samples: Optional[int] = None,
) -> SamplesMeanAndVariance:
    """
    `sample_conditional` will return a sample from the conditional distribution.
    In most cases this means calculating the conditional mean m and variance v and then
    returning m + sqrt(v) * eps, with eps ~ N(0, 1).
    However, for some combinations of Mok and Mof more efficient sampling routines exists.
    The dispatcher will make sure that we use the most efficient one.
    :return: samples, mean, cov
        samples has shape [num_samples, N, P] or [N, P] if num_samples is None
        mean and cov as for conditional()
    """
    if full_cov and full_output_cov:
        msg = "The combination of both `full_cov` and `full_output_cov` is not permitted."
        raise NotImplementedError(msg)
    mean, cov = conditional(
        Xnew,
        inducing_variable,
        kernel,
        f,
        q_sqrt=q_sqrt,
        white=white,
        full_cov=full_cov,
        full_output_cov=full_output_cov,
    )
    if full_cov:
        # mean: [..., N, P]
        # cov: [..., P, N, N]
        mean_for_sample = tf.linalg.adjoint(mean)  # [..., P, N]
        samples = sample_mvn(
            mean_for_sample, cov, full_cov=True, num_samples=num_samples
        )  # [..., (S), P, N]
        samples = tf.linalg.adjoint(samples)  # [..., (S), N, P]
    else:
        # mean: [..., N, P]
        # cov: [..., N, P] or [..., N, P, P]
        samples = sample_mvn(
            mean, cov, full_cov=full_output_cov, num_samples=num_samples
        )  # [..., (S), N, P]
    return samples, mean, cov