HessTrace

class cockpit.quantities.HessTrace(track_schedule, verbose=False, curvature='diag_h')[source]

Quantitiy Class tracking the trace of the Hessian during training.

Initialization sets the tracking schedule & creates the output dict.

Note

The curvature options "diag_h" and "diag_ggn_exact" are more expensive than "diag_ggn_mc", but more precise. For a classification task with C classes, the former require that C times more information be backpropagated through the computation graph.

Parameters
  • track_schedule (callable) – Function that maps the global_step to a boolean, which determines if the quantity should be computed.

  • verbose (bool, optional) – Turns on verbose mode. Defaults to False.

  • curvature (string) – Which diagonal curvature approximation should be used. Options are "diag_h", "diag_ggn_exact", "diag_ggn_mc".