# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""MaskRCNN model implementations."""
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Literal
import cv2
import torch
from datumaro import Polygon
from torch import nn
from torchvision import tv_tensors
from torchvision.ops import RoIAlign
from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.common.backbones import ResNet, build_model_including_pytorchcv
from otx.backend.native.models.common.losses import CrossEntropyLoss, CrossSigmoidFocalLoss, L1Loss
from otx.backend.native.models.common.utils.assigners import MaxIoUAssigner
from otx.backend.native.models.common.utils.coders import DeltaXYWHBBoxCoder
from otx.backend.native.models.common.utils.prior_generators import AnchorGenerator
from otx.backend.native.models.common.utils.samplers import RandomSampler
from otx.backend.native.models.detection.necks import FPN
from otx.backend.native.models.instance_segmentation.backbones.swin import SwinTransformer
from otx.backend.native.models.instance_segmentation.base import OTXInstanceSegModel
from otx.backend.native.models.instance_segmentation.heads import ConvFCBBoxHead, FCNMaskHead, RoIHead, RPNHead
from otx.backend.native.models.instance_segmentation.losses import ROICriterion, RPNCriterion
from otx.backend.native.models.instance_segmentation.segmentors.two_stage import TwoStageDetector
from otx.backend.native.models.instance_segmentation.utils.roi_extractors import SingleRoIExtractor
from otx.backend.native.models.modules.norm import build_norm_layer
from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper
from otx.backend.native.models.utils.utils import load_checkpoint
from otx.config.data import TileConfig
from otx.data.entity.torch import OTXPredBatch
from otx.metrics.mean_ap import MaskRLEMeanAPFMeasureCallable
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.backend.native.models.base import DataInputParams
from otx.backend.native.schedulers import LRSchedulerListCallable
from otx.metrics import MetricCallable
from otx.types.label import LabelInfoTypes
[docs]
class MaskRCNN(OTXInstanceSegModel):
"""Implementation of MaskRCNN for instance segmentation.
Args:
label_info (LabelInfoTypes): Information about the labels used in the model.
data_input_params (DataInputParams): Parameters for the data input.
model_name (str, optional): Name of the model. Defaults to "maskrcnn_resnet_50".
optimizer (OptimizerCallable, optional): Optimizer for the model. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler for the model.
Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional): Metric for evaluating the model.
Defaults to MaskRLEMeanAPFMeasureCallable.
torch_compile (bool, optional): Whether to use torch compile. Defaults to False.
tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
explain_mode (bool, optional): Whether to enable explainable AI mode. Defaults to False.
"""
pretrained_weights: ClassVar[dict[str, Any]] = {
"maskrcnn_resnet_50": "https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_mstrain-poly_3x_coco/"
"mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb.pth",
"maskrcnn_efficientnet_b2b": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/"
"models/instance_segmentation/v2/efficientnet_b2b-mask_rcnn-576x576.pth",
"maskrcnn_swin_tiny": "https://download.openmmlab.com/mmdetection/v2.0/swin/"
"mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco/"
"mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco_20210908_165006-90a4008c.pth",
}
def __init__(
self,
label_info: LabelInfoTypes,
data_input_params: DataInputParams,
model_name: Literal[
"maskrcnn_resnet_50",
"maskrcnn_efficientnet_b2b",
"maskrcnn_swin_tiny",
] = "maskrcnn_resnet_50",
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MaskRLEMeanAPFMeasureCallable,
torch_compile: bool = False,
tile_config: TileConfig = TileConfig(enable_tiler=False),
) -> None:
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
model_name=model_name,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
tile_config=tile_config,
)
def _create_model(self, num_classes: int | None = None) -> MaskRCNN:
num_classes = num_classes if num_classes is not None else self.num_classes
# TODO(Kirill): depricate train_cfg/test_cfg
train_cfg = {
"rpn": {
"allowed_border": -1,
"debug": False,
"pos_weight": -1,
"assigner": MaxIoUAssigner(
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1,
match_low_quality=True,
),
"sampler": RandomSampler(
add_gt_as_proposals=False,
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
),
},
"rpn_proposal": {
"max_per_img": 1000,
"min_bbox_size": 0,
"nms": {
"type": "nms",
"iou_threshold": 0.7,
},
"nms_pre": 2000,
},
"rcnn": {
"assigner": MaxIoUAssigner(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1,
match_low_quality=True,
),
"sampler": RandomSampler(
add_gt_as_proposals=True,
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
),
"debug": False,
"mask_size": 28,
"pos_weight": -1,
},
}
test_cfg = {
"rpn": {
"max_per_img": 1000,
"min_bbox_size": 0,
"nms": {
"type": "nms",
"iou_threshold": 0.7,
},
"nms_pre": 1000,
},
"rcnn": {
"mask_thr_binary": 0.5,
"max_per_img": 100,
"nms": {
"type": "nms",
"iou_threshold": 0.5,
},
"score_thr": 0.05,
},
}
rpn_assigner = MaxIoUAssigner(
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1,
match_low_quality=True,
)
rpn_sampler = RandomSampler(
add_gt_as_proposals=False,
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
)
rcnn_assigner = MaxIoUAssigner(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1,
match_low_quality=True,
)
rcnn_sampler = RandomSampler(
add_gt_as_proposals=True,
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
)
backbone = self._build_backbone()
neck = FPN(model_name=self.model_name)
rpn_bbox_coder = DeltaXYWHBBoxCoder(
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(1.0, 1.0, 1.0, 1.0),
)
rpn_head = RPNHead(
model_name=self.model_name,
anchor_generator=AnchorGenerator(
strides=[4, 8, 16, 32, 64],
ratios=[0.5, 1.0, 2.0],
scales=[8],
),
bbox_coder=rpn_bbox_coder,
assigner=rpn_assigner,
sampler=rpn_sampler,
train_cfg=train_cfg["rpn"],
test_cfg=test_cfg["rpn"],
)
roi_bbox_coder = DeltaXYWHBBoxCoder(
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(0.1, 0.1, 0.2, 0.2),
)
bbox_head = ConvFCBBoxHead(
model_name=self.model_name,
num_classes=num_classes,
bbox_coder=roi_bbox_coder,
)
bbox_roi_extractor = SingleRoIExtractor(
featmap_strides=[4, 8, 16, 32],
out_channels=rpn_head.feat_channels,
roi_layer=RoIAlign(
output_size=7,
sampling_ratio=0,
aligned=True,
spatial_scale=1.0,
),
)
mask_roi_extractor = SingleRoIExtractor(
featmap_strides=[4, 8, 16, 32],
out_channels=rpn_head.feat_channels,
roi_layer=RoIAlign(
output_size=14,
sampling_ratio=0,
aligned=True,
spatial_scale=1.0,
),
)
mask_head = FCNMaskHead(
conv_out_channels=rpn_head.feat_channels,
in_channels=rpn_head.feat_channels,
num_classes=num_classes,
num_convs=4,
)
roi_head = RoIHead(
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
assigner=rcnn_assigner,
sampler=rcnn_sampler,
)
rpn_criterion = RPNCriterion(
bbox_coder=DeltaXYWHBBoxCoder(
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(1.0, 1.0, 1.0, 1.0),
),
loss_bbox=L1Loss(loss_weight=1.0),
loss_cls=CrossEntropyLoss(loss_weight=1.0, use_sigmoid=True),
)
roi_criterion = ROICriterion(
num_classes=num_classes,
bbox_coder=DeltaXYWHBBoxCoder(
target_means=(0.0, 0.0, 0.0, 0.0),
target_stds=(0.1, 0.1, 0.2, 0.2),
),
loss_bbox=L1Loss(loss_weight=1.0),
# TODO(someone): performance of CrossSigmoidFocalLoss is worse without mmcv
# https://github.com/openvinotoolkit/training_extensions/pull/3431
loss_cls=CrossSigmoidFocalLoss(loss_weight=1.0, use_sigmoid=False),
loss_mask=CrossEntropyLoss(loss_weight=1.0, use_mask=True),
class_agnostic=False,
)
model = TwoStageDetector(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
roi_criterion=roi_criterion,
rpn_criterion=rpn_criterion,
)
load_checkpoint(model, self.pretrained_weights[self.model_name], map_location="cpu")
return model
def _build_backbone(self) -> nn.Module:
"""Builds the backbone for the model."""
backbone_cfg: dict[str, Any] = {
"maskrcnn_resnet_50": {
"depth": 50,
"frozen_stages": 1,
},
"maskrcnn_swin_tiny": {
"drop_path_rate": 0.2,
"patch_norm": True,
"convert_weights": True,
},
"maskrcnn_efficientnet_b2b": {
"type": "efficientnet_b2b",
"out_indices": [2, 3, 4, 5],
"frozen_stages": -1,
"pretrained": True,
"activation": nn.SiLU,
"normalization": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True),
},
}
if "resnet" in self.model_name:
return ResNet(
**backbone_cfg[self.model_name],
)
if "efficientnet" in self.model_name:
cfg = backbone_cfg[self.model_name]
return build_model_including_pytorchcv(cfg=cfg)
if "swin" in self.model_name:
return SwinTransformer(
**backbone_cfg[self.model_name],
)
msg = ValueError(f"Model {self.model_name} is not supported.")
raise msg
@property
def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
data_input_params=self.data_input_params,
resize_mode="fit_to_window",
pad_value=0,
swap_rgb=False,
via_onnx=True,
onnx_export_configuration={
"input_names": ["image"],
"output_names": ["boxes", "labels", "masks"],
"dynamic_axes": {
"image": {0: "batch"},
"boxes": {0: "batch", 1: "num_dets"},
"labels": {0: "batch", 1: "num_dets"},
"masks": {0: "batch", 1: "num_dets", 2: "height", 3: "width"},
},
"opset_version": 11,
"autograd_inlining": False,
},
output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None,
)
[docs]
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_iseg_ckpt(state_dict, add_prefix)
@property
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for MaskRCNN-Eff."""
if self.model_name == "maskrcnn_efficientnet_b2b":
return {
"ignored_scope": {
"types": ["Add", "Divide", "Multiply", "Sigmoid"],
"validate": False,
},
"preset": "mixed",
}
if self.model_name == "maskrcnn_swin_t":
return {"model_type": "transformer"}
return {}
class RotatedMaskRCNNModel(MaskRCNN):
"""Base class for the rotated detection models used in OTX."""
def predict_step(self, *args: torch.Any, **kwargs: torch.Any) -> OTXPredBatch:
"""Predict step for rotated detection task.
Note: This method is overridden to convert masks to rotated bounding boxes.
Returns:
TorchPredBatch: The predicted polygons (rboxes), scores, labels, masks.
"""
preds = super().predict_step(*args, **kwargs)
batch_scores: list[torch.Tensor] = []
batch_bboxes: list[tv_tensors.BoundingBoxes] = []
batch_labels: list[torch.LongTensor] = []
batch_polygons: list[list[Polygon]] = []
batch_masks: list[tv_tensors.Mask] = []
for field_name, field in zip(
["imgs_info", "bboxes", "scores", "labels", "masks"],
[preds.imgs_info, preds.bboxes, preds.scores, preds.labels, preds.masks],
):
if field is None:
msg = f"Field '{field_name}' is None, which is not allowed."
raise ValueError(msg)
for img_info, pred_bboxes, pred_scores, pred_labels, pred_masks in zip( # type: ignore[misc]
preds.imgs_info, # type: ignore[arg-type]
preds.bboxes, # type: ignore[arg-type]
preds.scores, # type: ignore[arg-type]
preds.labels, # type: ignore[arg-type]
preds.masks, # type: ignore[arg-type]
):
boxes = []
scores = []
labels = []
masks = []
polygons = []
for bbox, score, label, mask in zip(pred_bboxes, pred_scores, pred_labels, pred_masks):
if mask.sum() == 0:
continue
np_mask = mask.detach().cpu().numpy().astype(int)
contours, hierarchies = cv2.findContours(np_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
if hierarchies is None:
continue
rbox_polygons = []
for contour, hierarchy in zip(contours, hierarchies[0]):
# skip inner contours
if hierarchy[3] != -1 or len(contour) <= 2:
continue
rbox_points = Polygon(cv2.boxPoints(cv2.minAreaRect(contour)).reshape(-1))
rbox_polygons.append((rbox_points, rbox_points.get_area()))
# select the largest polygon
if len(rbox_polygons) > 0:
rbox_polygons.sort(key=lambda x: x[1], reverse=True)
polygons.append(rbox_polygons[0][0])
scores.append(score)
boxes.append(bbox)
labels.append(label)
masks.append(mask)
if len(boxes):
scores = torch.stack(scores)
boxes = tv_tensors.BoundingBoxes(torch.stack(boxes), format="XYXY", canvas_size=img_info.ori_shape) # type: ignore[union-attr]
labels = torch.stack(labels)
masks = torch.stack(masks)
batch_scores.append(scores)
batch_bboxes.append(boxes)
batch_labels.append(labels)
batch_polygons.append(polygons)
batch_masks.append(masks)
return OTXPredBatch(
batch_size=preds.batch_size,
images=preds.images,
imgs_info=preds.imgs_info,
scores=batch_scores,
bboxes=batch_bboxes,
masks=batch_masks,
polygons=batch_polygons,
labels=batch_labels,
)