Source code for gpflow.experimental.check_shapes.parser

# Copyright 2022 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=unused-argument

from abc import ABC
from dataclasses import replace
from typing import Any, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type

from lark.exceptions import UnexpectedInput
from lark.lark import Lark
from lark.lexer import PatternRE, PatternStr, Token
from lark.tree import Tree

from .argument_ref import (
    RESULT_TOKEN,
    AllElementsRef,
    ArgumentRef,
    AttributeArgumentRef,
    IndexArgumentRef,
    KeysRef,
    RootArgumentRef,
    ValuesRef,
)
from .bool_specs import (
    BoolTest,
    ParsedAndBoolSpec,
    ParsedArgumentRefBoolSpec,
    ParsedBoolSpec,
    ParsedNotBoolSpec,
    ParsedOrBoolSpec,
)
from .config import DocstringFormat, get_rewrite_docstrings
from .error_contexts import (
    ArgumentContext,
    ErrorContext,
    LarkUnexpectedInputContext,
    MultipleElementBoolContext,
    StackContext,
)
from .exceptions import CheckShapesError, DocstringParseError, SpecificationParseError
from .specs import (
    ParsedArgumentSpec,
    ParsedDimensionSpec,
    ParsedFunctionSpec,
    ParsedNoteSpec,
    ParsedShapeSpec,
    ParsedTensorSpec,
)

_VARIABLE_RANK_LEADING_TOKEN = "*"
_VARIABLE_RANK_TRAILING_TOKEN = "..."


def _tree_children(tree: Tree[Token]) -> Iterable[Tree[Token]]:
    """ Return all the children of `tree` that are trees themselves. """
    return (child for child in tree.children if isinstance(child, Tree))


def _token_children(tree: Tree[Token]) -> Iterable[str]:
    """ Return the values of all the children of `tree` that are tokens. """
    return (child.value for child in tree.children if isinstance(child, Token))


class _TreeVisitor(ABC):
    """
    Functionality for visiting the nodes of parse-trees.

    This differs from the classes built-in in Lark, in that it allows passing `*args` and
    `**kwargs`.

    Subclasses should add methods with the same name as Lark rules. Those methods should take the
    parse tree of the rule, followed by any other `*args` and `**kwargs` you want. They may return
    anything.
    """

    def visit(self, tree: Tree[Token], *args: Any, **kwargs: Any) -> Any:
        name = tree.data
        visit = getattr(self, name, None)
        assert visit, f"No method found with name {name}."
        return visit(tree, *args, **kwargs)


