Source code for cockpit.quantities.grad_hist

"""Histograms of individual gradient transformations."""

import torch
from backpack.extensions import BatchGrad

from cockpit.quantities.bin_adaptation import NoAdaptation
from cockpit.quantities.quantity import SingleStepQuantity
from cockpit.quantities.utils_hists import histogram2d
from cockpit.quantities.utils_transforms import BatchGradTransformsHook


[docs]class GradHist1d(SingleStepQuantity): """Quantity class for one-dimensional histograms of indivdual gradient elements. The histograms consider individual gradient elements, with outliers being clipped to lie in the visible range. """ def __init__( self, track_schedule, verbose=False, bins=100, range=(-2, 2), adapt=None, ): """Initialization sets the tracking schedule & creates the output dict. Args: track_schedule (callable): Function that maps the ``global_step`` to a boolean, which determines if the quantity should be computed. verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. bins (int): Number of bins range ([float, float], optional): Lower and upper limit of the bin range. Default: ``(-2, 2)``. adapt (BinAdaptation): Policy for adapting the bin limits. Per default, no adaptation is performed. """ super().__init__(track_schedule, verbose=verbose) self._range = range self._bins = bins self._adapt = NoAdaptation(verbose=verbose) if adapt is None else adapt def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(BatchGrad()) ext += self._adapt.extensions(global_step) return ext def extension_hooks(self, global_step): """Return list of BackPACK extension hooks required for the computation. Args: global_step (int): The current iteration number. Returns: [callable]: List of required BackPACK extension hooks for the current iteration. """ hooks = [] if self.should_compute(global_step): hooks.append( BatchGradTransformsHook(transforms={"hist_1d": self._compute_histogram}) ) hooks += self._adapt.extension_hooks(global_step) return hooks def track(self, global_step, params, batch_loss): """Perform scheduled computations and store result. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. """ super().track(global_step, params, batch_loss) # update limits if self._adapt.should_compute(global_step): self._range = self._adapt.compute( global_step, params, batch_loss, self._range ) def _compute(self, global_step, params, batch_loss): """Evaluate the individual gradient histogram. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. Returns: dict: Entry ``'hist'`` holds the histogram, entry ``'edges'`` holds the bin limits. """ hist = sum(p.grad_batch_transforms["hist_1d"][0] for p in params).detach() edges = params[0].grad_batch_transforms["hist_1d"][1] return {"hist": hist, "edges": edges} def _compute_histogram(self, batch_grad): """Transform individual gradients into histogram data. This function is be applied onto every parameter's individual gradient during backpropagation. If `p` denotes the parameter, then the return value is stored in the `"hist"` field of `p.grad_batch_transforms`. Args: batch_grad (torch.Tensor): Individual gradient of a parameter `p`. If `p` is of shape `(*)`, the individual gradients have shape `(N, *)`, where `N` denotes the batch size. Returns: (torch.Tensor, torch.Tensor): First tensor represents histogram counts, second tensor are bin edges. Both are on the input's device. """ # NOTE ``batch_grad`` is 1/B ∇ℓᵢ so we need to compensate the 1/B B = batch_grad.shape[0] individual_gradients = B * batch_grad start, end = self._range individual_gradients = torch.clamp(individual_gradients, start, end) hist = torch.histc(individual_gradients, bins=self._bins, min=start, max=end) edges = torch.linspace(start, end, self._bins + 1, device=batch_grad.device) return hist, edges
[docs]class GradHist2d(SingleStepQuantity): """Quantity class for two-dimensional histograms over gradient and parameters. Tracks two-dimensional histogram of individual gradient elements over parameters. Individual gradient values are binned on the x-axis, parameter values are binned on the y-axis. """ def __init__( self, track_schedule, verbose=False, bins=(40, 50), range=((-1, 1), (-2, 2)), adapt=(None, None), keep_individual=False, ): """Initialize the 2D Histogram of individual gradient elements over parameters. Args: track_schedule (callable): Function that maps the ``global_step`` to a boolean, which determines if the quantity should be computed. verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. bins ([int, int]): Number of bins in x and y direction. Default: ``(40, 50)`` range ([[float, float], [float, float]], optional): Bin limits in x and y direction. Default ``((-1, 1), (-2, 2))``. adapt ([BinAdaptation or None, BinAdaptation or None], optional): Policy for adapting the bin limits in x and y direction. ``None`` indicates no adaptation. Default value: ``(None, None)``. keep_individual (bool, optional): Whether to keep individual parameter histograms. Defaults to False. """ super().__init__(track_schedule, verbose=verbose) self._range = list(range) self._bins = bins self._adapt = [NoAdaptation(verbose=verbose) if a is None else a for a in adapt] self._keep_individual = keep_individual def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(BatchGrad()) for adapt in self._adapt: ext += adapt.extensions(global_step) return ext def extension_hooks(self, global_step): """Return list of BackPACK extension hooks required for the computation. Args: global_step (int): The current iteration number. Returns: [callable]: List of required BackPACK extension hooks for the current iteration. """ hooks = [] if self.should_compute(global_step): hooks.append(BatchGradTransformsHook({"hist_2d": self._compute_histogram})) for adapt in self._adapt: hooks += adapt.extension_hooks(global_step) return hooks def track(self, global_step, params, batch_loss): """Perform scheduled computations and store result. Args: global_step (int): The current iteration number. params ([torch.Tensor]): List of torch.Tensors holding the network's parameters. batch_loss (torch.Tensor): Mini-batch loss from current step. """ super().track(global_step, params, batch_loss) # update limits for dim, adapt in enumerate(self._adapt): if adapt.should_compute(global_step): self._range[dim] = adapt.compute( global_step, params, batch_loss, self._range ) def _compute(self, global_step, params, batch_loss): """Aggregate histogram data over parameters and save to output.""" hist = sum(p.grad_batch_transforms["hist_2d"][0] for p in params).detach() edges = params[0].grad_batch_transforms["hist_2d"][1] result = {"hist": hist, "edges": edges} if self._keep_individual: result["param_groups"] = len(params) for idx, p in enumerate(params): hist, edges = p.grad_batch_transforms["hist_2d"] hist = hist.detach() result[f"param_{idx}"] = {"hist": hist, "edges": edges} return result def _compute_histogram(self, batch_grad): """Transform individual gradients and parameters into a 2d histogram. Args: batch_grad (torch.Tensor): Individual gradient of a parameter `p`. If `p` is of shape `(*)`, the individual gradients have shape `(N, *)`, where `N` denotes the batch size. Returns: (torch.Tensor, (torch.Tensor, torch.Tensor)): First tensor represents histogram counts, second tuple holds the bin edges in x and y direction. All are on the input's device. """ # NOTE ``batch_grad`` is 1/B ∇ℓᵢ so we need to compensate the 1/B B = batch_grad.shape[0] data = [B * batch_grad, batch_grad._param_weakref().data] for dim, data_dim in enumerate(data): lower, upper = self._range[dim] bins = self._bins[dim] # Histogram implementation does not include the limits, clip to bin center bin_size = (upper - lower) / bins data[dim] = torch.clamp( data_dim, min=lower + bin_size / 2, max=upper - bin_size / 2 ) return self.__hist_high_mem(*data) def __hist_high_mem(self, individual_gradients, param): """Compute histogram with memory-intensive strategy. Args: individual_gradients (torch.Tensor): Individual gradients, clipped to the histogram range. param (torch.Tensor): Parameter, clipped to the histogram range. Returns: (torch.Tensor, (torch.Tensor, torch.Tensor)): First tensor represents histogram counts, second tuple holds the bin edges in x and y direction. All are on the input's device. """ batch_size = individual_gradients.shape[0] expand_arg = [batch_size] + param.dim() * [-1] data = torch.stack( ( individual_gradients.flatten(), param.unsqueeze(0).expand(*expand_arg).flatten(), ) ) return histogram2d(data, bins=self._bins, range=self._range)