Source code for otx.backend.native.models.instance_segmentation.rtmdet_inst

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""RTMDetInst model implementations."""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, ClassVar, Literal

from torch import nn

from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.common.losses import GIoULoss, QualityFocalLoss
from otx.backend.native.models.common.utils.assigners import DynamicSoftLabelAssigner
from otx.backend.native.models.common.utils.coders import DistancePointBBoxCoder
from otx.backend.native.models.common.utils.prior_generators import MlvlPointGenerator
from otx.backend.native.models.common.utils.samplers import PseudoSampler
from otx.backend.native.models.detection.backbones import CSPNeXt
from otx.backend.native.models.detection.detectors import SingleStageDetector
from otx.backend.native.models.detection.necks import CSPNeXtPAFPN
from otx.backend.native.models.instance_segmentation.base import OTXInstanceSegModel
from otx.backend.native.models.instance_segmentation.heads import RTMDetInstSepBNHead
from otx.backend.native.models.instance_segmentation.losses import DiceLoss, RTMDetInstCriterion
from otx.backend.native.models.modules.norm import build_norm_layer
from otx.backend.native.models.utils.utils import load_checkpoint
from otx.config.data import TileConfig
from otx.metrics.mean_ap import MaskRLEMeanAPFMeasureCallable

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
    from torch import Tensor

    from otx.backend.native.models.base import DataInputParams
    from otx.backend.native.schedulers import LRSchedulerListCallable
    from otx.metrics import MetricCallable
    from otx.types.label import LabelInfoTypes


[docs] class RTMDetInst(OTXInstanceSegModel): """Implementation of RTMDetInst for instance segmentation. Args: label_info (LabelInfoTypes): Information about the labels used in the model. data_input_params (DataInputParams): Parameters for the data input. model_name (str, optional): Name of the model. Defaults to "rtmdet_inst_tiny". optimizer (OptimizerCallable, optional): Optimizer for the model. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler for the model. Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Metric for evaluating the model. Defaults to MaskRLEMeanAPFMeasureCallable. torch_compile (bool, optional): Whether to use torch compile. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). explain_mode (bool, optional): Whether to enable explainable AI mode. Defaults to False. """ pretrained_weights: ClassVar[dict[str, str]] = { "rtmdet_inst_tiny": ( "https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/" "rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth" ), } def __init__( self, label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal["rtmdet_inst_tiny"] = "rtmdet_inst_tiny", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MaskRLEMeanAPFMeasureCallable, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False), ) -> None: super().__init__( label_info=label_info, data_input_params=data_input_params, model_name=model_name, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, tile_config=tile_config, ) def _create_model(self, num_classes: int | None = None) -> RTMDetInst: num_classes = num_classes if num_classes is not None else self.num_classes train_cfg = { "assigner": DynamicSoftLabelAssigner(topk=13), "sampler": PseudoSampler(), "allowed_border": -1, "pos_weight": -1, "debug": False, } test_cfg = { "nms": {"type": "nms", "iou_threshold": 0.5}, "score_thr": 0.05, "mask_thr_binary": 0.5, "max_per_img": 100, "min_bbox_size": 0, "nms_pre": 300, } backbone = CSPNeXt(model_name=self.model_name) neck = CSPNeXtPAFPN(model_name=self.model_name) bbox_head = RTMDetInstSepBNHead( num_classes=num_classes, in_channels=96, stacked_convs=2, share_conv=True, pred_kernel_size=1, feat_channels=96, normalization=partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), activation=partial(nn.SiLU, inplace=True), anchor_generator=MlvlPointGenerator( offset=0, strides=[8, 16, 32], ), bbox_coder=DistancePointBBoxCoder(), train_cfg=train_cfg, test_cfg=test_cfg, ) criterion = RTMDetInstCriterion( num_classes=num_classes, loss_cls=QualityFocalLoss( use_sigmoid=True, beta=2.0, loss_weight=1.0, ), loss_bbox=GIoULoss(loss_weight=2.0), loss_mask=DiceLoss( loss_weight=2.0, eps=5.0e-06, reduction="mean", ), ) model = SingleStageDetector( backbone=backbone, neck=neck, bbox_head=bbox_head, criterion=criterion, train_cfg=train_cfg, test_cfg=test_cfg, ) model.init_weights() load_checkpoint(model, self.pretrained_weights[self.model_name], map_location="cpu") return model @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" return OTXNativeModelExporter( task_level_export_parameters=self._export_parameters, data_input_params=self.data_input_params, resize_mode="fit_to_window_letterbox", pad_value=114, swap_rgb=False, via_onnx=True, onnx_export_configuration={ "input_names": ["image"], "output_names": ["boxes", "labels", "masks"], "dynamic_axes": { "image": {0: "batch"}, "boxes": {0: "batch", 1: "num_dets"}, "labels": {0: "batch", 1: "num_dets"}, "masks": {0: "batch", 1: "num_dets", 2: "height", 3: "width"}, }, "opset_version": 11, "autograd_inlining": False, }, # TODO(Eugene): Add XAI support for RTMDetInst output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None, )
[docs] def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]: """Forward function for export. NOTE : RTMDetInst uses explain_mode unlike other models. """ shape = (int(inputs.shape[2]), int(inputs.shape[3])) meta_info = { "pad_shape": shape, "batch_input_shape": shape, "img_shape": shape, "scale_factor": (1.0, 1.0), } meta_info_list = [meta_info] * len(inputs) return self.model.export(inputs, meta_info_list, explain_mode=self.explain_mode)