Source code for cockpit.quantities.norm_test

"""Class for tracking the Norm Test."""

from backpack.extensions import BatchGrad

from cockpit.quantities.quantity import SingleStepQuantity
from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchL2Grad


[docs]class NormTest(SingleStepQuantity): """Quantitiy Class for the norm test. Note: Norm test as proposed in - Byrd, R., Chin, G., Nocedal, J., & Wu, Y., Sample size selection in optimization methods for machine learning (2012). https://link.springer.com/article/10.1007%2Fs10107-012-0572-5 """ def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(BatchGrad()) return ext def extension_hooks(self, global_step): """Return list of BackPACK extension hooks required for the computation. Args: global_step (int): The current iteration number. Returns: [callable]: List of required BackPACK extension hooks for the current iteration. """ hooks = [] if self.should_compute(global_step): hooks.append(BatchGradTransformsHook_BatchL2Grad()) return hooks def _compute(self, global_step, params, batch_loss): """Track the practical version of the norm test. Return maximum θ for which the norm test would pass. The norm test is defined by Equation (3.9) in byrd2012sample. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Maximum θ for which the norm test would pass. """ batch_l2_squared = self._fetch_batch_l2_squared_via_batch_grad_transforms( params, aggregate=True ) grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True) batch_size = batch_l2_squared.size(0) var_l1 = self._compute_variance_l1( batch_size, batch_l2_squared, grad_l2_squared ) return self._compute_theta_max(batch_size, var_l1, grad_l2_squared).item() def _compute_theta_max(self, batch_size, var_l1, grad_l2_squared): """Return maximum θ for which the norm test would pass. Args: batch_size (int): Mini-batch size. var_l1 (torch.Tensor): [description] grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. Returns: [type]: [description] """ return (var_l1 / batch_size / grad_l2_squared).sqrt() def _compute_variance_l1(self, batch_size, batch_l2_squared, grad_l2_squared): """Compute the sample variance ℓ₁ norm. It shows up in Equations (3.9) and (3.11) in byrd2012sample and relies on the sample variance (Equation 3.6). The ℓ₁ norm can be computed using individual gradient squared ℓ₂ norms and the mini-batch gradient squared ℓ₂ norm. Args: batch_size (int): Mini-batch size. batch_l2_squared (torch.Tensor): [description] grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. Returns: torch.Tensor: The sample variance ℓ₁ norm. """ return (1 / (batch_size - 1)) * ( batch_size ** 2 * batch_l2_squared.sum() - batch_size * grad_l2_squared )