Source code for cockpit.instruments.performance_gauge

"""Performance Gauge."""

import warnings

import seaborn as sns

from cockpit.instruments.utils_instruments import (
    _add_last_value_to_legend,
    check_data,
    create_basic_plot,
)


[docs]def performance_gauge(self, fig, gridspec): """Plotting train/valid accuracy vs. epoch and mini-batch loss vs. iteration. This instruments visualizes the currently most popular diagnostic metrics. It shows the mini-batch loss in each iteration (overlayed with an exponentially weighted average) as well as accuracies for both the training as well as the validation set. The current accuracy numbers are also shown in the legend. **Preview** .. image:: ../../_static/instrument_previews/Performance.png :alt: Preview Performance Gauge **Requires** This instrument visualizes quantities passed via the :func:`cockpit.Cockpit.log()` method. Args: self (CockpitPlotter): The cockpit plotter requesting this instrument. fig (matplotlib.figure.Figure): Figure of the Cockpit. gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be placed """ # Plot Trace vs iteration title = "Performance Plot" # Check if the required data is available, else skip this instrument requires = ["iteration", "Loss"] plot_possible = check_data(self.tracking_data, requires) if not plot_possible: if self.debug: warnings.warn( "Couldn't get the loss data for the " + title + " instrument", stacklevel=1, ) return # Mini-batch train loss plot_args = { "x": "iteration", "y": "Loss", "data": self.tracking_data, "EMA": "y", "EMA_alpha": self.EMA_alpha, "EMA_cmap": self.cmap2, "x_scale": "symlog" if self.show_log_iter else "linear", "y_scale": "linear", "cmap": self.cmap, "title": title, "xlim": "tight", "ylim": None, "fontweight": "bold", "facecolor": self.bg_color_instruments2, } ax = fig.add_subplot(gridspec) create_basic_plot(**plot_args, ax=ax) requires = ["iteration", "train_accuracy", "valid_accuracy"] plot_possible = check_data(self.tracking_data, requires) if not plot_possible: if self.debug: warnings.warn( "Couldn't get the accuracy data for the " + title + " instrument", stacklevel=1, ) return else: clean_accuracies = self.tracking_data[ ["iteration", "train_accuracy", "valid_accuracy"] ].dropna() # Train Accuracy plot_args = { "x": "iteration", "y": "train_accuracy", "data": clean_accuracies, } ax2 = ax.twinx() sns.lineplot( **plot_args, ax=ax2, label=plot_args["y"].title().replace("_", " "), linewidth=2, color=self.primary_color, ) # Train Accuracy plot_args = { "x": "iteration", "y": "valid_accuracy", "data": clean_accuracies, } sns.lineplot( **plot_args, ax=ax2, label=plot_args["y"].title().replace("_", " "), linewidth=2, color=self.secondary_color, ) # Customization ax2.set_ylim([0, 1]) ax2.set_ylabel("Accuracy") _add_last_value_to_legend(ax2, percentage=True)