Source code for gpflow.experimental.check_shapes.checker

# Copyright 2022 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Class responsible for remembering and checking shapes.
"""
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union

from ..utils import experimental
from .base_types import Shape
from .config import get_enable_check_shapes
from .error_contexts import (
    ErrorContext,
    FunctionCallContext,
    IndexContext,
    NoteContext,
    ParallelContext,
    ShapeContext,
    StackContext,
    TensorSpecContext,
    VariableContext,
)
from .exceptions import ShapeMismatchError, VariableTypeError
from .parser import parse_tensor_spec
from .shapes import get_shape
from .specs import ParsedDimensionSpec, ParsedShapeSpec, ParsedTensorSpec

S = TypeVar("S")


TensorSpecLike = Union[ParsedTensorSpec, str, Shape]


def _as_parsed_tensor_spec(tensor_spec: TensorSpecLike, context: ErrorContext) -> ParsedTensorSpec:
    if isinstance(tensor_spec, ParsedTensorSpec):
        return tensor_spec
    elif isinstance(tensor_spec, str):
        return parse_tensor_spec(tensor_spec, context)
    else:
        dimension_specs = []
        if isinstance(tensor_spec, tuple):
            for dim in tensor_spec:
                if isinstance(dim, int):
                    dimension_specs.append(ParsedDimensionSpec(dim, None, False, False))
                else:
                    assert dim is None
                    dimension_specs.append(ParsedDimensionSpec(None, None, False, False))
        else:
            assert tensor_spec is None
            dimension_specs.append(ParsedDimensionSpec(None, None, True, False))
        shape = ParsedShapeSpec(tuple(dimension_specs))
        return ParsedTensorSpec(shape, None)


@dataclass
class _ObservedDim:
    """
    Storage of observed size of a single dimension variable.
    """

    size: Optional[int] = None

    def check_and_update(self, actual: Optional[int], broadcast: bool) -> bool:
        """
        Attempt to merge new data into this observation.

        Returns whether the new data is compatible with existing observations.  If this returns
        `False` it may have been left in an invalid state and should not be used again.
        """
        if (actual is None) or (broadcast and actual == 1):
            # Update contains no information. Nothing to do.
            return True

        if self.size is None:
            self.size = actual

        return self.size == actual


@dataclass
class _ObservedDims:
    """
    Storage of observed sizes of a var-rank / batch variable.
    """

    sizes: Optional[List[_ObservedDim]] = None
    known_rank: bool = False

    def check_and_update(
        self, actual: Optional[Tuple[Optional[int], ...]], broadcast: bool
    ) -> bool:
        """
        Attempt to merge new data into this observation.

        Returns whether the new data is compatible with existing observations.  If this returns
        `False` it may have been left in an invalid state and should not be used again.
        """
        if actual is None:
            # Update contains no information. Nothing to do.
            return True

        if self.sizes is None:
            self.sizes = []
            assert not self.known_rank

        # First make sure lengths are set up and matches.
        longer = len(self.sizes) - len(actual)
        if self.known_rank:
            if broadcast:
                if longer < 0:
                    return False
            else:
                if longer != 0:
                    return False
        else:
            if longer < 0:
                self.sizes = [_ObservedDim() for _ in range(-longer)] + self.sizes
                longer = 0
            if broadcast:
                pass  # We don't know anything about total rank.
            else:
                if longer > 0:
                    return False
                self.known_rank = True
        assert longer >= 0

        # Then match individual dimensions.
        for i, actual_dim in enumerate(actual):
            if not self.sizes[i + longer].check_and_update(actual_dim, broadcast):
                return False

        return True


[docs]class ShapeChecker: """ Mechanism for checking the shapes of tensors. This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages. Example: .. literalinclude:: /examples/test_check_shapes_examples.py :start-after: [shape_checker__raw] :end-before: [shape_checker__raw] :dedent: """ @experimental def __init__(self) -> None: self._seen_shapes: List[Tuple[Shape, ParsedTensorSpec, ErrorContext]] = [] # Here we're (ab-)using `Dict[ , None]` instead of `set`, because `dict`s retain ordering. self._specs_by_variable: Dict[str, Dict[Tuple[ParsedTensorSpec, ErrorContext], None]] = {} self._rank1_variables: Set[str] = set() self._varrank_variables: Set[str] = set() self._additional_context: List[ErrorContext] = [] self._seen_dims: Dict[str, Union[_ObservedDim, _ObservedDims]] = {}
[docs] def add_context(self, context: ErrorContext) -> None: """ Add arbirtary context to the shape checker. This context will be included in any error messages. :param context: Context to add to this shape checker. """ self._additional_context.append(context)
[docs] def check_shape( self, shaped: S, tensor_spec: TensorSpecLike, context: Optional[ErrorContext] = None ) -> S: """ Raise an error if a tensor has the wrong shape. This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages. :param shaped: The object whose shape to check. :param tensor_spec: Specification to check the tensor against. Usually this is a ``str`` in the format described in `Shape specification`_. Alternatively this can be a pre-parsed :class:`ParsedTensorSpec`, or an actual :class:`Shape`. :param context: Information about where ``shaped`` is coming from, for improved error messages. :returns: ``shaped``, for convenience. """ if not get_enable_check_shapes(): return shaped if context is None: context = FunctionCallContext(self.check_shape).precompute() self.check_shapes([(shaped, tensor_spec, context)]) return shaped
[docs] def check_shapes( self, checks: Iterable[ Union[Tuple[Any, TensorSpecLike], Tuple[Any, TensorSpecLike, ErrorContext]] ], ) -> None: """ Raise an error if any tensor has the wrong shape. This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages. :param checks: Checks to perform. The elements can either be ``(shaped, tensor_spec)`` or ``(shaped, tensor_spec, context)`` tuples. Where: ``shaped`` is the tensor whose shape to check; ``tensor_spec`` is the specification to check it against (see `Shape specification`_); and ``context`` contains (optional) information about where ``shaped`` came from - for better error messages. """ if not get_enable_check_shapes(): return new_shapes = [] new_variables = set() call_context: Optional[ErrorContext] = None for i, check in enumerate(checks): shaped, tensor_spec, *contexts = check if contexts: (context,) = contexts else: if call_context is None: call_context = FunctionCallContext(self.check_shapes).precompute() context = StackContext(call_context, IndexContext(i)) parsed_tensor_check = _as_parsed_tensor_spec(tensor_spec, context) if shaped is None: shape: Shape = None else: shape = get_shape(shaped, context) self._seen_shapes.append((shape, parsed_tensor_check, context)) if shape is not None: new_shapes.append((shape, parsed_tensor_check)) for dim_spec in parsed_tensor_check.shape.dims: if dim_spec.variable_name is None: continue self._specs_by_variable.setdefault(dim_spec.variable_name, {})[ (parsed_tensor_check, context) ] = None if dim_spec.variable_rank: self._varrank_variables.add(dim_spec.variable_name) else: self._rank1_variables.add(dim_spec.variable_name) new_variables.add(dim_spec.variable_name) new_variable_error_contexts = [] for variable in sorted(new_variables): if (variable in self._rank1_variables) and (variable in self._varrank_variables): new_variable_error_contexts.append( StackContext( VariableContext(variable), ParallelContext( tuple( StackContext(c, TensorSpecContext(s)) for s, c in self._specs_by_variable[variable] ) ), ) ) if new_variable_error_contexts: raise VariableTypeError(ParallelContext(tuple(new_variable_error_contexts))) def _assert(condition: bool) -> None: if not condition: contexts: List[ErrorContext] = [] contexts.extend(self._additional_context) for shape, tensor_spec, context in self._seen_shapes: shape_error_context: ErrorContext = ShapeContext(tensor_spec.shape, shape) if tensor_spec.note is not None: shape_error_context = ParallelContext( (NoteContext(tensor_spec.note), shape_error_context) ) contexts.append(StackContext(context, shape_error_context)) raise ShapeMismatchError(ParallelContext(tuple(contexts))) for actual, tensor_spec in new_shapes: actual_len = len(actual) actual_i = 0 shape_spec = tensor_spec.shape expected = shape_spec.dims expected_len = len(expected) n_variable_rank = sum(dim_spec.variable_rank for dim_spec in expected) assert n_variable_rank <= 1, "At most one variable-rank ParsedDimensionSpec allowed." if n_variable_rank == 0: _assert(expected_len == actual_len) else: _assert(expected_len - n_variable_rank <= actual_len) for dim_spec in expected: if dim_spec.variable_rank: variable_rank_len = actual_len - (expected_len - n_variable_rank) actual_dims = actual[actual_i : actual_i + variable_rank_len] actual_i += variable_rank_len expected_name = dim_spec.variable_name if expected_name is None: # Anonymous dimension spec - we don't care about the actual values. continue expected_dims = self._seen_dims.setdefault(expected_name, _ObservedDims()) assert isinstance(expected_dims, _ObservedDims) _assert(expected_dims.check_and_update(actual_dims, dim_spec.broadcastable)) else: actual_dim = actual[actual_i] actual_i += 1 if actual_dim is None: continue if dim_spec.constant is not None: expected_dim = _ObservedDim() assert expected_dim.check_and_update(dim_spec.constant, broadcast=False) elif dim_spec.variable_name is not None: expected_name = dim_spec.variable_name maybe_expected_dim = self._seen_dims.setdefault( expected_name, _ObservedDim() ) assert isinstance(maybe_expected_dim, _ObservedDim) expected_dim = maybe_expected_dim else: # Anonymous dimension - we don't care about the actual value. continue _assert(expected_dim.check_and_update(actual_dim, dim_spec.broadcastable))