Source code for cockpit.quantities.mean_gsnr

"""Class for tracking the Mean Gradient Signal to Noise Ration (GSNR)."""


from backpack.extensions import BatchGrad

from cockpit.context import get_batch_size
from cockpit.quantities.quantity import SingleStepQuantity
from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared


[docs]class MeanGSNR(SingleStepQuantity): """Quantitiy Class for the mean gradient signal-to-noise ratio (GSNR). Note: Mean gradient signal-to-noise ratio as defined by - Liu, J., et al. Understanding Why Neural Networks Generalize Well Through GSNR of Parameters (2020). https://arxiv.org/abs/2001.07384 """ def __init__(self, track_schedule, verbose=False, epsilon=1e-5): """Initialize. Args: track_schedule (callable): Function that maps the ``global_step`` to a boolean, which determines if the quantity should be computed. verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. epsilon (float): Stabilization constant. Defaults to 1e-5. """ super().__init__(track_schedule, verbose=verbose) self._epsilon = epsilon def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(BatchGrad()) return ext def extension_hooks(self, global_step): """Return list of BackPACK extension hooks required for the computation. Args: global_step (int): The current iteration number. Returns: [callable]: List of required BackPACK extension hooks for the current iteration. """ hooks = [] if self.should_compute(global_step): hooks.append(BatchGradTransformsHook_SumGradSquared()) return hooks def _compute(self, global_step, params, batch_loss): """Track the mean GSNR. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Mean GSNR of the current iteration. """ return self._compute_gsnr(global_step, params, batch_loss).mean().item() def _compute_gsnr(self, global_step, params, batch_loss): """Compute gradient signal-to-noise ratio. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of parameters considered in the computation. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Mean GSNR of the current iteration. """ grad_squared = self._fetch_grad(params, aggregate=True) ** 2 sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms( params, aggregate=True ) batch_size = get_batch_size(global_step) return grad_squared / ( batch_size * sum_grad_squared - grad_squared + self._epsilon )