How to Visualize PyTorch Neural Networks – 3 Examples in Python
If you truly want to wrap your head around a deep learning model, visualizing it might be a good idea. These networks typically have dozens of layers, and figuring out what’s going on from the summary alone won’t get you far. That’s why today we’ll show you 3 ways to visualize Pytorch neural networks.
We’ll first build a simple feed-forward neural network model for the well-known Iris dataset. You’ll see that visualizing models/model architectures isn’t complicated at all, and will take you only a couple of lines of code.
Table of contents:
- Getting Started with PyTorch: Let’s Build a Neural Network
- Torchviz: Visualize PyTorch Neural Networks With a Single Function Call
- Netron: Desktop App for Visualizing ONNX Models
- Tensorboard: Visualize Machine Learning Workflow and Graphs
- Summing up How to Visualize PyTorch Neural Networks
Getting Started with PyTorch: Let’s Build a Neural Network
Building a neural network model from scratch in PyTorch is easier than it sounds. Previous experience with the library is desirable, but not required – you’ll have no trouble following if you prefer some other deep learning package.
We’ll build a model around the Iris dataset for two reasons:
- No data preparation is needed – the dataset is simple to understand, clean, and ready for supervised machine learning classification.
- You don’t need a huge network to get accurate results – which makes visualizing the network easier.
The code snippet below imports all Python libraries we’ll need for now and loads in the dataset:
import torch import torch.nn as nn import torch.nn.functional as F import pandas as pd iris = pd.read_csv("https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv") iris.head()
Now, PyTorch can’t understand Pandas DataFrames, so we’ll have to convert the dataset into a tensor format.
The features of the dataset can be passed straight into the
torch.tensor() function, while the target variable requires some encoding (from string to integer):
X = torch.tensor(iris.drop("variety", axis=1).values, dtype=torch.float) y = torch.tensor( [0 if vty == "Setosa" else 1 if vty == "Versicolor" else 2 for vty in iris["variety"]], dtype=torch.long ) print(X[:3]) print() print(y[:3])
And that’s it. The dataset is ready to be passed into a PyTorch neural network model. Let’s build one next. It will have an input layer going from 4 features to 16 nodes, one hidden layer, and an output layer going from 16 nodes to 3 class probabilities:
class Net(nn.Module): def __init__(self): super().__init__() self.input = nn.Linear(in_features=4, out_features=16) self.hidden_1 = nn.Linear(in_features=16, out_features=16) self.output = nn.Linear(in_features=16, out_features=3) def forward(self, x): x = F.relu(self.input(x)) x = F.relu(self.hidden_1(x)) return self.output(x) model = Net() print(model)
It’s easy to look at the summary of this model since there are only a couple of layers, but imagine you had a deep network with dozens of layers – all of the sudden, the summary would be too large to fit the screen.
In the following section, we’ll explore the first way to visualize PyTorch neural networks, and that is with the Torchviz library.
Torchviz: Visualize PyTorch Neural Networks With a Single Function Call
Torchviz is a Python package used to create visualizations of PyTorch execution graphs and traces. It depends on Graphviz, which is a dependency you’ll have to install system-wide (Mac example shown below). Once installed, you can install Torchviz with pip:
brew install graphviz pip install torchviz
To use Torchviz in Python, you’ll have to import the
make_dot() function, make an instance of your neural network class, and calculate prediction probabilities of the entire training set or a batch of samples. Since the Iris dataset is small, we’ll calculate predictions for all flower instances:
from torchviz import make_dot model = Net() y = model(X)
That’s all you need to visualize the network. Simply pass the average of the probability tensor alongside the model parameters to the
You can also see what autograd saves for the backward pass by specifying two additional parameters:
make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
It’s a bit more detailed graph, but maybe that’s what you’re aiming for.
Next, we’ll explore a Desktop app used to visualize any ONNX model.
Netron: Desktop App for Visualizing ONNX Models
Netron is a Desktop and Web interface for visualizing neural network models from different libraries, including PyTorch. It works best if you export the model into an ONNX format (Open Neural Network Exchange), which is as simple as a function call in PyTorch.
You can download the Desktop standalone application, or you can use a web interface linked in the documentation. There are also Python server options, but we haven’t explored them.
To get started, specify names for inputs and outputs as a list of string(s). Feel free to name these however you want. Once done, call the
torch.onnx.export() function to export the model to a file:
input_names = ["Iris"] output_names = ["Iris Species Prediction"] torch.onnx.export(model, X, "model.onnx", input_names=input_names, output_names=output_names)
The model is now saved to
model.onnx file, and you can easily load it into Netron. Here’s what it looks like:
Let’s explore another way to visualize PyTorch neural networks which Tensorflow users will find familiar.
Tensorboard: Visualize Machine Learning Workflow and Graphs
TensorBoard is a visualization and tooling framework needed for machine learning experimentations. It has many features useful to deep learning researchers and practitioners, one of them being visualizing the model graph.
That’s exactly the feature we’ll explore today. But first, make sure to install TensorBoard through pip:
pip install tensorboard
So, how can you connect the PyTorch model with TensorBoard? You’ll need to take advantage of the
SummaryWriter class from PyTorch, and add a network graph to a log directory. In our example, the logs will be saved to the
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter("torchlogs/") model = Net() writer.add_graph(model, X) writer.close()
Once the network graph is saved, navigate to the log directory from the shell and launch TensorBoard:
cd <path-to-logs-dir> tensorboard --logdir=./
You’ll be able to see the model graph on
http://localhost:6006. You can click on any graph element and TensorBoard will expand it for you, as shown in the figure below:
And that’s it for the ways to visualize PyTorch neural networks. Let’s make a short recap next.
Summing up How to Visualize PyTorch Neural Networks
If you want to understand what’s going on in a neural network model, visualizing the network graph is the way to go. Sure, you need to actually understand why the network is constructed the way it is, but that’s a fundamental deep learning knowledge we assume you have.
Maximize the benefits of your ML projects with templates using PyTorch Lightning & Hydra.
We’ve explored three ways to visualize neural network models from PyTorch – with Torchviz, Netron, and TensorBoard. All are excellent, and there’s no way to pick a winner. Let us know which one you prefer.
Do you use some other tool to visualize neural network model graphs? Please let us know in the comment section below. Also, don’t hesitate to move the discussion to Twitter – @appsilon. We’d love to hear from you.
What are benefits of Model Serialization? Find out in our latest blog post by Piotr Storożenko.