gpflow.experimental.check_shapes#
A library for annotating and checking the shapes of tensors.
The main entry point is check_shapes()
.
Example:
import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes
@tf.function
@check_shapes(
"features: [batch..., n_features]",
"weights: [n_features]",
"return: [batch...]",
)
def linear_model(features: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
return tf.einsum("...i,i -> ...", features, weights)
Speed, and interactions with tf.function#
Shape checking has some performance impact. Shape checking can be disabled to help alleviate this. Shape checking can be set to one of three different states:
ENABLED
. Shapes are checked wherever they can be.EAGER_MODE_ONLY
. Shapes are not checked within anything wrapped intf.function()
.DISABLED
. Shapes are never checked.
The state can be set with set_enable_check_shapes()
:
set_enable_check_shapes(ShapeCheckingState.DISABLED)
Alternatively you can use disable_check_shapes()
to disable shape checking in smaller scopes:
with disable_check_shapes():
performance_sensitive_function()
Beware that any function declared while shape checking is disabled, will continue not to check shapes, even if shape checking is otherwise enabled again.
The default state is EAGER_MODE_ONLY
; which is appropriate for smaller project, experiments, and
notebooks. Write and debug your code in eager mode, and add tf.function()
when you believe
your code is correct and you want it to run fast. For larger project you probably want to modify
this setting. In particular you may want to enable all shape checks in your unit tests. If you use
pytest you can do this by updating your root conftest.py
with:
@pytest.fixture(autouse=True)
def enable_shape_checks() -> Iterable[None]:
old_enable = get_enable_check_shapes()
old_rewrite_docstrings = get_rewrite_docstrings()
old_function_call_precompute = get_enable_function_call_precompute()
set_enable_check_shapes(ShapeCheckingState.ENABLED)
set_rewrite_docstrings(DocstringFormat.SPHINX)
set_enable_function_call_precompute(True)
yield
set_enable_function_call_precompute(old_function_call_precompute)
set_rewrite_docstrings(old_rewrite_docstrings)
set_enable_check_shapes(old_enable)
If shape checking is set to ENABLED
and your code is wrapped in tf.function()
shape checks
are performed while tracing graphs, but not compiled into the actual graphs. This is considered a
feature as that means that check_shapes()
doesn’t impact the execution speed of your functions
after they have been compiled.
Best-effort checking#
This library will perform shape checks on a best-effort basis. Many things can prevent this library from being able to check shapes. For example:
Unknown shapes. Sometimes the library is not able to determine the shape of an object, and thus cannot check that object. For example
Optional
arguments with valueNone
cannot be checked, and compiled TensorFlow code can have variables with an unknown shape.Use of variable-rank dimensions (see below). In general we cannot infer the size of variable-rank dimensions if there are multiple variable-rank specifications within the same shape specification (e.g.
cov: [m..., n...]
). This library will try to learn the size of these variable-rank dimensions from neighbouring shape specifications, but this is not always possible. Use ofbroadcast
with variable-rank dimensions makes it even harder to infer these values.
Check specification#
The shapes to check are specified by the arguments to check_shapes()
. Each argument is a
string of the format:
<argument specifier> ":" <shape specifier> ["if" <condition>] ["#" <note>]
Argument specification#
The <argument specifier>
must start with either the name of an argument to the decorated
function, or the special name return
. The value return
refers to the value returned by the
function.
The <argument specifier>
can then be modified to refer to elements of the object in several
ways:
Use
.<name>
to refer to an attribute of an object:@dataclass class Statistics: mean: AnyNDArray std: AnyNDArray @check_shapes( "data: [n_rows, n_columns]", "return.mean: [n_columns]", "return.std: [n_columns]", ) def compute_statistics(data: AnyNDArray) -> Statistics: return Statistics(np.mean(data, axis=0), np.std(data, axis=0))
Use
[<index>]
to refer to a specific element of a sequence. This is particularly useful if your function returns a tuple of values:@check_shapes( "data: [n_rows, n_columns]", "return[0]: [n_columns]", "return[1]: [n_columns]", ) def compute_mean_and_std(data: AnyNDArray) -> Tuple[AnyNDArray, AnyNDArray]: return np.mean(data, axis=0), np.std(data, axis=0)
Use
[all]
to select all elements of a collection:@check_shapes( "data[all]: [., n_columns]", "return: [., n_columns]", ) def concat_rows(data: Sequence[AnyNDArray]) -> AnyNDArray: return np.concatenate(data, axis=0) concat_rows( [ np.ones((1, 3)), np.ones((4, 3)), ] )
Use
.keys()
to select all keys of a mapping:@check_shapes( "data.keys(): [.]", "return: []", ) def sum_key_lengths(data: Mapping[Tuple[int, ...], str]) -> int: return sum(len(k) for k in data) sum_key_lengths( { (3,): "foo", (1, 2): "bar", } )
Use
.values()
to select all values of a mapping:@check_shapes( "data.values(): [., n_columns]", "return: [., n_columns]", ) def concat_rows(data: Mapping[str, AnyNDArray]) -> AnyNDArray: return np.concatenate(list(data.values()), axis=0) concat_rows( { "foo": np.ones((1, 3)), "bar": np.ones((4, 3)), } )
Note
We do not support looking up a specific key or value in a dict
.
If the argument, or any of the looked-up values, are None
the check is skipped. This is useful
for optional values:
@check_shapes(
"x1: [n_rows_1, n_inputs]",
"x2: [n_rows_2, n_inputs]",
"return: [n_rows_1, n_rows_2]",
)
def squared_exponential_kernel(
variance: float, x1: AnyNDArray, x2: Optional[AnyNDArray] = None
) -> AnyNDArray:
if x2 is None:
x2 = x1
cov: AnyNDArray = variance * np.exp(
-0.5 * np.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=2)
)
return cov
squared_exponential_kernel(1.0, np.ones((3, 2)), np.ones((4, 2)))
squared_exponential_kernel(3.2, np.ones((3, 2)))
Shape specification#
Shapes are specified by the syntax:
"[" <dimension specifier 1> "," <dimension specifer 2> "," ... "," <dimension specifier n> "]"
where <dimension specifier i>
is one of:
<integer>
, to require that dimension to have that exact size:@check_shapes( "v1: [2]", "v2: [2]", ) def vector_2d_distance(v1: AnyNDArray, v2: AnyNDArray) -> float: return float(np.sqrt(np.sum((v1 - v2) ** 2)))
<name>
, to bind that dimension to a variable. Dimensions bound to the same variable must have the same size, though that size can be anything:@check_shapes( "v1: [d]", "v2: [d]", ) def vector_distance(v1: AnyNDArray, v2: AnyNDArray) -> float: return float(np.sqrt(np.sum((v1 - v2) ** 2)))
None
or.
to allow exactly one single dimension without constraints:@check_shapes( "v: [None]", ) def vector_length(v: AnyNDArray) -> float: return float(np.sqrt(np.sum(v ** 2)))
or:
@check_shapes( "v: [.]", ) def vector_length(v: AnyNDArray) -> float: return float(np.sqrt(np.sum(v ** 2)))
*<name>
or<name>...
, to bind any number of dimensions to a variable. Again, multiple uses of the same variable name must match the same dimension sizes:@check_shapes( "x: [*batch, n_columns]", "return: [*batch]", ) def batch_mean(x: AnyNDArray) -> AnyNDArray: mean: AnyNDArray = np.mean(x, axis=-1) return mean
or:
@check_shapes( "x: [batch..., n_columns]", "return: [batch...]", ) def batch_mean(x: AnyNDArray) -> AnyNDArray: mean: AnyNDArray = np.mean(x, axis=-1) return mean
*
or...
, to allow any number of dimensions without constraints:@check_shapes( "x: [*]", ) def rank(x: AnyNDArray) -> int: return len(x.shape)
or:
@check_shapes( "x: [...]", ) def rank(x: AnyNDArray) -> int: return len(x.shape)
A scalar shape is specified by []
:
@check_shapes(
"x: [...]",
"return: []",
)
def mean(x: AnyNDArray) -> AnyNDArray:
mean: AnyNDArray = np.sum(x) / x.size
return mean
Any of the above can be prefixed with the keyword broadcast
to allow any value that broadcasts
to the specification. For example:
@check_shapes(
"a: [broadcast batch...]",
"b: [broadcast batch...]",
"return: [batch...]",
)
def add(a: AnyNDArray, b: AnyNDArray) -> AnyNDArray:
return a + b
Specifically, to mark a dimension as broadcast
means:
If the specification is that the dimension should have size
n
, then the actual dimension must have value1
orn
.If all leading dimension specifications are also marked
broadcast
, then the actual shape is allowed to be shorter than the specification — the dimension is allowed to be missing.
Condition specification#
You can use the optional if <condition>
syntax to conditionally evaluate shapes. If an if
<condition>
is used, the specification is only appplied if <condition>
evaluates to True
.
This is useful if shapes depend on other input parameters. Valid conditions are:
<argument specifier>
, with the same syntax and rules as above, except that constructions that evaluates to multiple elements are disallowed. Uses thebool
built-in to convert the value of the argument to abool
:@check_shapes( "a: [broadcast batch...] if check_a", "b: [broadcast batch...] if check_b", "return: [batch...]", ) def add(a: AnyNDArray, b: AnyNDArray, check_a: bool = True, check_b: bool = True) -> AnyNDArray: return a + b add(np.ones((3, 1)), np.ones((1, 4)), check_b=False)
<argument specifier> is None
, and<argument specifier> is not None
, with the usual rules for an<argument specifier>
, to test whether an argument is, or is not,None
. We currently only allow tests againstNone
, not general Python equality tests:@check_shapes( "a: [n_a]", "b: [n_b]", "return: [n_a, n_a] if b is None", "return: [n_a, n_b] if b is not None", ) def square(a: AnyNDArray, b: Optional[AnyNDArray] = None) -> AnyNDArray: if b is None: b = a result: AnyNDArray = a[:, None] * b[None, :] return result square(np.ones((3,))) square(np.ones((3,)), np.ones((4,)))
<left> or <right>
, evaluates toTrue
if any of<left>
or<right>
evaluates toTrue
and toFalse
otherwise:@check_shapes( "a: [broadcast batch...] if check_all or check_a", "b: [broadcast batch...] if check_all or check_b", "return: [batch...]", ) def add( a: AnyNDArray, b: AnyNDArray, check_all: bool = False, check_a: bool = True, check_b: bool = True, ) -> AnyNDArray: return a + b add(np.ones((3, 1)), np.ones((1, 4)), check_b=False)
<left> and <right>
, evaluates toFalse
if any of<left>
or<right>
evaluates toFalse
and toTrue
otherwise:@check_shapes( "a: [broadcast batch...] if enable_checks and check_a", "b: [broadcast batch...] if enable_checks and check_b", "return: [batch...]", ) def add( a: AnyNDArray, b: AnyNDArray, enable_checks: bool = True, check_a: bool = True, check_b: bool = True, ) -> AnyNDArray: return a + b add(np.ones((3, 1)), np.ones((1, 4)), check_b=False)
not <right>
, evaluates to the opposite value of<right>
:@check_shapes( "a: [broadcast batch...] if not disable_checks", "b: [broadcast batch...] if not disable_checks", "return: [batch...]", ) def add(a: AnyNDArray, b: AnyNDArray, disable_checks: bool = False) -> AnyNDArray: return a + b add(np.ones((3, 1)), np.ones((1, 4)))
(<exp>)
, uses parenthesis to change operator precedence, as usual.
Conditions can be composed to apply different specs, depending on function arguments:
@check_shapes(
"a: [j] if a_vector",
"a: [i, j] if (not a_vector)",
"b: [j] if b_vector",
"b: [j, k] if (not b_vector)",
"return: [1, 1] if a_vector and b_vector",
"return: [1, k] if a_vector and (not b_vector)",
"return: [i, 1] if (not a_vector) and b_vector",
"return: [i, k] if (not a_vector) and (not b_vector)",
)
def multiply(a: AnyNDArray, b: AnyNDArray, a_vector: bool, b_vector: bool) -> AnyNDArray:
if a_vector:
a = a[None, :]
if b_vector:
b = b[:, None]
return a @ b
multiply(np.ones((4,)), np.ones((4, 5)), a_vector=True, b_vector=False)
Note
All specifications with either no if
syntax or a <condition>
that evaluates to True
will be applied. It is possible for multiple specifications to apply to the same value.
Note specification#
You can add notes to your specifications using a #
followed by the note. These notes will be
appended to relevant error messages and appear in rewritten docstrings. You can add notes in two
places:
On a single line by itself, to add a note to the entire function:
@check_shapes( "features: [batch..., n_features]", "# linear_model currently only supports a single output.", "weights: [n_features]", "return: [batch...]", ) def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray: prediction: AnyNDArray = np.einsum("...i,i -> ...", features, weights) return prediction
After the specification of a single argument, to add a note to that argument only:
@check_shapes( "features: [batch..., n_features]", "weights: [n_features] # linear_model currently only supports a single output.", "return: [batch...]", ) def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray: prediction: AnyNDArray = np.einsum("...i,i -> ...", features, weights) return prediction
Shape reuse#
Just like with other code it is useful to be able to specify a shape in one place and reuse the specification. In particular this ensures that your code keep having internally consistent shapes, even if it is refactored.
Class inheritance#
If you have a class hiererchy, you probably want to ensure that derived classes handle tensors with
the same shapes as the base classes. You can use the inherit_check_shapes()
decorator to
inherit shapes from overridden methods:
class Model(ABC):
@abstractmethod
@check_shapes(
"features: [batch..., n_features]",
"return: [batch...]",
)
def predict(self, features: AnyNDArray) -> AnyNDArray:
pass
class LinearModel(Model):
@check_shapes(
"weights: [n_features]",
)
def __init__(self, weights: AnyNDArray) -> None:
self._weights = weights
@inherit_check_shapes
def predict(self, features: AnyNDArray) -> AnyNDArray:
prediction: AnyNDArray = np.einsum("...i,i -> ...", features, self._weights)
return prediction
Functional programming#
If you prefer functional- over object oriented programming, you may have functions that you require to handle the same shapes. To do this, remember that in Python a decorator is just a function, and functions are objects that can be stored:
check_metric_shapes = check_shapes(
"actual: [n_rows, n_labels]",
"predicted: [n_rows, n_labels]",
"return: []",
)
@check_metric_shapes
def rmse(actual: AnyNDArray, predicted: AnyNDArray) -> float:
return float(np.mean(np.sqrt(np.mean((predicted - actual) ** 2, axis=-1))))
@check_metric_shapes
def mape(actual: AnyNDArray, predicted: AnyNDArray) -> float:
return float(np.mean(np.abs((predicted - actual) / actual)))
Other reuse of shapes#
You can use get_check_shapes()
to get, and reuse, the shape definitions from a previously
declared function. This is particularly useful to ensure fakes in tests use the same shapes as the
production implementation:
class Model(ABC):
@abstractmethod
@check_shapes(
"features: [batch..., n_features]",
"return: [batch...]",
)
def predict(self, features: AnyNDArray) -> AnyNDArray:
pass
@check_shapes(
"test_features: [n_rows, n_features]",
"test_labels: [n_rows]",
)
def evaluate_model(model: Model, test_features: AnyNDArray, test_labels: AnyNDArray) -> float:
prediction = model.predict(test_features)
return float(np.mean(np.sqrt(np.mean((prediction - test_labels) ** 2, axis=-1))))
def test_evaluate_model() -> None:
fake_features = np.ones((10, 3))
fake_labels = np.ones((10,))
fake_predictions = np.ones((10,))
@get_check_shapes(Model.predict)
def fake_predict(features: AnyNDArray) -> AnyNDArray:
assert features is fake_features
return fake_predictions
fake_model = MagicMock(spec=Model, predict=fake_predict)
assert pytest.approx(0.0) == evaluate_model(fake_model, fake_features, fake_labels)
Checking intermediate results#
You can use the function check_shape()
to check the shape of an intermediate result. This
function will use the same namespace as the immediately surrounding check_shapes()
decorator,
and should only be called within functions that has such a decorator. For example:
@check_shapes(
"weights: [n_features, n_labels]",
"test_features: [n_rows, n_features]",
"test_labels: [n_rows, n_labels]",
"return: []",
)
def loss(weights: AnyNDArray, test_features: AnyNDArray, test_labels: AnyNDArray) -> AnyNDArray:
prediction = check_shape(test_features @ weights, "[n_rows, n_labels]")
error: AnyNDArray = check_shape(prediction - test_labels, "[n_rows, n_labels]")
square_error = check_shape(error ** 2, "[n_rows, n_labels]")
mean_square_error = check_shape(np.mean(square_error, axis=-1), "[n_rows]")
root_mean_square_error = check_shape(np.sqrt(mean_square_error), "[n_rows]")
loss: AnyNDArray = np.mean(root_mean_square_error)
return loss
Checking shapes without a decorator#
While the check_shapes()
decorator is the recommend way to use this library, it is possible to
use it without the decorator. In fact the decorator is just a wrapper around the class
ShapeChecker
, which can be used to check shapes directly:
def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray:
checker = ShapeChecker()
checker.check_shape(features, "[batch..., n_features]")
checker.check_shape(weights, "[n_features]")
prediction: AnyNDArray = checker.check_shape(
np.einsum("...i,i -> ...", features, weights), "[batch...]"
)
return prediction
You can use the function get_shape_checker()
to get the ShapeChecker
used by any
immediately surrounding check_shapes()
decorator.
Documenting shapes#
The check_shapes()
decorator rewrites the docstring (.__doc__
) of the decorated function
to add information about shapes, in a format compatible with
Sphinx.
Only functions that already have a docstring will be updated. Functions that have no docstring at all will not have one added, this is so that we do not override a docstring that would have been inherited from a super class.
For example:
@check_shapes(
"features: [batch..., n_features]",
"weights: [n_features]",
"return: [batch...]",
)
def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray:
"""
Computes a prediction from a linear model.
:param features: Data to make predictions from.
:param weights: Model weights.
:returns: Model predictions.
"""
prediction: AnyNDArray = np.einsum("...i,i -> ...", features, weights)
return prediction
will have .__doc__
:
"""
Computes a prediction from a linear model.
:param features:
* **features** has shape [*batch*..., *n_features*].
Data to make predictions from.
:param weights:
* **weights** has shape [*n_features*].
Model weights.
:returns:
* **return** has shape [*batch*...].
Model predictions.
"""
if you do not wish to have your docstrings rewritten, you can disable it with
set_rewrite_docstrings()
:
set_rewrite_docstrings(None)
Supported types#
This library has built-in support for checking the shapes of:
Python built-in scalars:
bool
,int
,float
andstr
.Python built-in sequences:
tuple
andlist
.NumPy
ndarray
s.TensorFlow
Tensor
s andVariable
s.TensorFlow Probability
DeferredTensor
s, includingTransformedVariable
andgpflow.Parameter
.
Shapes of custom types#
check_shapes
uses the function get_shape()
to extract the shape of an object.
You can use register_get_shape()
to extend get_shape()
to extract shapes for you own
custom types:
class LinearModel:
@check_shapes(
"weights: [n_features]",
)
def __init__(self, weights: AnyNDArray) -> None:
self._weights = weights
@check_shapes(
"self: [n_features]",
"features: [batch..., n_features]",
"return: [batch...]",
)
def predict(self, features: AnyNDArray) -> AnyNDArray:
prediction: AnyNDArray = np.einsum("...i,i -> ...", features, self._weights)
return prediction
@register_get_shape(LinearModel)
def get_linear_model_shape(model: LinearModel, context: ErrorContext) -> Shape:
shape: Shape = model._weights.shape
return shape
@check_shapes(
"model: [n_features]",
"test_features: [n_rows, n_features]",
"test_labels: [n_rows]",
"return: []",
)
def loss(model: LinearModel, test_features: AnyNDArray, test_labels: AnyNDArray) -> float:
prediction = model.predict(test_features)
return float(np.mean(np.sqrt(np.mean((prediction - test_labels) ** 2, axis=-1))))
Modules#
- gpflow.experimental.check_shapes.accessors
- gpflow.experimental.check_shapes.argument_ref
- gpflow.experimental.check_shapes.bool_specs
- gpflow.experimental.check_shapes.checker_context
- gpflow.experimental.check_shapes.decorator
- gpflow.experimental.check_shapes.error_contexts
- gpflow.experimental.check_shapes.exceptions
- gpflow.experimental.check_shapes.parser
- gpflow.experimental.check_shapes.shapes
- gpflow.experimental.check_shapes.specs
Classes#
gpflow.experimental.check_shapes.DocstringFormat#
gpflow.experimental.check_shapes.ErrorContext#
- class gpflow.experimental.check_shapes.ErrorContext[source]#
Bases:
ABC
A context in which an error can occur.
Contexts should be immutable, and implement
__eq__()
- so that they can be composed usingStackContext
andParallelContext
.The contexts are often created even if an error doesn’t actually occur, so they should be cheap to create - prefer to do any slow computation in
print()
, rather than in__init__()
.Maybe think of an
ErrorContext
as a factory of error messages.- abstract print(builder)[source]#
Print this context to the given MessageBuilder.
- Parameters:
builder (
MessageBuilder
) –- Return type:
None
gpflow.experimental.check_shapes.ShapeChecker#
- class gpflow.experimental.check_shapes.ShapeChecker[source]#
Bases:
object
Mechanism for checking the shapes of tensors.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
Example:
def linear_model(features: AnyNDArray, weights: AnyNDArray) -> AnyNDArray: checker = ShapeChecker() checker.check_shape(features, "[batch..., n_features]") checker.check_shape(weights, "[n_features]") prediction: AnyNDArray = checker.check_shape( np.einsum("...i,i -> ...", features, weights), "[batch...]" ) return prediction
- add_context(context)[source]#
Add arbirtary context to the shape checker.
This context will be included in any error messages.
- Parameters:
context (
ErrorContext
) – Context to add to this shape checker.- Return type:
None
- check_shape(shaped, tensor_spec, context=None)[source]#
Raise an error if a tensor has the wrong shape.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
- Parameters:
shaped (
TypeVar
(S
)) – The object whose shape to check.tensor_spec (
Union
[ParsedTensorSpec
,str
,Tuple
[Optional
[int
],...
],None
]) – Specification to check the tensor against. Usually this is astr
in the format described in Shape specification. Alternatively this can be a pre-parsedParsedTensorSpec
, or an actualShape
.context (
Optional
[ErrorContext
]) – Information about whereshaped
is coming from, for improved error messages.
- Return type:
TypeVar
(S
)- Returns:
shaped
, for convenience.
- check_shapes(checks)[source]#
Raise an error if any tensor has the wrong shape.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
- Parameters:
checks (
Iterable
[Union
[Tuple
[Any
,Union
[ParsedTensorSpec
,str
,Tuple
[Optional
[int
],...
],None
]],Tuple
[Any
,Union
[ParsedTensorSpec
,str
,Tuple
[Optional
[int
],...
],None
],ErrorContext
]]]) – Checks to perform. The elements can either be(shaped, tensor_spec)
or(shaped, tensor_spec, context)
tuples. Where:shaped
is the tensor whose shape to check;tensor_spec
is the specification to check it against (see Shape specification); andcontext
contains (optional) information about whereshaped
came from - for better error messages.- Return type:
None
gpflow.experimental.check_shapes.ShapeCheckingState#
- class gpflow.experimental.check_shapes.ShapeCheckingState(value)[source]#
Bases:
Enum
Different states of whether to actually check shapes.
- DISABLED = 'disabled'#
Never check shapes.
- EAGER_MODE_ONLY = 'eager_mode_only'#
Only check shapes if tf.inside_function() is False.
- ENABLED = 'enabled'#
Always check shapes.
Functions#
gpflow.experimental.check_shapes.check_shape#
- gpflow.experimental.check_shapes.check_shape(shaped, tensor_spec, context=None)[source]#
Raise an error if a tensor has the wrong shape.
This uses the
ShapeChecker
from the wrappingcheck_shapes()
decorator. Behaviour is undefined if you call this from a function that is not directly wrapped incheck_shapes()
orinherit_check_shapes()
.Example:
@check_shapes( "weights: [n_features, n_labels]", "test_features: [n_rows, n_features]", "test_labels: [n_rows, n_labels]", "return: []", ) def loss(weights: AnyNDArray, test_features: AnyNDArray, test_labels: AnyNDArray) -> AnyNDArray: prediction = check_shape(test_features @ weights, "[n_rows, n_labels]") error: AnyNDArray = check_shape(prediction - test_labels, "[n_rows, n_labels]") square_error = check_shape(error ** 2, "[n_rows, n_labels]") mean_square_error = check_shape(np.mean(square_error, axis=-1), "[n_rows]") root_mean_square_error = check_shape(np.sqrt(mean_square_error), "[n_rows]") loss: AnyNDArray = np.mean(root_mean_square_error) return loss
- Parameters:
shaped (
TypeVar
(S
)) – The object whose shape to check.tensor_spec (
Union
[ParsedTensorSpec
,str
,Tuple
[Optional
[int
],...
],None
]) – Specification to check the tensor against. See: Shape specification.context (
Optional
[ErrorContext
]) – Information about whereshaped
is coming from, for improved error messages.
- Return type:
TypeVar
(S
)- Returns:
shaped
, for convenience.
gpflow.experimental.check_shapes.check_shapes#
- gpflow.experimental.check_shapes.check_shapes(*specs, tf_decorator=False)[source]#
Decorator that checks the shapes of tensor arguments.
Example:
import tensorflow as tf from gpflow.experimental.check_shapes import check_shapes @tf.function @check_shapes( "features: [batch..., n_features]", "weights: [n_features]", "return: [batch...]", ) def linear_model(features: tf.Tensor, weights: tf.Tensor) -> tf.Tensor: return tf.einsum("...i,i -> ...", features, weights)
- Parameters:
specs (
str
) – Specification of arguments to check. See: Check specification.tf_decorator (
bool
) – Whether to wrap the shape check withtf.compat.v1.flags.tf_decorator.make_decorator
. Setting this True seems to solve some problems, particularly related to Keras models, but create some other problems, particularly related to branching on tensors.
- Return type:
Callable
[[TypeVar
(C
, bound=Callable
[...
,Any
])],TypeVar
(C
, bound=Callable
[...
,Any
])]
gpflow.experimental.check_shapes.disable_check_shapes#
gpflow.experimental.check_shapes.get_check_shapes#
- gpflow.experimental.check_shapes.get_check_shapes(func)[source]#
Get the
check_shapes
that was applied tofunc
.- Raises:
ValueError – If no
check_shapes
was applied tofunc
.- Parameters:
func (
Callable
[...
,Any
]) –- Return type:
Callable
[[TypeVar
(C
, bound=Callable
[...
,Any
])],TypeVar
(C
, bound=Callable
[...
,Any
])]
gpflow.experimental.check_shapes.get_enable_check_shapes#
gpflow.experimental.check_shapes.get_enable_function_call_precompute#
gpflow.experimental.check_shapes.get_rewrite_docstrings#
gpflow.experimental.check_shapes.get_shape#
- gpflow.experimental.check_shapes.get_shape(shaped, context)[source]#
Returns the shape of the given object.
See also
register_get_shape()
.- Parameters:
shaped (
Any
) – The objects whose shape to extract.context (
ErrorContext
) – Context we are getting the shape in, for improved error messages.
- Return type:
Optional
[Tuple
[Optional
[int
],...
]]- Returns:
The shape of
shaped
, orNone
if the shape exists, but is unknown.- Raises:
NoShapeError – If objects of this type does not have shapes.
gpflow.experimental.check_shapes.get_shape_checker#
- gpflow.experimental.check_shapes.get_shape_checker()[source]#
Get the
ShapeChecker
from the wrappingcheck_shapes()
decorator.Behaviour is undefined if you call this from a function that is not directly wrapped in
check_shapes()
orinherit_check_shapes()
.- Return type:
gpflow.experimental.check_shapes.inherit_check_shapes#
- gpflow.experimental.check_shapes.inherit_check_shapes(func)[source]#
Decorator that inherits the
check_shapes()
decoration from any overridden method in a super-class.Example:
class Model(ABC): @abstractmethod @check_shapes( "features: [batch..., n_features]", "return: [batch...]", ) def predict(self, features: AnyNDArray) -> AnyNDArray: pass class LinearModel(Model): @check_shapes( "weights: [n_features]", ) def __init__(self, weights: AnyNDArray) -> None: self._weights = weights @inherit_check_shapes def predict(self, features: AnyNDArray) -> AnyNDArray: prediction: AnyNDArray = np.einsum("...i,i -> ...", features, self._weights) return prediction
See: Class inheritance.
- Parameters:
func (
TypeVar
(C
, bound=Callable
[...
,Any
])) –- Return type:
TypeVar
(C
, bound=Callable
[...
,Any
])
gpflow.experimental.check_shapes.register_get_shape#
- gpflow.experimental.check_shapes.register_get_shape(shape_type)[source]#
Register a function for extracting the shape from a given type of objects.
Example:
class LinearModel: @check_shapes( "weights: [n_features]", ) def __init__(self, weights: AnyNDArray) -> None: self._weights = weights @check_shapes( "self: [n_features]", "features: [batch..., n_features]", "return: [batch...]", ) def predict(self, features: AnyNDArray) -> AnyNDArray: prediction: AnyNDArray = np.einsum("...i,i -> ...", features, self._weights) return prediction @register_get_shape(LinearModel) def get_linear_model_shape(model: LinearModel, context: ErrorContext) -> Shape: shape: Shape = model._weights.shape return shape @check_shapes( "model: [n_features]", "test_features: [n_rows, n_features]", "test_labels: [n_rows]", "return: []", ) def loss(model: LinearModel, test_features: AnyNDArray, test_labels: AnyNDArray) -> float: prediction = model.predict(test_features) return float(np.mean(np.sqrt(np.mean((prediction - test_labels) ** 2, axis=-1))))
See also
get_shape()
.- Parameters:
shape_type (
Type
[Any
]) – Type of objects to extract shapes from.- Return type:
Callable
[[Callable
[[Any
,ErrorContext
],Optional
[Tuple
[Optional
[int
],...
]]]],Callable
[[Any
,ErrorContext
],Optional
[Tuple
[Optional
[int
],...
]]]]
gpflow.experimental.check_shapes.set_enable_check_shapes#
- gpflow.experimental.check_shapes.set_enable_check_shapes(enabled)[source]#
Set whether to enable
check_shapes
.Check shapes has a non-zero impact on performance. If this is unacceptable to you, you can use this function to disable it.
Example:
set_enable_check_shapes(ShapeCheckingState.DISABLED)
See also
disable_check_shapes()
.- Parameters:
enabled (
Union
[ShapeCheckingState
,str
,bool
]) –- Return type:
None
gpflow.experimental.check_shapes.set_enable_function_call_precompute#
- gpflow.experimental.check_shapes.set_enable_function_call_precompute(enabled)[source]#
Set whether to precompute function call path and line numbers for debugging.
This is disabled by default, because it is (relatively) slow. Enabling this can give better error messages.
Example:
set_enable_function_call_precompute(True) buggy_function()
- Parameters:
enabled (
bool
) –- Return type:
None
gpflow.experimental.check_shapes.set_rewrite_docstrings#
- gpflow.experimental.check_shapes.set_rewrite_docstrings(docstring_format)[source]#
Set how
check_shapes
should rewrite docstrings.See
DocstringFormat
for valid choices.- Parameters:
docstring_format (
Union
[DocstringFormat
,str
,None
]) –- Return type:
None