November 14, 2022

A Gentle Introduction to PyTorch for Beginners (2023)

When machine learning with Python, you have multiple options for which library or framework to use. However, if you're moving toward deep learning, you should probably use either TensorFlow or PyTorch, the two most famous deep learning frameworks.

In this article, we'll go through a quick introduction to the PyTorch framework, going all the way from the initial concepts to the training and testing of the first image classification model.

We won't dive deep into learning complex concepts and mathematics, as this article intends to be a more hands-on approach for how to start with PyTorch as a tool, not with deep learning as a concept.

Therefore, we assume you have some intermediate Python knowledge--including classes and object-oriented programming--and you're familiar with the main concepts of deep learning.

PyTorch

PyTorch is a powerful, yet easy-to-use deep learning library for Python, mainly used for applications such as computer vision and natural language processing.

While TensorFlow was developed by Google, PyTorch was developed by Facebook's AI Research Group, which has recently shifted management of the framework to the newly created PyTorch Foundation, which is under the supervision of the Linux Foundation.

The flexibility of PyTorch allows easy integration of new data types and algorithms, and the framework is also efficient and scalable, since it was designed to minimize the number of computations required and to be compatible with a variety of hardware architectures.

Tensors

In deep learning, tensors are a fundamental data structure that is very similar to arrays and matrices, with which we can efficiently perform mathematical operations on large sets of data. A tensor can be represented as a matrix, but also as a vector, a scalar, or a higher-dimensional array.

To make it easier to visualize, you can think of a tensor as a simple array containing scalars or other arrays. On PyTorch, a tensor is a structure very similar to a ndarray, with the difference that they are capable of running on a GPU, which dramatically speeds up the computational process.

It's simple to create a tensor from a NumPy ndarray:

import torch
import numpy as np

ndarray = np.array([0, 1, 2])
t = torch.from_numpy(ndarray)
print(t)
    tensor([0, 1, 2])

A tensor on PyTorch has three attributes:

  • shape: the size of the tensor
  • data type: the type of data stored in the tensor
  • device: the device in which the tensor is stored

If we print the attributes from the tensor we created, we'll have the following:

print(t.shape)
print(t.dtype)
print(t.device)
    torch.Size([3])
    torch.int64
    cpu

This means we have a one-dimensional tensor with the size 3, containing integers stored in the CPU.

We can always instantiate a tensor from a Python list, too:

t = torch.tensor([0, 1, 2])
print(t)
    tensor([0, 1, 2])

Tensors can also be multidimensional:

ndarray = np.array([[0, 1, 2], [3, 4, 5]])
t = torch.from_numpy(ndarray)
print(t)
    tensor([[0, 1, 2],
            [3, 4, 5]])

It's also possible to create a tensor from another tensor. In this case, the new tensor inherits the characteristics of the initial one. The example below creates a tensor with random numbers based on the previously created tensor:

new_t = torch.rand_like(t, dtype=torch.float)
print(new_t)
    tensor([[0.1366, 0.5994, 0.3963],
            [0.1126, 0.8860, 0.8233]])

Note that the rand_like() function creates a new tensor with shape (2, 2). However, as the function returns values from 0 to 1, we had to overwrite the data type to float.

We can also create a random tensor simply from the shape we expect it to have:

my_shape = (3, 3)
rand_t = torch.rand(my_shape)
print(rand_t)
    tensor([[0.8099, 0.8816, 0.3071],
            [0.1003, 0.3190, 0.3503],
            [0.9088, 0.0844, 0.0547]])

Tensor Operations

Just like in NumPy, there are multiple possible operations we can perform with tensors--like slicing, transposing, and multiplying matrices, among others.

The slicing of a tensor is done exactly like any other array structure in Python. Consider the tensor below:

zeros_tensor = torch.zeros((2, 3))
print(zeros_tensor)
    tensor([[0., 0., 0.],
            [0., 0., 0.]])

We can easily index the first row or the first column:

print(zeros_tensor[1])
print(zeros_tensor[:, 0])
    tensor([0., 0., 0.])
    tensor([0., 0.])

