{ "cells": [ { "cell_type": "markdown", "id": "2d76c0d8", "metadata": {}, "source": [ "\n", "# Finetuning DETR\n", "\n", "This tutorial explains how to use the [Detr R50 Finetune] module to train a custom model based on [DetrR50 architecture] for object detection application.\n", "\n", "
\n", " \n", "**Goals**\n", "\n", "1. Train a model based on [DetrR50 architecture] to predict pets in [COCO detection 2017 dataset]\n", "2. Use the trained model to make inferences.\n", "
\n", "\n", "[DetrR50 architecture]: https://arxiv.org/abs/2005.12872\n", "[COCO detection 2017 dataset]: https://cocodataset.org/#detection-2017\n", "[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune" ] }, { "cell_type": "markdown", "id": "1f184a76", "metadata": {}, "source": [ "## 1. Train DETR50 Finetune\n", "\n", "[Detr R50 Finetune] module is an extension (child class) of [Detr R50], which enables to change the fixed number of 91 classes of the last embedded layer to a desired value, in order to use the robust model for a specific application (finetuning).\n", "\n", "
\n", "\n", "**See also**\n", " \n", "* See [Funetunig torch vision models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) to learn more about finetuning. \n", "* Check [Models] to know all possible configurations of the model.\n", "\n", "
\n", "\n", "[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune\n", "[Detr R50]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50\n", "[Models]: ../alonet/detr_models.rst" ] }, { "cell_type": "markdown", "id": "fcd240d4", "metadata": {}, "source": [ "Its statement is the same as [Detr R50 Finetune], with difference that now `num_classes` **attribute is mandatory**:\n", "\n", "[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune" ] }, { "cell_type": "code", "execution_count": null, "id": "ce3b6cca", "metadata": {}, "outputs": [], "source": [ "from alonet.detr import DetrR50Finetune\n", "\n", "detr_finetune = DetrR50Finetune(num_classes = 2)" ] }, { "cell_type": "markdown", "id": "07f5927f", "metadata": {}, "source": [ "Given that [Detr R50 Finetune] is a module based on [Detr R50], we can use it in conjunction with the [LitDetr] module, with *training* purposes:\n", "\n", "[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune\n", "[Detr R50]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50\n", "[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr" ] }, { "cell_type": "code", "execution_count": null, "id": "0926106d", "metadata": {}, "outputs": [], "source": [ "from alonet.detr import LitDetr\n", "\n", "lit_detr = LitDetr(model = detr_finetune)" ] }, { "cell_type": "markdown", "id": "ccbe1dcd", "metadata": {}, "source": [ "Finally, we need to choose the dataset which the model will be trained.\n", "The full-code is shown below for train all animals in [COCO detection 2017 dataset]:\n", "\n", "[COCO detection 2017 dataset]: https://cocodataset.org/#detection-2017" ] }, { "cell_type": "code", "execution_count": null, "id": "3bdbc0a3", "metadata": {}, "outputs": [], "source": [ "from argparse import ArgumentParser\n", "import alonet\n", "from alonet.detr import CocoDetection2Detr, LitDetr, DetrR50Finetune\n", "\n", "# Build parser\n", "parser = ArgumentParser()\n", "parser = alonet.common.add_argparse_args(parser) # Common alonet parser\n", "args = parser.parse_args([])\n", "args.no_suffix = True # Fix run_id = expe_name\n", "args.limit_train_batches = 1000\n", "args.limit_val_batches = 200\n", "\n", "# Define COCO dataset as pl.LightningDataModule for only animals\n", "pets = ['cat', 'dog']\n", "coco_loader = CocoDetection2Detr(classes = pets)\n", "\n", "# Define architecture as pl.LightningModule, using PRETRAINED WEIGHTS\n", "lit_detr = LitDetr(model = DetrR50Finetune(len(pets), weights = 'detr-r50'))\n", "\n", "# Start train loop\n", "args.max_epochs = 5 # Due to finetune, we just need 5 epochs to train this model\n", "args.save = True\n", "lit_detr.run_train(\n", " data_loader = coco_loader,\n", " args = args,\n", " project = \"detr\",\n", " expe_name = \"pets\",\n", ")" ] }, { "cell_type": "markdown", "id": "2cc2fbcd", "metadata": {}, "source": [ "Once the process has been completed, the \\$HOME/.aloception/project_run_id/run_id folder folder will be created with the different checkpoint files.\n", "\n", "
\n", "\n", "**Warning**\n", "\n", "A common mistake in the use of pre-trained weights is to try to load the weights on [LitDetr] and not on the model ([Detr R50 Finetune]). By default, [LitDetr] will try to load the weights from the original [DetrR50 architecture], which will produce an error in all finetune models.\n", "\n", "
\n", "\n", "
\n", "\n", "**Important**\n", "\n", "The advantage of using finetune is the fast convergence. This is due to the use of pre-trained weights, with respect to their random initialization.\n", "\n", "
\n", "\n", "
\n", "\n", "**Hint**\n", "\n", "Check the following links to get more about:\n", "\n", "- [Pytorch lightning data modules](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html)\n", "- [Pytorch lightning modules](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html)\n", "- [How to setup your data]\n", "- [Train a Detr model].\n", "\n", "
\n", "\n", "[DetrR50 architecture]: https://arxiv.org/abs/2005.12872\n", "[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune\n", "[LitDetr]: ../alonet/detr_training.rst#alonet.detr.train.LitDetr\n", "[How to setup your data]: data_setup.rst\n", "[Train a Detr model]: training_detr.ipynb" ] }, { "cell_type": "markdown", "id": "7c2a1bd5", "metadata": {}, "source": [ "## 2. Make inferences\n", "\n", "In order to make some inferences on the dataset using the trained model, we need to load the weights. For that, we can use one function in [Alonet] for this purpose. Also, we need to keep in mind **the project and run id that we used in training process**:\n", "\n", "[Alonet]: ../alonet/alonet.rst" ] }, { "cell_type": "code", "execution_count": null, "id": "575ed292", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "from argparse import Namespace\n", "from alonet.common import load_training\n", "\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "# Define the architecture\n", "detr_finetune = DetrR50Finetune(len(pets))\n", "\n", "# Load weights according project_run_id and run_id\n", "args = Namespace(\n", " project_run_id = \"detr\",\n", " run_id = \"pets\"\n", ")\n", "lit_detr = load_training(\n", " LitDetr, \n", " args = args, \n", " model = detr_finetune,\n", ")\n", "lit_detr.model.to(device)" ] }, { "cell_type": "markdown", "id": "b3fb76ac", "metadata": {}, "source": [ "This enables to use the valid dataset and show some results:" ] }, { "cell_type": "code", "execution_count": null, "id": "79f1dacf", "metadata": {}, "outputs": [], "source": [ "frames = next(iter(coco_loader.val_dataloader()))\n", "frames = frames[0].batch_list(frames).to(device)\n", "pred_boxes = lit_detr.inference(lit_detr(frames))[0] # Inference from forward result\n", "gt_boxes = frames[0].boxes2d # Get ground truth boxes\n", "\n", "print(pred_boxes)\n", "\n", "frames.get_view([\n", " gt_boxes.get_view(frames[0], title=\"Ground truth boxes\"),\n", " pred_boxes.get_view(frames[0], title=\"Predicted boxes\"),\n", "]).render()" ] }, { "cell_type": "markdown", "id": "c557c0c4", "metadata": {}, "source": [ "
\n", " \n", "See also \n", "\n", "See [Aloscene] to find out how to render images in [Aloception]\n", "\n", "
\n", "\n", "
\n", "\n", "**What is next ?**\n", "\n", "Learn how to train a complex model based on deformable attention module in **[Training Deformable]** tutorial.\n", "\n", "
\n", "\n", "[Aloscene]: ../aloscene/aloscene.rst\n", "[Aloception]: ../index.rst\n", "[Training Deformable]: training_deformable_detr.ipynb" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }