Training and evaluating neural networks flexibly and transparently
Priyansi & Victor & Taras
Slides: https://vfdev-5.github.io/pybucaramanga-pytorch-ignite-slides/
Priyansi @Priyansi | CS Undergrad working on revamping PyTorch-Ignite's docs and managing the community | |
Victor @vfdev-5 | Software Engineer at Quansight working on AI-related open source projects | |
Taras @trsvchn | OpenSource Enthusiast with MS degree in Biology |
| .
|
Computer Vision example with Fashion MNIST
Problem: 1 - how to classify images ?
model(image) -> predicted label
2 - How measure model performances ?
predicted labels vs correct labels
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
# Setup training/test data
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, transform=ToTensor())
batch_size = 64
# Create data loaders
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
# Optionally, for debugging:
for X, y in test_dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
# Output:
# Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
# Shape of y: torch.Size([64]) torch.int64
import torch
from torch import nn
device = "cuda" if torch.cuda.is_available() else "cpu"
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
model.train()
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test(dataloader, model, loss_fn):
# code to compute and print average loss and accuracy
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
For NN training and evaluation:
model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
criterion = torch.nn.NLLLoss()
max_epochs = 10
validate_every = 100
checkpoint_every = 100
def validate(model, val_loader):
model = model.eval()
num_correct = 0
num_examples = 0
for batch in val_loader:
input, target = batch
output = model(input)
correct = torch.eq(torch.round(output).type(target.type()), target).view(-1)
num_correct += torch.sum(correct).item()
num_examples += correct.shape[0]
return num_correct / num_examples
def checkpoint(model, optimizer, checkpoint_dir):
# ...
def save_best_model(model, current_accuracy, best_accuracy):
# ...
iteration = 0
best_accuracy = 0.0
for epoch in range(max_epochs):
for batch in train_loader:
model = model.train()
optimizer.zero_grad()
input, target = batch
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if iteration % validate_every == 0:
binary_accuracy = validate(model, val_loader)
print("After {} iterations, binary accuracy = {:.2f}"
.format(iteration, binary_accuracy))
save_best_model(model, binary_accuracy, best_accuracy)
if iteration % checkpoint_every == 0:
checkpoint(model, optimizer, checkpoint_dir)
iteration += 1
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
|
|
|
With PyTorch-Ignite:
Let’s train a MNIST classifier with PyTorch-Ignite!
Install PyTorch and TorchVision
$ pip install torch torchvision
Install PyTorch-Ignite
via pip
π¦
$ pip install pytorch-ignite
or conda
π
$ conda install ignite -c pytorch
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import TensorboardLogger
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST(download=True, root=".", transform=data_transform, train=True)
val_dataset = MNIST(download=True, root=".", transform=data_transform, train=False)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
def forward(self, x):
return self.model(x)
device = "cuda"
model = Net().to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device)
val_metrics = {
"accuracy": Accuracy(),
"loss": Loss(criterion)
}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
trainer
engine to train the modelevaluator
engine to compute metrics on validation set + save the best models@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training_loss(engine):
print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print(f"Validation Results - Epoch[{trainer.state.epoch}] "
f"Avg accuracy: {metrics['accuracy']:.2f} "
f"Avg loss: {metrics['loss']:.2f}")
ModelCheckpoint
handler with accuracy as a score functionmodel_checkpoint = ModelCheckpoint(
"checkpoint",
n_saved=2,
filename_prefix="best",
score_function=lambda e: e.state.metrics["accuracy"],
score_name="accuracy",
)
evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})
tb_logger = TensorboardLogger(log_dir="tb-logger")
tb_logger.attach_output_handler(
trainer,
event_name=Events.ITERATION_COMPLETED(every=100),
tag="training",
output_transform=lambda loss: {"batch_loss": loss},
)
tb_logger.attach_output_handler(
evaluator,
event_name=Events.EPOCH_COMPLETED,
tag="validation",
metric_names="all",
global_step_transform=global_step_from_engine(trainer)
)
trainer.run(train_loader, max_epochs=5)
Epoch[1], Iter[100] Loss: 0.19
Epoch[1], Iter[200] Loss: 0.13
Epoch[1], Iter[300] Loss: 0.08
Epoch[1], Iter[400] Loss: 0.11
Training Results - Epoch[1] Avg accuracy: 0.97 Avg loss: 0.09
Validation Results - Epoch[1] Avg accuracy: 0.97 Avg loss: 0.08
...
Epoch[5], Iter[1900] Loss: 0.02
Epoch[5], Iter[2000] Loss: 0.11
Epoch[5], Iter[2100] Loss: 0.05
Epoch[5], Iter[2200] Loss: 0.02
Epoch[5], Iter[2300] Loss: 0.01
Training Results - Epoch[5] Avg accuracy: 0.99 Avg loss: 0.02
Validation Results - Epoch[5] Avg accuracy: 0.99 Avg loss: 0.03
https://code-generator.pytorch-ignite.ai/
What is Code-Generator?: web app to quickly produce quick-start python code for common training tasks in deep learning.
Why to use Code-Generator?: start working on a task without rewriting everything from scratch.
Any questions before we go on ?
.
| In its simpliest form:
|
No more coding for/while
loops on epochs and iterations. Users instantiate engines and run them.
from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.metrics import Accuracy
# Setup training engine:
def train_step(engine, batch):
# Users can do whatever they need on a single iteration
# Eg. forward/backward pass for any number of models, optimizers, etc.
# ...
trainer = Engine(train_step)
# Setup single model evaluation engine
evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()})
def validation():
state = evaluator.run(validation_data_loader)
# print computed metrics
print(trainer.state.epoch, state.metrics)
# Run model's validation at the end of each epoch
trainer.add_event_handler(Events.EPOCH_COMPLETED, validation)
# Start the training
trainer.run(training_data_loader, max_epochs=100)
Handlers can be any function: e.g. lambda, simple function, class method, etc.
trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# attach handler with args, kwargs
mydata = [1, 2, 3, 4]
logger = ...
def on_training_ended(data):
print(f"Training is ended. mydata={data}")
# User can use variables from another scope
logger.info("Training is ended")
trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
# call any number of functions on a single event
trainer.add_event_handler(Events.COMPLETED, lambda engine: print(engine.state.times))
@trainer.on(Events.ITERATION_COMPLETED)
def log_something(engine):
print(engine.state.output)
# run the validation every 5 epochs
@trainer.on(Events.EPOCH_COMPLETED(every=5))
def run_validation():
# run validation
@trainer.on(Events.COMPLETED | Events.EPOCH_COMPLETED(every=10))
def run_another_validation():
# ...
# change some training variable once on 20th epoch
@trainer.on(Events.EPOCH_STARTED(once=20))
def change_training_variable():
# ...
# Trigger handler with customly defined frequency
@trainer.on(Events.ITERATION_COMPLETED(event_filter=first_x_iters))
def log_gradients():
# ...
from ignite.engine import EventEnum
# Define custom events
class BackpropEvents(EventEnum):
BACKWARD_STARTED = 'backward_started'
BACKWARD_COMPLETED = 'backward_completed'
OPTIM_STEP_COMPLETED = 'optim_step_completed'
def train_step(engine, batch):
# ...
loss = criterion(y_pred, y)
engine.fire_event(BackpropEvents.BACKWARD_STARTED)
loss.backward()
engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)
optimizer.step()
engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED)
# ...
trainer = Engine(train_step)
trainer.register_events(*BackpropEvents)
@trainer.on(BackpropEvents.BACKWARD_STARTED)
def function_before_backprop(engine):
# ...
50+ distributed ready out-of-the-box metrics to easily evaluate models.
precision = Precision(average=False)
recall = Recall(average=False)
F1_per_class = (precision * recall * 2 / (precision + recall))
F1_mean = F1_per_class.mean() # torch mean method
F1_mean.attach(engine, "F1")
.
|
|
Run the same code across all supported backends seamlessly
nccl
, gloo
, mpi
gloo
or nccl
communication backendpytorch/xla
import ignite.distributed as idist
def training(local_rank, *args, **kwargs):
dataloder_train = idist.auto_dataloder(dataset, ...)
model = ...
model = idist.auto_model(model)
optimizer = ...
optimizer = idist.auto_optimizer(optimizer)
backend = 'nccl' # or 'gloo', 'horovod', 'xla-tpu' or None
with idist.Parallel(backend) as parallel:
parallel.run(training)
Handle distributed launchers with the same code
torch.multiprocessing.spawn
torch.distributed.launch
horovodrun
slurm
High-level helper methods
idist.auto_model()
idist.auto_optim()
idist.auto_dataloader()
Collective operations
all_reduce
, all_gather
, and moreAny questions before we go on ?
How to translate pure PyTorch code to PyTorch+Ignite
Any questions before we go on ?
Community-driven open source and NumFOCUS Affiliated Project
maintained by volunteers in the PyTorch community:
@vfdev-5, @ydcjeff, @KickItLikeShika, @sdesrozis, @alykhantejani, @anmolsjoshi,
@trsvchn, @Moh-Yakoub, ..., @fco-dv, @gucifer, @Priyansi, ...
With the support of:
More details here: https://pytorch-ignite.ai/ecosystem/
Google Summer of Code 2021
Google Season of Docs 2021
Hacktoberfest 2020 and 2021
PyData Global Mentored Sprint 2020 and coming up 2021 (End of October)
Our new website development
Public meetings on Discord, open to everyone
Stay tuned for upcoming events …
The repositories participating:
How it works:
We are looking for motivated contributors to help out with the project.
Everyone is welcome to contribute
Thanks for your attention !Questions? ππ©βπ»ππ¨βπ»π©βπ» | Follow us on and check out our new website: |