# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""RTMDet model implementations."""
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar, Literal
from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.common.losses import GIoULoss, QualityFocalLoss
from otx.backend.native.models.common.utils.assigners import DynamicSoftLabelAssigner
from otx.backend.native.models.common.utils.coders import DistancePointBBoxCoder
from otx.backend.native.models.common.utils.prior_generators import MlvlPointGenerator
from otx.backend.native.models.common.utils.samplers import PseudoSampler
from otx.backend.native.models.detection.backbones import CSPNeXt
from otx.backend.native.models.detection.base import OTXDetectionModel
from otx.backend.native.models.detection.detectors import SingleStageDetector
from otx.backend.native.models.detection.heads import RTMDetSepBNHead
from otx.backend.native.models.detection.losses import RTMDetCriterion
from otx.backend.native.models.detection.necks import CSPNeXtPAFPN
from otx.backend.native.models.utils.utils import load_checkpoint
from otx.config.data import TileConfig
from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable
from otx.types.export import TaskLevelExportParameters
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.backend.native.schedulers import LRSchedulerListCallable
from otx.metrics import MetricCallable
from otx.types.label import LabelInfoTypes
[docs]
class RTMDet(OTXDetectionModel):
"""OTX Detection model class for RTMDet.
Attributes:
pretrained_weights (ClassVar[dict[str, str]]): Dictionary containing URLs for pretrained weights.
input_size_multiplier (int): Multiplier for the input size.
Args:
label_info (LabelInfoTypes): Information about the labels.
data_input_params (DataInputParams): Parameters for data input.
model_name (str, optional): Name of the model to use. Defaults to "rtmdet_tiny".
optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler.
Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional): Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable.
torch_compile (bool, optional): Whether to use torch compile. Defaults to False.
tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
"""
pretrained_weights: ClassVar[dict[str, str]] = {
"rtmdet_tiny": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/object_detection/v2/rtmdet_tiny.pth",
}
input_size_multiplier = 32
def __init__(
self,
label_info: LabelInfoTypes,
data_input_params: DataInputParams,
model_name: Literal["rtmdet_tiny"] = "rtmdet_tiny",
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAveragePrecisionFMeasureCallable,
torch_compile: bool = False,
tile_config: TileConfig = TileConfig(enable_tiler=False),
) -> None:
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
model_name=model_name,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
tile_config=tile_config,
)
def _create_model(self, num_classes: int | None = None) -> SingleStageDetector:
num_classes = num_classes if num_classes is not None else self.num_classes
train_cfg = {
"assigner": DynamicSoftLabelAssigner(topk=13),
"sampler": PseudoSampler(),
"allowed_border": -1,
"pos_weight": -1,
"debug": False,
}
test_cfg = {
"nms": {"type": "nms", "iou_threshold": 0.65},
"score_thr": 0.001,
"mask_thr_binary": 0.5,
"max_per_img": 300,
"min_bbox_size": 0,
"nms_pre": 30000,
}
backbone = CSPNeXt(model_name=self.model_name)
neck = CSPNeXtPAFPN(model_name=self.model_name)
bbox_head = RTMDetSepBNHead(
model_name=self.model_name,
num_classes=num_classes,
anchor_generator=MlvlPointGenerator(offset=0, strides=[8, 16, 32]),
bbox_coder=DistancePointBBoxCoder(),
train_cfg=train_cfg, # TODO ( kirill): remove
test_cfg=test_cfg, # TODO ( kirill): remove
)
criterion = RTMDetCriterion(
num_classes=num_classes,
loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0),
loss_bbox=GIoULoss(loss_weight=2.0),
)
model = SingleStageDetector(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
criterion=criterion,
train_cfg=train_cfg, # TODO ( kirill): remove
test_cfg=test_cfg, # TODO ( kirill): remove
)
model.init_weights()
load_checkpoint(model, self.pretrained_weights[self.model_name], map_location="cpu")
return model
@property
def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
data_input_params=self.data_input_params,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=True,
via_onnx=True,
onnx_export_configuration={
"input_names": ["image"],
"output_names": ["boxes", "labels"],
"dynamic_axes": {
"image": {0: "batch"},
"boxes": {0: "batch", 1: "num_dets"},
"labels": {0: "batch", 1: "num_dets"},
},
"autograd_inlining": False,
},
output_names=["bboxes", "labels", "feature_vector", "saliency_map"] if self.explain_mode else None,
)
@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
return super()._export_parameters.wrap(optimization_config={"preset": "mixed"})