Finetuning DETR¶
This tutorial explains how to use the Detr R50 Finetune module to train a custom model based on DetrR50 architecture for object detection application.
Goals
Train a model based on DetrR50 architecture to predict pets in COCO detection 2017 dataset
Use the trained model to make inferences.
1. Train DETR50 Finetune¶
Detr R50 Finetune module is an extension (child class) of Detr R50, which enables to change the fixed number of 91 classes of the last embedded layer to a desired value, in order to use the robust model for a specific application (finetuning).
See also
See Funetunig torch vision models to learn more about finetuning.
Check Models to know all possible configurations of the model.
Its statement is the same as Detr R50 Finetune, with difference that now num_classes
attribute is mandatory:
[ ]:
from alonet.detr import DetrR50Finetune
detr_finetune = DetrR50Finetune(num_classes = 2)
Given that Detr R50 Finetune is a module based on Detr R50, we can use it in conjunction with the LitDetr module, with training purposes:
[ ]:
from alonet.detr import LitDetr
lit_detr = LitDetr(model = detr_finetune)
Finally, we need to choose the dataset which the model will be trained. The full-code is shown below for train all animals in COCO detection 2017 dataset:
[ ]:
from argparse import ArgumentParser
import alonet
from alonet.detr import CocoDetection2Detr, LitDetr, DetrR50Finetune
# Build parser
parser = ArgumentParser()
parser = alonet.common.add_argparse_args(parser) # Common alonet parser
args = parser.parse_args([])
args.no_suffix = True # Fix run_id = expe_name
args.limit_train_batches = 1000
args.limit_val_batches = 200
# Define COCO dataset as pl.LightningDataModule for only animals
pets = ['cat', 'dog']
coco_loader = CocoDetection2Detr(classes = pets)
# Define architecture as pl.LightningModule, using PRETRAINED WEIGHTS
lit_detr = LitDetr(model = DetrR50Finetune(len(pets), weights = 'detr-r50'))
# Start train loop
args.max_epochs = 5 # Due to finetune, we just need 5 epochs to train this model
args.save = True
lit_detr.run_train(
data_loader = coco_loader,
args = args,
project = "detr",
expe_name = "pets",
)
Once the process has been completed, the $HOME/.aloception/project_run_id/run_id folder folder will be created with the different checkpoint files.
Warning
A common mistake in the use of pre-trained weights is to try to load the weights on LitDetr and not on the model (Detr R50 Finetune). By default, LitDetr will try to load the weights from the original DetrR50 architecture, which will produce an error in all finetune models.
Important
The advantage of using finetune is the fast convergence. This is due to the use of pre-trained weights, with respect to their random initialization.
Hint
Check the following links to get more about:
2. Make inferences¶
In order to make some inferences on the dataset using the trained model, we need to load the weights. For that, we can use one function in Alonet for this purpose. Also, we need to keep in mind the project and run id that we used in training process:
[ ]:
%matplotlib inline
import torch
from argparse import Namespace
from alonet.common import load_training
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Define the architecture
detr_finetune = DetrR50Finetune(len(pets))
# Load weights according project_run_id and run_id
args = Namespace(
project_run_id = "detr",
run_id = "pets"
)
lit_detr = load_training(
LitDetr,
args = args,
model = detr_finetune,
)
lit_detr.model.to(device)
This enables to use the valid dataset and show some results:
[ ]:
frames = next(iter(coco_loader.val_dataloader()))
frames = frames[0].batch_list(frames).to(device)
pred_boxes = lit_detr.inference(lit_detr(frames))[0] # Inference from forward result
gt_boxes = frames[0].boxes2d # Get ground truth boxes
print(pred_boxes)
frames.get_view([
gt_boxes.get_view(frames[0], title="Ground truth boxes"),
pred_boxes.get_view(frames[0], title="Predicted boxes"),
]).render()
See also
See Aloscene to find out how to render images in Aloception
What is next ?
Learn how to train a complex model based on deformable attention module in Training Deformable tutorial.