check_shapes.checker¶
Class responsible for remembering and checking shapes.
Module Contents¶
- class check_shapes.checker._ObservedDim¶
Storage of observed size of a single dimension variable.
- check_and_update(actual, broadcast)¶
Attempt to merge new data into this observation.
Returns whether the new data is compatible with existing observations. If this method returns False this object may have been left in an invalid state and should not be used again.
- Parameters:
actual (Optional[int]) –
broadcast (bool) –
- Return type:
bool
- class check_shapes.checker._ObservedDims¶
Storage of observed sizes of a var-rank / batch variable.
- check_and_update(actual, broadcast, shape_possibly_truncated)¶
Attempt to merge new data into this observation.
Returns whether the new data is compatible with existing observations. If this method returns False this object may have been left in an invalid state and should not be used again.
- Parameters:
actual (Optional[Tuple[Optional[int], Ellipsis]]) –
broadcast (bool) –
shape_possibly_truncated (bool) –
- Return type:
bool
- class check_shapes.checker._ShapeCheck¶
A shape check that is waiting to be performed.
- property finished: bool¶
Whether this entire check has been performed.
- Return type:
bool
- actual: _Shape¶
Actual observed shape.
Only actual[actual_begin:actual_end] is still waiting to be checked. The beginning and end may already have been checked.
- expected: check_shapes.specs.ParsedShapeSpec¶
Specification to check against.
Only expected[expected_begin:expected_end] is still waiting to be checked. The beginning and end may already have been checked.
- class check_shapes.checker._VariableState¶
Structure of stuff we need to know about each variable.
- uses: List[Tuple[check_shapes.specs.ParsedTensorSpec, check_shapes.error_contexts.ErrorContext]]¶
List of specs where this variable is used.
- observed_dim: Optional[_ObservedDim]¶
Observed size of this variable, if the variable is rank-1.
Set this None if this variable is varrank.
- observed_dims: Optional[_ObservedDims]¶
Observed shape of this variable, if the variable is varrank.
Set this None if this variable is rank-1.
- waiting_for_varrank: Set[_ShapeCheck]¶
Checks that are waiting for the rank of this variable to be determined.
- class check_shapes.checker.ShapeChecker¶
Mechanism for checking the shapes of tensors.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
Example:
def linear_model(features: ndarray, weights: ndarray) -> ndarray: checker = ShapeChecker() checker.check_shape(features, "[batch..., n_features]") checker.check_shape(weights, "[n_features]") prediction: ndarray = checker.check_shape( np.einsum("...i,i -> ...", features, weights), "[batch...]" ) return prediction
- add_context(context)¶
Add arbirtary context to the shape checker.
This context will be included in any error messages.
- Parameters:
context (check_shapes.error_contexts.ErrorContext) – Context to add to this shape checker.
- Return type:
None
- check_shape(shaped, tensor_spec, context=None)¶
Raise an error if a tensor has the wrong shape.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
- Parameters:
shaped (S) – The object whose shape to check.
tensor_spec (TensorSpecLike) – Specification to check the tensor against. Usually this is a
strin the format described under “Shape specification” in our User Guide. Alternatively this can be a pre-parsedParsedTensorSpec, or an actualShape.context (Optional[check_shapes.error_contexts.ErrorContext]) – Information about where
shapedis coming from, for improved error messages.
- Returns:
shaped, for convenience.- Return type:
S
- check_shapes(checks)¶
Raise an error if any tensor has the wrong shape.
This remembers observed shapes and specifications, so that tensors can be checked for compatibility across multiple calls, and so that we can provide good error messages.
- Parameters:
checks (Iterable[Union[Tuple[Any, TensorSpecLike], Tuple[Any, TensorSpecLike, check_shapes.error_contexts.ErrorContext]]]) – Checks to perform. The elements can either be
(shaped, tensor_spec)or(shaped, tensor_spec, context)tuples. Where:shapedis the tensor whose shape to check;tensor_specis the specification to check it against (see “Shape specification” in our User Guide); andcontextcontains (optional) information about whereshapedcame from - for better error messages.- Return type:
None
- _parse_checks(checks)¶
Sanity check, register and parse the given
checksinto_ShapeCheckobjects.- Parameters:
checks (Iterable[Union[Tuple[Any, TensorSpecLike], Tuple[Any, TensorSpecLike, check_shapes.error_contexts.ErrorContext]]]) –
- Return type:
List[_ShapeCheck]
- _match_dims(shape_check)¶
Match expected dimensions against actual dimensions.
If some dimensions cannot be determined, the remaining dimensions are added to _VariableState.waiting_for_varrank.
- Parameters:
shape_check (_ShapeCheck) –
- Return type:
Iterable[Tuple[check_shapes.specs.ParsedDimensionSpec, _Shape, bool]]
- _check_dim(expected, actual_dims, shape_possibly_truncated, shape_checks)¶
Checks that
actual_dimmatchesexpected.Newly learned information may enable the evaluation of deferred shape checks - any such will be added to
shape_checks.- Parameters:
expected (check_shapes.specs.ParsedDimensionSpec) –
actual_dims (_Shape) –
shape_possibly_truncated (bool) –
shape_checks (List[_ShapeCheck]) –
- Return type:
None
- _assert(condition)¶
Raise a nicely formatted
ShapeMismatchErrorifconditionis notTrue.- Parameters:
condition (bool) –
- Return type:
None