We can also have this tensor transposed:

transposed = zeros_tensor.T
print(transposed)
    tensor([[0., 0.],
            [0., 0.],
            [0., 0.]])

Finally, we can multiply the tensors:

ones_tensor = torch.ones(3, 3)
product = torch.matmul(zeros_tensor, ones_tensor)
print(product)
    tensor([[0., 0., 0.],
            [0., 0., 0.]])

Notice that we used the zeros and ones function to create a tensor containing only zeros and ones with the shape we passed.

These operations are just a fraction of what PyTorch can do. However, the purpose of this article is not to cover each of them, but to give a general idea of how they work. If you want to learn more, PyTorch has a complete documentation.

Loading Data

PyTorch comes with a built-in module that provides ready-to-use datasets for many deep learning applications, such as computer vision, speech recognition, and natural language processing. This means that it's possible to build your own neural network without the need to collect and process data yourself.

As an example, we'll download the MNIST dataset. The MNIST is a dataset of images of handwritten digits, containing 60 thousand samples and a test set of 10 thousand images.

We'll use the datasets module from torchvision to download the data:

from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.MNIST(root=".", train=True, download=True, transform=ToTensor())

test_data = datasets.MNIST(root=".", train=False, download=True, transform=ToTensor())
    Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz

      0%|          | 0/9912422 [00:00<?, ?it/s]

    Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

    Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz

      0%|          | 0/28881 [00:00<?, ?it/s]

    Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz

      0%|          | 0/1648877 [00:00<?, ?it/s]

    Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz

      0%|          | 0/4542 [00:00<?, ?it/s]

    Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Inside the downloading function, we have the following parameters:

  1. root: the directory where the data will be saved. You can pass a string with the directory's path. A dot (as seen in the example) will save the files in the same directory you're in.

  2. train: used to inform PyTorch whether you're downloading the train or test set.

  3. download: whether to download the data if it's already unavailable at the path you specified.

  4. transform: to transform the data. In our code, we select tensor.

If we print the first element of the train set, we'll see the following:

training_data[0]
    (tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
               0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
               0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
               0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
               0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
               0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
               0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
               0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
               0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
               0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
               0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
               0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
               0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
               0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
               0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
               0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
               0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
               0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
               0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000],
              [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
               0.0000, 0.0000, 0.0000, 0.0000]]]), 5)
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],

The above tensor is just a small part of the entire element, as it would be too big to display.

This bunch of numbers may not mean anything to us, and since they represent images, we can use matplotlib to visualize them as actual images:

figure = plt.figure(figsize=(8, 8)) 
cols, rows = 5, 5

for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

output_25_0.png

We can also use the classes attribute to see the classes inside the data:

training_data.classes

    ['0 - zero',
     '1 - one',
     '2 - two',
     '3 - three',
     '4 - four',
     '5 - five',
     '6 - six',
     '7 - seven',
     '8 - eight',
     '9 - nine']

When the model is trained, it can receive new inputs, then classify as one of these classes.

Now that we have downloaded the data, we'll use the DataLoader. This enables us to iterate over the dataset in mini-batches instead of one observation at a time, and to shuffle the data while training the models. Here's the code:

from torch.utils.data import DataLoader

loaded_train = DataLoader(training_data, batch_size=64, shuffle=True)
loaded_test = DataLoader(test_data, batch_size=64, shuffle=True)

Neural Networks

In deep learning, a neural network is a type of algorithm used to model data with complex patterns. A neural network attempts to simulate the functioning of the human brain through multiple layers connected by processing nodes, which behave like human neurons. These layers connected by nodes create a complex net that is able to process and understand huge amounts of complex data.

In PyTorch, everything related to neural networks is built using the torch.nn module. The network itself is written as a class that inherits from nn.Module, and, inside the class, we'll use nn to build the layers. The following is a simple implementation taken from the PyTorch documentation:

from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

