{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training Deformable DETR\n", "\n", "This tutorial explains how to use [LitDeformableDetr] module to train [Deformable DetrR50 architecture] from scratch, using [COCO2017 detection dataset] as input.\n", "\n", "
\n", " \n", "**Goals**\n", " \n", "1. Learn about how to instantiate and run [LitDeformableDetr] train\n", "2. Load trained weights and make inference with pre-trained weights\n", "3. Compare performance between [LitDetr] and [LitDeformableDetr]\n", "\n", "
\n", "\n", "[Deformable DetrR50 architecture]: https://arxiv.org/abs/2010.04159\n", "[COCO2017 detection dataset]: https://cocodataset.org/#detection-2017\n", "[LitDeformableDetr]: ../alonet/deformable_training.rst#alonet.deformable_detr.train.LitDeformableDetr\n", "[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Deformable DETR trainer\n", "\n", "[Aloception] is developed under the [Pytorch Lightning] framework, and provides different modules that facilitate the use of datasets and training models. Like [LitDetr], [LitDeformableDetr] allows the initialization of its parameters according to five levels:\n", "\n", "[LitDeformableDetr]: ../alonet/deformable_training.rst#alonet.deformable_detr.train.LitDeformableDetr\n", "[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr\n", "[Aloception]: ../index.rst\n", "[Pytorch Lightning]: https://www.pytorchlightning.ai/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from alonet.deformable_detr import LitDeformableDetr, DeformableDetrR50Finetune\n", "from argparse import ArgumentParser, Namespace\n", "\n", "def params2Namespace(litdetr, level):\n", " print(f\"[INFO] LEVEL {level}:\", Namespace(\n", " accumulate_grad_batches=litdetr.accumulate_grad_batches, \n", " gradient_clip_val=litdetr.gradient_clip_val, \n", " model_name=litdetr.model_name, \n", " weights=litdetr.weights\n", " ))\n", "\n", "# Level 1\n", "# Create LightningModule with default parameters\n", "lit_deformable = LitDeformableDetr()\n", "params2Namespace(lit_deformable,1)\n", "\n", "# Level 2\n", "# Define LightningModule changing some parameters\n", "lit_deformable = LitDeformableDetr(accumulate_grad_batches=5)\n", "params2Namespace(lit_deformable,2)\n", "\n", "# Level 3\n", "# Use namespace to define attribute values\n", "parser = ArgumentParser()\n", "args = LitDeformableDetr.add_argparse_args(parser).parse_args([])\n", "lit_deformable = LitDeformableDetr(args) \n", "params2Namespace(lit_deformable,3)\n", "\n", "# Level 4\n", "# Combine previous approaches. \n", "my_model = DeformableDetrR50Finetune(num_classes = 2, weights = \"deformable-detr-r50\")\n", "lit_deformable = LitDeformableDetr(args, model = my_model, model_name = \"finetune\")\n", "params2Namespace(lit_deformable,4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", " \n", "**Hint**\n", "\n", "For a more detailed explanation, see the tutorial on [how to train a Detr model].\n", "\n", "
\n", "\n", "Now, a common example of the training pipeline in [Aloception] is described below:\n", "\n", "[how to train a Detr model]: training_detr.ipynb\n", "[Aloception]: ../index.rst" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from argparse import ArgumentParser\n", "\n", "import alonet\n", "from alonet.detr import CocoDetection2Detr\n", "from alonet.deformable_detr import LitDeformableDetr\n", "\n", "import torch\n", "\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "# Parameters definition\n", "# Build parser (concatenates arguments to modify the entire project)\n", "parser = ArgumentParser(conflict_handler=\"resolve\")\n", "parser = CocoDetection2Detr.add_argparse_args(parser)\n", "parser = LitDeformableDetr.add_argparse_args(parser)\n", "parser = alonet.common.add_argparse_args(parser) # Add common arguments in train process\n", "args = parser.parse_args([])\n", "\n", "# Dataset use to train\n", "coco_loader = CocoDetection2Detr(args)\n", "lit_deformable = LitDeformableDetr(args)\n", "\n", "# Train process\n", "# args.save = True # Uncomment this line to store trained weights\n", "lit_deformable.run_train(\n", " data_loader=coco_loader, \n", " args=args, \n", " project=\"deformable\", \n", " expe_name=\"coco_detr\", \n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "**Note**\n", "\n", "[Deformable DETR R50 with refinement] architecture is used by default in the definition of the [LitDeformableDetr] class. However, we can change the model to [Deformable DETR R50] architecture setting `model_name = \"deformable-detr-r50\"` in the class instance.\n", "\n", "
\n", "\n", "
\n", "\n", "**Hint**\n", "\n", "Learn the difference betweent [Deformable DETR R50] and [Deformable DETR R50 with refinement] (**Iterative Bounding Box Refinement**) in [Deformable DETR: Deformable Transformers for End-to-End Object Detection] article.\n", "\n", "
\n", "\n", "
\n", " \n", "**Attention**\n", "\n", "This code has a high computational cost and demands several hours of training, given its initialization from scratch. It is recommended to skip to the next section to see the results of the trained network.\n", "
\n", "\n", "[Deformable DETR: Deformable Transformers for End-to-End Object Detection]: https://arxiv.org/abs/2010.04159\n", "[Deformable DETR R50 with refinement]: ../alonet/deformable_models.rst#module-alonet.deformable_detr.deformable_detr_r50_refinement\n", "[Deformable DETR R50]: ../alonet/deformable_models.rst#module-alonet.deformable_detr.deformable_detr_r50\n", "[LitDeformableDetr]: ../alonet/deformable_training.rst#alonet.deformable_detr.train.LitDeformableDetr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Make inferences\n", "\n", "Once the training is finished, we can load the trained weights knowing the project and run id (`~/.aloception/project_run_id/run_id` path). For this, a function of the common module of aloception could be used:\n", "\n", "```python\n", "from argparse import Namespace\n", "from alonet.common import load_training\n", "\n", "args = Namespace(project_run_id = \"project_run_id\", run_id = \"run_id\")\n", "lit_detr = load_training(LitDeformableDetr, args = args)\n", "```\n", "\n", "Moreover, [LitDeformableDetr] allows download and load pre-trained weights for use. This is achieved by using the `weights` attribute:\n", "\n", "[LitDeformableDetr]: ../alonet/deformable_training.rst#alonet.deformable_detr.train.LitDeformableDetr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lit_detr = LitDeformableDetr(\n", " weights = \"deformable-detr-r50\", \n", " model_name = \"deformable-detr-r50\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "**Note**\n", "\n", "Setting `weights = \"deformable-detr-r50\"` and removing `model_name`, [Deformable DETR R50 with refinement] will be instantiated and loaded with the pre-trained weights.\n", "\n", "
\n", "\n", "To conclude, we could make some detections with the following code:\n", "\n", "[Deformable DETR R50 with refinement]: ../alonet/deformable_models.rst#module-alonet.deformable_detr.deformable_detr_r50_refinement" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from alonet.detr import CocoDetection2Detr\n", "from alonet.deformable_detr import LitDeformableDetr\n", "\n", "import torch\n", "\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "# Dataset use to train\n", "coco_loader = CocoDetection2Detr()\n", "lit_deformable = LitDeformableDetr(weights = \"deformable-detr-r50-refinement\")\n", "lit_deformable.model = lit_deformable.model.eval().to(device)\n", "\n", "# Check a random result\n", "frame = next(iter(coco_loader.val_dataloader()))\n", "frame = frame[0].batch_list(frame).to(device)\n", "pred_boxes = lit_deformable.inference(lit_deformable(frame))[0] # Inference from forward result\n", "gt_boxes = frame[0].boxes2d\n", "\n", "frame.get_view(\n", " [\n", " gt_boxes.get_view(frame[0], title=\"Ground truth boxes\"),\n", " pred_boxes.get_view(frame[0], title=\"Predicted boxes\"),\n", " ], size = (1920,1080)\n", ").render()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Model performance comparison\n", "\n", "In order to compare the performance of the object detection models presented in the training tutorials, we can use the [AP Metrics] module that allows to calculate different metrics based on Average Precision (AP).\n", "\n", "[AP Metrics]: ../alonet/alonet.metrics.rst" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from alonet.metrics import ApMetrics\n", "from alonet.detr import CocoDetection2Detr, LitDetr\n", "from alonet.deformable_detr import LitDeformableDetr\n", "import torch\n", "\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "# Models to compare in COCO dataset\n", "lit_models = {\n", " \"lit_detr\": {\n", " \"model\": LitDetr(weights = \"detr-r50\"),\n", " \"metrics\": ApMetrics()\n", " }, \n", " \"lit_deformable\": {\n", " \"model\": LitDeformableDetr(weights = \"deformable-detr-r50-refinement\"),\n", " \"metrics\": ApMetrics()\n", " } \n", "}\n", "coco_loader = CocoDetection2Detr(batch_size = 1)\n", "\n", "for lit_model in lit_models.values():\n", " lit_model[\"model\"].model.to(device)\n", " lit_model[\"model\"].model.eval()\n", "\n", "for it, data in enumerate(coco_loader.val_dataloader()):\n", " frame = data[0].batch_list(data)\n", " frame = frame.to(device)\n", "\n", " gt_boxes = frame.boxes2d[0]\n", " for lit_model in lit_models.values():\n", " model = lit_model[\"model\"]\n", " pred_boxes = model.inference(model(frame))[0]\n", " \n", " lit_model[\"metrics\"].add_sample(pred_boxes, gt_boxes)\n", "\n", " print(f\"it:{it}\", end=\"\\r\")\n", "\n", "for name, lit_model in lit_models.items():\n", " print(f\"Results for {name} model:\")\n", " lit_model[\"metrics\"].calc_map(print_result = True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", " \n", "**What is next?**\n", "\n", "Learn how to train a custom architecture in **[Finetuning Deformable DETR]** tutorial.\n", "
\n", "\n", "[Finetuning Deformable DETR]: finetuning_deformable_detr.rst" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 4 }