Source code for gpflow.experimental.check_shapes.argument_ref

# Copyright 2022 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code for (de)referencing arguments.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Iterable, List, Mapping, Sequence, Tuple

from .error_contexts import (
    ArgumentContext,
    AttributeContext,
    ErrorContext,
    IndexContext,
    MappingKeyContext,
    MappingValueContext,
    StackContext,
)
from .exceptions import ArgumentReferenceError

# The special name used to represent the returned value. `return` is a good choice because we know
# that no argument can be called `return` because `return` is a reserved keyword.
RESULT_TOKEN = "return"


[docs]class ArgumentRef(ABC): """ A reference to an argument. """ @property @abstractmethod def is_result(self) -> bool: """ Whether this is a reference to the function result. """ @property @abstractmethod def root_argument_name(self) -> str: """ Name of the argument this reference eventually starts from. Returns `RESULT_TOKEN` if this in an argument to the function result. """
[docs] @abstractmethod def get( self, arg_map: Mapping[str, Any], context: ErrorContext ) -> Sequence[Tuple[Any, ErrorContext]]: """ Get the value(s) of this argument from the given argument map. """
@abstractmethod def __repr__(self) -> str: """ Return a string representation of this reference. """
[docs]@dataclass(frozen=True) class RootArgumentRef(ArgumentRef): """ A reference to a single argument. """ argument_name: str @property def is_result(self) -> bool: return self.argument_name == RESULT_TOKEN @property def root_argument_name(self) -> str: return self.argument_name
[docs] def get( self, arg_map: Mapping[str, Any], context: ErrorContext ) -> Sequence[Tuple[Any, ErrorContext]]: relative_context = ArgumentContext(self.argument_name) try: arg_value = arg_map[self.argument_name] except Exception as e: raise ArgumentReferenceError(StackContext(context, relative_context)) from e return [(arg_value, relative_context)]
def __repr__(self) -> str: return self.argument_name
[docs]@dataclass(frozen=True) # type: ignore[misc] class DelegatingArgumentRef(ArgumentRef): """ Abstract base class for :class:`ArgumentRef`\ s that delegates to a source. """ source: ArgumentRef @property def is_result(self) -> bool: return self.source.is_result @property def root_argument_name(self) -> str: return self.source.root_argument_name
[docs] @abstractmethod def map_value(self, value: Any, context: ErrorContext) -> Iterable[Tuple[Any, ErrorContext]]: """ Map this value, from `self.source` to new value(s). """
[docs] def map_context(self, context: ErrorContext) -> ErrorContext: """ Pre-map this error context from `self.source`. The mapped value will both be used for error messages and passed to `map_value` above. """ return context
[docs] def get( self, arg_map: Mapping[str, Any], context: ErrorContext ) -> Sequence[Tuple[Any, ErrorContext]]: results: List[Tuple[Any, ErrorContext]] = [] sources = self.source.get(arg_map, context) for source, source_relative_context in sources: if source is None: results.append((source, source_relative_context)) continue try: relative_context = self.map_context(source_relative_context) except Exception as e: raise ArgumentReferenceError(context) from e try: results.extend(self.map_value(source, relative_context)) except Exception as e: raise ArgumentReferenceError(StackContext(context, relative_context)) from e return results
[docs]@dataclass(frozen=True) class AttributeArgumentRef(DelegatingArgumentRef): """ A reference to an attribute on an argument. """ attribute_name: str
[docs] def map_value(self, value: Any, context: ErrorContext) -> Iterable[Tuple[Any, ErrorContext]]: return [(getattr(value, self.attribute_name), context)]
[docs] def map_context(self, context: ErrorContext) -> ErrorContext: return StackContext(context, AttributeContext(self.attribute_name))
def __repr__(self) -> str: return f"{self.source!r}.{self.attribute_name}"
[docs]@dataclass(frozen=True) class IndexArgumentRef(DelegatingArgumentRef): """ A reference to an element in a list. """ index: int
[docs] def map_value(self, value: Any, context: ErrorContext) -> Iterable[Tuple[Any, ErrorContext]]: return [(value[self.index], context)]
[docs] def map_context(self, context: ErrorContext) -> ErrorContext: return StackContext(context, IndexContext(self.index))
def __repr__(self) -> str: return f"{self.source!r}[{self.index}]"
[docs]@dataclass(frozen=True) class AllElementsRef(DelegatingArgumentRef): """ A reference to all elements in a collection. """
[docs] def map_value(self, value: Any, context: ErrorContext) -> Iterable[Tuple[Any, ErrorContext]]: return [(v, StackContext(context, IndexContext(i))) for i, v in enumerate(value)]
def __repr__(self) -> str: return f"{self.source!r}[all]"
[docs]@dataclass(frozen=True) class KeysRef(DelegatingArgumentRef): """ A reference to all keys of a mapping. """
[docs] def map_value(self, value: Any, context: ErrorContext) -> Iterable[Tuple[Any, ErrorContext]]: return [(k, StackContext(context, MappingKeyContext(k))) for k in value]
def __repr__(self) -> str: return f"{self.source!r}.keys()"
[docs]@dataclass(frozen=True) class ValuesRef(DelegatingArgumentRef): """ A reference to all values of a mapping. """
[docs] def map_value(self, value: Any, context: ErrorContext) -> Iterable[Tuple[Any, ErrorContext]]: return [(v, StackContext(context, MappingValueContext(k))) for k, v in value.items()]
def __repr__(self) -> str: return f"{self.source!r}.values()"