
A library for annotating and checking the shapes of tensors.

The main entry point is check_shapes().


import tensorflow as tf

from gpflow.experimental.check_shapes import check_shapes

    "features: [batch..., n_features]",
    "weights: [n_features]",
    "return: [batch...]",
def linear_model(features: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
    return tf.einsum("...i,i -> ...", features, weights)

Check specification#

The shapes to check are specified by the arguments to check_shapes(). Each argument is a string of the format:

<argument specifier> ":" <shape specifier> ["if" <condition>] ["#" <note>]

Argument specification#

The <argument specifier> must start with either the name of an argument to the decorated function, or the special name return. The value return refers to the value returned by the function.

The <argument specifier> can then be modified to refer to elements of the object in several ways:

  • Use .<name> to refer to an attribute of an object:

    class Statistics:
        mean: AnyNDArray
        std: AnyNDArray
        "data: [n_rows, n_columns]",
        "return.mean: [n_columns]",
        "return.std: [n_columns]",
    def compute_statistics(data: AnyNDArray) -> Statistics:
        return Statistics(np.mean(data, axis=0), np.std(data, axis=0))
  • Use [<index>] to refer to a specific element of a sequence. This is particularly useful if your function returns a tuple of values:

        "data: [n_rows, n_columns]",
        "return[0]: [n_columns]",
        "return[1]: [n_columns]",
    def compute_mean_and_std(data: AnyNDArray) -> Tuple[AnyNDArray, AnyNDArray]:
        return np.mean(data, axis=0), np.std(data, axis=0)
  • Use [all] to select all elements of a collection:

        "data[all]: [., n_columns]",
        "return: [., n_columns]",
    def concat_rows(data: Collection[AnyNDArray]) -> AnyNDArray:
        return np.concatenate(data, axis=0)
            np.ones((1, 3)),
            np.ones((4, 3)),
  • Use .keys() to select all keys of a mapping:

        "data.keys(): [.]",
        "return: []",
    def sum_key_lengths(data: Mapping[Tuple[int, ...], str]) -> int:
        return sum(len(k) for k in data)
            (3,): "foo",
            (1, 2): "bar",
  • Use .values() to select all values of a mapping:

        "data.values(): [., n_columns]",
        "return: [., n_columns]",
    def concat_rows(data: Mapping[str, AnyNDArray]) -> AnyNDArray:
        return np.concatenate(list(data.values()), axis=0)
            "foo": np.ones((1, 3)),
            "bar": np.ones((4, 3)),


We do not support looking up a specific key or value in a dict.

If the argument, or any of the looked-up values, are None the check is skipped. This is useful for optional values:

    "x1: [n_rows_1, n_inputs]",
    "x2: [n_rows_2, n_inputs]",
    "return: [n_rows_1, n_rows_2]",
def squared_exponential_kernel(
    variance: float, x1: AnyNDArray, x2: Optional[AnyNDArray] = None
) -> AnyNDArray:
    if x2 is None:
        x2 = x1
    cov: AnyNDArray = variance * np.exp(
        -0.5 * np.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=2)
    return cov

squared_exponential_kernel(1.0, np.ones((3, 2)), np.ones((4, 2)))
squared_exponential_kernel(3.2, np.ones((3, 2)))

Shape specification#

Shapes are specified by the syntax:

"[" <dimension specifier 1> "," <dimension specifer 2> "," ... "," <dimension specifier n> "]"

where <dimension specifier i> is one of:

  • <integer>, to require that dimension to have that exact size:

        "v1: [2]",
        "v2: [2]",
    def vector_2d_distance(v1: AnyNDArray, v2: AnyNDArray) -> float:
        return float(np.sqrt(np.sum((v1 - v2) ** 2)))
  • <name>, to bind that dimension to a variable. Dimensions bound to the same variable must have the same size, though that size can be anything:

        "v1: [d]",
        "v2: [d]",
    def vector_distance(v1: AnyNDArray, v2: AnyNDArray) -> float:
        return float(np.sqrt(np.sum((v1 - v2) ** 2)))
  • None or . to allow exactly one single dimension without constraints:

        "v: [None]",
    def vector_length(v: AnyNDArray) -> float:
        return float(np.sqrt(np.sum(v ** 2)))


        "v: [.]",
    def vector_length(v: AnyNDArray) -> float:
        return float(np.sqrt(np.sum(v ** 2)))
  • *<name> or <name>..., to bind any number of dimensions to a variable. Again, multiple uses of the same variable name must match the same dimension sizes:

        "x: [*batch, n_columns]",
        "return: [*batch]",
    def batch_mean(x: AnyNDArray) -> AnyNDArray:
        mean: AnyNDArray = np.mean(x, axis=-1)
        return mean


        "x: [batch..., n_columns]",
        "return: [batch...]",
    def batch_mean(x: AnyNDArray) -> AnyNDArray:
        mean: AnyNDArray = np.mean(x, axis=-1)
        return mean
  • * or ..., to allow any number of dimensions without constraints:

        "x: [*]",
    def rank(x: AnyNDArray) -> int:
        return len(x.shape)


        "x: [...]",
    def rank(x: AnyNDArray) -> int:
        return len(x.shape)

A scalar shape is specified by []:

    "x: [...]",
    "return: []",
def mean(x: AnyNDArray) -> AnyNDArray:
    mean: AnyNDArray = np.sum(x) / x.size
    return mean

Any of the above can be prefixed with the keyword broadcast to allow any value that broadcasts to the specification. For example:

    "a: [broadcast batch...]",
    "b: [broadcast batch...]",
    "return: [batch...]",
def add(a: AnyNDArray, b: AnyNDArray) -> AnyNDArray:
    return a + b

Condition specification#

You can use the optional if <condition> syntax to conditionally evaluate shapes. If an if <condition> is used, the specification is only appplied if <condition> evaluates to True. This is useful if shapes depend on other input parameters. Valid conditions are:

  • <argument specifier>, with the same syntax and rules as above, except that constructions that evaluates to multiple elements are disallowed. Uses the bool built-in to convert the value of the argument to a bool:

        "a: [broadcast batch...] if check_a",
        "b: [broadcast batch...] if check_b",
        "return: [batch...]",
    def add(a: AnyNDArray, b: AnyNDArray, check_a: bool = True, check_b: bool = True) -> AnyNDArray:
        return a + b
    add(np.ones((3, 1)), np.ones((1, 4)), check_b=False)
  • <left> or <right>, evaluates to True if any of <left> or <right> evaluates to True and to False otherwise:

        "a: [broadcast batch...] if check_all or check_a",
        "b: [broadcast batch...] if check_all or check_b",
        "return: [batch...]",
    def add(
        a: AnyNDArray,
        b: AnyNDArray,
        check_all: bool = False,
        check_a: bool = True,
        check_b: bool = True,
    ) -> AnyNDArray:
        return a + b
    add(np.ones((3, 1)), np.ones((1, 4)), check_b=False)
  • <left> and <right>, evaluates to False if any of <left> or <right> evaluates to False and to True otherwise:

        "a: [broadcast batch...] if enable_checks and check_a",
        "b: [broadcast batch...] if enable_checks and check_b",
        "return: [batch...]",
    def add(
        a: AnyNDArray,
        b: AnyNDArray,
        enable_checks: bool = True,
        check_a: bool = True,
        check_b: bool = True,
    ) -> AnyNDArray:
        return a + b
    add(np.ones((3, 1)), np.ones((1, 4)), check_b=False)
  • not <right>, evaluates to the opposite value of <right>:

        "a: [broadcast batch...] if not disable_checks",
        "b: [broadcast batch...] if not disable_checks",
        "return: [batch...]",
    def add(a: AnyNDArray, b: AnyNDArray, disable_checks: bool = False) -> AnyNDArray:
        return a + b
    add(np.ones((3, 1)), np.ones((1, 4)))
  • (<exp>), uses parenthesis to change operator precedence, as usual.

Conditions can be composed to apply different specs, depending on function arguments:

    "a: [j] if a_vector",
    "a: [i, j] if (not a_vector)",
    "b: [j] if b_vector",
    "b: [j, k] if (not b_vector)",
    "return: [1, 1] if a_vector and b_vector",
    "return: [1, k] if a_vector and (not b_vector)",
    "return: [i, 1] if (not a_vector) and b_vector",
    "return: [i, k] if (not a_vector) and (not b_vector)",
def multiply(a: AnyNDArray, b: AnyNDArray, a_vector: bool, b_vector: bool) -> AnyNDArray:
    if a_vector:
        a = a[None, :]
    if b_vector:
        b = b[:, None]

    return a @ b

multiply(np.ones((4,)), np.ones((4, 5)), a_vector=True, b_vector=False)


All specifications with either no if syntax or a <condition> that evaluates to True will be applied. It is possible for multiple specifications to apply to the same value.

Note specification#

You can add notes to your specifications using a # followed by the note. These notes will be appended to relevant error messages and appear in rewritten docstrings. You can add notes in two places:

  • On a single line by itself, to add a note to the entire function:

        "features: [batch..., n_features]",
        "# linear_model currently only supports a single output.",
        "weights: [n_features]",
        "return: [batch...]",
    def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray:
        prediction: AnyNDArray = np.einsum("...i,i -> ...", features, weights)
        return prediction
  • After the specification of a single argument, to add a note to that argument only:

        "features: [batch..., n_features]",
        "weights: [n_features] # linear_model currently only supports a single output.",
        "return: [batch...]",
    def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray:
        prediction: AnyNDArray = np.einsum("...i,i -> ...", features, weights)
        return prediction

Shape reuse#

Just like with other code it is useful to be able to specify a shape in one place and reuse the specification. In particular this ensures that your code keep having internally consistent shapes, even if it is refactored.

Class inheritance#

If you have a class hiererchy, you probably want to ensure that derived classes handle tensors with the same shapes as the base classes. You can use the inherit_check_shapes() decorator to inherit shapes from overridden methods:

class Model(ABC):
        "features: [batch..., n_features]",
        "return: [batch...]",
    def predict(self, features: AnyNDArray) -> AnyNDArray:

class LinearModel(Model):
        "weights: [n_features]",
    def __init__(self, weights: AnyNDArray) -> None:
        self._weights = weights

    def predict(self, features: AnyNDArray) -> AnyNDArray:
        prediction: AnyNDArray = np.einsum("...i,i -> ...", features, self._weights)
        return prediction

Functional programming#

If you prefer functional- over object oriented programming, you may have functions that you require to handle the same shapes. To do this, remember that in Python a decorator is just a function, and functions are objects that can be stored:

check_metric_shapes = check_shapes(
    "actual: [n_rows, n_labels]",
    "predicted: [n_rows, n_labels]",
    "return: []",

def rmse(actual: AnyNDArray, predicted: AnyNDArray) -> float:
    return float(np.mean(np.sqrt(np.mean((predicted - actual) ** 2, axis=-1))))

def mape(actual: AnyNDArray, predicted: AnyNDArray) -> float:
    return float(np.mean(np.abs((predicted - actual) / actual)))

Other reuse of shapes#

You can use get_check_shapes() to get, and reuse, the shape definitions from a previously declared function. This is particularly useful to ensure fakes in tests use the same shapes as the production implementation:

class Model(ABC):
        "features: [batch..., n_features]",
        "return: [batch...]",
    def predict(self, features: AnyNDArray) -> AnyNDArray:

    "test_features: [n_rows, n_features]",
    "test_labels: [n_rows]",
def evaluate_model(model: Model, test_features: AnyNDArray, test_labels: AnyNDArray) -> float:
    prediction = model.predict(test_features)
    return float(np.mean(np.sqrt(np.mean((prediction - test_labels) ** 2, axis=-1))))

def test_evaluate_model() -> None:
    fake_features = np.ones((10, 3))
    fake_labels = np.ones((10,))
    fake_predictions = np.ones((10,))

    def fake_predict(features: AnyNDArray) -> AnyNDArray:
        assert features is fake_features
        return fake_predictions

    fake_model = MagicMock(spec=Model, predict=fake_predict)

    assert pytest.approx(0.0) == evaluate_model(fake_model, fake_features, fake_labels)

Checking intermediate results#

You can use the function check_shape() to check the shape of an intermediate result. This function will use the same namespace as the immediately surrounding check_shapes() decorator, and should only be called within functions that has such a decorator. For example:

    "weights: [n_features, n_labels]",
    "test_features: [n_rows, n_features]",
    "test_labels: [n_rows, n_labels]",
    "return: []",
def loss(weights: AnyNDArray, test_features: AnyNDArray, test_labels: AnyNDArray) -> AnyNDArray:
    prediction = check_shape(test_features @ weights, "[n_rows, n_labels]")
    error: AnyNDArray = check_shape(prediction - test_labels, "[n_rows, n_labels]")
    square_error = check_shape(error ** 2, "[n_rows, n_labels]")
    mean_square_error = check_shape(np.mean(square_error, axis=-1), "[n_rows]")
    root_mean_square_error = check_shape(np.sqrt(mean_square_error), "[n_rows]")
    loss: AnyNDArray = np.mean(root_mean_square_error)
    return loss

Checking shapes without a decorator#

While the check_shapes() decorator is the recommend way to use this library, it is possible to use it without the decorator. In fact the decorator is just a wrapper around the class ShapeChecker, which can be used to check shapes directly:

def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray:
    checker = ShapeChecker()
    checker.check_shape(features, "[batch..., n_features]")
    checker.check_shape(weights, "[n_features]")
    prediction: AnyNDArray = checker.check_shape(
        np.einsum("...i,i -> ...", features, weights), "[batch...]"
    return prediction

You can use the function get_shape_checker() to get the ShapeChecker used by any immediately surrounding check_shapes() decorator.

Speed, and interactions with tf.function#

If you want to wrap your function in both tf.function() and check_shapes() it is recommended you put the tf.function() outermost so that the shape checks are inside tf.function(). Shape checks are performed while tracing graphs, but not compiled into the actual graphs. This is considered a feature as that means that check_shapes() doesn’t impact the execution speed of compiled functions. However, it also means that tensor dimensions of dynamic size are not verified in compiled mode.

Disabling shape checking#

If your code is very performance sensitive and check_shapes is causing an unacceptable slowdown it can be disabled. Preferably use the context mananger disable_check_shapes():

with disable_check_shapes():

Alternatively check_shapes can also be disable globally with set_enable_check_shapes():



Beware that any function declared while shape checking is disabled, will continue not to check shapes, even if shape checking is otherwise enabled again.

Documenting shapes#

The check_shapes() decorator rewrites the docstring (.__doc__) of the decorated function to add information about shapes, in a format compatible with Sphinx. Only parameters that already have a :param ...: section will be modified.

For example:

    "features: [batch..., n_features]",
    "weights: [n_features]",
    "return: [batch...]",
def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray:
    Computes a prediction from a linear model.

    :param features: Data to make predictions from.
    :param weights: Model weights.
    :returns: Model predictions.
    prediction: AnyNDArray = np.einsum("...i,i -> ...", features, weights)
    return prediction

will have .__doc__:

Computes a prediction from a linear model.

:param features:
    * **features** has shape [*batch*..., *n_features*].

    Data to make predictions from.
:param weights:
    * **weights** has shape [*n_features*].

    Model weights.
    * **return** has shape [*batch*...].

    Model predictions.

if you do not wish to have your docstrings rewritten, you can disable it with set_rewrite_docstrings():


Supported types#

This library has built-in support for checking the shapes of:

  • Python built-in scalars: bool, int, float and str.

  • Python sequences:, including tuple and list.

  • NumPy ndarrays.

  • TensorFlow Tensors and Variables.

  • TensorFlow Probability DeferredTensors.

Shapes of custom types#

check_shapes uses the function get_shape() to extract the shape of an object. get_shape() uses functools.singledispatch() to branch on the type of object to get the shape from, and you can extend this to extract shapes for you own custom types:

class LinearModel:
        "weights: [n_features]",
    def __init__(self, weights: AnyNDArray) -> None:
        self._weights = weights

        "self: [n_features]",
        "features: [batch..., n_features]",
        "return: [batch...]",
    def predict(self, features: AnyNDArray) -> AnyNDArray:
        prediction: AnyNDArray = np.einsum("...i,i -> ...", features, self._weights)
        return prediction

def get_linear_model_shape(model: LinearModel, context: ErrorContext) -> Shape:
    shape: Shape = model._weights.shape
    return shape

    "model: [n_features]",
    "test_features: [n_rows, n_features]",
    "test_labels: [n_rows]",
    "return: []",
def loss(model: LinearModel, test_features: AnyNDArray, test_labels: AnyNDArray) -> float:
    prediction = model.predict(test_features)
    return float(np.mean(np.sqrt(np.mean((prediction - test_labels) ** 2, axis=-1))))




class gpflow.experimental.check_shapes.DocstringFormat(value)[source]#

Bases: enum.Enum

Enumeration of supported formats of docstrings.

NONE = 'none'#

Do not rewrite docstrings.

SPHINX = 'sphinx'#

Rewrite docstrings in the Sphinx format.


class gpflow.experimental.check_shapes.ErrorContext[source]#

Bases: abc.ABC

A context in which an error can occur.

Contexts should be immutable, and implement __eq__() - so that they can be composed using StackContext and ParallelContext.

The contexts are often created even if an error doesn’t actually occur, so they should be cheap to create - prefer to do any slow computation in print(), rather than in __init__().

Maybe think of an ErrorContext as a factory of error messages.

abstract print(builder)[source]#

Print this context to the given MessageBuilder.


builder (MessageBuilder) –

Return type



class gpflow.experimental.check_shapes.ShapeChecker[source]#

Bases: object

Mechanism for checking the shapes of tensors.

This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.


def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray:
    checker = ShapeChecker()
    checker.check_shape(features, "[batch..., n_features]")
    checker.check_shape(weights, "[n_features]")
    prediction: AnyNDArray = checker.check_shape(
        np.einsum("...i,i -> ...", features, weights), "[batch...]"
    return prediction

Add arbirtary context to the shape checker.

This context will be included in any error messages.


context (ErrorContext) – Context to add to this shape checker.

Return type


check_shape(shaped, tensor_spec, context=None)[source]#

Raise an error if a tensor has the wrong shape.

This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.

  • shaped (TypeVar(S)) – The object whose shape to check.

  • tensor_spec (Union[ParsedTensorSpec, str, Tuple[Optional[int], ...], None]) – Specification to check the tensor against. Usually this is a str in the format described in Shape specification. Alternatively this can be a pre-parsed ParsedTensorSpec, or an actual Shape.

  • context (Optional[ErrorContext]) – Information about where shaped is coming from, for improved error messages.

Return type



shaped, for convenience.


Raise an error if any tensor has the wrong shape.

This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.


checks (Iterable[Union[Tuple[Any, Union[ParsedTensorSpec, str, Tuple[Optional[int], ...], None]], Tuple[Any, Union[ParsedTensorSpec, str, Tuple[Optional[int], ...], None], ErrorContext]]]) – Checks to perform. The elements can either be (shaped, tensor_spec) or (shaped, tensor_spec, context) tuples. Where: shaped is the tensor whose shape to check; tensor_spec is the specification to check it against (see Shape specification); and context contains (optional) information about where shaped came from - for better error messages.

Return type




gpflow.experimental.check_shapes.check_shape(shaped, tensor_spec, context=None)[source]#

Raise an error if a tensor has the wrong shape.

This uses the ShapeChecker from the wrapping check_shapes() decorator. Behaviour is undefined if you call this from a function that is not directly wrapped in check_shapes() or inherit_check_shapes().


    "weights: [n_features, n_labels]",
    "test_features: [n_rows, n_features]",
    "test_labels: [n_rows, n_labels]",
    "return: []",
def loss(weights: AnyNDArray, test_features: AnyNDArray, test_labels: AnyNDArray) -> AnyNDArray:
    prediction = check_shape(test_features @ weights, "[n_rows, n_labels]")
    error: AnyNDArray = check_shape(prediction - test_labels, "[n_rows, n_labels]")
    square_error = check_shape(error ** 2, "[n_rows, n_labels]")
    mean_square_error = check_shape(np.mean(square_error, axis=-1), "[n_rows]")
    root_mean_square_error = check_shape(np.sqrt(mean_square_error), "[n_rows]")
    loss: AnyNDArray = np.mean(root_mean_square_error)
    return loss
  • shaped (TypeVar(S)) – The object whose shape to check.

  • tensor_spec (Union[ParsedTensorSpec, str, Tuple[Optional[int], ...], None]) – Specification to check the tensor against. See: Shape specification.

  • context (Optional[ErrorContext]) – Information about where shaped is coming from, for improved error messages.

Return type



shaped, for convenience.



Decorator that checks the shapes of tensor arguments.


import tensorflow as tf

from gpflow.experimental.check_shapes import check_shapes

    "features: [batch..., n_features]",
    "weights: [n_features]",
    "return: [batch...]",
def linear_model(features: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
    return tf.einsum("...i,i -> ...", features, weights)

specs (str) – Specification of arguments to check. See: Check specification.

Return type

Callable[[TypeVar(C, bound= Callable[..., Any])], TypeVar(C, bound= Callable[..., Any])]



Context manager that temporarily disables shape checking.


with disable_check_shapes():
Return type




Get the check_shapes that was applied to func.


ValueError – If no check_shapes was applied to func.


func (Callable[..., Any]) –

Return type

Callable[[TypeVar(C, bound= Callable[..., Any])], TypeVar(C, bound= Callable[..., Any])]



Get whether to enable check_shapes.

Return type




Get how check_shapes should rewrite docstrings.

Return type



gpflow.experimental.check_shapes.get_shape(shaped, context)[source]#

Returns the shape of the given object.

  • shaped (Any) – The objects whose shape to extract.

  • context (ErrorContext) – Context we are getting the shape in, for improved error messages.

Return type

Optional[Tuple[Optional[int], ...]]


The shape of shaped, or None if the shape exists, but is unknown.


NoShapeError – If objects of this type does not have shapes.



Get the ShapeChecker from the wrapping check_shapes() decorator.

Behaviour is undefined if you call this from a function that is not directly wrapped in check_shapes() or inherit_check_shapes().

Return type




Decorator that inherits the check_shapes() decoration from any overridden method in a super-class.


class Model(ABC):
        "features: [batch..., n_features]",
        "return: [batch...]",
    def predict(self, features: AnyNDArray) -> AnyNDArray:

class LinearModel(Model):
        "weights: [n_features]",
    def __init__(self, weights: AnyNDArray) -> None:
        self._weights = weights

    def predict(self, features: AnyNDArray) -> AnyNDArray:
        prediction: AnyNDArray = np.einsum("...i,i -> ...", features, self._weights)
        return prediction

See: Class inheritance.


func (TypeVar(C, bound= Callable[..., Any])) –

Return type

TypeVar(C, bound= Callable[..., Any])



Set whether to enable check_shapes.

Check shapes has a non-zero impact on performance. If this is unacceptable to you, you can use this function to disable it.




See also disable_check_shapes().


enabled (bool) –

Return type




Set how check_shapes should rewrite docstrings.

See DocstringFormat for valid choices.


docstring_format (Union[DocstringFormat, str, None]) –

Return type
