Source code for cockpit.quantities.hess_max_ev

"""Class for tracking the Maximum Hessian Eigenvalue."""

import warnings

import numpy as np
import torch
from backpack.hessianfree.hvp import hessian_vector_product
from backpack.utils.convert_parameters import vector_to_parameter_list
from scipy.sparse.linalg import LinearOperator, eigsh
from torch.nn.utils import parameters_to_vector

from cockpit.quantities.quantity import SingleStepQuantity


[docs]class HessMaxEV(SingleStepQuantity): """Quantity Class tracking the Hessian's largest eigenvalue.""" def __init__(self, track_schedule, verbose=False, use_power=True): """Initialization sets the tracking schedule & creates the output dict. Args: track_schedule (callable): Function that maps the ``global_step`` to a boolean, which determines if the quantity should be computed. verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. use_power (bool): If ``True``, uses a power iteration that works on GPUs. If ``True``, use :func:`scipy.sparse.linalg.eigsh` for eigenvalue computation (requires transfers from GPU to CPU and ``torch`` to ``numpy``.) """ super().__init__(track_schedule, verbose=verbose) self._use_power = use_power def create_graph(self, global_step): """Return whether access to the forward pass computation graph is needed. Args: global_step (int): The current iteration number. Returns: bool: ``True`` if the computation graph shall not be deleted, else ``False``. """ return self.should_compute(global_step) def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ return [] def _compute(self, global_step, params, batch_loss): """Compute the larges Hessian eigenvalue. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: float: Larges eigenvalue of the Hessian. """ HVP = HVPLinearOperator( batch_loss, params, grad_params=self._fetch_grad(params) ) if self._use_power: max_ev = HVP.power_iteration().item() else: max_ev = eigsh(HVP, k=1, which="LA", return_eigenvectors=False)[0] return np.float64(max_ev)
# Utility Function and Classes # class BaseLinearOperator(LinearOperator): """Base class for GGNVPs and HVPs.""" def _preprocess(self, v_numpy): """Convert to `torch`, block into parameters.""" v_torch = torch.from_numpy(v_numpy).to(self.device) return vector_to_parameter_list(v_torch, self.params) def _postprocess(self, v_torch): """Flatten and concatenate, then convert to `numpy` array.""" v_torch_flat = [v.contiguous() for v in v_torch] return parameters_to_vector(v_torch_flat).cpu().numpy() class HVPLinearOperator(BaseLinearOperator): """Scipy interface for multiplication with the Hessian in torch.""" def __init__(self, loss, params, grad_params=None, dtype=np.float32): """Multiplication with the Hessian of loss w.r.t. params.""" num_params = sum(p.numel() for p in params) shape = (num_params, num_params) super().__init__(dtype, shape) self.loss = loss self.params = params if grad_params is None: self.grad_params = torch.autograd.grad( loss, params, create_graph=True, retain_graph=True ) else: self.grad_params = grad_params self.device = loss.device def _matvec(self, v_numpy): """Multiply with the Hessian.""" v_torch = self._preprocess(v_numpy) Hv_torch = self.hessian_vector_product(v_torch) return self._postprocess(Hv_torch) def hessian_vector_product(self, v_torch): """Multiply by the Hessian using autodiff in torch.""" return hessian_vector_product( self.loss, self.params, v_torch, grad_params=self.grad_params ) def power_iteration(self, maxiter=100, rtol=1e-3, atol=1e-6): """Compute the largest eigenvalue by power iteration. Args: maxiter (int): Maximum number of iterations. rtol (float): Relative tolerance to determine convergence from consecutive eigenvalues. atol (float): Absolute tolerance to determine convergence from consecutive eigenvalues. Returns: (torch.Tensor): Maximum eigenvalue. Warns if maximum number of iterations was reached and returns the potentially unconverged estimate. """ def converged(old, new): return torch.allclose(old, new, rtol=rtol, atol=atol) def normalize(vecs): norm = sum((v ** 2).sum() for v in vecs).sqrt() for v in vecs: v /= norm def iteration(vecs): new_vecs = self.hessian_vector_product(vecs) new_eigval = sum((v * new_v).sum() for v, new_v in zip(vecs, new_vecs)) normalize(new_vecs) return new_vecs, new_eigval vecs = [torch.rand_like(p) for p in self.params] normalize(vecs) eigval = torch.Tensor([float("inf")]).to(vecs[0].device) for _ in range(maxiter): vecs, new_eigval = iteration(vecs) if converged(eigval, new_eigval): return new_eigval eigval = new_eigval warnings.warn(f"Exceeded maximum number of {maxiter} iterations") return eigval