"""Quantity for Takeuchi Information Criterion (TIC)."""
from backpack import extensions
from cockpit.context import get_batch_size
from cockpit.quantities.quantity import SingleStepQuantity
from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared
class TIC(SingleStepQuantity):
"""Base class for different Takeuchi Information Criterion approximations.
Note: Takeuchi Information criterion (TIC) rediscovered by
- Thomas, V., et al.
On the interplay between noise and curvature and its effect on
optimization and generalization (2019).
https://arxiv.org/abs/1906.07774
"""
extensions_from_str = {
"diag_h": extensions.DiagHessian,
"diag_ggn_exact": extensions.DiagGGNExact,
"diag_ggn_mc": extensions.DiagGGNMC,
}
def __init__(
self,
track_schedule,
verbose=False,
curvature="diag_h",
epsilon=1e-7,
):
"""Initialization sets the tracking schedule & creates the output dict.
Note:
The curvature options ``"diag_h"`` and ``"diag_ggn_exact"`` are more
expensive than ``"diag_ggn_mc"``, but more precise. For a classification
task with ``C`` classes, the former require that ``C`` times more
information be backpropagated through the computation graph.
Args:
track_schedule (callable): Function that maps the ``global_step``
to a boolean, which determines if the quantity should be computed.
verbose (bool, optional): Turns on verbose mode. Defaults to ``False``.
curvature (str): Which diagonal curvature approximation should be used.
Options are ``"diag_h"``, ``"diag_ggn_exact"``, ``"diag_ggn_mc"``.
epsilon (float): Stabilization constant. Defaults to ``1e-7``.
"""
super().__init__(track_schedule, verbose=verbose)
self._curvature = curvature
self._epsilon = epsilon
def extensions(self, global_step):
"""Return list of BackPACK extensions required for the computation.
Args:
global_step (int): The current iteration number.
Raises:
KeyError: If curvature string has unknown associated extension.
Returns:
list: (Potentially empty) list with required BackPACK quantities.
"""
ext = []
if self.should_compute(global_step):
ext.append(extensions.BatchGrad())
try:
ext.append(self.extensions_from_str[self._curvature]())
except KeyError as e:
available = list(self.extensions_from_str.keys())
raise KeyError(f"Available: {available}") from e
return ext
def extension_hooks(self, global_step):
"""Return list of BackPACK extension hooks required for the computation.
Args:
global_step (int): The current iteration number.
Returns:
[callable]: List of required BackPACK extension hooks for the current
iteration.
"""
hooks = []
if self.should_compute(global_step):
hooks.append(BatchGradTransformsHook_SumGradSquared())
return hooks
[docs]class TICDiag(TIC):
"""Quantity class for tracking the TIC using diagonal curvature approximation.
The diagonal curvature approximation provide cheap inversion.
Note: Takeuchi Information criterion (TIC) rediscovered by
- Thomas, V., et al.
On the interplay between noise and curvature and its effect on
optimization and generalization (2019).
https://arxiv.org/abs/1906.07774
"""
def _compute(self, global_step, params, batch_loss):
"""Compute the TICDiag using a diagonal curvature approximation.
Args:
global_step (int): The current iteration number.
params ([torch.Tensor]): List of torch.Tensors holding the network's
parameters.
batch_loss (torch.Tensor): Mini-batch loss from current step.
Returns:
float: TIC computed using a diagonal curvature approximation.
"""
sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
params, aggregate=True
)
curvature = self._fetch_diag_curvature(params, self._curvature, aggregate=True)
batch_size = get_batch_size(global_step)
return (
(batch_size * sum_grad_squared / (curvature + self._epsilon)).sum().item()
)
[docs]class TICTrace(TIC):
"""Quantity class for the TIC using the trace of curvature and gradient covariance.
Note: Takeuchi Information criterion (TIC) rediscovered by
- Thomas, V., et al.
On the interplay between noise and curvature and its effect on
optimization and generalization (2019).
https://arxiv.org/abs/1906.07774
"""
def _compute(self, global_step, params, batch_loss):
"""Compute the TICTrace using a trace approximation.
Args:
global_step (int): The current iteration number.
params ([torch.Tensor]): List of torch.Tensors holding the network's
parameters.
batch_loss (torch.Tensor): Mini-batch loss from current step.
Returns:
float: TIC computed using a trace approximation.
"""
sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
params, aggregate=True
)
curvature = self._fetch_diag_curvature(params, self._curvature, aggregate=True)
batch_size = get_batch_size(global_step)
return (
batch_size * sum_grad_squared.sum() / (curvature.sum() + self._epsilon)
).item()