PyTorch-Ignite: what and why? πŸ€”

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

def train_step(engine, batch):
  #  ... any training logic ...
  return batch_loss

trainer = Engine(train_step)

# Compose your pipeline ..., max_epochs=100)

metrics = {
  "precision": Precision(),
  "recall": Recall()

evaluator = create_supervised_evaluator(

def run_evaluation():

handler = ModelCheckpoint(
  '/tmp/models', 'checkpoint'
  {'model': model}

What makes PyTorch-Ignite unique ?

  • Composable and interoperable components
  • Simple and understandable code
  • Open-source community involvement

Alice: How it differs from other similar project?

Bob: See other PyTorch Community Voices editions :)

Key concepts in a nutshell

PyTorch-Ignite is about:

  1. Engine and Event System
  2. Out-of-the-box metrics to easily evaluate models
  3. Built-in handlers to compose training pipeline
  4. Distributed Training support

Engine and Event System

  • Engine

    • Loops on user data
    • Applies an arbitrary user function on batches
  • Event system

    • Customizable event collections
    • Triggers handlers attached to events
In its simpliest form:
while epoch < max_epochs:

    for batch in data:
        output = train_step(batch)


Simplified training and validation loop

No more coding for/while loops on epochs and iterations. Users instantiate engines and run them.

from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.metrics import Accuracy

# Setup training engine:
def train_step(engine, batch):
    # Users can do whatever they need on a single iteration
    # Eg. forward/backward pass for any number of models, optimizers, etc.
    # ...

trainer = Engine(train_step)

# Setup single model evaluation engine
evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()})

def validation():
    state =
    # print computed metrics
    print(trainer.state.epoch, state.metrics)

# Run model's validation at the end of each epoch
trainer.add_event_handler(Events.EPOCH_COMPLETED, validation)

# Start the training, max_epochs=100)

Power of Events & Handlers πŸš€

1. Execute any number of functions whenever you wish

Handlers can be any function: e.g. lambda, simple function, class method, etc.

trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))

# attach handler with args, kwargs
mydata = [1, 2, 3, 4]
logger = ...

def on_training_ended(data):
    print(f"Training is ended. mydata={data}")
    # User can use variables from another scope"Training is ended")

trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
# call any number of functions on a single event
trainer.add_event_handler(Events.COMPLETED, lambda engine: print(engine.state.times))

def log_something(engine):

Power of Events & Handlers

2. Built-in events filtering and stacking

# run the validation every 5 epochs
def run_validation():
    # run validation

@trainer.on(Events.COMPLETED | Events.EPOCH_COMPLETED(every=10))
def run_another_validation():
    # ...

# change some training variable once on 20th epoch
def change_training_variable():
    # ...

# Trigger handler with customly defined frequency
def log_gradients():
    # ...

Power of Events & Handlers

3. Custom events to go beyond standard events

from ignite.engine import EventEnum

# Define custom events
class BackpropEvents(EventEnum):
    BACKWARD_STARTED = 'backward_started'
    BACKWARD_COMPLETED = 'backward_completed'
    OPTIM_STEP_COMPLETED = 'optim_step_completed'

def train_step(engine, batch):
    # ...
    loss = criterion(y_pred, y)
    # ...

trainer = Engine(train_step)

def function_before_backprop(engine):
    # ...

Out-of-the-box metrics πŸ“ˆ

50+ distributed ready out-of-the-box metrics to easily evaluate models.

  • Dedicated to many Deep Learning tasks
  • Easily composable to assemble a custom metric
  • Easily extendable to create custom metrics
precision = Precision(average=False)
recall = Recall(average=False)
F1_per_class = (precision * recall * 2 / (precision + recall))
F1_mean = F1_per_class.mean()  # torch mean method
F1_mean.attach(engine, "F1")

Built-in Handlers

  • Logging to experiment tracking systems
  • Checkpointing,
  • Early stopping
  • Profiling
  • Parameter scheduling
  • etc.
# model checkpoint handler
checkpoint = ModelCheckpoint('/tmp/ckpts', 'training')
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'model': model})

# early stopping handler
def score_function(engine):
    val_loss = engine.state.metrics['acc']
    return val_loss
es = EarlyStopping(3, score_function, trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)

# Piecewise linear parameter scheduler
scheduler = PiecewiseLinear(optimizer, 'lr', [(10, 0.5), (20, 0.45), (21, 0.3), (30, 0.1), (40, 0.1)])
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

# TensorBoard logger: batch loss, metrics
tb_logger = TensorboardLogger(log_dir="tb-logger")
    trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training",
    output_transform=lambda loss: {"batch_loss": loss},

    evaluator, event_name=Events.EPOCH_COMPLETED,
    tag="training", metric_names="all",

Distributed Training support

Run the same code across all supported backends seamlessly

  • Backends from native torch distributed configuration: nccl, gloo, mpi
  • Horovod framework with gloo or nccl communication backend
  • XLA on TPUs via pytorch/xla
