Source code for cockpit.instruments.distance_gauge

"""Distance Gauge."""

import warnings

from cockpit.instruments.utils_instruments import check_data, create_basic_plot
from cockpit.quantities.utils_quantities import _root_sum_of_squares


[docs]def distance_gauge(self, fig, gridspec): """Distance gauge showing two different quantities related to distance. This instruments shows two quantities at once. Firstly, the :math:`L_2`-distance of the current parameters to their initialization. This describes the total distance that the optimization trajectory "has traveled so far" and can be seen via the blue-to-green dots (and the left y-axis). Secondly, the update sizes of individual steps are shown via the yellow-to-blue dots (and the right y-axis). It measure the distance that a single parameter update covers. Both quantities are overlayed with an exponentially weighted average. .. image:: ../../_static/instrument_previews/Distances.png :alt: Preview Distances Gauge **Requires** The distance instrument requires data from both, the :class:`~cockpit.quantities.UpdateSize` and the :class:`~cockpit.quantities.Distance` quantity class. Args: self (CockpitPlotter): The cockpit plotter requesting this instrument. fig (matplotlib.figure.Figure): Figure of the Cockpit. gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be placed """ # Plot Trace vs iteration title = "Distance" # Check if the required data is available, else skip this instrument requires = ["Distance", "UpdateSize"] plot_possible = check_data(self.tracking_data, requires) if not plot_possible: if self.debug: warnings.warn( "Couldn't get the required data for the " + title + " instrument", stacklevel=1, ) return # Compute self.tracking_data["Distance_all"] = self.tracking_data.Distance.map( lambda x: _root_sum_of_squares(x) if type(x) == list else x ) self.tracking_data["UpdateSize_all"] = self.tracking_data.UpdateSize.map( lambda x: _root_sum_of_squares(x) if type(x) == list else x ) plot_args = { "x": "iteration", "y": "Distance_all", "data": self.tracking_data, "y_scale": "linear", "x_scale": "symlog" if self.show_log_iter else "linear", "cmap": self.cmap, "EMA": "y", "EMA_alpha": self.EMA_alpha, "EMA_cmap": self.cmap2, "title": title, "xlim": "tight", "ylim": None, "fontweight": "bold", "facecolor": self.bg_color_instruments, } ax = fig.add_subplot(gridspec) create_basic_plot(**plot_args, ax=ax) ax2 = ax.twinx() plot_args = { "x": "iteration", "y": "UpdateSize_all", "data": self.tracking_data, "y_scale": "linear", "x_scale": "symlog" if self.show_log_iter else "linear", "cmap": self.cmap.reversed(), "EMA": "y", "EMA_alpha": self.EMA_alpha, "EMA_cmap": self.cmap2.reversed(), "xlim": "tight", "ylim": None, "marker": ",", } create_basic_plot(**plot_args, ax=ax2)