check_shapes.checker

Class responsible for remembering and checking shapes.

Module Contents

class check_shapes.checker._ObservedDim

Storage of observed size of a single dimension variable.

check_and_update(actual, broadcast)

Attempt to merge new data into this observation.

Returns whether the new data is compatible with existing observations. If this method returns False this object may have been left in an invalid state and should not be used again.

Parameters:
  • actual (Optional[int]) –

  • broadcast (bool) –

Return type:

bool

class check_shapes.checker._ObservedDims

Storage of observed sizes of a var-rank / batch variable.

check_and_update(actual, broadcast, shape_possibly_truncated)

Attempt to merge new data into this observation.

Returns whether the new data is compatible with existing observations. If this method returns False this object may have been left in an invalid state and should not be used again.

Parameters:
  • actual (Optional[Tuple[Optional[int], Ellipsis]]) –

  • broadcast (bool) –

  • shape_possibly_truncated (bool) –

Return type:

bool

class check_shapes.checker._ShapeCheck

A shape check that is waiting to be performed.

property finished: bool

Whether this entire check has been performed.

Return type:

bool

actual: _Shape

Actual observed shape.

Only actual[actual_begin:actual_end] is still waiting to be checked. The beginning and end may already have been checked.

expected: check_shapes.specs.ParsedShapeSpec

Specification to check against.

Only expected[expected_begin:expected_end] is still waiting to be checked. The beginning and end may already have been checked.

class check_shapes.checker._VariableState

Structure of stuff we need to know about each variable.

uses: List[Tuple[check_shapes.specs.ParsedTensorSpec, check_shapes.error_contexts.ErrorContext]]

List of specs where this variable is used.

observed_dim: Optional[_ObservedDim]

Observed size of this variable, if the variable is rank-1.

Set this None if this variable is varrank.

observed_dims: Optional[_ObservedDims]

Observed shape of this variable, if the variable is varrank.

Set this None if this variable is rank-1.

waiting_for_varrank: Set[_ShapeCheck]

Checks that are waiting for the rank of this variable to be determined.

class check_shapes.checker.ShapeChecker

Mechanism for checking the shapes of tensors.

This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.

Example:


def linear_model(features: ndarray, weights: ndarray) -> ndarray:
    checker = ShapeChecker()
    checker.check_shape(features, "[batch..., n_features]")
    checker.check_shape(weights, "[n_features]")
    prediction: ndarray = checker.check_shape(
        np.einsum("...i,i -> ...", features, weights), "[batch...]"
    )
    return prediction

add_context(context)

Add arbirtary context to the shape checker.

This context will be included in any error messages.

Parameters:

context (check_shapes.error_contexts.ErrorContext) – Context to add to this shape checker.

Return type:

None

check_shape(shaped, tensor_spec, context=None)

Raise an error if a tensor has the wrong shape.

This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.

Parameters:
  • shaped (S) – The object whose shape to check.

  • tensor_spec (TensorSpecLike) – Specification to check the tensor against. Usually this is a str in the format described under “Shape specification” in our User Guide. Alternatively this can be a pre-parsed ParsedTensorSpec, or an actual Shape.

  • context (Optional[check_shapes.error_contexts.ErrorContext]) – Information about where shaped is coming from, for improved error messages.

Returns:

shaped, for convenience.

Return type:

S

check_shapes(checks)

Raise an error if any tensor has the wrong shape.

This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.

Parameters:

checks (Iterable[Union[Tuple[Any, TensorSpecLike], Tuple[Any, TensorSpecLike, check_shapes.error_contexts.ErrorContext]]]) – Checks to perform. The elements can either be (shaped, tensor_spec) or (shaped, tensor_spec, context) tuples. Where: shaped is the tensor whose shape to check; tensor_spec is the specification to check it against (see “Shape specification” in our User Guide); and context contains (optional) information about where shaped came from - for better error messages.

Return type:

None

_parse_checks(checks)

Sanity check, register and parse the given checks into _ShapeCheck objects.

Parameters:

checks (Iterable[Union[Tuple[Any, TensorSpecLike], Tuple[Any, TensorSpecLike, check_shapes.error_contexts.ErrorContext]]]) –

Return type:

List[_ShapeCheck]

_match_dims(shape_check)

Match expected dimensions against actual dimensions.

If some dimensions cannot be determined, the remaining dimensions are added to _VariableState.waiting_for_varrank.

Parameters:

shape_check (_ShapeCheck) –

Return type:

Iterable[Tuple[check_shapes.specs.ParsedDimensionSpec, _Shape, bool]]

_check_dim(expected, actual_dims, shape_possibly_truncated, shape_checks)

Checks that actual_dim matches expected.

Newly learned information may enable the evaluation of deferred shape checks - any such will be added to shape_checks.

Parameters:
  • expected (check_shapes.specs.ParsedDimensionSpec) –

  • actual_dims (_Shape) –

  • shape_possibly_truncated (bool) –

  • shape_checks (List[_ShapeCheck]) –

Return type:

None

_assert(condition)

Raise a nicely formatted ShapeMismatchError if condition is not True.

Parameters:

condition (bool) –

Return type:

None