class _ParseSpec(_TreeVisitor):
    def __init__(self, source: str) -> None:
        self._source = source

    def argument_spec(self, tree: Tree[Token]) -> ParsedArgumentSpec:
        argument_ref, shape_spec, *other_specs = _tree_children(tree)
        argument = self.visit(argument_ref, False)
        shape = self.visit(shape_spec)
        condition = None
        note = None
        for other_spec in other_specs:
            other = self.visit(other_spec)
            if isinstance(other, ParsedBoolSpec):
                assert condition is None
                condition = other
            else:
                assert isinstance(other, ParsedNoteSpec)
                assert note is None
                note = other
        tensor = ParsedTensorSpec(shape, note)
        return ParsedArgumentSpec(argument, tensor, condition)

    def bool_spec_or(self, tree: Tree[Token]) -> ParsedBoolSpec:
        left, right = _tree_children(tree)
        return ParsedOrBoolSpec(self.visit(left), self.visit(right))

    def bool_spec_and(self, tree: Tree[Token]) -> ParsedBoolSpec:
        left, right = _tree_children(tree)
        return ParsedAndBoolSpec(self.visit(left), self.visit(right))

    def bool_spec_not(self, tree: Tree[Token]) -> ParsedBoolSpec:
        (right,) = _tree_children(tree)
        return ParsedNotBoolSpec(self.visit(right))

    def bool_spec_argument_ref_is_none(self, tree: Tree[Token]) -> ParsedBoolSpec:
        (argument_ref,) = _tree_children(tree)
        return ParsedArgumentRefBoolSpec(self.visit(argument_ref, True), BoolTest.IS_NONE)

    def bool_spec_argument_ref_is_not_none(self, tree: Tree[Token]) -> ParsedBoolSpec:
        (argument_ref,) = _tree_children(tree)
        return ParsedArgumentRefBoolSpec(self.visit(argument_ref, True), BoolTest.IS_NOT_NONE)

    def bool_spec_argument_ref(self, tree: Tree[Token]) -> ParsedBoolSpec:
        (argument_ref,) = _tree_children(tree)
        return ParsedArgumentRefBoolSpec(self.visit(argument_ref, True), BoolTest.BOOL)

    def argument_ref_root(self, tree: Tree[Token], is_for_bool_spec: bool) -> ArgumentRef:
        (token,) = _token_children(tree)
        return RootArgumentRef(token)

    def argument_ref_attribute(self, tree: Tree[Token], is_for_bool_spec: bool) -> ArgumentRef:
        (source,) = _tree_children(tree)
        (token,) = _token_children(tree)
        return AttributeArgumentRef(self.visit(source, is_for_bool_spec), token)

    def argument_ref_index(self, tree: Tree[Token], is_for_bool_spec: bool) -> ArgumentRef:
        (source,) = _tree_children(tree)
        (token,) = _token_children(tree)
        return IndexArgumentRef(self.visit(source, is_for_bool_spec), int(token))

    def argument_ref_all(self, tree: Tree[Token], is_for_bool_spec: bool) -> ArgumentRef:
        (source,) = _tree_children(tree)
        self._disallow_multiple_element_bool_spec(source, is_for_bool_spec)
        return AllElementsRef(self.visit(source, is_for_bool_spec))

    def argument_ref_keys(self, tree: Tree[Token], is_for_bool_spec: bool) -> ArgumentRef:
        (source,) = _tree_children(tree)
        self._disallow_multiple_element_bool_spec(source, is_for_bool_spec)
        return KeysRef(self.visit(source, is_for_bool_spec))

    def argument_ref_values(self, tree: Tree[Token], is_for_bool_spec: bool) -> ArgumentRef:
        (source,) = _tree_children(tree)
        self._disallow_multiple_element_bool_spec(source, is_for_bool_spec)
        return ValuesRef(self.visit(source, is_for_bool_spec))

    def _disallow_multiple_element_bool_spec(
        self, source: Tree[Token], is_for_bool_spec: bool
    ) -> None:
        if is_for_bool_spec:
            meta = source.meta
            raise SpecificationParseError(
                MultipleElementBoolContext(self._source, meta.end_line, meta.end_column)
            )

    def tensor_spec(self, tree: Tree[Token]) -> ParsedTensorSpec:
        shape_spec, *note_specs = _tree_children(tree)
        shape = self.visit(shape_spec)
        if note_specs:
            (note_spec,) = note_specs
            note = self.visit(note_spec)
        else:
            note = None
        return ParsedTensorSpec(shape, note)

    def shape_spec(self, tree: Tree[Token]) -> ParsedShapeSpec:
        (dimension_specs,) = _tree_children(tree)
        return ParsedShapeSpec(self.visit(dimension_specs))

    def dimension_specs(self, tree: Tree[Token]) -> Tuple[ParsedDimensionSpec, ...]:
        return tuple(
            self.visit(dimension_spec, i) for i, dimension_spec in enumerate(_tree_children(tree))
        )

    def dimension_spec_broadcast(self, tree: Tree[Token], i: int) -> ParsedDimensionSpec:
        (dimension_spec,) = _tree_children(tree)
        child = self.visit(dimension_spec, i)
        assert isinstance(child, ParsedDimensionSpec)
        return replace(child, broadcastable=True)

    def dimension_spec_constant(self, tree: Tree[Token], i: int) -> ParsedDimensionSpec:
        (token,) = _token_children(tree)
        return ParsedDimensionSpec(
            constant=int(token), variable_name=None, variable_rank=False, broadcastable=False
        )

    def dimension_spec_variable(self, tree: Tree[Token], i: int) -> ParsedDimensionSpec:
        (token,) = _token_children(tree)
        return ParsedDimensionSpec(
            constant=None, variable_name=token, variable_rank=False, broadcastable=False
        )

    def dimension_spec_anonymous(self, tree: Tree[Token], i: int) -> ParsedDimensionSpec:
        return ParsedDimensionSpec(
            constant=None, variable_name=None, variable_rank=False, broadcastable=False
        )

    def dimension_spec_variable_rank(self, tree: Tree[Token], i: int) -> ParsedDimensionSpec:
        (token1, token2) = _token_children(tree)
        if token1 == _VARIABLE_RANK_LEADING_TOKEN:
            variable_name = token2
        else:
            assert token2 == _VARIABLE_RANK_TRAILING_TOKEN
            variable_name = token1
        return ParsedDimensionSpec(
            constant=None, variable_name=variable_name, variable_rank=True, broadcastable=False
        )

    def dimension_spec_anonymous_variable_rank(
        self, tree: Tree[Token], i: int
    ) -> ParsedDimensionSpec:
        return ParsedDimensionSpec(
            constant=None, variable_name=None, variable_rank=True, broadcastable=False
        )

    def note_spec(self, tree: Tree[Token]) -> ParsedNoteSpec:
        _hash_token, *note_tokens = _token_children(tree)
        return ParsedNoteSpec(" ".join(token.strip() for token in note_tokens))


