Getting with Pytorch

Share This Class:

Table of Contents


Building a Neural Network in PyTorch

We recommend breaking up the network into its smaller reusable pieces. A network is a nn.Module consisting of operations or other nn.Modules as building blocks. Loss functions are also nn.Module and can, therefore, be directly integrated into the network.

A class inheriting from nn.Module must have a forward method implementing the forward pass of the respective layer or operation.

nn.module can be used on input data using This simply uses the call() method of the object to feed the input through the module.

output =

A Simple Network in PyTorch

Use the following pattern for simple networks with a single input and single output:

class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
    def forward(self, x):
        return self.block(x)

class SimpleNetwork(nn.Module):
    def __init__(self, num_resnet_blocks=6):
        super(SimpleNetwork, self).__init__()
        # here we add the individual layers
        layers = [ConvBlock(...)]
        for i in range(num_resnet_blocks):
            layers += [ResBlock(...)] = nn.Sequential(*layers)
    def forward(self, x):

Note the following:

  • We reuse simple, recurrent building blocks such as ConvBlock which consists of the same recurrent pattern of (convolution, activation, normalization) and put them into a separate nn.Module
  • We build up a list of desired layers and finally turn them into a model using nn.Sequential(). We use the * operator before the list object to unwrap it.
  • In the forward pass we just run the input through the model

A Network with skip connections in PyTorch