Although it's out of the scope of this article to go deep into what the layers are, how they work, and how to implement them, let's take a quick dive into what the above code does.

  • The nn.Flaten is responsible for transforming the data from multidimensional to one dimension only.

  • The nn.Sequential container creates a sequence of layers inside the network.

  • Inside the container, we have layers. Each type of layer transforms the data in a different way, and there are numerous ways to implement the layers in a neural network.

  • The forward function is the function called when the model is executed; however, we should not call it directly.

The following line instantiates our model:

model = NeuralNetwork()
print(model)
    NeuralNetwork(
      (flatten): Flatten(start_dim=1, end_dim=-1)
      (linear_relu_stack): Sequential(
        (0): Linear(in_features=784, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU()
        (4): Linear(in_features=512, out_features=10, bias=True)
      )
    )

Training the Neural Network

Now that we have defined our neural network, we can put it to use. Before starting the training, we should first set a loss function. The loss function measures how far our model is from the correct results, and it's what we'll try to minimize during the training of the network. Cross-entropy is a common loss function used for classification tasks, and it's the one we'll use. We should initialize the function:

loss_function = nn.CrossEntropyLoss()

One last step before training is to set an optimization algorithm. Such an algorithm will be in charge of adjusting the model during the training process in order to minimize the error measured by the loss function we chose above. A common choice for this kind of task is the stochastic gradient descent algorithm. PyTorch, however, has several other possibilities that you can become familiar with here. Below is the code:

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

The lr parameter is the learning rate, which represents the speed at which the model's parameters will be updated during each iteration in training.

Finally, it's time to train and test the network. For each of these tasks, we'll implement a function. The train function consists of looping through the data one batch at a time, using the optimizer to adjust the model, and computing the prediction and the loss. This is PyTorch's standard implementation:

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

Notice that for each iteration, we get the data to feed the model, but also keep track of the number of the batch so we can print the loss and the current batch every 100 iterations.

And then we have the test function, which computes the accuracy and the loss, this time using the test set:

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f

We then set the number of epochs to train our model. An epoch consists of an iteration over the dataset. For instance, if we set epochs=5 , it means we'll go through the entire dataset 5 times with the neural network training and testing. The more we train, the better the results.

This is PyTorch's implementation and the output of such a loop:

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(loaded_train, model, loss_function, optimizer)
    test(loaded_test, model, loss_function)
print("Done!")
    Epoch 1
    -------------------------------
    loss: 2.296232  [    0/60000]
    Test Error: 
     Accuracy: 47.3%, Avg loss: 2.254638 

    Epoch 2
    -------------------------------
    loss: 2.260034  [    0/60000]
    Test Error: 
     Accuracy: 63.2%, Avg loss: 2.183432 

    Epoch 3
    -------------------------------
    loss: 2.173747  [    0/60000]
    Test Error: 
     Accuracy: 66.9%, Avg loss: 2.062604 

    Epoch 4
    -------------------------------
    loss: 2.078938  [    0/60000]
    Test Error: 
     Accuracy: 72.4%, Avg loss: 1.859960 

    Epoch 5
    -------------------------------
    loss: 1.871736  [    0/60000]
    Test Error: 
     Accuracy: 75.8%, Avg loss: 1.562622 
    Done!

Notice that in each epoch, we print the loss function at every 100 batches in the training loop, and it keeps getting lower. Also, after each epoch, we can see the accuracy getting higher as the average loss decreases.

If we had set more epochs--let's say 10, 50, or even 100--chances are we'd see even better results, but the outputs would be much longer and much harder to visualize and understand.

With our model finally trained, it's easy to save it and load it when necessary:

torch.save(model, "model.pth")
model = torch.load("model.pth")

Conclusion

In this article, we covered the basics of using PyTorch for deep learning, including:

  • Tensors and how to use them

  • How to load and prepare the data

  • Neural networks and how to define them on PyTorch

  • How to train your first image classification model

  • Saving and loading models on PyTorch

Otávio Simões Silveira

About the author

Otávio Simões Silveira

Otávio is an economist and data scientist from Brazil. In his free time, he writes about Python and Data Science on the internet. You can find him at LinkedIn.