Callbacks¶
List of callbacks used in different modules, in order to present different metrics while them are training. See Callbacks to get more information.
Base Metrics Callback¶
Class to implement a callback based for a specific metric
See Also¶
All the possible Metrics
- class alonet.callbacks.base_metrics_callback.InstancesBaseMetricsCallback(base_metric, *args, **kwargs)¶
Bases:
pytorch_lightning.callbacks.base.Callback
- Parameters
- base_metricmetrics
A metric object of Metrics
- add_sample(base_metric, pred_boxes, gt_boxes, pred_masks=None, gt_masks=None)¶
Add a sample to some Metrics. One might want to inhert this method to edit the
pred_boxes
andgt_boxes
boxes before to add them.- Parameters
- base_metricMetrics
Metric intance.
- pred_boxes
BoundingBoxes2D
Predicted boxes2D.
- gt_boxes
BoundingBoxes2D
GT boxes2d.
- pred_masks
Mask
Predicted Masks for segmentation task
- gt_masks
Mask
GT masks in segmentation task.
- inference(pl_module, m_outputs, **kwargs)¶
This method will call the
inference()
method of the module’s model and will expect to receive the predicted boxes2D and/or Masks.- Parameters
- pl_modulepl.LightningModule
Pytorch lighting module with inference function
- m_outputsdict
Forward outputs
- Returns
BoundingBoxes2D
Boxes predicted
Mask
Masks predicted
Notes
If
m_outputs
does not containpred_masks
attribute, a [None]*B list will be returned by default
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)¶
Method call after each validation batch. This class is a pytorch lightning callback, therefore this method will by automaticly call by pl.
This method will call the infernece method of the module’s model and will expect to receive the predicted boxes2D and/or Masks. Theses elements will be aggregate to compute the different metrics in the on_validation_end method. The infernece method will be call using the m_outputs key from the outputs dict. If m_outputs is a list, then the list will be consider as an temporal list. Therefore, this callback will aggregate the prediction for each element of the sequence and will log the final results with the timestep prefix val/t/ instead of simply /val/
- Parameters
- trainer: pl.Trainer
Pytorch lightning trainer
- pl_module: pl.LightningModule
Pytorch lightning module. The
m_outputs
key is expected for this this callback to work properly.- outputs: dict
Training/Validation step outputs of the pl.LightningModule class.
- batch: list
Batch comming from the dataloader. Usually, a list of frame.
- batch_idx: int
Id the batch
- dataloader_idx: int
Dataloader batch ID.
- on_validation_end(trainer, pl_module)¶
Method call at the end of each validation epoch. The method will use all the aggregate data over the epoch to log the final metrics on wandb. This class is a pytorch lightning callback, therefore this method will by automaticly call by pl.
This method is currently a WIP since some metrics are not logged due to some wandb error when loading Table.
- Parameters
- trainer: pl.Trainer
Pytorch lightning trainer
- pl_module: pl.LightningModule
Pytorch lightning module
AP Metrics Callback¶
Callback that stores samples in each step to calculate the AP for one IoU and one class
See Also¶
ApMetrics
, the specific metric implement in this callback
- class alonet.callbacks.map_metrics_callback.ApMetricsCallback(*args, **kwargs)¶
Bases:
alonet.callbacks.base_metrics_callback.InstancesBaseMetricsCallback
- on_validation_end(trainer, pl_module)¶
Method call at the end of each validation epoch. The method will use all the aggregate data over the epoch to log the final metrics on wandb. This class is a pytorch lightning callback, therefore this method will by automaticly call by pl.
This method is currently a WIP since some metrics are not logged due to some wandb error when loading Table.
- Parameters
- trainer: pl.Trainer
Pytorch lightning trainer
- pl_module: pl.LightningModule
Pytorch lightning module
PQ Metrics Callback¶
Callback that stores samples in each step to calculate the different Panoptic Quality metrics
See Also¶
PQMetrics
, the specific metric implement in this callback
- class alonet.callbacks.pq_metrics_callback.PQMetricsCallback(*args, **kwargs)¶
Bases:
alonet.callbacks.base_metrics_callback.InstancesBaseMetricsCallback
- on_validation_end(trainer, pl_module)¶
Method call at the end of each validation epoch. The method will use all the aggregate data over the epoch to log the final metrics on wandb. This class is a pytorch lightning callback, therefore this method will by automaticly call by pl.
This method is currently a WIP since some metrics are not logged due to some wandb error when loading Table.
- Parameters
- trainer: pl.Trainer
Pytorch lightning trainer
- pl_module: pl.LightningModule
Pytorch lightning module
Metrics Callbacks¶
- class alonet.callbacks.metrics_callback.MetricsCallback(step_smooth=100, val_names=None)¶
Bases:
pytorch_lightning.callbacks.base.Callback
Callback for any training that need to log metrics and loss results. To be used, the model outputs must be a dict containing the key “metrics”. During training, the scalar values will be averaged over step_smooth steps.
Additionally, if the metrics name start with histogram or scatter, this class will treat them as such and will try to log histogram/scatter on wandb.
- Parameters
- step_smooth: int
size in steps of the window for computing the moving average of scalar metrics .
- val_nameslist[str]
names associated with each val_dataloader. Stats will be computed separately for each dataset and logged in wandb with the associated val_name as prefix.
Examples
Here is an example of an expected model forward output.
>>> outputs = { "metrics": { "loss": loss_value, "cross_entropy": cross_entropy_value, "histogram": torch.tensor(n,), "scatter": (["name_x_axis", "name_y_axis"], torch.tensor(n, 2)), } }
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)¶
Method called after each training batch. This class is a pytorch lightning callback, therefore this method will by automatically called by pytorch lightning.
- Parameters
- trainer: pl.Trainer
Pytorch lightning trainer
- pl_module: pl.LightningModule
Pytorch lightning module
- outputs:
Training/Validation step outputs of the pl.LightningModule class. The metrics key is expected for this callback to work properly. m_outputs[metrics] must be a dict. For each key, - if the key contains the keyword ‘histogram’ The value of the tensor will be aggregate to compute an histogram all trainer.log_every_n_steps. - If the key contains the keyword ‘scatter’ The value of the key must of a tuple, (scatter_names: tuple, torch.tensor) with the first element being a list of len 2 with the name of the X axis and the Y axis. The second element will be a tensor of size (N, 2). - Othwerwise The tensor is expected to be a single scaler with the value to log
- batch: list
Batch coming from the dataloader. Usually, a list of frame.
- batch_idx: int
Id of the batch
- dataloader_idx: int
Id of the dataloader.
- on_train_epoch_end(trainer, pl_module)¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule and access outputs via the module OR
Cache data across train batch hooks inside the callback implementation to post-process in this hook.
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)¶
Method called after each validation batch. This class is a Pytorch Lightning callback, therefore this method will by automatically called by Pytorch Lightning.
- Parameters
- trainer: pl.Trainer
Pytorch lightning trainer
- pl_module: pl.LightningModule
Pytorch lightning module
- outputs:
Training/Validation step outputs of the pl.LightningModule class. The metrics key is expected for this callback to work properly. m_outputs[metrics] must be a dict. For each key, - if the key contains the keyword ‘histogram’ The value of the tensor will be aggregate to compute an histogram all trainer.log_every_n_steps. - If the key contains the keyword ‘scatter’ The value of the key must of a tuple, (scatter_names: tuple, torch.tensor) with the first element being a list of len 2 with the name of the X axis and the Y axis. The second element will be a tensor of size (N, 2). - Othwerwise The tensor is expected to be a single scaler with the value to log
- batch: list
Batch comming from the dataloader. Usually, a list of frame.
- batch_idx: int
Id the batch
- dataloader_idx: int
Dataloader id.
- on_validation_epoch_end(trainer, pl_module)¶
Called when the val epoch ends.
- alonet.callbacks.metrics_callback.get_rank()¶
- alonet.callbacks.metrics_callback.is_dist_avail_and_initialized()¶
Object Detector Callback¶
Callback for any object detection training that use Frame
as GT.
- class alonet.callbacks.object_detector_callback.ObjectDetectorCallback(val_frames)¶
Bases:
pytorch_lightning.callbacks.base.Callback
The callback load frames every x training step as well as once every validation step on the given
val_frames
and log the different objects predicted- Parameters
- val_framesUnion[list,
Frames
] List of sample from the validation set to use to load the validation progress
- val_framesUnion[list,
- log_boxes_2d(frames, preds_boxes, trainer, name)¶
Given a frames and predicted boxes2d, this method will log the images into wandb
- Parameters
- frameslist of
frame
Frame with GT boxes2d attached
- preds_boxeslist of
BoundingBoxes2D
A set of predicted boxes2d
- trainerpl.trainer.trainer.Trainer
Lightning trainer
- frameslist of
- log_boxes_3d(frames, preds_boxes, trainer, name)¶
Given a frames and predicted boxes3d, this method will log the images into wandb
- Parameters
- frames
frame
Frame with GT boxes2d attached
- preds_boxes
BoundingBoxes3D
A set of predicted boxes3d
- trainer: pl.trainer.trainer.Trainer
Lightning trainer
- frames
- log_masks(frames, pred_masks, trainer, name)¶
Given a frames and predicted masks in segmentation tasks, this method will log the images into wandb
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)¶
Called when the train batch ends.
- on_validation_epoch_end(trainer, pl_module)¶
Called when the val epoch ends.