class _RewritedocString(_TreeVisitor):
    def __init__(self, source: str, function_spec: ParsedFunctionSpec) -> None:
        self._source = source
        self._spec_lines = self._argument_specs_to_sphinx(function_spec.arguments)
        self._notes = tuple(note.note for note in function_spec.notes)
        self._indent = self._guess_indent(source)

    def _argument_specs_to_sphinx(
        self,
        argument_specs: Collection[ParsedArgumentSpec],
    ) -> Mapping[str, Sequence[str]]:
        result: Dict[str, List[str]] = {}
        for spec in argument_specs:
            result.setdefault(spec.argument_ref.root_argument_name, []).append(
                self._argument_spec_to_sphinx(spec)
            )
        for lines in result.values():
            lines.sort()
        return result

    def _argument_spec_to_sphinx(self, argument_spec: ParsedArgumentSpec) -> str:
        tensor_spec = argument_spec.tensor
        shape_spec = tensor_spec.shape
        out = []
        out.append(f"* **{repr(argument_spec.argument_ref)}**")
        out.append(" has shape [")
        out.append(self._shape_spec_to_sphinx(shape_spec))
        out.append("]")

        if argument_spec.condition is not None:
            out.append(" if ")
            out.append(self._bool_spec_to_sphinx(argument_spec.condition, False))

        out.append(".")

        if tensor_spec.note is not None:
            note_spec = tensor_spec.note
            out.append(" ")
            out.append(note_spec.note)
        return "".join(out)

    def _bool_spec_to_sphinx(self, bool_spec: ParsedBoolSpec, paren_wrap: bool) -> str:
        if isinstance(bool_spec, ParsedOrBoolSpec):
            result = (
                self._bool_spec_to_sphinx(bool_spec.left, True)
                + " or "
                + self._bool_spec_to_sphinx(bool_spec.right, True)
            )
        elif isinstance(bool_spec, ParsedAndBoolSpec):
            result = (
                self._bool_spec_to_sphinx(bool_spec.left, True)
                + " and "
                + self._bool_spec_to_sphinx(bool_spec.right, True)
            )
        elif isinstance(bool_spec, ParsedNotBoolSpec):
            result = "not " + self._bool_spec_to_sphinx(bool_spec.right, True)
        else:
            assert isinstance(bool_spec, ParsedArgumentRefBoolSpec)
            if bool_spec.bool_test == BoolTest.BOOL:
                paren_wrap = False  # Never wrap a stand-alone argument.
                result = f"*{bool_spec.argument_ref!r}*"
            elif bool_spec.bool_test == BoolTest.IS_NONE:
                result = f"*{bool_spec.argument_ref!r}* is *None*"
            else:
                assert bool_spec.bool_test == BoolTest.IS_NOT_NONE
                result = f"*{bool_spec.argument_ref!r}* is not *None*"

        if paren_wrap:
            result = f"({result})"

        return result

    def _shape_spec_to_sphinx(self, shape_spec: ParsedShapeSpec) -> str:
        return ", ".join(self._dim_spec_to_sphinx(dim) for dim in shape_spec.dims)

    def _dim_spec_to_sphinx(self, dim_spec: ParsedDimensionSpec) -> str:
        tokens = []

        if dim_spec.broadcastable:
            tokens.append("broadcast ")

        if dim_spec.constant is not None:
            tokens.append(str(dim_spec.constant))
        elif dim_spec.variable_name:
            tokens.append(f"*{dim_spec.variable_name}*")
        else:
            if not dim_spec.variable_rank:
                tokens.append(".")

        if dim_spec.variable_rank:
            tokens.append("...")

        return "".join(tokens)

    def _guess_indent(self, docstring: str) -> Optional[int]:
        """
        Infer the level of indentation of a docstring.

        Returns `None` if the indentation could not be inferred.
        """
        # Algorithm adapted from:
        #     https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation

        # Convert tabs to spaces (following the normal Python rules)
        # and split into a list of lines:
        lines = docstring.expandtabs().splitlines()
        # Determine minimum indentation (first line doesn't count):
        no_indent = -1
        indent = no_indent
        for line in lines[1:]:
            stripped = line.lstrip()
            if not stripped:
                continue
            line_indent = len(line) - len(stripped)
            if indent == no_indent or line_indent < indent:
                indent = line_indent
        return indent if indent != no_indent else None

    def _insert_spec_lines(
        self, out: List[str], pos: int, spec_lines: Sequence[str], docs: Tree[Token]
    ) -> int:
        leading_str = self._source[pos : docs.meta.start_pos].rstrip()
        docs_start = pos + len(leading_str)
        docs_str = self._source[docs_start : docs.meta.end_pos]
        trailing_str = docs_str.lstrip()

        docs_indent = self._guess_indent(docs_str)
        if docs_indent is None:
            if self._indent is None:
                docs_indent = 4
            else:
                docs_indent = self._indent + 4
        indent_str = "\n" + docs_indent * " "

        out.append(leading_str)
        for spec_line in spec_lines:
            out.append(indent_str)
            out.append(spec_line)
        out.append("\n")
        out.append(indent_str)
        out.append(trailing_str)
        return docs.meta.end_pos

    def _insert_param_info_fields(
        self,
        is_first_info_field: bool,
        spec_lines: Mapping[str, Sequence[str]],
        out: List[str],
        pos: int,
    ) -> int:
        leading_str = self._source[pos:].rstrip()
        out.append(leading_str)
        pos += len(leading_str)

        if not self._source:
            # Case where nothing preceeds these fields. Just write them.
            needed_newlines = 0
        elif is_first_info_field:
            # Free-form documentation preceeds these fields. Have 2 newlines to separate them.
            needed_newlines = 2
        else:
            # Another info-field preceeds these fields.
            needed_newlines = 1

        indent = self._indent or 0
        indent_str = indent * " "
        indent_one_str = 4 * " "

        for arg_name, arg_lines in spec_lines.items():
            out.append(needed_newlines * "\n")
            needed_newlines = 1

            out.append(indent_str)
            if arg_name == RESULT_TOKEN:
                out.append(":returns:")
            else:
                out.append(f":param {arg_name}:")
            for arg_line in arg_lines:
                out.append("\n")
                out.append(indent_str)
                out.append(indent_one_str)
                out.append(arg_line)

        return pos

    def docstring(self, tree: Tree[Token]) -> str:
        # The strategy here is:
        # * `out` contains a list of strings that will be concatenated and form the final result.
        # * `pos` is the position such that `self._source[:pos]` has already been added to `out`,
        #   and `self._source[pos:]` still needs to be added.
        # * When visiting children we pass `out` and `pos`, and the children add content to `out`
        #   and return a new `pos`.
        docs, info_fields = _tree_children(tree)
        out: List[str] = []
        pos = 0

        if self._notes:
            if not docs.meta.empty:
                out.append(self._source[pos : docs.meta.end_pos])
                pos = docs.meta.end_pos
            indent = self._indent or 0
            indent_str = indent * " "
            for note in self._notes:
                if out:
                    out.append("\n\n")
                out.append(indent_str)
                out.append(note)

        pos = self.visit(info_fields, out, pos)
        out.append(self._source[pos:])

        return "".join(out)

    def info_fields(self, tree: Tree[Token], out: List[str], pos: int) -> int:
        spec_lines = dict(self._spec_lines)
        is_first_info_field = True
        for child in _tree_children(tree):
            # This will remove the self._spec_lines corresponding to found `:param:`'s.
            pos = self.visit(child, spec_lines, out, pos)
            is_first_info_field = False

        # Add any remaining `:param:`s:
        pos = self._insert_param_info_fields(is_first_info_field, spec_lines, out, pos)

        # Make sure info fields are terminated by a new-line:
        if self._spec_lines:
            if (pos >= len(self._source)) or (self._source[pos] != "\n"):
                out.append("\n")

        return pos

    def info_field_param(
        self, tree: Tree[Token], spec_lines: Dict[str, Sequence[str]], out: List[str], pos: int
    ) -> int:
        info_field_args, docs = _tree_children(tree)
        arg_name = self.visit(info_field_args)
        arg_lines = spec_lines.pop(arg_name, None)
        if arg_lines:
            pos = self._insert_spec_lines(out, pos, arg_lines, docs)
        return pos

    def info_field_returns(
        self, tree: Tree[Token], spec_lines: Dict[str, Sequence[str]], out: List[str], pos: int
    ) -> int:
        (docs,) = _tree_children(tree)
        return_lines = spec_lines.pop(RESULT_TOKEN, None)
        if return_lines:
            pos = self._insert_spec_lines(out, pos, return_lines, docs)
        return pos

    def info_field_other(
        self, tree: Tree[Token], spec_lines: Dict[str, Sequence[str]], out: List[str], pos: int
    ) -> int:
        return pos

    def info_field_args(self, tree: Tree[Token]) -> str:
        tokens = list(_token_children(tree))
        if not tokens:
            return ""
        return tokens[-1]


