Advanced Example
In this second example, we will take a look into a couple more advanced features and ways to customize Cockpit.
Note
Just like before, you also need the utility file
which provides us with the training data, a convolutional network and a logpath.
You can copy all example files from our
repository.
1"""A slightly advanced example of using Cockpit with PyTorch for Fashion-MNIST."""
2
3import torch
4from _utils_examples import cnn, fmnist_data, get_logpath
5from backpack import extend, extensions
6
7from cockpit import Cockpit, CockpitPlotter, quantities
8from cockpit.utils import schedules
9
10# Build Fashion-MNIST classifier
11fmnist_data = fmnist_data()
12model = extend(cnn()) # Use a basic convolutional network
13loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="mean"))
14individual_loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="none"))
15
16# Create SGD Optimizer
17opt = torch.optim.SGD(model.parameters(), lr=5e-1)
18
19# Create Cockpit and a plotter
20# Customize the tracked quantities and their tracking schedule
21quantities = [
22 quantities.GradNorm(schedules.linear(interval=1)),
23 quantities.Distance(schedules.linear(interval=1)),
24 quantities.UpdateSize(schedules.linear(interval=1)),
25 quantities.HessMaxEV(schedules.linear(interval=3)),
26 quantities.GradHist1d(schedules.linear(interval=10), bins=10),
27]
28cockpit = Cockpit(model.parameters(), quantities=quantities)
29plotter = CockpitPlotter()
30
31# Main training loop
32max_steps, global_step = 50, 0
33for inputs, labels in iter(fmnist_data):
34 opt.zero_grad()
35
36 # forward pass
37 outputs = model(inputs)
38 loss = loss_fn(outputs, labels)
39 losses = individual_loss_fn(outputs, labels)
40
41 # backward pass
42 with cockpit(
43 global_step,
44 extensions.DiagHessian(), # Other BackPACK quantities can be computed as well
45 info={
46 "batch_size": inputs.shape[0],
47 "individual_losses": losses,
48 "loss": loss,
49 "optimizer": opt,
50 },
51 ):
52 loss.backward(create_graph=cockpit.create_graph(global_step))
53
54 # optimizer step
55 opt.step()
56 global_step += 1
57
58 print(f"Step: {global_step:5d} | Loss: {loss.item():.4f}")
59
60 if global_step % 10 == 0:
61 plotter.plot(
62 cockpit,
63 savedir=get_logpath(),
64 show_plot=False,
65 save_plot=True,
66 savename_append=str(global_step),
67 )
68
69 if global_step >= max_steps:
70 break
71
72# Write Cockpit to json file.
73cockpit.write(get_logpath())
74
75# Plot results from file
76plotter.plot(
77 get_logpath(),
78 savedir=get_logpath(),
79 show_plot=False,
80 save_plot=True,
81 savename_append="_final",
82)
To run this example script
,
run
python 02_advanced_fmnist.py
This time no Cockpit-plot will show. Instead, the plots will be directly saved to files. Everything that the Cockpit tracked during training will also be stored and both this logfile as well as the plots can be inspected and analyzed after training is finished.
$ python 02_advanced_fmnist.py
Step: 1 | Loss: 2.2965
Step: 2 | Loss: 2.3185
Step: 3 | Loss: 2.2932
Step: 4 | Loss: 2.2893
Step: 5 | Loss: 2.2865
Step: 6 | Loss: 2.2780
Step: 7 | Loss: 2.2485
Step: 8 | Loss: 2.2600
Step: 9 | Loss: 2.1823
Step: 10 | Loss: 2.0557
[cockpit|plot] Saving figure in ~/logfiles/cockpit_output__primary10.png
Step: 11 | Loss: 2.0539
Step: 12 | Loss: 2.4792
[...]
We will now go over the main changes compared to the basic example. The relevant lines including the most important changes are highilghted above.
Network Architecture
12model = extend(cnn()) # Use a basic convolutional network
In contrast to the basic example, we use a convolutional model architecture instead of a dense network. For Cockpit this does not change anything! Just like before, the network needs to be extended, which works in exactly the same way as before.
Customizing the Quantities
19# Create Cockpit and a plotter
20# Customize the tracked quantities and their tracking schedule
21quantities = [
22 quantities.GradNorm(schedules.linear(interval=1)),
23 quantities.Distance(schedules.linear(interval=1)),
24 quantities.UpdateSize(schedules.linear(interval=1)),
25 quantities.HessMaxEV(schedules.linear(interval=3)),
26 quantities.GradHist1d(schedules.linear(interval=10), bins=10),
27]
28cockpit = Cockpit(model.parameters(), quantities=quantities)
Cockpit allows you to fully customize what and how often you want to track it.
Instead of using a pre-defined configuration, here, we customize our set of quantities.
We use five quantities, the GradNorm
, the
Distance
and UpdateSize
,
the HessMaxEV
and the
GradHist1d
. For each quantity we use a different
rate of tracking, e.g., we track the distance and update size in every step, but
the largest eigenvalue of the Hessian only every third step.
We further customize the gradient histogram
by specifying that we only want to use 10 different bins
.
Additional BackPACK Extensions
44 extensions.DiagHessian(), # Other BackPACK quantities can be computed as well
Cockpit uses BackPACK for most of the background computations to extract more
information from a single backward pass. If you want to use additional BackPACK
extensions you can just pass them using the with cockpit()
context.
Plotting Options
60 if global_step % 10 == 0:
61 plotter.plot(
62 cockpit,
63 savedir=get_logpath(),
64 show_plot=False,
65 save_plot=True,
66 savename_append=str(global_step),
67 )
68
69 if global_step >= max_steps:
70 break
71
72# Write Cockpit to json file.
73cockpit.write(get_logpath())
74
75# Plot results from file
76plotter.plot(
77 get_logpath(),
78 savedir=get_logpath(),
79 show_plot=False,
80 save_plot=True,
81 savename_append="_final",
82)
In this example, we now create a Cockpit view every tenth iteration. Instead of showing it in real-time, however, we directly save to disk. At the end of the training process, we write all Cockpit information to a log file. We can then also plot from this file, which we do in the last step.
Writing and then plotting from a log file allows Cockpit to not only be used as a real-time diagnostic tool, but also to examine experiments later or compare multiple runs by contrasting their Cockpit views.
The final Cockpit plot gets also saved and will look similar to the below image.
You can see that the CockpitPlotter
only shows the
instruments where the corresponding quantities
have been tracked.