Source code for cockpit.instruments.mean_gsnr_gauge

"""Mean GSNR gauge."""

import warnings

from cockpit.instruments.utils_instruments import check_data, create_basic_plot


[docs]def mean_gsnr_gauge(self, fig, gridspec): """Mean GSNR gauge, showing the mean GSNR versus iteration. The mean GSNR describes the average gradient signal-to-noise-ratio. `Recent work <https://arxiv.org/abs/2001.07384>`_ used this quantity to study the generalization performances of neural networks, noting "that larger GSNR during training process leads to better generalization performance. The instrument shows the mean GSNR versus iteration, overlayed with an exponentially weighted average. **Preview** .. image:: ../../_static/instrument_previews/MeanGSNR.png :alt: Preview MeanGSNR Gauge **Requires** This instrument requires data from the :class:`~cockpit.quantities.MeanGSNR` quantity class. Args: self (CockpitPlotter): The cockpit plotter requesting this instrument. fig (matplotlib.figure.Figure): Figure of the Cockpit. gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be placed """ # Plot Trace vs iteration title = "Mean GSNR" # Check if the required data is available, else skip this instrument requires = ["MeanGSNR"] plot_possible = check_data(self.tracking_data, requires) if not plot_possible: if self.debug: warnings.warn( "Couldn't get the required data for the " + title + " instrument", stacklevel=1, ) return plot_args = { "x": "iteration", "y": "MeanGSNR", "data": self.tracking_data, "x_scale": "symlog" if self.show_log_iter else "linear", "y_scale": "linear", "cmap": self.cmap, "EMA": "y", "EMA_alpha": self.EMA_alpha, "EMA_cmap": self.cmap2, "title": title, "xlim": "tight", "ylim": None, "fontweight": "bold", "facecolor": self.bg_color_instruments, } ax = fig.add_subplot(gridspec) create_basic_plot(**plot_args, ax=ax)