Source code for gpflow.experimental.check_shapes.error_contexts

# Copyright 2022 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=broad-except

"""
Infrastructure for describing the context of an error.

The :class:`MessageBuilder` is used to format / indent messages nicely.

The :class:`ErrorContext` is a reusable bit of information about where/why an error occurred that
can be written to a :class:`MessageBuilder`.

:class:`ErrorContext`\ s can be composed using the :class:`StackContext` and
:class:`ParallelContext`.

This allows reusable error messages in a consistent format.
"""
import inspect
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Union,
)

from lark.exceptions import UnexpectedCharacters, UnexpectedEOF, UnexpectedInput

from .base_types import Shape
from .config import get_enable_function_call_precompute

if TYPE_CHECKING:  # pragma: no cover
    from .argument_ref import ArgumentRef
    from .bool_specs import ParsedBoolSpec
    from .specs import ParsedNoteSpec, ParsedShapeSpec, ParsedTensorSpec

_UNKNOWN_FILE = "<Unknown file>"
_UNKNOWN_LINE = "<Unknown line>"
_DISABLED_FILE_AND_LINE = (
    f"{_UNKNOWN_FILE}:{_UNKNOWN_LINE}"
    " (Disabled. Call gpflow.experimental.check_shapes.set_enable_function_call_precompute(True) to"
    " see this. (Slow.))"
)
_NONE_SHAPE = "<Tensor is None or has unknown shape>"

_NO_VALUE = object()
"""
Sentinel to represent "no value" in places where ``None`` is a valid value.
"""


