Training Panoptic Head module¶
This tutorial explains how to use LitPanopticDetr module to train the PanopticHead architecture from scratch, using COCO2017 panoptic annotations and COCO2017 detection dataset as inputs. With that, the new architecture is able to detect boxes and masks for object tasks.
Goals
Declaration of LitPanopticDetr and CocoPanoptic2Detr modules
Run training
Load trained weights and make inference with pre-trained weights
Warning
The following guide needs to download COCO2017 panoptic annotations and COCO2017 detection dataset previously. The dataset module assumes that the information is stored in the following way:
| coco
| ├── train2017
| \| ├── img_train_0.jpg
| \| ├── img_train_1.jpg
| \| ├── …
| \| └── img_train_L.jpg
| ├── valid2017
| \| ├── img_val_0.jpg
| \| ├── img_val_1.jpg
| \| ├── …
| \| └── img_val_M.jpg
| └── annotations
| ├── panoptic_train2017.json
| ├── panoptic_val2017.json
| ├── panoptic_train2017
| \| ├── img_ann_train_0.jpg
| \| ├── img_ann_train_1.jpg
| \| ├── …
| \| └── img_ann_train_L.jpg
| └── panoptic_val2017
| ├── img_ann_val_0.jpg
| ├── img_ann_val_1.jpg
| ├── …
| └── img_ann_val_M.jpg
See https://cocodataset.org/#panoptic-2018 for more information about panoptic tasks.
1. LitPanopticDetr and CocoPanoptic2Detr¶
Aloception is developed under the Pytorch Lightning framework, and provides different modules that facilitate the use of datasets and training models.
LitPanopticDetr is a module based on LitDetr. For this reason, the ways to instantiate the module are the same.
See also
Previous knowledged about how to train a Detr model
On the other hand, CocoPanoptic2Detr follows the same logic than CocoDetection2Detr. Therefore, the declaration of the modules could be:
[ ]:
from alonet.detr_panoptic import LitPanopticDetr
from alonet.detr import CocoPanoptic2Detr
lit_panoptic = LitPanopticDetr()
coco_loader = CocoPanoptic2Detr()
Important
By default, LitPanopticDetr load the DETR50 pretrained weights
LitPanopticDetr does not have
num_classes
attribute, because PanopticHead is a module that match with the output of a model based on DETR, using the number of classes defined by it. Then, there are two ways to change the number of classes.
Use a finetune mode on LitPanopticDetr declaration:
[ ]:
from alonet.detr import DetrR50Finetune
from alonet.detr_panoptic import PanopticHead
# Define Detr finetune model
my_detr_model = DetrR50Finetune(num_classes = 2)
# Uses it to create a new panoptic head model
my_model = PanopticHead(DETR_module = my_detr_model)
# Make the pytorch lightning module
lit_panoptic = LitPanopticDetr(model_name="finetune", model=my_model)
Implement directly the DetrR50PanopticFinetune model:
[ ]:
from alonet.detr_panoptic import DetrR50PanopticFinetune
# Define Detr+PanopticHead with custom number of classes
my_model = DetrR50PanopticFinetune(num_classes = 2)
# Make the pytorch lightning module
lit_panoptic = LitPanopticDetr(model_name="finetune", model=my_model)
See also
detr finetune and deformable detr finetune tutorials.
2. Train process¶
See also
The training process is based on the Pytorch Lightning Trainer Module. For more information, please consult their online documentation.
In order to make an example, let’s take the COCO detection 2017 dataset as a training base. The common pipeline is described below:
[ ]:
from argparse import ArgumentParser
import alonet
from alonet.detr_panoptic import LitPanopticDetr
from alonet.detr import CocoPanoptic2Detr
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Parameters definition
# Build parser (concatenates arguments to modify the entire project)
parser = ArgumentParser(conflict_handler="resolve")
parser = CocoPanoptic2Detr.add_argparse_args(parser)
parser = LitPanopticDetr.add_argparse_args(parser)
parser = alonet.common.add_argparse_args(parser) # Add common arguments in train process
args = parser.parse_args([])
# Dataset use to train
args.batch_size = 1 # The training has a high computational memory cost. Recommended use this
coco_loader = CocoPanoptic2Detr(args)
lit_panoptic = LitPanopticDetr(args)
# Train process
# args.save = True # Uncomment this line to store trained weights
lit_panoptic.run_train(
data_loader=coco_loader,
args=args,
project="panoptic",
expe_name="coco",
)
Attention
This code has a high computational cost and demands several hours of training, given its initialization from scratch. It is recommended to skip to the next section to see the results of the trained network.
3. Make inferences¶
Once the training is finished, we can load the trained weights knowing the project and run id (~/.aloception/project_run_id/run_id
path). For this, a function of the common module of aloception could be used:
from argparse import Namespace
from alonet.common import load_training
args = Namespace(project_run_id = "project_run_id", run_id = "run_id")
lit_panoptic = load_training(LitPanopticDetr, args = args)
Moreover, LitPanopticDetr allows download and load pre-trained weights for use. This is achieved by using the weights
attribute:
[ ]:
lit_detr = LitPanopticDetr(weights = "detr-r50-panoptic")
Finally, we have a pre-trained model ready to make some detections.
[ ]:
%matplotlib inline
import matplotlib.pylab as plt
plt.rcParams['figure.dpi'] = 120
from alonet.detr_panoptic import LitPanopticDetr
from alonet.detr import CocoPanoptic2Detr
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Dataset use to train
coco_loader = CocoPanoptic2Detr(batch_size = 1)
lit_panoptic = LitPanopticDetr(weights = "detr-r50-panoptic")
lit_panoptic.model = lit_panoptic.model.eval().to(device)
# Check a random result
frame = next(iter(coco_loader.val_dataloader()))
frame = frame[0].batch_list(frame).to(device)
pred_boxes, pred_masks = lit_panoptic.inference(lit_panoptic(frame))
pred_boxes, pred_masks = pred_boxes[0], pred_masks[0]
gt_boxes = frame[0].boxes2d
gt_masks = frame[0].segmentation
frame.get_view(
[
gt_boxes.get_view(frame[0], title="Ground truth boxes"),
pred_boxes.get_view(frame[0], title="Predicted boxes"),
gt_masks.get_view(frame[0], title="Ground truth masks"),
pred_masks.get_view(frame[0], title="Predicted masks"),
]
).render()
What is next?
Know about a complex model based on deformable attention module in Training Deformable tutorial.