import ignite.distributed as idist

def training(local_rank, *args, **kwargs):
    dataloder_train = idist.auto_dataloder(dataset, ...)

    model = ...
    model = idist.auto_model(model)

    optimizer = ...
    optimizer = idist.auto_optimizer(optimizer)

backend = 'nccl'  # or 'gloo', 'horovod', 'xla-tpu' or None
with idist.Parallel(backend) as parallel:

Distributed Training support

Distributed launchers

Handle distributed launchers with the same code

  • torch.multiprocessing.spawn
  • torch.distributed.launch
  • horovodrun
  • slurm

Distributed Training support

Unified Distributed API

  • High-level helper methods

    • idist.auto_model()
    • idist.auto_optim()
    • idist.auto_dataloader()
  • Collective operations

    • all_reduce, all_gather, and more

The Big Picture

Quick-start example πŸ‘©β€πŸ’»πŸ‘¨β€πŸ’»

Let’s train a MNIST classifier with PyTorch-Ignite!


With pip:

$ pip install pytorch-ignite

or with conda:

$ conda install ignite -c pytorch

Import, import, import…

import torch
from torch import nn
from import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import TensorboardLogger

Start with a PyTorch code

Set up the dataflow, define a model (adapted ResNet18), a loss and an optimizer.
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_dataset = MNIST(download=True, root=".", transform=data_transform, train=True)
val_dataset = MNIST(download=True, root=".", transform=data_transform, train=False)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        return self.model(x)

device = "cuda"
model = Net().to(device)

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()

Here goes PyTorch-Ignite!

trainer = create_supervised_trainer(model, optimizer, criterion, device)

val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(criterion)

evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
  • trainer engine to train the model
  • evaluator engine to compute metrics on validation set + save the best models

Add handlers for logging the progress

def log_training_loss(engine):
    print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")

def log_validation_results(trainer):
    metrics = evaluator.state.metrics
    print(f"Validation Results - Epoch[{trainer.state.epoch}] "
          f"Avg accuracy: {metrics['accuracy']:.2f} "
          f"Avg loss: {metrics['loss']:.2f}")

Add ModelCheckpoint handler with accuracy as a score function

model_checkpoint = ModelCheckpoint(
    score_function=lambda e: e.state.metrics["accuracy"],

evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})

Add Tensorboard Logger

tb_logger = TensorboardLogger(log_dir="tb-logger")

    output_transform=lambda loss: {"batch_loss": loss},


πŸš€Liftoff!πŸš€, max_epochs=5)
Epoch[1], Iter[100] Loss: 0.19
Epoch[1], Iter[200] Loss: 0.13
Epoch[1], Iter[300] Loss: 0.08
Epoch[1], Iter[400] Loss: 0.11
Training Results - Epoch[1] Avg accuracy: 0.97 Avg loss: 0.09
Validation Results - Epoch[1] Avg accuracy: 0.97 Avg loss: 0.08
Epoch[5], Iter[1900] Loss: 0.02
Epoch[5], Iter[2000] Loss: 0.11
Epoch[5], Iter[2100] Loss: 0.05
Epoch[5], Iter[2200] Loss: 0.02
Epoch[5], Iter[2300] Loss: 0.01
Training Results - Epoch[5] Avg accuracy: 0.99 Avg loss: 0.02
Validation Results - Epoch[5] Avg accuracy: 0.99 Avg loss: 0.03

Complete code

PyTorch-Ignite Code-Generator

  • What is Code-Generator?: web app to quickly produce quick-start python code for common training tasks in deep learning.

  • Why to use Code-Generator?: start working on a task without rewriting everything from scratch.

πŸ”₯ Convert PyTorch to Ignite ❀️‍πŸ”₯

How to translate pure PyTorch code to PyTorch+Ignite

About “PyTorch-Ignite” project

Community-driven open source and NumFOCUS Affiliated Project

maintained by volunteers in the PyTorch community:

@vfdev-5, @ydcjeff, @KickItLikeShika, @sdesrozis, @alykhantejani, @anmolsjoshi,
@trsvchn, @Moh-Yakoub, ..., @fco-dv, @gucifer, @Priyansi, ...

With the support of:

Projects using PyTorch-Ignite

More details here:

Community Engagement

  • Google Summer of Code 2021

    • Mentored two great students (Ahmed and Arpan)
  • Google Season of Docs 2021

    • Working with great tech writer (Priyansi)
  • Hacktoberfest 2020 and coming up 2021

  • PyData Global Mentored Sprint 2020

  • Our new website development (thanks to Jeff Yang!)

  • PyTorch-Ignite Code-Generator project

Stay tuned for upcoming events …

Join the PyTorch-Ignite Community

We are looking for motivated contributors to help out with the project.

o1 Everyone is welcome to contribute o2


How to start:

Thanks for watching

and listening !



