Source code for gpflow.experimental.check_shapes.shapes

# Copyright 2022 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code for extracting shapes from object.
"""
import inspect
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple, Type, Union

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from .base_types import Shape
from .error_contexts import ErrorContext, IndexContext, ObjectTypeContext, StackContext
from .exceptions import NoShapeError

if TYPE_CHECKING:  # pragma: no cover
    # Avoid cyclic imports:
    from ...base import AnyNDArray
else:
    AnyNDArray = Any
GetShape = Callable[[Any, ErrorContext], Shape]


_GET_SHAPES: Dict[Type[Any], GetShape] = {}


[docs]def register_get_shape(shape_type: Type[Any]) -> Callable[[GetShape], GetShape]: """ Register a function for extracting the shape from a given type of objects. Example: .. literalinclude:: /examples/test_check_shapes_examples.py :start-after: [custom_type] :end-before: [custom_type] :dedent: See also :func:`get_shape`. :param shape_type: Type of objects to extract shapes from. """ # Yes, what's happening here looks extremely much like `functools.singledispatch`; # however we cannot actually use `functools.singledispatch`, because it uses a for/else # statement, which TensorFlow doesn't know how to compile... def _register(getter: GetShape) -> GetShape: _GET_SHAPES[shape_type] = getter return getter return _register
[docs]def get_shape(shaped: Any, context: ErrorContext) -> Shape: """ Returns the shape of the given object. See also :func:`register_get_shape`. :param shaped: The objects whose shape to extract. :param context: Context we are getting the shape in, for improved error messages. :returns: The shape of ``shaped``, or ``None`` if the shape exists, but is unknown. :raises NoShapeError: If objects of this type does not have shapes. """ for t in inspect.getmro(shaped.__class__): getter = _GET_SHAPES.get(t) if getter is not None: return getter(shaped, context) raise NoShapeError(StackContext(context, ObjectTypeContext(shaped)))
[docs]@register_get_shape(bool) @register_get_shape(int) @register_get_shape(float) @register_get_shape(str) def get_scalar_shape(shaped: Any, context: ErrorContext) -> Shape: return ()
[docs]@register_get_shape(list) @register_get_shape(tuple) def get_sequence_shape(shaped: Sequence[Any], context: ErrorContext) -> Shape: if len(shaped) == 0: # If the sequence doesn't have any elements we cannot use the first element to determine the # shape, and the shape is unknown. return None child_shape = get_shape(shaped[0], StackContext(context, IndexContext(0))) if child_shape is None: return None return (len(shaped),) + child_shape
[docs]@register_get_shape(np.ndarray) def get_ndarray_shape(shaped: AnyNDArray, context: ErrorContext) -> Shape: result: Tuple[int, ...] = shaped.shape return result
[docs]@register_get_shape(tf.Tensor) @register_get_shape(tf.Variable) @register_get_shape(tfp.util.DeferredTensor) def get_tensorflow_shape( shaped: Union[tf.Tensor, tf.Variable, tfp.util.DeferredTensor], context: ErrorContext ) -> Shape: shape = shaped.shape if not shape: return None return tuple(shape)