Training¶
Training¶
Pytorch Lightning Module to
train models based on detr
module
- class alonet.detr.train.LitDetr(args=None, model=None, **kwargs)¶
Bases:
pytorch_lightning.core.lightning.LightningModule
- Parameters
- argsNamespace, optional
Attributes stored in specific Namespace, by default None
- weightsstr, optional
Weights name to load, by default None
- gradient_clip_valfloat, optional
pytorch_lightning.trainer.trainer parameter. 0 means don’t clip, by default 0.1
- accumulate_grad_batchesint, optional
Accumulates grads every k batches or as set up in the dict, by default 4
- model_namestr, optional
Name use to define the model, by default “detr-r50”
- modeltorch.nn, optional
Custom model to train
Notes
Arguments entered by the user (kwargs) will replace those stored in args attribute
- static add_argparse_args(parent_parser, parser=None)¶
Add arguments to parent parser with default values
- Parameters
- parent_parserArgumentParser
Object to append new arguments
- parserArgumentParser.argument_group, optional
Argument group to append the parameters, by default None
- Returns
- ArgumentParser
Object with new arguments concatenated
- assert_input(frames, inference=False)¶
Check if input-frames have the correct format
- Parameters
- frames
Frames
Input frames
- inferencebool, optional
Check input from inference procedure, by default False
- frames
- build_criterion(matcher=None, loss_ce_weight=1, loss_boxes_weight=5, loss_giou_weight=2, eos_coef=0.1, losses=['labels', 'boxes'], aux_loss_stage=6)¶
Build the default criterion
- Parameters
- matchertorch.nn, optional
One specfic matcher to use in criterion process, by default the output of
build_matcher()
- loss_ce_weightfloat, optional
Weight of cross entropy loss in total loss, by default 1
- loss_boxes_weightfloat, optional
Weight of boxes loss in total loss, by default 5
- loss_giou_weightfloat, optional
Weight of GIoU loss in total loss, by default 2
- eos_coeffloat, optional
Background/End of the Sequence (EOS) coefficient, by default 0.1
- losseslist, optional
List of losses to take into account in total loss, by default [“labels”, “boxes”]. Possible values: [“labels”, “boxes”, “masks”] (use the latest in segmentation tasks)
- aux_loss_stageint, optional
Size of stages from
aux_outputs
key in forward ouputs, by default 6
- Returns
DetrCriterion
Criterion use to train the model
- build_matcher(cost_class=1, cost_boxes=5, cost_giou=2)¶
Build the default matcher
- Parameters
- cost_classfloat, optional
Weight of class cost in Hungarian Matcher, by default 1
- cost_boxesfloat, optional
Weight of boxes cost in Hungarian Matcher, by default 5
- cost_gioufloat, optional
Weight of GIoU cost in Hungarian Matcher, by default 2
- Returns
DetrHungarianMatcher
Hungarian Matcher, as a Pytorch model
- build_model(num_classes=91, aux_loss=True, weights=None)¶
Build the default model
- Parameters
- num_classesint, optional
Number of classes in embed layer, by default 91
- aux_lossbool, optional
Return auxiliar outputs in forward output, by default True
- weightsstr, optional
Path or id to load weights, by default None
- Returns
detr
Pytorch model
- Raises
- Exception
Only
detr-r50
models are supported yet.
- callbacks(data_loader)¶
Given a data loader, this method will return the default callbacks of the training loop.
- Parameters
- data_loadertorch.utils.data.DataLoader
Dataloader to get a sample to use on
object_detector_callback
- Returns
- List[Callbacks]
Callbacks use in train process
- configure_optimizers()¶
AdamW optimizer configuration, using different learning rates for backbone and others parameters
- Returns
- torch.optim
AdamW optimizer to update weights
- forward(frames)¶
Run a forward pass through the model.
- inference(m_outputs)¶
Given the model forward outputs, this method will return an
BoundingBoxes2D
tensor.- Parameters
- m_outputs: dict
Dict with the model forward outptus
- Returns
- List[
BoundingBoxes2D
] Set of boxes for each batch
- List[
- run_train(data_loader, args=None, project='detr', expe_name='detr_50', callbacks=None)¶
Train the model using pytorch lightning
- Parameters
- data_loadertorch.utils.data.DataLoader
Dataloader use in
callbacks()
function- projectstr, optional
Project name using to save checkpoints, by default “detr”
- expe_namestr, optional
Specific experiment name to save checkpoints, by default “detr_50”
- callbackslist, optional
List of callbacks to use, by default
callbacks()
output- argsNamespace, optional
Additional arguments use in training process, by default None
- training: bool¶
- training_step(frames, batch_idx)¶
Train the model for one step
- validation_step(frames, batch_idx)¶
Run one step of validation
Criterion¶
This class computes the loss for DETR
. The process happens in two steps:
We compute hungarian assignment between ground truth boxes and the outputs of the model
We supervise each pair of matched ground-truth / prediction (supervise class and box).
- class alonet.detr.criterion.DetrCriterion(matcher, loss_ce_weight, loss_boxes_weight, loss_giou_weight, eos_coef, aux_loss_stage, losses)¶
Bases:
torch.nn.modules.module.Module
Create the criterion.
- Parameters
- num_classes: int
number of object categories, omitting the special no-object category
- matcher: nn.Module
module able to compute a matching between targets and proposed boxes
- loss_ce_weight: float
Cross entropy class weight
- loss_boxes_weight: float
Boxes loss l1 weight
- loss_giou_weight: float
Boxes loss GIOU
- eos_coef: float
relative classification weight applied to the no-object category
- aux_loss_stage:
Number of auxialiry stage
- losses: list
list of all the losses to be applied. See
get_loss()
for list of available losses.
- forward(m_outputs, frames, matcher_frames=None, compute_statistical_metrics=False, **kwargs)¶
This performs the loss computation.
- Parameters
- outputsdict
Dict of tensors, see the output specification of the model for the format
- targets
Frames
Target frames
- compute_statistical_metricsbool
Whether to compute statistical data bout the model outputs/inputs, by default False
- Returns
- torch.tensor
Total loss as weighting of losses
- dict
Individual losses
- get_loss(loss, outputs, frames, indices, num_boxes, update_loss_map=None, **kwargs)¶
Compute a loss given the model outputs, the target frame, the results from the matcher and the number of total boxes accross the devices.
- Parameters
- lossstr
Name of the loss to compute
- outputsdict
Detr model forward outputs
- frames
Frames
Trgat frame with boxes2d and labels
- indices: list
List of tuple with predicted indices and target indices
- num_boxes: torch.Tensor
Number of total target boxes
- update_loss_mapdict
Append new loss function to take into account in total loss process, by default None
- Returns
- Dict
Losses of the loss procedure.
- get_metrics(outputs, frames, indices, num_boxes, **kwargs)¶
Compute some usefull metrics related to the model performance
- Parameters
- outputsdict
Detr model forward outputs
- frames
Frames
Trgat frame with boxes2d and labels
- indiceslist
List of tuple with predicted indices and target indices
- num_boxestorch.Tensor
Number of total target boxes
- Returns
- metrics: dict
objectness_recall: Percentage of detect object among the GT object (class invariant)
recall : Percentage of detect class among all the GT class
objectness_true_pos: Among the positive prediction of the model. how much are really positive ? (class invariant)
precision: Among the positive prediction of the model. how much are well classify ?
true_neg: Among the negative predictions of the model, how much are really negative ?
slot_true_neg: Among all the negative slot, how may are predicted as negative ? (class invariant)
Notes
Important
The metrics described above do not reflect directly the true performance of the model. The are only directly corredlated with the loss & the hungarian used to train the model. Therefore, the computed recall, is NOT the recall, is not the true recall but the recall based on the SLOT & the hungarian choice. That being said, it is still a usefull information to monitor the training progress.
- get_statistical_metrics(outputs, frames, indices, num_boxes, **kwargs)¶
Compute some usefull statistical metrics about the model outputs and the inputs.
- Parameters
- outputsdict
Detr model forward outputs
- frames
Frames
Target frame with boxes2d and labels
- indiceslist
List of tuple with predicted indices and target indices
- num_boxestorch.Tensor
Number of total target boxes
- Returns
- dict
Metrics
- loss_boxes(outputs, frames, indices, num_boxes, **kwargs)¶
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
- Parameters
- outputsdict
Detr model forward outputs
- frames
Frames
Target frame with boxes2d and labels
- indiceslist
List of tuple with predicted indices and target indices
- num_boxestorch.Tensor
Number of total target boxes
- loss_labels(outputs, frames, indices, num_boxes, **kwargs)¶
Compute the clasification loss
- Parameters
- outputsdict
Detr model forward outputs
- frames
Frames
Target frame with boxes2d and labels
- indiceslist
List of tuple with predicted indices and target indices
- num_boxestorch.Tensor
Number of total target boxes
- training: bool¶
Matcher¶
Modules to compute the matching cost and solve the corresponding LSAP.
- class alonet.detr.matcher.DetrHungarianMatcher(cost_class=1, cost_boxes=1, cost_giou=1)¶
Bases:
torch.nn.modules.module.Module
This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don’t include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects).
- forward(m_outputs, frames, **kwargs)¶
Performs the matching
- Parameters
- m_outputs: dict
Dict output of the alonet.detr.Detr model. This is a dict that contains at least these entries: “pred_logits”: Tensor of dim [batch_size, num_queries, num_classes] with the classification logits “pred_boxes”: Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
- frames: aloscene.Frame
Target frame with a set of boxes2d named : “gt_boxes_2d” with labels.
- Returns
- A list of size batch_size, containing tuples of (index_i, index_j) where:
index_i is the indices of the selected predictions (in order)
index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
- hungarian(batch_cost_matrix, **kwargs)¶
- hungarian_cost_class(tgt_boxes, m_outputs, **kwargs)¶
Compute the cost class for the Hungarina matcher
- Parameters
- m_outputs: dict
Dict output of the alonet.detr.Detr model. This is a dict that contains at least these entries: “pred_logits”: Tensor of dim [batch_size, num_queries, num_classes] with the classification logits “pred_boxes”: Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
- tgt_boxes: aloscene.BoundingBoxes2D
Target boxes2d across the batch
- hungarian_cost_giou_boxes(tgt_boxes, m_outputs, **kwargs)¶
Compute GIOU cost between boxes
- Parameters
- m_outputs: dict
Dict output of the alonet.detr.Detr model. This is a dict that contains at least these entries: “pred_logits”: Tensor of dim [batch_size, num_queries, num_classes] with the classification logits “pred_boxes”: Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
- tgt_boxes: aloscene.BoundingBoxes2D
Target boxes2d across the batch
- hungarian_cost_l1_boxes(tgt_boxes, m_outputs, **kwargs)¶
Compute l1 cost between boxes
- Parameters
- m_outputs: dict
Dict output of the alonet.detr.Detr model. This is a dict that contains at least these entries: “pred_logits”: Tensor of dim [batch_size, num_queries, num_classes] with the classification logits “pred_boxes”: Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
- tgt_boxes: aloscene.BoundingBoxes2D
Target boxes2d across the batch
- training: bool¶
- alonet.detr.matcher.build_matcher(args)¶
Callbacks¶
Detr Callback for object detection training that use frame
as GT.
- class alonet.detr.callbacks.DetrObjectDetectorCallback(*args, **kwargs)¶
Bases:
alonet.callbacks.object_detector_callback.ObjectDetectorCallback
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)¶
Called when the train batch ends.
- on_validation_epoch_end(trainer, pl_module)¶
Called when the val epoch ends.