class _CachedParser:
    """
    Small wrapper around Lark so that we can reuse as much code as possible between the different
    things we parse.
    """

    def __init__(
        self,
        grammar_filename: str,
        start_symbol: str,
        parser_name: str,
        re_terminal_descriptions: Mapping[str, str],
        transformer_class: Type[_TreeVisitor],
        exception_class: Type[CheckShapesError],
    ) -> None:
        self._cache: Dict[Tuple[str, Tuple[Any, ...]], Any] = {}
        self._parser = Lark.open(
            grammar_filename,
            rel_to=__file__,
            propagate_positions=True,
            start=start_symbol,
            parser=parser_name,
        )
        self._terminal_descriptions = {}
        self._transformer_class = transformer_class
        self._exception_class = exception_class

        # Pre-compute nice terminal descriptions for our error messages:
        missing = set()
        unused = dict(re_terminal_descriptions)
        for terminal in self._parser.terminals:
            name = terminal.name
            pattern = terminal.pattern
            if isinstance(pattern, PatternStr):
                description = f'"{pattern.value}"'
            else:
                assert isinstance(pattern, PatternRE)
                unused_description = unused.pop(name, None)
                # If we enter this `if` then the parser is misconfigured, so we never get here, even
                # during tests.
                if unused_description is None:  # pragma: no cover
                    missing.add(name)
                    description = "ERROR"
                else:
                    description = f"{re_terminal_descriptions[name]} (re={pattern.value})"
            self._terminal_descriptions[name] = description
        assert not unused, f"Redundant terminal descriptions were provided: {sorted(unused)}"
        assert not missing, f"Some RE terminals did not have a description: {sorted(missing)}"

    def parse(self, text: str, transformer_args: Tuple[Any, ...], context: ErrorContext) -> Any:
        sentinel = object()
        cache_key = (text, transformer_args)
        result = self._cache.get(cache_key, sentinel)
        if result is sentinel:
            try:
                tree = self._parser.parse(text)
            except UnexpectedInput as ui:
                raise self._exception_class(
                    StackContext(
                        context, LarkUnexpectedInputContext(text, ui, self._terminal_descriptions)
                    )
                ) from ui

            try:
                result = self._transformer_class(*transformer_args).visit(tree)
            except CheckShapesError as cse:
                raise self._exception_class(StackContext(context, cse.context)) from cse

            self._cache[cache_key] = result
        return result


