Source code for otx.backend.native.models.instance_segmentation.maskrcnn

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""MaskRCNN model implementations."""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import cv2
import torch
from datumaro import Polygon
from torch import nn
from torchvision import tv_tensors
from torchvision.ops import RoIAlign

from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.common.backbones import ResNet, build_model_including_pytorchcv
from otx.backend.native.models.common.losses import CrossEntropyLoss, CrossSigmoidFocalLoss, L1Loss
from otx.backend.native.models.common.utils.assigners import MaxIoUAssigner
from otx.backend.native.models.common.utils.coders import DeltaXYWHBBoxCoder
from otx.backend.native.models.common.utils.prior_generators import AnchorGenerator
from otx.backend.native.models.common.utils.samplers import RandomSampler
from otx.backend.native.models.detection.necks import FPN
from otx.backend.native.models.instance_segmentation.backbones.swin import SwinTransformer
from otx.backend.native.models.instance_segmentation.base import OTXInstanceSegModel
from otx.backend.native.models.instance_segmentation.heads import ConvFCBBoxHead, FCNMaskHead, RoIHead, RPNHead
from otx.backend.native.models.instance_segmentation.losses import ROICriterion, RPNCriterion
from otx.backend.native.models.instance_segmentation.segmentors.two_stage import TwoStageDetector
from otx.backend.native.models.instance_segmentation.utils.roi_extractors import SingleRoIExtractor
from otx.backend.native.models.modules.norm import build_norm_layer
from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper
from otx.backend.native.models.utils.utils import load_checkpoint
from otx.config.data import TileConfig
from otx.data.entity.torch import OTXPredBatch
from otx.metrics.mean_ap import MaskRLEMeanAPFMeasureCallable

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable

    from otx.backend.native.models.base import DataInputParams
    from otx.backend.native.schedulers import LRSchedulerListCallable
    from otx.metrics import MetricCallable
    from otx.types.label import LabelInfoTypes


[docs] class MaskRCNN(OTXInstanceSegModel): """Implementation of MaskRCNN for instance segmentation. Args: label_info (LabelInfoTypes): Information about the labels used in the model. data_input_params (DataInputParams): Parameters for the data input. model_name (str, optional): Name of the model. Defaults to "maskrcnn_resnet_50". optimizer (OptimizerCallable, optional): Optimizer for the model. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler for the model. Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Metric for evaluating the model. Defaults to MaskRLEMeanAPFMeasureCallable. torch_compile (bool, optional): Whether to use torch compile. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). explain_mode (bool, optional): Whether to enable explainable AI mode. Defaults to False. """ pretrained_weights: ClassVar[dict[str, Any]] = { "maskrcnn_resnet_50": "https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_mstrain-poly_3x_coco/" "mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb.pth", "maskrcnn_efficientnet_b2b": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/" "models/instance_segmentation/v2/efficientnet_b2b-mask_rcnn-576x576.pth", "maskrcnn_swin_tiny": "https://download.openmmlab.com/mmdetection/v2.0/swin/" "mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco/" "mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco_20210908_165006-90a4008c.pth", } def __init__( self, label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal[ "maskrcnn_resnet_50", "maskrcnn_efficientnet_b2b", "maskrcnn_swin_tiny", ] = "maskrcnn_resnet_50", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MaskRLEMeanAPFMeasureCallable, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False), ) -> None: super().__init__( label_info=label_info, data_input_params=data_input_params, model_name=model_name, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, tile_config=tile_config, ) def _create_model(self, num_classes: int | None = None) -> MaskRCNN: num_classes = num_classes if num_classes is not None else self.num_classes # TODO(Kirill): depricate train_cfg/test_cfg train_cfg = { "rpn": { "allowed_border": -1, "debug": False, "pos_weight": -1, "assigner": MaxIoUAssigner( pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1, match_low_quality=True, ), "sampler": RandomSampler( add_gt_as_proposals=False, num=256, pos_fraction=0.5, neg_pos_ub=-1, ), }, "rpn_proposal": { "max_per_img": 1000, "min_bbox_size": 0, "nms": { "type": "nms", "iou_threshold": 0.7, }, "nms_pre": 2000, }, "rcnn": { "assigner": MaxIoUAssigner( pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1, match_low_quality=True, ), "sampler": RandomSampler( add_gt_as_proposals=True, num=512, pos_fraction=0.25, neg_pos_ub=-1, ), "debug": False, "mask_size": 28, "pos_weight": -1, }, } test_cfg = { "rpn": { "max_per_img": 1000, "min_bbox_size": 0, "nms": { "type": "nms", "iou_threshold": 0.7, }, "nms_pre": 1000, }, "rcnn": { "mask_thr_binary": 0.5, "max_per_img": 100, "nms": { "type": "nms", "iou_threshold": 0.5, }, "score_thr": 0.05, }, } rpn_assigner = MaxIoUAssigner( pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, ignore_iof_thr=-1, match_low_quality=True, ) rpn_sampler = RandomSampler( add_gt_as_proposals=False, num=256, pos_fraction=0.5, neg_pos_ub=-1, ) rcnn_assigner = MaxIoUAssigner( pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, ignore_iof_thr=-1, match_low_quality=True, ) rcnn_sampler = RandomSampler( add_gt_as_proposals=True, num=512, pos_fraction=0.25, neg_pos_ub=-1, ) backbone = self._build_backbone() neck = FPN(model_name=self.model_name) rpn_bbox_coder = DeltaXYWHBBoxCoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0), ) rpn_head = RPNHead( model_name=self.model_name, anchor_generator=AnchorGenerator( strides=[4, 8, 16, 32, 64], ratios=[0.5, 1.0, 2.0], scales=[8], ), bbox_coder=rpn_bbox_coder, assigner=rpn_assigner, sampler=rpn_sampler, train_cfg=train_cfg["rpn"], test_cfg=test_cfg["rpn"], ) roi_bbox_coder = DeltaXYWHBBoxCoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(0.1, 0.1, 0.2, 0.2), ) bbox_head = ConvFCBBoxHead( model_name=self.model_name, num_classes=num_classes, bbox_coder=roi_bbox_coder, ) bbox_roi_extractor = SingleRoIExtractor( featmap_strides=[4, 8, 16, 32], out_channels=rpn_head.feat_channels, roi_layer=RoIAlign( output_size=7, sampling_ratio=0, aligned=True, spatial_scale=1.0, ), ) mask_roi_extractor = SingleRoIExtractor( featmap_strides=[4, 8, 16, 32], out_channels=rpn_head.feat_channels, roi_layer=RoIAlign( output_size=14, sampling_ratio=0, aligned=True, spatial_scale=1.0, ), ) mask_head = FCNMaskHead( conv_out_channels=rpn_head.feat_channels, in_channels=rpn_head.feat_channels, num_classes=num_classes, num_convs=4, ) roi_head = RoIHead( bbox_roi_extractor=bbox_roi_extractor, bbox_head=bbox_head, mask_roi_extractor=mask_roi_extractor, mask_head=mask_head, assigner=rcnn_assigner, sampler=rcnn_sampler, ) rpn_criterion = RPNCriterion( bbox_coder=DeltaXYWHBBoxCoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0), ), loss_bbox=L1Loss(loss_weight=1.0), loss_cls=CrossEntropyLoss(loss_weight=1.0, use_sigmoid=True), ) roi_criterion = ROICriterion( num_classes=num_classes, bbox_coder=DeltaXYWHBBoxCoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(0.1, 0.1, 0.2, 0.2), ), loss_bbox=L1Loss(loss_weight=1.0), # TODO(someone): performance of CrossSigmoidFocalLoss is worse without mmcv # https://github.com/openvinotoolkit/training_extensions/pull/3431 loss_cls=CrossSigmoidFocalLoss(loss_weight=1.0, use_sigmoid=False), loss_mask=CrossEntropyLoss(loss_weight=1.0, use_mask=True), class_agnostic=False, ) model = TwoStageDetector( backbone=backbone, neck=neck, rpn_head=rpn_head, roi_head=roi_head, roi_criterion=roi_criterion, rpn_criterion=rpn_criterion, ) load_checkpoint(model, self.pretrained_weights[self.model_name], map_location="cpu") return model def _build_backbone(self) -> nn.Module: """Builds the backbone for the model.""" backbone_cfg: dict[str, Any] = { "maskrcnn_resnet_50": { "depth": 50, "frozen_stages": 1, }, "maskrcnn_swin_tiny": { "drop_path_rate": 0.2, "patch_norm": True, "convert_weights": True, }, "maskrcnn_efficientnet_b2b": { "type": "efficientnet_b2b", "out_indices": [2, 3, 4, 5], "frozen_stages": -1, "pretrained": True, "activation": nn.SiLU, "normalization": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), }, } if "resnet" in self.model_name: return ResNet( **backbone_cfg[self.model_name], ) if "efficientnet" in self.model_name: cfg = backbone_cfg[self.model_name] return build_model_including_pytorchcv(cfg=cfg) if "swin" in self.model_name: return SwinTransformer( **backbone_cfg[self.model_name], ) msg = ValueError(f"Model {self.model_name} is not supported.") raise msg @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" return OTXNativeModelExporter( task_level_export_parameters=self._export_parameters, data_input_params=self.data_input_params, resize_mode="fit_to_window", pad_value=0, swap_rgb=False, via_onnx=True, onnx_export_configuration={ "input_names": ["image"], "output_names": ["boxes", "labels", "masks"], "dynamic_axes": { "image": {0: "batch"}, "boxes": {0: "batch", 1: "num_dets"}, "labels": {0: "batch", 1: "num_dets"}, "masks": {0: "batch", 1: "num_dets", 2: "height", 3: "width"}, }, "opset_version": 11, "autograd_inlining": False, }, output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None, )
[docs] def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_iseg_ckpt(state_dict, add_prefix)
@property def _optimization_config(self) -> dict[str, Any]: """PTQ config for MaskRCNN-Eff.""" if self.model_name == "maskrcnn_efficientnet_b2b": return { "ignored_scope": { "types": ["Add", "Divide", "Multiply", "Sigmoid"], "validate": False, }, "preset": "mixed", } if self.model_name == "maskrcnn_swin_t": return {"model_type": "transformer"} return {}
class RotatedMaskRCNNModel(MaskRCNN): """Base class for the rotated detection models used in OTX.""" def predict_step(self, *args: torch.Any, **kwargs: torch.Any) -> OTXPredBatch: """Predict step for rotated detection task. Note: This method is overridden to convert masks to rotated bounding boxes. Returns: TorchPredBatch: The predicted polygons (rboxes), scores, labels, masks. """ preds = super().predict_step(*args, **kwargs) batch_scores: list[torch.Tensor] = [] batch_bboxes: list[tv_tensors.BoundingBoxes] = [] batch_labels: list[torch.LongTensor] = [] batch_polygons: list[list[Polygon]] = [] batch_masks: list[tv_tensors.Mask] = [] for field_name, field in zip( ["imgs_info", "bboxes", "scores", "labels", "masks"], [preds.imgs_info, preds.bboxes, preds.scores, preds.labels, preds.masks], ): if field is None: msg = f"Field '{field_name}' is None, which is not allowed." raise ValueError(msg) for img_info, pred_bboxes, pred_scores, pred_labels, pred_masks in zip( # type: ignore[misc] preds.imgs_info, # type: ignore[arg-type] preds.bboxes, # type: ignore[arg-type] preds.scores, # type: ignore[arg-type] preds.labels, # type: ignore[arg-type] preds.masks, # type: ignore[arg-type] ): boxes = [] scores = [] labels = [] masks = [] polygons = [] for bbox, score, label, mask in zip(pred_bboxes, pred_scores, pred_labels, pred_masks): if mask.sum() == 0: continue np_mask = mask.detach().cpu().numpy().astype(int) contours, hierarchies = cv2.findContours(np_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) if hierarchies is None: continue rbox_polygons = [] for contour, hierarchy in zip(contours, hierarchies[0]): # skip inner contours if hierarchy[3] != -1 or len(contour) <= 2: continue rbox_points = Polygon(cv2.boxPoints(cv2.minAreaRect(contour)).reshape(-1)) rbox_polygons.append((rbox_points, rbox_points.get_area())) # select the largest polygon if len(rbox_polygons) > 0: rbox_polygons.sort(key=lambda x: x[1], reverse=True) polygons.append(rbox_polygons[0][0]) scores.append(score) boxes.append(bbox) labels.append(label) masks.append(mask) if len(boxes): scores = torch.stack(scores) boxes = tv_tensors.BoundingBoxes(torch.stack(boxes), format="XYXY", canvas_size=img_info.ori_shape) # type: ignore[union-attr] labels = torch.stack(labels) masks = torch.stack(masks) batch_scores.append(scores) batch_bboxes.append(boxes) batch_labels.append(labels) batch_polygons.append(polygons) batch_masks.append(masks) return OTXPredBatch( batch_size=preds.batch_size, images=preds.images, imgs_info=preds.imgs_info, scores=batch_scores, bboxes=batch_bboxes, masks=batch_masks, polygons=batch_polygons, labels=batch_labels, )