Source code for cockpit.instruments.alpha_gauge

"""Alpha Gauge."""

import math
import warnings

import numpy as np
import seaborn as sns
from scipy import stats

from cockpit.instruments.utils_instruments import _beautify_plot, check_data


[docs]def alpha_gauge(self, fig, gridspec): r"""Showing a distribution of the alpha values. This alpha instruments provides a summary of the alpha values of all tracked iterations (shown in gray) as well as the last 10% of them (shown in blue). The alpha value uses loss and gradient information to build a noise-informed univariate quadratic approximation of the loss function to assess to which point on this parabola the optimizer moves. The parabola (shown in orange) is always normalized such that the starting point of each iteration is at :math:`x=-1`. If the optimzer takes a single step to the local minimum (e.g. the valley floor) it would be indicated by an :math:`\alpha` value of :math:`0`. Analogously, taking a step to the exact opposite side of the valley, will be descrbied by :math:`\alpha=1`. The instruments shows a histogram and a distribution fit of all alpha values, as well as the last 10% of tracked iterations. The mean values of both distributions are also shown in the figure's legend. For a more detailed explanation of the alpha value and why - perhaps against intuition - values larger than zero might be desirable, have a look at the Cockpit paper: - `Schneider, F., Dangel, F., & Hennig, P., Cockpit: A Practical Debugging Tool for Training Deep Neural Networks (2021). <https://arxiv.org/abs/2102.06604>`_ .. image:: ../../_static/instrument_previews/Alpha.png :alt: Preview Alpha Gauge **Requires** This instrument requires data from the :class:`~cockpit.quantities.Alpha` quantity class. Args: self (CockpitPlotter): The cockpit plotter requesting this instrument. fig (matplotlib.figure.Figure): Figure of the Cockpit. gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be placed. """ # Plot Alpha Distribution title = "Alpha Distribution" # Check if the required data is available, else skip this instrument requires = ["Alpha"] plot_possible = check_data(self.tracking_data, requires, min_elements=2) if not plot_possible: if self.debug: warnings.warn( "Couldn't get the required data for the " + title + " instrument", stacklevel=1, ) return plot_args = { "xlabel": "Local Step Length", "ylabel": "Stand. Loss", "title": title, "xlim": [-1.5, 1.5], "ylim": [0, 1.75], "fontweight": "bold", "facecolor": self.bg_color_instruments, "zero_lines": True, "center": [True, False], } color_all = "gray" color_last = self.primary_color color_parabola = self.secondary_color ax = fig.add_subplot(gridspec) # Plot unit parabola x = np.linspace(plot_args["xlim"][0], plot_args["xlim"][1], 100) y = x ** 2 ax.plot(x, y, linewidth=2, color=color_parabola) _beautify_plot(**plot_args, ax=ax) # Alpha Histogram ax2 = ax.twinx() # All alphas sns.histplot( self.tracking_data["Alpha"].dropna(), ax=ax2, kde=True, color=color_all, kde_kws={"cut": 10}, alpha=0.5, stat="probability", label="all", ) (mu_all, _) = stats.norm.fit(self.tracking_data["Alpha"].dropna()) # Last 10% alphas len_last_elements = int(len(self.tracking_data["Alpha"]) / 10) sns.histplot( self.tracking_data["Alpha"].dropna().tail(len_last_elements), ax=ax2, kde=True, color=color_last, kde_kws={"cut": 10}, alpha=0.5, stat="probability", label="last 10 %", ) if len_last_elements == 0: mu_last = math.nan else: (mu_last, _) = stats.norm.fit( self.tracking_data["Alpha"].dropna().tail(len_last_elements) ) # Manually beautify the plot: # Adding Zone Lines ax.axvline(0, ls="-", color="#ababba", linewidth=1.5, zorder=0) ax.axvline(-1, ls="-", color="#ababba", linewidth=1.5, zorder=0) ax.axvline(1, ls="-", color="#ababba", linewidth=1.5, zorder=0) ax.axhline(0, ls="-", color="#ababba", linewidth=1.5, zorder=0) ax.axhline(1, ls="-", color="#ababba", linewidth=0.5, zorder=0) # Labels ax.set_xlabel(r"Local step length $\alpha$") ax2.set_ylabel(r"$\alpha$ density") _add_indicators( self, ax, mu_last, plot_args, color_all, color_last, len_last_elements ) # Legend # Get the fitted parameters used by sns lines2, labels2 = ax2.get_legend_handles_labels() for idx, lab in enumerate(labels2): if "all" in lab and not math.isnan(mu_all): labels2[idx] = lab + " ($\mu=${0:.2f})".format(mu_all) # noqa: W605 if "last 10 %" in lab and not math.isnan(mu_last): labels2[idx] = lab + " ($\mu=${0:.2f})".format(mu_last) # noqa: W605 ax2.legend(lines2, labels2)
def _add_indicators( self, ax, mu_last, plot_args, color_all, color_last, len_last_elements ): """Adds indicators that some alpha values were outside of ploting range.""" # Add indicator for outliers if ( not math.isnan(mu_last) and max(self.tracking_data["Alpha"].dropna().tail(len_last_elements)) > plot_args["xlim"][1] ): ax.annotate( "", xy=(1.8, 0.3), xytext=(1.7, 0.3), size=20, arrowprops=dict(color=color_last), ) elif max(self.tracking_data["Alpha"]) > plot_args["xlim"][1]: ax.annotate( "", xy=(1.8, 0.3), xytext=(1.7, 0.3), size=20, arrowprops=dict(color=color_all), ) if ( not math.isnan(mu_last) and min(self.tracking_data["Alpha"].dropna().tail(len_last_elements)) < plot_args["xlim"][0] ): ax.annotate( "", xy=(-1.8, 0.3), xytext=(-1.7, 0.3), size=20, arrowprops=dict(color=color_last), ) elif min(self.tracking_data["Alpha"]) < plot_args["xlim"][0]: ax.annotate( "", xy=(-1.8, 0.3), xytext=(-1.7, 0.3), size=20, arrowprops=dict(color=color_all), )