"""Class for tracking the EB criterion for early stopping."""
from backpack.extensions import BatchGrad
from cockpit.context import get_batch_size, get_optimizer
from cockpit.quantities.quantity import SingleStepQuantity
from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared
from cockpit.utils.optim import ComputeStep
[docs]class EarlyStopping(SingleStepQuantity):
"""Quantity class for the evidence-based early-stopping criterion.
This criterion uses local statistics of the gradients to indicate when training
should be stopped. If the criterion exceeds zero, training should be stopped.
Note: Proposed in
- Mahsereci, M., Balles, L., Lassner, C., & Hennig, P.,
Early stopping without a validation set (2017).
"""
def __init__(self, track_schedule, verbose=False, epsilon=1e-5):
"""Initialization sets the tracking schedule & creates the output dict.
Args:
track_schedule (callable): Function that maps the ``global_step``
to a boolean, which determines if the quantity should be computed.
verbose (bool, optional): Turns on verbose mode. Defaults to ``False``.
epsilon (float): Stabilization constant. Defaults to 0.0.
"""
super().__init__(track_schedule, verbose=verbose)
self._epsilon = epsilon
def extensions(self, global_step):
"""Return list of BackPACK extensions required for the computation.
Args:
global_step (int): The current iteration number.
Returns:
list: (Potentially empty) list with required BackPACK quantities.
"""
ext = []
if self.should_compute(global_step):
ext.append(BatchGrad())
return ext
def extension_hooks(self, global_step):
"""Return list of BackPACK extension hooks required for the computation.
Args:
global_step (int): The current iteration number.
Returns:
[callable]: List of required BackPACK extension hooks for the current
iteration.
"""
hooks = []
if self.should_compute(global_step):
hooks.append(BatchGradTransformsHook_SumGradSquared())
return hooks
def _compute(self, global_step, params, batch_loss):
"""Compute the EB early stopping criterion.
Evaluates the left hand side of Equ. 7 in
- Mahsereci, M., Balles, L., Lassner, C., & Hennig, P.,
Early stopping without a validation set (2017).
If this value exceeds 0, training should be stopped.
Args:
global_step (int): The current iteration number.
params ([torch.Tensor]): List of torch.Tensors holding the network's
parameters.
batch_loss (torch.Tensor): Mini-batch loss from current step.
Returns:
float: Result of the Early stopping criterion. Training should stop
if it is larger than 0.
Raises:
ValueError: If the used optimizer differs from SGD with default parameters.
"""
if not ComputeStep.is_sgd_default_kwargs(get_optimizer(global_step)):
raise ValueError("This criterion only supports zero-momentum SGD.")
B = get_batch_size(global_step)
grad_squared = self._fetch_grad(params, aggregate=True) ** 2
# compensate BackPACK's 1/B scaling
sgs_compensated = (
B ** 2
* self._fetch_sum_grad_squared_via_batch_grad_transforms(
params, aggregate=True
)
)
diag_variance = (sgs_compensated - B * grad_squared) / (B - 1)
snr = grad_squared / (diag_variance + self._epsilon)
return 1 - B * snr.mean().item()