[ ]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import alonet
# for DETR
from alonet.detr import Detr, DetrR50, DetrR50Finetune, LitDetr
from alonet.detr.trt_exporter import DetrTRTExporter
# for Deformable DETR
from alonet.deformable_detr import (
DeformableDETR,
DeformableDetrR50Refinement,
DeformableDetrR50RefinementFinetune)
from alonet.deformable_detr.trt_exporter import DeformableDetrTRTExporter
from alonet.common import pl_helpers
from alonet.torch2trt import TRTExecutor
import aloscene
from aloscene import Frame
In this tutorial, we will convert DETR in TensorRT in order to reduce memory footprint and inference time.
The model weight can be loaded either from a .pth file or from a run_id using Aloception
API.
This notebook might crash if there is not enough GPU memory. In this case you can reduce the image size or run only cells from either Load from .pth checkpoint or Inference with TensorRT or Load weight from run_id.
Now, let’s define some constant what we will use throughout this tutorial.
[1]:
INPUT_SHAPE = [3, 1280, 1920] # [C, H, W], change input shape if needed
BATCH_SIZE = 1
PRECISION = "fp16" # or "fp32"
The input dimension is [B, C, H, W] of which [C, H, W] is defined by INPUT_SHAPE
and B is defined by BATCH_SIZE
.
PRECISION
defines the precision of model weights. It is either “fp32” or “fp16” for float32 and float16 respectively.
We will run inference in a test image for qualitative comparison between PyTorch and TensorRT
[ ]:
image_path = "PATH/TO/IMAGE"
img = Frame(image_path)
frame = img.resize(INPUT_SHAPE[1:]).norm_resnet()
frame = Frame.batch_list([frame])
img.get_view().render()
DETR¶
Load from .pth checkpoint¶
In this example, we use weight trained DETR-R50 on COCO from official repository DETR but the workflow is valid for any finetuned model with its associated .pth file.
[ ]:
# 1. Instantiate model and load trained weight
weight_path = "PATH/TO/CHECKPOINT.pth"
num_classes = background_class = 91 # COCO classes
torch_model = DetrR50(
num_classes=num_classes,
aux_loss=False, # we don't want auxilary outputs
return_dec_outputs=False, # we don't want decoder outputs
)
torch_model.eval()
alonet.common.load_weights(torch_model, weight_path, torch_model.device)
[ ]:
# 2. Instantiate corresponding exporter
model_name = "".join(os.path.basename(weight_path).split(".")[:-1])
# Because the exporter will use ONNX format as an intermediate bridge
# between PyTorch and TensorRT, we need to specify a path where the ONNX file will be save.
onnx_path = os.path.join(os.path.dirname(weight_path), model_name + ".onnx")
exporter = DetrTRTExporter(
model=torch_model,
onnx_path=onnx_path,
input_shapes=(INPUT_SHAPE,),
input_names=["img"],
batch_size=BATCH_SIZE,
precision=PRECISION,
)
[ ]:
# 3. Run the exporter
exporter.export_engine()
engine_path = exporter.engine_path
After the export, 2 files will be created in the root directory containing the checkpoint file.
The .onnx file is a ONNX graph which serves as intermediate bridge between PyTorch and TensorRT. The .engine file is the model serialized as TensorRT engine. For deployment and inference, .engine file will be deserialized and executed by TensorRT.
Inference with TensorRT¶
[ ]:
class DetrInference():
def __init__(self, background_class=91) -> None:
self.background_class = background_class
def get_outs_filter(self, *args, **kwargs):
return Detr.get_outs_filter(self, *args, **kwargs)
def __call__(self, forward_out, **kwargs):
forward_out = {key: torch.tensor(forward_out[key]) for key in forward_out}
return Detr.inference(self, forward_out, **kwargs)
In other to benefit the inference logic implemented in alonet DETR without instantiating the whole model in PyTorch, we create a helper class which calls Detr.inference
method.
[ ]:
trt_model = TRTExecutor(engine_path)
trt_model.print_bindings_info()
The input img
shape is (B, C, H, W) with C=4 because we concatenate RGB image (B, 3, H, W) and its mask of shape (B, 1, H, W) containing 1 on padded pixels.
[ ]:
m_input = np.concatenate([frame.as_tensor(), frame.mask.as_tensor()], axis=1, dtype=np.float32)
trt_m_outputs = trt_model(m_input)
trt_pred_boxes = DetrInference(background_class=background_class)(trt_m_outputs)
# visualize the result
trt_pred_boxes[0].get_view(frame=img).render()
[ ]:
# compare with the result from model in PyTorch
with torch.no_grad():
torch_m_outputs = torch_model(frame)
torch_pred_boxes = torch_model.inference(torch_m_outputs)
torch_pred_boxes[0].get_view(frame=img).render()
A quick qualitative comparison show that 2 models give nearly identical results. The minor difference is from the fact that we use the precision float16 for our TensorRT engine which is not the case for the model in PyTorch.
Load weight from run_id¶
After having trained your DETR model using aloception API, we can load the model from a run_id and export it to TensorRT using the same workflow.
[ ]:
# Define the train project and the run_id from which we want to load weight
project = "YOUR_PROJECT_NAME"
run_id = "YOUR_RUN_ID"
model_name = "MODEL_NAME"
num_classes = ... # number of classes in your finetune model
[ ]:
# 1. Instantiate the model and load weight from run_id
torch_model = DetrR50Finetune(
num_classes=num_classes,
aux_loss=False, # we don't want auxilary outputs
return_dec_outputs=False, # we don't want decoder outputs
)
lit_model = pl_helpers.load_training(
LitDetr, # The PyTorch Lightning Module that was used in training
project_run_id=project,
run_id=run_id,
model=torch_model
)
torch_model = lit_model.model.eval()
[ ]:
# 2. Instantiate the exporter
# Because the exporter will use ONNX format as an intermediate bridge
# between PyTorch and TensorRT, we need to specify a path where the ONNX file will be save.
project_dir, run_id_dir, _ = pl_helpers.get_expe_infos(project, run_id)
onnx_path = os.path.join(run_id_dir, model_name + ".onnx")
exporter = DetrTRTExporter(
model=torch_model,
onnx_path=onnx_path,
input_shapes=(INPUT_SHAPE,),
input_names=["img"],
batch_size=BATCH_SIZE,
precision=PRECISION,
)
[ ]:
# 3. Run the exporter
exporter.export_engine()
engine_path = exporter.engine_path
[ ]:
# Test inference
trt_model = TRTExecutor(engine_path)
trt_model.print_bindings_info()
[ ]:
class DetrInference():
def __init__(self, background_class=91) -> None:
self.background_class = background_class
def get_outs_filter(self, *args, **kwargs):
return Detr.get_outs_filter(self, *args, **kwargs)
def __call__(self, forward_out, **kwargs):
forward_out = {key: torch.tensor(forward_out[key]) for key in forward_out}
return Detr.inference(self, forward_out, **kwargs)
[ ]:
# Test inference
m_input = np.concatenate([frame.as_tensor(), frame.mask.as_tensor()], axis=1, dtype=np.float32)
trt_m_outputs = trt_model(m_input)
trt_pred_boxes = DetrInference(background_class=num_classes)(trt_m_outputs)
# visualize the result
trt_pred_boxes[0].get_view(frame=img).render()
[ ]:
# compare with the result from model in PyTorch
with torch.no_grad():
torch_m_outputs = torch_model(frame)
torch_pred_boxes = torch_model.inference(torch_m_outputs)
torch_pred_boxes[0].get_view(frame=img).render()
As explained above, the comparison show that 2 models give nearly identical results. The minor difference is from the fact that we use the precision float16 for our TensorRT engine.