Testing Your PyTorch Models with Torcheck

A convenient sanity check toolkit for PyTorch

Peng Yan
Towards Data Science

--

Photo by Scott Graham on Unsplash

Have you ever had the experience of training a PyTorch model for long hours, only to find that you have typed one line wrong in the model’s forward method? Have you ever run into the situation that you obtained somewhat reasonable output from your model, but not sure if it indicated you had built the model right, or it was just because deep learning was so powerful that even wrong model architecture produced descent results?

Speaking for myself, testing a deep learning model drives me crazy from time to time. The most outstanding pain points are:

  • Its black box nature makes it hard to test. If not impossible, it requires much expertise to make sense of the intermediate results.
  • The long training time greatly reduces the number of iterations.
  • There is not a dedicated tool. Usually you will want to test your model on a small sample dataset, which involves repeatedly writing boilerplate code for setup optimizer, calculating loss, and backward propagation .

To mitigate this overhead, I have done some research previously. I found a wonderful medium post on this topic by Chase Roberts. The core idea is that we can never be one hundred percent sure that our model is correct, but at least it should be able to pass some sanity checks. In other words, these sanity checks are necessary but may not be sufficient.

To save your time, here is a summary of all the sanity checks he proposed:

  • A model parameter should always change during the training procedure, if it is not frozen on purpose. This can be a weight tensor for a PyTorch linear layer.
  • A model parameter should not change during the training procedure, if it is frozen. This can be a pre-trained layer you don’t want to update.
  • The range of model outputs should obey certain conditions depending on your model property. For example, if it is a classification model, its outputs should not all be in the range (0, 1). Otherwise, it is highly likely that you incorrectly apply softmax function to the outputs before loss calculation.
  • (This one actually is not from that post, but is a common one) A model parameter should not contain NaN(not a number) or Inf(infinite number) in most cases. Same applies to model outputs.

In addition to proposing these checks, He also built a Python package that implemented them. It is a nice package, but still there are unsolved pain points. The package is built several years ago and is no longer being maintained.

So, inspired by this idea of sanity check, aiming at creating an easy-to-use Python package, torcheck is created! Its major innovations are:

  • Additional testing code is no longer needed. Just add a few lines of code specifying the checks before training, torcheck will take over, perform checks while the training happens, and raise an informative error message when checks fail.
  • Checking your model on different levels is made possible. Instead of checking the whole model, you can specify checks for a submodule, a linear layer, or even the weight tensor! This enables more customization around the checks for complicated architecture.

Next, we will give you a quick tutorial on torcheck. Here are some links you may find useful:

Suppose we have coded up a ConvNet model for classifying the MNIST dataset. The full training routine looks like this:

There is actually a subtle error in the model code. Some of you may have noticed: in line 16, we carelessly put x on the right hand side, which should be output instead.

Now let’s see how torcheck help you detect this hidden error!

Step 0: Installation

Before we start, first install the package in one line.

$ pip install torcheck

Step 1: Adding torcheck code

Next we will add in code. Torcheck code always resides right before the training for loop, after your model and optimizer instantiation, as is shown below:

Step 1.1: Registering your optimizer(s)

First, register your optimizer(s) with torcheck:

torcheck.register(optimizer)

Step 1.2: Adding sanity checks

Next, add all the checks you want to perform in the four categories.

1. Parameters change/not change

For our example, we want all the model parameters to change during the training procedure. Adding the check is simple:

# check all the model parameters will change
# module_name is optional, but it makes error messages more informative when checks fail
torcheck.add_module_changing_check(model, module_name="my_model")

Side Note

To demonstrate the full capability of torcheck, let’s say later you freeze the convolutional layers and only want to fine tune the linear layers. Adding checks in this situation would be like:

# check the first convolutional layer's parameters won't change
torcheck.add_module_unchanging_check(model.conv1, module_name="conv_layer_1")
# check the second convolutional layer's parameters won't change
torcheck.add_module_unchanging_check(model.conv2, module_name="conv_layer_2")
# check the third convolutional layer's parameters won't change
torcheck.add_module_unchanging_check(model.conv3, module_name="conv_layer_3")
# check the first linear layer's parameters will change
torcheck.add_module_changing_check(model.fc1, module_name="linear_layer_1")
# check the second linear layer's parameters will change
torcheck.add_module_changing_check(model.fc2, module_name="linear_layer_2")

2. Output range check

Since our model is a classification model, we want to add the check mentioned earlier: model outputs should not all be in the range (0, 1).

# check model outputs are not all within (0, 1)
# aka softmax hasn't been applied before loss calculation
torcheck.add_module_output_range_check(
model,
output_range=(0, 1),
negate_range=True,
)

The negate_range=True argument carries the meaning of “not all”. If you simply want to check model outputs are all within a certain range, just remove that argument.

Although not applicable to our example, torcheck enables you to check the intermediate outputs of submodules as well.

3. NaN check

We definitely want to make sure model parameters don’t become NaN during training, and model outputs don’t contain NaN. Adding the NaN check is simple:

# check whether model parameters become NaN or outputs contain NaN
torcheck.add_module_nan_check(model)

4. Inf check

Similarly, add the Inf check:

# check whether model parameters become infinite or outputs contain infinite value
torcheck.add_module_inf_check(model)

After adding all the checks of interest, the final training code looks like this:

Step 2: Training and fixing

Now let’s run the training as usual and see what happens:

$ python run.pyTraceback (most recent call last):
(stack trace information here)
RuntimeError: The following errors are detected while training:
Module my_model's conv1.weight should change.
Module my_model's conv1.bias should change.

Bang! We immediately get an error message saying that our model’s conv1.weight and conv1.bias don’t change. There must be something wrong with model.conv1 .

As expected, we head over to the model code, notice the error, fix it, and rerun the training. Now everything works like a charm :)

(Optional) Step 3: Turning off checks

Yay! Our model has passed all the checks. To get rid of them, we can simply call

torcheck.disable()

This is useful when you want to run your model on a validation set, or you just want to remove the checking overhead from your model training.

If you ever want to turn on the checks again, just call

torcheck.enable()

And that’s pretty much all you need to start using torcheck. For a more comprehensive introduction, please refer to its GitHub page.

As the package is still in early stage, there may be bugs here and there. Please do not hesitate to report them. Any contribution is welcome.

Hope you find this package useful and make you more confident about your black box models!

--

--