How to Visualize PyTorch Neural Networks - 3 Examples in Python

By:
Dario Radečić
November 17, 2022

If you truly want to wrap your head around a deep learning model, visualizing it might be a good idea. These networks typically have dozens of layers, and figuring out what's going on from the summary alone won't get you far. That's why today we'll show you 3 ways to visualize Pytorch neural networks.

We'll first build a simple feed-forward neural network model for the well-known <a href="https://gist.github.com/netj/8836201" target="_blank" rel="noopener">Iris dataset</a>. You'll see that visualizing models/model architectures isn't complicated at all, and will take you only a couple of lines of code.
<blockquote>Data for Good - <a href="https://appsilon.com/yolo-counting-nests-antarctic-birds/" target="_blank" rel="noopener">How Appsilon Counted Nests of Shags with YOLO Object Detection Algorithm</a>.</blockquote>
Table of contents:
<ul><li><a href="#getting-started">Getting Started with PyTorch: Let's Build a Neural Network</a></li><li><a href="#torchviz">Torchviz: Visualize PyTorch Neural Networks With a Single Function Call</a></li><li><a href="#netron">Netron: Desktop App for Visualizing ONNX Models</a></li><li><a href="#tensorboard">Tensorboard: Visualize Machine Learning Workflow and Graphs</a></li><li><a href="#summary">Summing up How to Visualize PyTorch Neural Networks</a></li></ul>

<hr />

<h2 id="getting-started">Getting Started with PyTorch: Let's Build a Neural Network</h2>
Building a neural network model from scratch in PyTorch is easier than it sounds. Previous experience with the library is desirable, but not required - you'll have no trouble following if you prefer some other deep learning package.

We'll build a model around the Iris dataset for two reasons:
<ol><li><b>No data preparation is needed</b> - the dataset is simple to understand, clean, and ready for supervised machine learning classification.</li><li><b>You don't need a huge network to get accurate results</b> - which makes visualizing the network easier.</li></ol>
The code snippet below imports all Python libraries we'll need for now and loads in the dataset:
<pre><code class="language-python">import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
<br>iris = pd.read_csv("https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv")
iris.head()</code></pre>
<img class="wp-image-16095 size-full" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2f546c250f02d99708f_15037eaa_1-2.webp" alt="Image 1 - Head of the Iris dataset for the PyTorch neural network example" width="496" height="197" /> Image 1 - Head of the Iris dataset

Now, PyTorch can't understand Pandas DataFrames, so we'll have to convert the dataset into a tensor format.

The features of the dataset can be passed straight into the <code>torch.tensor()</code> function, while the target variable requires some encoding (from string to integer):
<pre><code class="language-python">X = torch.tensor(iris.drop("variety", axis=1).values, dtype=torch.float)
y = torch.tensor(
   [0 if vty == "Setosa" else 1 if vty == "Versicolor" else 2 for vty in iris["variety"]],
   dtype=torch.long
)
<br>print(X[:3])
print()
print(y[:3])</code></pre>
<img class="size-full wp-image-16097" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2f54d0cd7b298104144_bef4a3d6_2-2.webp" alt="Image 2 - Contents of the feature and target tensors" width="399" height="103" /> Image 2 - Contents of the feature and target tensors

And that's it. The dataset is ready to be passed into a PyTorch neural network model. Let's build one next. It will have an input layer going from 4 features to 16 nodes, one hidden layer, and an output layer going from 16 nodes to 3 class probabilities:
<pre><code class="language-python">class Net(nn.Module):
   def __init__(self):
       super().__init__()
       self.input = nn.Linear(in_features=4, out_features=16)
       self.hidden_1 = nn.Linear(in_features=16, out_features=16)
       self.output = nn.Linear(in_features=16, out_features=3)
       
   def forward(self, x):
       x = F.relu(self.input(x))
       x = F.relu(self.hidden_1(x))
       return self.output(x)
   
   
model = Net()
print(model)</code></pre>
<img class="size-full wp-image-16099" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2f6e365060c184c6163_8693fd06_3-2.webp" alt="Image 3 - Summary of a neural network model" width="594" height="100" /> Image 3 - Summary of a neural network model

It's easy to look at the summary of this model since there are only a couple of layers, but imagine you had a deep network with dozens of layers - all of the sudden, the summary would be too large to fit the screen.

In the following section, we'll explore the first way to visualize PyTorch neural networks, and that is with the Torchviz library.
<h2 id="torchviz">Torchviz: Visualize PyTorch Neural Networks With a Single Function Call</h2>
<a href="https://github.com/szagoruyko/pytorchviz" target="_blank" rel="noopener">Torchviz</a> is a Python package used to create visualizations of PyTorch execution graphs and traces. It depends on Graphviz, which is a dependency you'll have to install system-wide (Mac example shown below). Once installed, you can install Torchviz with pip:
<pre><code class="language-shell">brew install graphviz
pip install torchviz</code></pre>
To use Torchviz in Python, you'll have to import the <code>make_dot()</code> function, make an instance of your neural network class, and calculate prediction probabilities of the entire training set or a batch of samples. Since the Iris dataset is small, we'll calculate predictions for all flower instances:
<pre><code class="language-python">from torchviz import make_dot
<br>model = Net()
y = model(X)</code></pre>
That's all you need to visualize the network. Simply pass the average of the probability tensor alongside the model parameters to the <code>make_dot()</code> function:
<pre><code class="language-python">make_dot(y.mean(), params=dict(model.named_parameters()))</code></pre>
<img class="wp-image-16101 size-full" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2f7519e83c01afa4847_817eb63f_4-2.webp" alt="Image 4 - Visualizing a neural network model with torchviz (1)" width="578" height="806" /> Image 4 - Visualizing model with torchviz (1)