_TENSOR_SPEC_PARSER = _CachedParser(
    grammar_filename="check_shapes.lark",
    start_symbol="tensor_spec",
    parser_name="lalr",
    re_terminal_descriptions={
        "NOTE_TEXT": "note / comment text",
        "CNAME": "variable name",
        "INT": "integer",
        "WS": "whitespace",
    },
    transformer_class=_ParseSpec,
    exception_class=SpecificationParseError,
)
_ARGUMENT_SPEC_PARSER = _CachedParser(
    grammar_filename="check_shapes.lark",
    start_symbol="argument_or_note_spec",
    parser_name="lalr",
    re_terminal_descriptions={
        "NOTE_TEXT": "note / comment text",
        "CNAME": "variable name",
        "INT": "integer",
        "WS": "whitespace",
    },
    transformer_class=_ParseSpec,
    exception_class=SpecificationParseError,
)
_SPHINX_DOCSTRING_PARSER = _CachedParser(
    grammar_filename="docstring.lark",
    start_symbol="docstring",
    parser_name="earley",
    re_terminal_descriptions={
        "ANY": "any text",
        "CNAME": "variable name",
        "INFO_FIELD_OTHER": "Sphinx info field",
        "PARAM": "Sphinx parameter field",
        "RETURNS": "Sphinx `return` field",
        "WS": "whitespace",
    },
    transformer_class=_RewritedocString,
    exception_class=DocstringParseError,
)


