check_shapes.checker
¶
Class responsible for remembering and checking shapes.
Module Contents¶
- class check_shapes.checker._ObservedDim¶
Storage of observed size of a single dimension variable.
- check_and_update(actual, broadcast)¶
Attempt to merge new data into this observation.
Returns whether the new data is compatible with existing observations. If this method returns False this object may have been left in an invalid state and should not be used again.
- Parameters:
actual (Optional[int]) –
broadcast (bool) –
- Return type:
bool
- class check_shapes.checker._ObservedDims¶
Storage of observed sizes of a var-rank / batch variable.
- check_and_update(actual, broadcast, shape_possibly_truncated)¶
Attempt to merge new data into this observation.
Returns whether the new data is compatible with existing observations. If this method returns False this object may have been left in an invalid state and should not be used again.
- Parameters:
actual (Optional[Tuple[Optional[int], Ellipsis]]) –
broadcast (bool) –
shape_possibly_truncated (bool) –
- Return type:
bool
- class check_shapes.checker._ShapeCheck¶
A shape check that is waiting to be performed.
- property finished: bool¶
Whether this entire check has been performed.
- Return type:
bool
- actual: _Shape¶
Actual observed shape.
Only actual[actual_begin:actual_end] is still waiting to be checked. The beginning and end may already have been checked.
- expected: check_shapes.specs.ParsedShapeSpec¶
Specification to check against.
Only expected[expected_begin:expected_end] is still waiting to be checked. The beginning and end may already have been checked.
- class check_shapes.checker._VariableState¶
Structure of stuff we need to know about each variable.
- uses: List[Tuple[check_shapes.specs.ParsedTensorSpec, check_shapes.error_contexts.ErrorContext]]¶
List of specs where this variable is used.
- observed_dim: Optional[_ObservedDim]¶
Observed size of this variable, if the variable is rank-1.
Set this None if this variable is varrank.
- observed_dims: Optional[_ObservedDims]¶
Observed shape of this variable, if the variable is varrank.
Set this None if this variable is rank-1.
- waiting_for_varrank: Set[_ShapeCheck]¶
Checks that are waiting for the rank of this variable to be determined.
- class check_shapes.checker.ShapeChecker¶
Mechanism for checking the shapes of tensors.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
Example:
def linear_model(features: ndarray, weights: ndarray) -> ndarray: checker = ShapeChecker() checker.check_shape(features, "[batch..., n_features]") checker.check_shape(weights, "[n_features]") prediction: ndarray = checker.check_shape( np.einsum("...i,i -> ...", features, weights), "[batch...]" ) return prediction
- add_context(context)¶
Add arbirtary context to the shape checker.
This context will be included in any error messages.
- Parameters:
context (check_shapes.error_contexts.ErrorContext) – Context to add to this shape checker.
- Return type:
None
- check_shape(shaped, tensor_spec, context=None)¶
Raise an error if a tensor has the wrong shape.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
- Parameters:
shaped (S) – The object whose shape to check.
tensor_spec (TensorSpecLike) – Specification to check the tensor against. Usually this is a
str
in the format described under “Shape specification” in our User Guide. Alternatively this can be a pre-parsedParsedTensorSpec
, or an actualShape
.context (Optional[check_shapes.error_contexts.ErrorContext]) – Information about where
shaped
is coming from, for improved error messages.
- Returns:
shaped
, for convenience.- Return type:
S
- check_shapes(checks)¶
Raise an error if any tensor has the wrong shape.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
- Parameters:
checks (Iterable[Union[Tuple[Any, TensorSpecLike], Tuple[Any, TensorSpecLike, check_shapes.error_contexts.ErrorContext]]]) – Checks to perform. The elements can either be
(shaped, tensor_spec)
or(shaped, tensor_spec, context)
tuples. Where:shaped
is the tensor whose shape to check;tensor_spec
is the specification to check it against (see “Shape specification” in our User Guide); andcontext
contains (optional) information about whereshaped
came from - for better error messages.- Return type:
None
- _parse_checks(checks)¶
Sanity check, register and parse the given
checks
into_ShapeCheck
objects.- Parameters:
checks (Iterable[Union[Tuple[Any, TensorSpecLike], Tuple[Any, TensorSpecLike, check_shapes.error_contexts.ErrorContext]]]) –
- Return type:
List[_ShapeCheck]
- _match_dims(shape_check)¶
Match expected dimensions against actual dimensions.
If some dimensions cannot be determined, the remaining dimensions are added to _VariableState.waiting_for_varrank.
- Parameters:
shape_check (_ShapeCheck) –
- Return type:
Iterable[Tuple[check_shapes.specs.ParsedDimensionSpec, _Shape, bool]]
- _check_dim(expected, actual_dims, shape_possibly_truncated, shape_checks)¶
Checks that
actual_dim
matchesexpected
.Newly learned information may enable the evaluation of deferred shape checks - any such will be added to
shape_checks
.- Parameters:
expected (check_shapes.specs.ParsedDimensionSpec) –
actual_dims (_Shape) –
shape_possibly_truncated (bool) –
shape_checks (List[_ShapeCheck]) –
- Return type:
None
- _assert(condition)¶
Raise a nicely formatted
ShapeMismatchError
ifcondition
is notTrue
.- Parameters:
condition (bool) –
- Return type:
None