TICTrace
- class cockpit.quantities.TICTrace(track_schedule, verbose=False, curvature='diag_h', epsilon=1e-07)[source]
Quantity class for the TIC using the trace of curvature and gradient covariance.
Note: Takeuchi Information criterion (TIC) rediscovered by
Thomas, V., et al. On the interplay between noise and curvature and its effect on optimization and generalization (2019). https://arxiv.org/abs/1906.07774
Initialization sets the tracking schedule & creates the output dict.
Note
The curvature options
"diag_h"
and"diag_ggn_exact"
are more expensive than"diag_ggn_mc"
, but more precise. For a classification task withC
classes, the former require thatC
times more information be backpropagated through the computation graph.- Parameters
track_schedule (callable) – Function that maps the
global_step
to a boolean, which determines if the quantity should be computed.verbose (bool, optional) – Turns on verbose mode. Defaults to
False
.curvature (str) – Which diagonal curvature approximation should be used. Options are
"diag_h"
,"diag_ggn_exact"
,"diag_ggn_mc"
.epsilon (float) – Stabilization constant. Defaults to
1e-7
.