HessTrace
- class cockpit.quantities.HessTrace(track_schedule, verbose=False, curvature='diag_h')[source]
Quantitiy Class tracking the trace of the Hessian during training.
Initialization sets the tracking schedule & creates the output dict.
Note
The curvature options
"diag_h"and"diag_ggn_exact"are more expensive than"diag_ggn_mc", but more precise. For a classification task withCclasses, the former require thatCtimes more information be backpropagated through the computation graph.- Parameters
track_schedule (callable) – Function that maps the
global_stepto a boolean, which determines if the quantity should be computed.verbose (bool, optional) – Turns on verbose mode. Defaults to
False.curvature (string) – Which diagonal curvature approximation should be used. Options are
"diag_h","diag_ggn_exact","diag_ggn_mc".