HessTrace
- class cockpit.quantities.HessTrace(track_schedule, verbose=False, curvature='diag_h')[source]
Quantitiy Class tracking the trace of the Hessian during training.
Initialization sets the tracking schedule & creates the output dict.
Note
The curvature options
"diag_h"
and"diag_ggn_exact"
are more expensive than"diag_ggn_mc"
, but more precise. For a classification task withC
classes, the former require thatC
times more information be backpropagated through the computation graph.- Parameters
track_schedule (callable) – Function that maps the
global_step
to a boolean, which determines if the quantity should be computed.verbose (bool, optional) – Turns on verbose mode. Defaults to
False
.curvature (string) – Which diagonal curvature approximation should be used. Options are
"diag_h"
,"diag_ggn_exact"
,"diag_ggn_mc"
.