Source code for cockpit.quantities.ortho_test

"""Class for tracking the Orthogonality Test."""

from backpack.extensions import BatchGrad

from cockpit.quantities.quantity import SingleStepQuantity
from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchDotGrad


[docs]class OrthoTest(SingleStepQuantity): """Quantity Class for the orthogonality test. Note: Orthogonality test as proposed in - Bollapragada, R., Byrd, R., & Nocedal, J., Adaptive Sampling Strategies for Stochastic Optimization (2017). https://arxiv.org/abs/1710.11258 """ def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(BatchGrad()) return ext def extension_hooks(self, global_step): """Return list of BackPACK extension hooks required for the computation. Args: global_step (int): The current iteration number. Returns: [callable]: List of required BackPACK extension hooks for the current iteration. """ hooks = [] if self.should_compute(global_step): hooks.append(BatchGradTransformsHook_BatchDotGrad()) return hooks def _compute(self, global_step, params, batch_loss): """Track the practical version of the orthogonality test. Return maximum ν for which the orthogonality test would pass. The orthogonality test is defined by Equation (3.3) in bollapragada2017adaptive. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Maximum ν for which the orthogonality test would pass. """ batch_dot = self._fetch_batch_dot_via_batch_grad_transforms( params, aggregate=True ) batch_size = batch_dot.size(0) grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True) var_orthogonal_projection = self._compute_orthogonal_projection_variance( batch_size, batch_dot, grad_l2_squared ) return self._compute_nu_max( batch_size, var_orthogonal_projection, grad_l2_squared ).item() def _compute_nu_max(self, batch_size, var_orthogonal_projection, grad_l2_squared): """Return maximum ν for which the orthogonality test would pass. The orthogonality test is defined by Equation (3.3) in bollapragada2017adaptive. Args: batch_size (int): Mini-batch size. var_orthogonal_projection (torch.Tensor): [description] grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. Returns: [type]: Maximum ν for which the orthogonality test would pass. """ return (var_orthogonal_projection / batch_size / grad_l2_squared).sqrt() def _compute_orthogonal_projection_variance( self, batch_size, batch_dot, grad_l2_squared ): """Compute sample variance of individual gradient orthogonal projections. The sample variance of orthogonal projections shows up in Equation (3.3) in bollapragada2017adaptive (https://arxiv.org/pdf/1710.11258.pdf) Args: batch_size (int): Mini-batch size. batch_dot (torch.Tensor): Individual gradient pairwise dot product. grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. Returns: torch.Tensor: The sample variance of individual gradient orthogonal projections on the mini-batch gradient. """ batch_l2_squared = batch_dot.diag() projections = batch_size * batch_dot.sum(1) return (1 / (batch_size - 1)) * ( batch_size ** 2 * batch_l2_squared.sum() - (projections ** 2 / grad_l2_squared).sum() )