[docs]def parse_tensor_spec(tensor_spec: str, context: ErrorContext) -> ParsedTensorSpec: """ Parse a `check_shapes` tensor specification. """ result = _TENSOR_SPEC_PARSER.parse(tensor_spec, (tensor_spec,), context) assert isinstance(result, ParsedTensorSpec) return result
[docs]def parse_function_spec(function_spec: Sequence[str], context: ErrorContext) -> ParsedFunctionSpec: """ Parse all `check_shapes` argument or note specification for a single function. """ arguments = [] notes = [] for i, spec in enumerate(function_spec): argument_context = StackContext(context, ArgumentContext(i)) parsed_spec = _ARGUMENT_SPEC_PARSER.parse(spec, (spec,), argument_context) if isinstance(parsed_spec, ParsedArgumentSpec): arguments.append(parsed_spec) else: assert isinstance(parsed_spec, ParsedNoteSpec) notes.append(parsed_spec) return ParsedFunctionSpec(tuple(arguments), tuple(notes))
[docs]def parse_and_rewrite_docstring( docstring: Optional[str], function_spec: ParsedFunctionSpec, context: ErrorContext ) -> Optional[str]: """ Rewrite `docstring` to include the shapes specified by the `argument_specs`. """ if docstring is None: return None docstring_format = get_rewrite_docstrings() if docstring_format == DocstringFormat.NONE: return docstring assert docstring_format == DocstringFormat.SPHINX, ( f"Current docstring format is {docstring_format}, but I don't know how to rewrite that." " See `gpflow.experimental.check_shapes.config.set_rewrite_docstrings`." ) result = _SPHINX_DOCSTRING_PARSER.parse(docstring, (docstring, function_spec), context) assert isinstance(result, str) return result