gpflow.experimental.check_shapes¶
check_shapes¶
A library for annotating and checking the shapes of tensors.
This library is compatible with both TensorFlow and NumPy.
The main entry point is shape_checking_study.numpy_example_2.check_shapes()
.
For example:
@tf.function
@check_shapes(
"features: [batch_shape..., n_features]",
"weights: [n_features]",
"return: [batch_shape...]",
)
def linear_model(
features: tf.Tensor, weights: tf.Tensor
) -> tf.Tensor:
...
Check specification¶
The shapes to check are specified by the arguments to check_shapes()
. Each argument is a
string of the format:
<argument specifier>: <shape specifier>
Argument specification¶
The <argument specifier>
must start with either the name of an argument to the decorated
function, or the special name return
. The value return
refers to the value returned by the
function.
The <argument specifier>
can then be modified to refer to elements of the object in two ways:
Use
.<name>
to refer to attributes of the object.Use
[<index>]
to refer to elements of a sequence. This is particularly useful if your function returns a tuple of values.
We do not support looking up values in a dict
.
For example:
@check_shapes(
"weights: ...",
"data.training_data: ...",
"return: ...",
"return[0]: ...",
"something[0].foo.bar[23]: ...",
)
def f(...):
...
Shape specification¶
Shapes are specified by the syntax
[<dimension specifier 1>, <dimension specifer 2>, ..., <dimension specifier n>]
, where
<dimension specifier i>
is
one of:
<integer>
, to require that dimension to have that exact size.
<name>
, to bind that dimension to a variable. Dimensions bound to the same variable must have the same size, though that size can be anything.
*<name>
or<name>...
, to bind any number of dimensions to a variable. Again, multiple uses of the same variable name must match the same dimension sizes.
A scalar shape is specified by []
.
For example:
@check_shapes(
"...: []",
"...: [3, 4]",
"...: [width, height]",
"...: [n_samples, *batch]",
"...: [batch..., 2]",
)
def f(...):
...
Class inheritance¶
If you have a class hiererchy, you probably want to ensure that derived classes handle tensors with
the same shapes as the base classes. You can use the inherit_check_shapes()
decorator to
inherit shapes from overridden methods.
Example:
class SuperClass(ABC):
@abstractmethod
@check_shapes(
("a", ["batch...", 4]),
("return", ["batch...", 1]),
)
def f(self, a: tf.Tensor) -> tf.Tensor:
...
class SubClass(SuperClass):
@inherit_check_shapes
def f(self, a: tf.Tensor) -> tf.Tensor:
...
Speed, and interactions with tf.function¶
If you want to wrap your function in both tf.function()
and check_shapes()
it is
recommended you put the tf.function()
outermost so that the shape checks are inside
tf.function()
. Shape checks are performed while tracing graphs, but not compiled into the
actual graphs. This is considered a feature as that means that check_shapes()
doesn’t impact
the execution speed of compiled functions. However, it also means that tensor dimensions of dynamic
size are not verified in compiled mode.
Documenting shapes¶
The check_shapes()
decorator rewrites the docstring (.__doc__) of the decorated function to
add information about shapes, in a format compatible with
Sphinx. Only parameters that already have a :param …:
section will be modified.
For example:
@tf.function
@check_shapes(
"features: [batch_shape..., n_features]",
"weights: [n_features]",
"return: [batch_shape...]",
)
def linear_model(
features: tf.Tensor, weights: tf.Tensor
) -> tf.Tensor:
"""
Computes a prediction from a linear model.
:param features: Data to make predictions from.
:param weights: Model weights.
:returns: Model predictions.
"""
...
will have .__doc__:
"""
Computes a prediction from a linear model.
:param a:
* **features** has shape [*batch_shape*..., *n_features*].
Data to make predictions from.
:param b:
* **weights** has shape [*n_features*].
Model weights.
:returns:
* **return** has shape [*batch_shape*...].
Model predictions.
"""
gpflow.experimental.check_shapes.ArgumentReferenceError¶
- class gpflow.experimental.check_shapes.ArgumentReferenceError(func, arg_map, arg_ref)[source]¶
Bases:
Exception
Error raised if the argument to check the shape of could not be resolved.
- Attributes
- args
Methods
with_traceback
Exception.with_traceback(tb) -- set self.__traceback__ to tb and return self.
- Parameters
func (
TypeVar
(C
, bound=Callable
[...
,Any
])) –arg_map (
Mapping
[str
,Any
]) –arg_ref (
ArgumentRef
) –
gpflow.experimental.check_shapes.ShapeMismatchError¶
- class gpflow.experimental.check_shapes.ShapeMismatchError(func, specs, arg_map)[source]¶
Bases:
Exception
Error raised if a function is called with tensors of the wrong shape.
- Attributes
- args
Methods
with_traceback
Exception.with_traceback(tb) -- set self.__traceback__ to tb and return self.
- Parameters
func (
TypeVar
(C
, bound=Callable
[...
,Any
])) –specs (
Sequence
[ParsedArgumentSpec
]) –arg_map (
Mapping
[str
,Any
]) –
gpflow.experimental.check_shapes.check_shapes¶
- gpflow.experimental.check_shapes.check_shapes(*specs)[source]¶
Decorator that checks the shapes of tensor arguments.
See: check_shapes.
- Parameters
spec_strs – Specification of arguments to check. See: Argument specification.
- Parameters
specs (
str
) –- Return type
Callable
[[TypeVar
(C
, bound=Callable
[...
,Any
])],TypeVar
(C
, bound=Callable
[...
,Any
])]
gpflow.experimental.check_shapes.inherit_check_shapes¶
- gpflow.experimental.check_shapes.inherit_check_shapes(func)[source]¶
Decorator that inherits the
check_shapes()
decoration from any overridden method in a super-class.See: Class inheritance.
- Parameters
func (
TypeVar
(C
, bound=Callable
[...
,Any
])) –- Return type
TypeVar
(C
, bound=Callable
[...
,Any
])