class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(...)

    def build_conv_block(self, ...):
        conv_block = []

        conv_block += [nn.Conv2d(...),
        if use_dropout:
            conv_block += [nn.Dropout(...)]
        conv_block += [nn.Conv2d(...),

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

Here the skip connection of a ResNet block has been implemented directly in the forward pass. PyTorch allows for dynamic operations during the forward pass.

A Network with multiple outputs in PyTorch

For a network requiring multiple outputs, such as building a perceptual loss using a pretrained VGG network we use the following pattern:

class Vgg19(nn.Module):
  def __init__(self, requires_grad=False):
    super(Vgg19, self).__init__()
    vgg_pretrained_features = models.vgg19(pretrained=True).features
    self.slice1 = torch.nn.Sequential()
    self.slice2 = torch.nn.Sequential()
    self.slice3 = torch.nn.Sequential()

    for x in range(7):
        self.slice1.add_module(str(x), vgg_pretrained_features[x])
    for x in range(7, 21):
        self.slice2.add_module(str(x), vgg_pretrained_features[x])
    for x in range(21, 30):
        self.slice3.add_module(str(x), vgg_pretrained_features[x])
    if not requires_grad:
        for param in self.parameters():
            param.requires_grad = False

  def forward(self, x):
    h_relu1 = self.slice1(x)
    h_relu2 = self.slice2(h_relu1)        
    h_relu3 = self.slice3(h_relu2)        
    out = [h_relu1, h_relu2, h_relu3]
    return out

Note here the following:

  • We use a pretrained model provided by torchvision.
  • We split up the network into three slices. Each slice consists of layers from the pretrained model.
  • We freeze the network by setting requires_grad = False
  • We return a list with the three outputs of our slices

Custom Loss

Even if PyTorch already has a lot of of standard loss function it might be necessary sometimes to create your own loss function. For this, create a separate file and extend the nn.Module class to create your custom loss function:

class CustomLoss(nn.Module):
    def __init__(self):
    def forward(self,x,y):
        loss = torch.mean((x - y)**2)
        return loss

Recommended code structure for training your model

A full example is provided in the cifar10-example folder of this repository.

Note that we used the following patterns:

  • We use BackgroundGenerator from prefetch_generator to load next batches in background see this issue for more information
  • We use tqdm to monitor training progress and show the compute efficiency. This helps us find bottlenecks in our data loading pipeline.
# import statements
import torch
import torch.nn as nn
from torch.utils import data

# set flags / seeds
torch.backends.cudnn.benchmark = True

# Start with main code
if __name__ == '__main__':
    # argparse for additional flags for experiment
    parser = argparse.ArgumentParser(description="Train a network for ...")
    opt = parser.parse_args() 
    # add code for datasets (we always use train and validation/ test set)
    data_transforms = transforms.Compose([
        transforms.Resize((opt.img_size, opt.img_size)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    train_dataset = datasets.ImageFolder(
        root=os.path.join(opt.path_to_data, "train"),
    train_data_loader = data.DataLoader(train_dataset, ...)
    test_dataset = datasets.ImageFolder(
        root=os.path.join(opt.path_to_data, "test"),
    test_data_loader = data.DataLoader(test_dataset ...)
    # instantiate network (which has been imported from **)
    net = MyNetwork(...)
    # create losses (criterion in pytorch)
    criterion_L1 = torch.nn.L1Loss()
    # if running on GPU and we want to use cuda move model there
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        net = net.cuda()
    # create optimizers
    optim = torch.optim.Adam(net.parameters(),
    # load checkpoint if needed/ wanted
    start_n_iter = 0
    start_epoch = 0
    if opt.resume:
        ckpt = load_checkpoint(opt.path_to_checkpoint) # custom method for loading last checkpoint
        start_epoch = ckpt['epoch']
        start_n_iter = ckpt['n_iter']
        print("last checkpoint restored")
    # if we want to run experiment on multiple GPUs we move the models there
    net = torch.nn.DataParallel(net)
    # typically we use tensorboardX to keep track of experiments
    writer = SummaryWriter(...)
    # now we start the main loop
    n_iter = start_n_iter
    for epoch in range(start_epoch, opt.epochs):
        # set models to train mode
        # use prefetch_generator and tqdm for iterating through data
        pbar = tqdm(enumerate(BackgroundGenerator(train_data_loader, ...)),
        start_time = time.time()
        # for loop going through dataset
        for i, data in pbar:
            # data preparation
            img, label = data
            if use_cuda:
                img = img.cuda()
                label = label.cuda()
            # It's very good practice to keep track of preparation time and computation time using tqdm to find any issues in your dataloader
            prepare_time = start_time-time.time()
            # forward and backward pass
            # udpate tensorboardX
            writer.add_scalar(..., n_iter)
            # compute computation time and *compute_efficiency*
            process_time = start_time-time.time()-prepare_time
            pbar.set_description("Compute efficiency: {:.2f}, epoch: {}/{}:".format(
                process_time/(process_time+prepare_time), epoch, opt.epochs))
            start_time = time.time()
        # maybe do a test pass every x epochs
        if epoch % x == x-1:
            # bring models to evaluation mode
            #do some tests
            pbar = tqdm(enumerate(BackgroundGenerator(test_data_loader, ...)),
            for i, data in pbar:
            # save checkpoint if needed

Training on Multiple GPUs in PyTorch

There are two distinct patterns in PyTorch to use multiple GPUs for training. From our experience both patterns are valid. The first one results however in nicer and less code. The second one seems to have a slight performance advantage due to less communication between the GPUs. I asked a question in the official PyTorch forum about the two approaches here

Split up the batch input of each network

The most common one is to simply split up the batches of all networks to the individual GPUs.

A model running on 1 GPU with batch size 64 would, therefore, run on 2 GPUs with each a batch size of 32. This can be done automatically by wrapping the model by nn.DataParallel(model).

Pack all networks in a super network and split up input batch

This pattern is less commonly used. A repository implementing this approach is shown here in the pix2pixHD implementation by Nvidia

Do’s and Don’t’s

Avoid Numpy Code in the forward method of a nn.Module

Numpy runs on the CPU and is slower than torch code. Since torch has been developed with being similar to numpy in mind most numpy functions are supported by PyTorch already.

Separate the DataLoader from the main Code

The data loading pipeline should be independent of your main training code. PyTorch uses background workers for loading the data more efficiently and without disturbing the main training process.

Don’t log results in every step

Typically we train our models for thousands of steps. Therefore, it is enough to log loss and other results every n’th step to reduce the overhead. Especially, saving intermediary results as images can be costly during training.

Use Command-line Arguments

It’s very handy to use command-line arguments to set parameters during code execution (batch sizelearning rate, etc). An easy way to keep track of the arguments for an experiment is by just printing the dictionary received from parse_args:

# saves arguments to config.txt file
opt = parser.parse_args()
with open("config.txt", "w") as f:

Use .detach() to free tensors from the graph if possible

PyTorch keeps track of of all operations involving tensors for automatic differentiation. Use .detach() to prevent recording of unnecessary operations.

Use .item() for printing scalar tensors

You can print variables directly, however it’s recommended to use variable.detach() or variable.item(). In earlier PyTorch versions < 0.4 you have to use .data to access the tensor of a variable.

Use the call method instead of forward on a nn.Module

The two ways are not identical as pointed out in one of the issues here:

output =
# they are not equal!
output =