[docs]class MessageBuilder: """ Utility for formatting nested text. """ def __init__(self, indent_str: str = "") -> None: self._indent_str = indent_str self._lines: List[Union[str, Tuple[str, ...], "MessageBuilder"]] = [] self._column_widths: List[int] = []
[docs] def add_line(self, text: Any) -> None: """ Write a line, indented at the current level. Input is converted to a `str` using `str(text)`. """ text_str = str(text) self._lines.append(text_str)
[docs] def add_columned_line(self, *texts: Any) -> None: """ Write a line of several values, left-aligned within the current indentation level. Inputs are converted to `str`s using `str(text)`. """ widths = self._column_widths text_strs = tuple(str(t) for t in texts) for i, t in enumerate(text_strs): w = len(t) if i < len(widths): widths[i] = max(widths[i], w) else: widths.append(w) self._lines.append(text_strs)
[docs] @contextmanager def indent(self) -> Iterator["MessageBuilder"]: """ Context manager for indenting text. """ child = MessageBuilder(self._indent_str + " ") self._lines.append(child) yield child
[docs] def build(self) -> str: """ Compile all collected text into a single string. """ fragments: List[str] = [] def _collect(builder: "MessageBuilder") -> None: # pylint: disable=protected-access for line in builder._lines: if isinstance(line, str): fragments.extend((builder._indent_str, line, "\n")) elif isinstance(line, tuple): fragments.append(builder._indent_str) for i, t in enumerate(line): fragments.append(t) if i + 1 < len(line): fragments.append(" " * (builder._column_widths[i] - len(t) + 1)) fragments.append("\n") else: assert isinstance(line, MessageBuilder) _collect(line) _collect(self) return "".join(fragments)
[docs]class ErrorContext(ABC): """ A context in which an error can occur. Contexts should be immutable, and implement :func:`__eq__` - so that they can be composed using :class:`StackContext` and :class:`ParallelContext`. The contexts are often created even if an error doesn't actually occur, so they should be cheap to create - prefer to do any slow computation in :func:`print`, rather than in :func:`__init__`. Maybe think of an :class:`ErrorContext` as a factory of error messages. """
[docs] @abstractmethod def print(self, builder: MessageBuilder) -> None: """ Print this context to the given `MessageBuilder`. """
[docs]@dataclass(frozen=True) class StackContext(ErrorContext): """ Error context where one context is "inside" another one. """ parent: ErrorContext child: ErrorContext
[docs] def print(self, builder: MessageBuilder) -> None: def _print(context: ErrorContext) -> None: nonlocal builder if isinstance(context, StackContext): _print(context.parent) _print(context.child) else: context.print(builder) with builder.indent() as b: builder = b _print(self)
[docs]@dataclass(frozen=True) class ParallelContext(ErrorContext): """ Error context with many contexts in parallel. """ children: Tuple[ErrorContext, ...]
[docs] def print(self, builder: MessageBuilder) -> None: # Merge any nested parallel contexts. flat: List[ErrorContext] = [] def _flatten(context: ErrorContext) -> None: if isinstance(context, ParallelContext): for child in context.children: _flatten(child) else: flat.append(context) _flatten(self) # If any `StackContexts`s have the same (grand-) parent, merge them, so that we don't repeat # stuff in the output. def _split_head(context: StackContext) -> Tuple[ErrorContext, ErrorContext]: parent = context.parent child = context.child if isinstance(parent, StackContext): # Transform: # /\ /\ # /\ C --> A /\ # A B B C # So we get: /\ # head = A, body = B C return _split_head(StackContext(parent.parent, StackContext(parent.child, child))) else: return parent, child by_head: Dict[ErrorContext, List[ErrorContext]] = {} for child in flat: if isinstance(child, StackContext): head, body = _split_head(child) else: head = child body = None bodies = by_head.setdefault(head, []) if body is not None: bodies.append(body) for head, bodies in by_head.items(): head.print(builder) with builder.indent() as b: ParallelContext(tuple(bodies)).print(b)
[docs]@dataclass(frozen=True) class FunctionCallContext(ErrorContext): """ An error occured inside a function that was called. Normally `print` should be called from within the called function: `func`. If that is impossible, first call `precompute` from within `func`. """ func: Callable[..., Any] path_and_line: Optional[str] = None def _get_target(self) -> Optional[Tuple[str, str]]: func = inspect.unwrap(self.func) search_name: str = func.__name__ search_path = inspect.getsourcefile(func) if search_path is None: return None return search_name, search_path def _count_wrappers(self) -> int: wrappers = -1 f: Optional[Callable[..., Any]] = self.func while f is not None: wrappers += 1 f = getattr(f, "__wrapped__", None) return wrappers def _get_calling_frame(self) -> Optional[inspect.FrameInfo]: search_name_path = self._get_target() if search_name_path is None: return None stack = inspect.stack(0) for i, frame in enumerate(stack): name_path = (frame.function, frame.filename) if search_name_path == name_path: return stack[i + self._count_wrappers() + 1] return None def _compute(self) -> str: if self.path_and_line is not None: return self.path_and_line frame = self._get_calling_frame() if frame is None: path = _UNKNOWN_FILE line = _UNKNOWN_LINE else: path = frame.filename line = str(frame.lineno) return f"{path}:{line}"
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line(f"{self.func.__name__} called at:", self._compute())
[docs] def precompute(self) -> "FunctionCallContext": """ Precompute the values to print. This is useful to capture the position of the stack of the relevant call, if this object is saved for later use. :returns: A new instance with precomptued values. """ if get_enable_function_call_precompute(): return FunctionCallContext(self.func, self._compute()) else: return FunctionCallContext(self.func, _DISABLED_FILE_AND_LINE)
[docs]@dataclass(frozen=True) class FunctionDefinitionContext(ErrorContext): """ An error occured in the context of a function definition. """ func: Callable[..., Any]
[docs] def print(self, builder: MessageBuilder) -> None: name = self.func.__qualname__ try: path = inspect.getsourcefile(self.func) except Exception: # pragma: no cover path = _UNKNOWN_FILE try: _, line_int = inspect.getsourcelines(self.func) line = str(line_int) except Exception: # pragma: no cover line = _UNKNOWN_LINE path_and_line = f"{path}:{line}" builder.add_columned_line("Function:", name) with builder.indent() as b: b.add_columned_line("Declared:", path_and_line)
[docs]@dataclass(frozen=True) class VariableContext(ErrorContext): """ An error occurred in the context of a shape specification variable. """ variable_name: str
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Variable:", self.variable_name)
[docs]@dataclass(frozen=True) class TensorSpecContext(ErrorContext): """ An error occurred in the context of a tensor specification. """ spec: "ParsedTensorSpec"
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Specification:", self.spec)
[docs]@dataclass(frozen=True) class ArgumentContext(ErrorContext): """ An error occurred in the context of an argument to a function. """ name_or_index: Union[str, int] value: Any = _NO_VALUE
[docs] def print(self, builder: MessageBuilder) -> None: if isinstance(self.name_or_index, str): builder.add_columned_line("Argument:", self.name_or_index) else: builder.add_columned_line("Argument number (0-indexed):", self.name_or_index) if self.value is not _NO_VALUE: with builder.indent() as b: b.add_columned_line("Value:", self.value)
def __eq__(self, other: Any) -> bool: return ( type(self) == type(other) and self.name_or_index == other.name_or_index and self.value is other.value ) def __hash__(self) -> int: return hash((self.name_or_index, id(self.value)))
[docs]@dataclass(frozen=True) class AttributeContext(ErrorContext): """ An error occurred in the context of an attribute on an object. """ name: str value: Any = _NO_VALUE
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Attribute:", f".{self.name}") if self.value is not _NO_VALUE: with builder.indent() as b: b.add_columned_line("Value:", self.value)
def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.name == other.name and self.value is other.value def __hash__(self) -> int: return hash((self.name, id(self.value)))
[docs]@dataclass(frozen=True) class IndexContext(ErrorContext): """ An error occurred in the context of an index in a sequence. """ index: int value: Any = _NO_VALUE
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Index:", f"[{self.index}]") if self.value is not _NO_VALUE: with builder.indent() as b: b.add_columned_line("Value:", self.value)
def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.index == other.index and self.value is other.value def __hash__(self) -> int: return hash((self.index, id(self.value)))
[docs]@dataclass(frozen=True) class MappingKeyContext(ErrorContext): """ An error occurent in the context of a key in a map. """ key: Any
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Map key:", f"[{self.key!r}]")
def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.key is other.key def __hash__(self) -> int: return id(self.key)
[docs]@dataclass(frozen=True) class MappingValueContext(ErrorContext): """ An error occurent in the context of a value in a map. """ key: Any value: Any = _NO_VALUE
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Map value, of key:", f"[{self.key!r}]") if self.value is not _NO_VALUE: with builder.indent() as b: b.add_columned_line("Value:", self.value)
def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.key is other.key and self.value is other.value def __hash__(self) -> int: return hash((id(self.key), id(self.value)))
[docs]@dataclass(frozen=True) class ConditionContext(ErrorContext): """ An error occurred in a conditional context. """ condition: "ParsedBoolSpec"
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Condition:", self.condition)
[docs]@dataclass(frozen=True) class ShapeContext(ErrorContext): """ An error occurred in the context of the shapes of function arguments. """ expected: "ParsedShapeSpec" actual: Shape
[docs] def print(self, builder: MessageBuilder) -> None: if self.actual is None: actual_str = _NONE_SHAPE else: actual_str = f"[{', '.join(str(dim) for dim in self.actual)}]" builder.add_columned_line("Expected:", self.expected) builder.add_columned_line("Actual:", actual_str)
[docs]@dataclass(frozen=True) class NoteContext(ErrorContext): """ An error occurred in a context where a user has added a note. """ note: "ParsedNoteSpec"
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Note:", self.note.note)
[docs]@dataclass(frozen=True) class ObjectValueContext(ErrorContext): """ An error was caused by the value of an object. """ obj: Any
[docs] def print(self, builder: MessageBuilder) -> None: builder.add_columned_line("Value:", repr(self.obj))
def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.obj is other.obj def __hash__(self) -> int: return id(self.obj)
[docs]@dataclass(frozen=True) class ObjectTypeContext(ErrorContext): """ An error was caused by the type of an object. """ obj: Any
[docs] def print(self, builder: MessageBuilder) -> None: t = type(self.obj) builder.add_columned_line("Object type:", f"{t.__module__}.{t.__qualname__}")
def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.obj is other.obj def __hash__(self) -> int: return id(self.obj)
[docs]@dataclass(frozen=True) # type: ignore[misc] # mypy doesn't like abstract `dataclasses`. class ParserInputContext(ErrorContext, ABC): """ Abstract base class for contexts that relate to parser input. """ text: str def print_line(self, builder: MessageBuilder, line: int, column: int) -> None: if line > 0: line_content = self.text.split("\n")[line - 1] builder.add_columned_line("Line:", f'"{line_content}"') if column > 0: builder.add_columned_line("", " " * column + "^")
[docs]@dataclass(frozen=True) class LarkUnexpectedInputContext(ParserInputContext): """ An error was caused by an `UnexpectedInput` error from `Lark`. """ error: UnexpectedInput terminal_descriptions: Mapping[str, str]
[docs] def print(self, builder: MessageBuilder) -> None: self.print_line( builder, getattr(self.error, "line", -1), getattr(self.error, "column", -1), ) expected: Sequence[str] = getattr(self.error, "accepts", []) if not expected: expected = getattr(self.error, "expected", []) expected = sorted(expected) expected_key = "Expected:" if len(expected) <= 1 else "Expected one of:" for expected_name in expected: expected_value = self.terminal_descriptions.get(expected_name, expected_name) builder.add_columned_line(expected_key, expected_value) expected_key = "" if isinstance(self.error, UnexpectedCharacters): builder.add_line("Found unexpected character.") if isinstance(self.error, UnexpectedEOF): builder.add_line("Found unexpected end of input.")
def __hash__(self) -> int: return hash((self.error, *sorted(self.terminal_descriptions.items())))
[docs]@dataclass(frozen=True) class MultipleElementBoolContext(ParserInputContext): """ An error was caused by trying to use a multi-element argument specification as a bool. """ line: int column: int
[docs] def print(self, builder: MessageBuilder) -> None: self.print_line(builder, self.line, self.column) builder.add_line( "Argument references that evaluate to multiple values are not supported for boolean" " expressions." )