Source code for gpflow.monitor.base

# Copyright 2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" MonitorTask base classes """

from abc import ABC, abstractmethod
from typing import Any, Callable, Collection, Union

import tensorflow as tf

__all__ = [
    "ExecuteCallback",
    "Monitor",
    "MonitorTask",
    "MonitorTaskGroup",
]


[docs] class MonitorTask(ABC): """ A base class for a monitoring task. All monitoring tasks are callable objects. A descendant class must implement the `run` method, which is the body of the monitoring task. """ def __call__(self, step: int, **kwargs: Any) -> None: """ It calls the 'run' function and sets the current step. :param step: current step in the optimisation. :param kwargs: additional keyword arguments that can be passed to the `run` method of the task. This is in particular handy for passing keyword argument to the callback of `ScalarToTensorBoard`. """ self.current_step = tf.cast(step, tf.int64) self.run(**kwargs)
[docs] @abstractmethod def run(self, **kwargs: Any) -> None: """ Implements the task to be executed on __call__. The current step is available through `self.current_step`. :param kwargs: keyword arguments available to the run method. """ raise NotImplementedError
[docs] class ExecuteCallback(MonitorTask): """ Executes a callback as task """ def __init__(self, callback: Callable[..., None]) -> None: """ :param callback: callable to be executed during the task. Arguments can be passed using keyword arguments. """ super().__init__() self.callback = callback
[docs] def run(self, **kwargs: Any) -> None: self.callback(**kwargs)
[docs] class MonitorTaskGroup: """ Class for grouping `MonitorTask` instances. A group defines all the tasks that are run at the same frequency, given by `period`. A `MonitorTaskGroup` can exist of a single instance or a list of `MonitorTask` instances. """ def __init__( self, task_or_tasks: Union[Collection[MonitorTask], MonitorTask], period: int = 1 ) -> None: """ :param task_or_tasks: a single instance or a list of `MonitorTask` instances. Each `MonitorTask` in the list will be run with the given `period`. :param period: defines how often to run the tasks; they will execute every `period`th step. For large values of `period` the tasks will be less frequently run. Defaults to running at every step (`period = 1`). """ self._tasks: Collection[MonitorTask] = [] self.tasks = task_or_tasks # type: ignore[assignment] self._period = period @property def tasks(self) -> Collection[MonitorTask]: return self._tasks @tasks.setter def tasks(self, task_or_tasks: Union[Collection[MonitorTask], MonitorTask]) -> None: """Ensures the tasks are stored as a list. Even if there is only a single task.""" if isinstance(task_or_tasks, MonitorTask): self._tasks = [task_or_tasks] else: assert isinstance(task_or_tasks, Collection) self._tasks = list(task_or_tasks) def __call__(self, step: int, **kwargs: Any) -> None: """Call each task in the group.""" if step % self._period == 0: for task in self.tasks: task(step, **kwargs)
[docs] class Monitor: r""" Accepts any number of of `MonitorTaskGroup` instances, and runs them according to their specified periodicity. Example use-case:: # Create some monitor tasks log_dir = "logs" model_task = ModelToTensorBoard(log_dir, model) image_task = ImageToTensorBoard(log_dir, plot_prediction, "image_samples") lml_task = ScalarToTensorBoard(log_dir, lambda: model.log_marginal_likelihood(), "lml") # Plotting tasks can be quite slow, so we want to run them less frequently. # We group them in a `MonitorTaskGroup` and set the period to 5. slow_tasks = MonitorTaskGroup(image_task, period=5) # The other tasks are fast. We run them at each iteration of the optimisation. fast_tasks = MonitorTaskGroup([model_task, lml_task], period=1) # We pass both groups to the `Monitor` monitor = Monitor(fast_tasks, slow_tasks) """ def __init__(self, *task_groups: MonitorTaskGroup) -> None: """ :param task_groups: a list of `MonitorTaskGroup`s to be executed. """ self.task_groups = task_groups def __call__(self, step: int, **kwargs: Any) -> None: for group in self.task_groups: group(step, **kwargs)