You can also see what autograd saves for the backward pass by specifying two additional parameters: <code>show_attrs=True</code> and <code>show_saved=True</code>:
<pre><code class="language-python">make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)</code></pre>
<img class="size-full wp-image-16103" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2f9a11449a24ce93e71_1489e78a_5-2.webp" alt="Image 5 - Visualizing model with torchviz (2)" width="796" height="1157" /> Image 5 - Visualizing model with torchviz (2)

It's a bit more detailed graph, but maybe that's what you're aiming for.

Next, we'll explore a Desktop app used to visualize any ONNX model.
<h2 id="netron">Netron: Desktop App for Visualizing ONNX Models</h2>
<a href="https://github.com/lutzroeder/netron" target="_blank" rel="noopener">Netron</a> is a Desktop and Web interface for visualizing neural network models from different libraries, including PyTorch. It works best if you export the model into an ONNX format (Open Neural Network Exchange), which is as simple as a function call in PyTorch.

You can download the Desktop standalone application, or you can use a web interface linked in the documentation. There are also Python server options, but we haven't explored them.

To get started, specify names for inputs and outputs as a list of string(s). Feel free to name these however you want. Once done, call the <code>torch.onnx.export()</code> function to export the model to a file:
<pre><code class="language-python">input_names = ["Iris"]
output_names = ["Iris Species Prediction"]
<br>torch.onnx.export(model, X, "model.onnx", input_names=input_names, output_names=output_names)</code></pre>
The model is now saved to <code>model.onnx</code> file, and you can easily load it into Netron. Here's what it looks like:

<img class="wp-image-16105 size-full" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2fa75b180b12ad3c1ce_e973b5fa_6.gif" alt="Image 6 - Visualizing PyTorch neural network model with Netron" width="1002" height="734" /> Image 6 - Visualizing model with Netron

Let's explore another way to visualize PyTorch neural networks which Tensorflow users will find familiar.
<h2 id="tensorboard">Tensorboard: Visualize Machine Learning Workflow and Graphs</h2>
<a href="https://www.tensorflow.org/tensorboard" target="_blank" rel="noopener">TensorBoard</a> is a visualization and tooling framework needed for machine learning experimentations. It has many features useful to deep learning researchers and practitioners, one of them being visualizing the model graph.

That's exactly the feature we'll explore today. But first, make sure to install TensorBoard through pip:
<pre><code class="language-shell">pip install tensorboard</code></pre>
So, how can you connect the PyTorch model with TensorBoard? You'll need to take advantage of the <code>SummaryWriter</code> class from PyTorch, and add a network graph to a log directory. In our example, the logs will be saved to the <code>torchlogs/</code> folder:
<pre><code class="language-python">from torch.utils.tensorboard import SummaryWriter
<br>writer = SummaryWriter("torchlogs/")
model = Net()
writer.add_graph(model, X)
writer.close()</code></pre>
Once the network graph is saved, navigate to the log directory from the shell and launch TensorBoard:
<pre><code class="language-shell">cd &lt;path-to-logs-dir&gt;
tensorboard --logdir=./</code></pre>
<img class="size-full wp-image-16107" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2fbc51a64697dafbbf2_ec8781ce_7-2.webp" alt="Image 7 - Starting Tensorboard from the shell" width="1149" height="604" /> Image 7 - Starting Tensorboard from the shell

You'll be able to see the model graph on <code>http://localhost:6006</code>. You can click on any graph element and TensorBoard will expand it for you, as shown in the figure below:

<img class="wp-image-16109 size-full" src="https://webflow-prod-assets.s3.amazonaws.com/6525256482c9e9a06c7a9d3c%2F65b7d2fb46c250f02d997a40_9c3a4b1b_8-2.webp" alt="Image 8 - Visualizing neural network model with Tensorboard" width="1303" height="1266" /> Image 8 - Visualizing model with Tensorboard

And that's it for the ways to visualize PyTorch neural networks. Let's make a short recap next.

<hr />

<h2 id="summary">Summing up How to Visualize PyTorch Neural Networks</h2>
If you want to understand what's going on in a neural network model, visualizing the network graph is the way to go. Sure, you need to actually understand why the network is constructed the way it is, but that's a fundamental deep learning knowledge we assume you have.
<blockquote>Maximize the benefits of your <a href="https://appsilon.com/pytorch-lightning-hydra-templates-in-machine-learning/" target="_blank" rel="noopener">ML projects with templates using PyTorch Lightning &amp; Hydra</a>.</blockquote>
We've explored three ways to visualize neural network models from PyTorch - with Torchviz, Netron, and TensorBoard. All are excellent, and there's no way to pick a winner. Let us know which one you prefer.

Do you use some other tool to visualize neural network model graphs? Please let us know in our community Slack channel. Also, don't hesitate to move the discussion to Twitter - @appsilon. We'd love to hear from you.

<blockquote>What are benefits of Model Serialization? <a href="https://appsilon.com/model-serialization-in-machine-learning/" target="_blank" rel="noopener">Find out in our latest blog post by Piotr Storożenko</a>.</blockquote>

Have questions or insights?

Engage with experts, share ideas and take your data journey to the next level!

Is Your Software GxP Compliant?

Download a checklist designed for clinical managers in data departments to make sure that software meets requirements for FDA and EMA submissions.

Sign up for ShinyWeekly

Join 4,2k explorers and get the Shiny Weekly Newsletter into your mailbox
for the latest in R/Shiny and Data Science.

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
Explore Possibilities

Share Your Data Goals with Us

From advanced analytics to platform development and pharma consulting, we craft solutions tailored to your needs.

Talk to our Experts
PyTorch
python
tutorials
ai&research