Training Deformable DETR

This tutorial explains how to use LitDeformableDetr module to train Deformable DetrR50 architecture from scratch, using COCO2017 detection dataset as input.

Goals

  1. Learn about how to instantiate and run LitDeformableDetr train

  2. Load trained weights and make inference with pre-trained weights

  3. Compare performance between LitDetr and LitDeformableDetr

1. Deformable DETR trainer

Aloception is developed under the Pytorch Lightning framework, and provides different modules that facilitate the use of datasets and training models. Like LitDetr, LitDeformableDetr allows the initialization of its parameters according to five levels:

[ ]:
from alonet.deformable_detr import LitDeformableDetr, DeformableDetrR50Finetune
from argparse import ArgumentParser, Namespace

def params2Namespace(litdetr, level):
    print(f"[INFO] LEVEL {level}:", Namespace(
        accumulate_grad_batches=litdetr.accumulate_grad_batches,
        gradient_clip_val=litdetr.gradient_clip_val,
        model_name=litdetr.model_name,
        weights=litdetr.weights
    ))

# Level 1
# Create LightningModule with default parameters
lit_deformable = LitDeformableDetr()
params2Namespace(lit_deformable,1)

# Level 2
# Define LightningModule changing some parameters
lit_deformable = LitDeformableDetr(accumulate_grad_batches=5)
params2Namespace(lit_deformable,2)

# Level 3
# Use namespace to define attribute values
parser = ArgumentParser()
args = LitDeformableDetr.add_argparse_args(parser).parse_args([])
lit_deformable = LitDeformableDetr(args)
params2Namespace(lit_deformable,3)

# Level 4
# Combine previous approaches.
my_model = DeformableDetrR50Finetune(num_classes = 2, weights = "deformable-detr-r50")
lit_deformable = LitDeformableDetr(args, model = my_model, model_name = "finetune")
params2Namespace(lit_deformable,4)

Hint

For a more detailed explanation, see the tutorial on how to train a Detr model.

Now, a common example of the training pipeline in Aloception is described below:

[ ]:
from argparse import ArgumentParser

import alonet
from alonet.detr import CocoDetection2Detr
from alonet.deformable_detr import LitDeformableDetr

import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Parameters definition
# Build parser (concatenates arguments to modify the entire project)
parser = ArgumentParser(conflict_handler="resolve")
parser = CocoDetection2Detr.add_argparse_args(parser)
parser = LitDeformableDetr.add_argparse_args(parser)
parser = alonet.common.add_argparse_args(parser)  # Add common arguments in train process
args = parser.parse_args([])

# Dataset use to train
coco_loader = CocoDetection2Detr(args)
lit_deformable = LitDeformableDetr(args)

# Train process
# args.save = True # Uncomment this line to store trained weights
lit_deformable.run_train(
    data_loader=coco_loader,
    args=args,
    project="deformable",
    expe_name="coco_detr",
)

Note

Deformable DETR R50 with refinement architecture is used by default in the definition of the LitDeformableDetr class. However, we can change the model to Deformable DETR R50 architecture setting model_name = "deformable-detr-r50" in the class instance.

Hint

Learn the difference betweent Deformable DETR R50 and Deformable DETR R50 with refinement (Iterative Bounding Box Refinement) in Deformable DETR: Deformable Transformers for End-to-End Object Detection article.

Attention

This code has a high computational cost and demands several hours of training, given its initialization from scratch. It is recommended to skip to the next section to see the results of the trained network.

2. Make inferences

Once the training is finished, we can load the trained weights knowing the project and run id (~/.aloception/project_run_id/run_id path). For this, a function of the common module of aloception could be used:

from argparse import Namespace
from alonet.common import load_training

args = Namespace(project_run_id = "project_run_id", run_id = "run_id")
lit_detr = load_training(LitDeformableDetr, args = args)

Moreover, LitDeformableDetr allows download and load pre-trained weights for use. This is achieved by using the weights attribute:

[ ]:
lit_detr = LitDeformableDetr(
    weights = "deformable-detr-r50",
    model_name = "deformable-detr-r50"
)

Note

Setting weights = "deformable-detr-r50" and removing model_name, Deformable DETR R50 with refinement will be instantiated and loaded with the pre-trained weights.

To conclude, we could make some detections with the following code:

[ ]:
%matplotlib inline
from alonet.detr import CocoDetection2Detr
from alonet.deformable_detr import LitDeformableDetr

import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Dataset use to train
coco_loader = CocoDetection2Detr()
lit_deformable = LitDeformableDetr(weights = "deformable-detr-r50-refinement")
lit_deformable.model = lit_deformable.model.eval().to(device)

# Check a random result
frame = next(iter(coco_loader.val_dataloader()))
frame = frame[0].batch_list(frame).to(device)
pred_boxes = lit_deformable.inference(lit_deformable(frame))[0]  # Inference from forward result
gt_boxes = frame[0].boxes2d

frame.get_view(
    [
        gt_boxes.get_view(frame[0], title="Ground truth boxes"),
        pred_boxes.get_view(frame[0], title="Predicted boxes"),
    ], size = (1920,1080)
).render()

3. Model performance comparison

In order to compare the performance of the object detection models presented in the training tutorials, we can use the AP Metrics module that allows to calculate different metrics based on Average Precision (AP).

[ ]:
from alonet.metrics import ApMetrics
from alonet.detr import CocoDetection2Detr, LitDetr
from alonet.deformable_detr import LitDeformableDetr
import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Models to compare in COCO dataset
lit_models = {
    "lit_detr": {
        "model": LitDetr(weights = "detr-r50"),
        "metrics": ApMetrics()
    },
    "lit_deformable": {
        "model": LitDeformableDetr(weights = "deformable-detr-r50-refinement"),
        "metrics": ApMetrics()
    }
}
coco_loader = CocoDetection2Detr(batch_size = 1)

for lit_model in lit_models.values():
    lit_model["model"].model.to(device)
    lit_model["model"].model.eval()

for it, data in enumerate(coco_loader.val_dataloader()):
    frame = data[0].batch_list(data)
    frame = frame.to(device)

    gt_boxes = frame.boxes2d[0]
    for lit_model in lit_models.values():
        model = lit_model["model"]
        pred_boxes = model.inference(model(frame))[0]

        lit_model["metrics"].add_sample(pred_boxes, gt_boxes)

    print(f"it:{it}", end="\r")

for name, lit_model in lit_models.items():
    print(f"Results for {name} model:")
    lit_model["metrics"].calc_map(print_result = True)

What is next?

Learn how to train a custom architecture in Finetuning Deformable DETR tutorial.