Source code for cockpit.quantities.cabs

"""Class for tracking the CABS criterion for adaptive batch size."""

from backpack.extensions import BatchGrad

from cockpit.context import get_batch_size, get_optimizer
from cockpit.quantities.quantity import SingleStepQuantity
from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared
from cockpit.utils.optim import ComputeStep


[docs]class CABS(SingleStepQuantity): """CABS Quantity class for the suggested batch size using the CABS criterion. CABS uses the current learning rate and variance of the stochastic gradients to suggest an optimal batch size. Only applies to SGD without momentum. Note: Proposed in - Balles, L., Romero, J., & Hennig, P., Coupling adaptive batch sizes with learning rates (2017). """ def get_lr(self, optimizer): """Extract the learning rate. Args: optimizer (torch.optim.Optimizer): A PyTorch optimizer. Returns: float: Learning rate Raises: ValueError: If the learning rate varies over parameter groups. """ lrs = {group["lr"] for group in optimizer.param_groups} if len(lrs) != 1: raise ValueError(f"Found non-unique learning rates {lrs}") return lrs.pop() def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(BatchGrad()) return ext def extension_hooks(self, global_step): """Return list of BackPACK extension hooks required for the computation. Args: global_step (int): The current iteration number. Returns: [callable]: List of required BackPACK extension hooks for the current iteration. """ hooks = [] if self.should_compute(global_step): hooks.append(BatchGradTransformsHook_SumGradSquared()) return hooks def _compute(self, global_step, params, batch_loss): """Compute the CABS rule. Return suggested batch size. Evaluates Equ. 22 of - Balles, L., Romero, J., & Hennig, P., Coupling adaptive batch sizes with learning rates (2017). Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Batch size suggested by CABS. Raises: ValueError: If the optimizer differs from SGD with default arguments. """ optimizer = get_optimizer(global_step) if not ComputeStep.is_sgd_default_kwargs(optimizer): raise ValueError("This criterion only supports zero-momentum SGD.") B = get_batch_size(global_step) lr = self.get_lr(optimizer) grad_squared = self._fetch_grad(params, aggregate=True) ** 2 # # compensate BackPACK's 1/B scaling sgs_compensated = ( B ** 2 * self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True ) ) return ( lr * (sgs_compensated - B * grad_squared).sum() / (B * batch_loss) ).item()