# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""TV MaskRCNN model implementations."""
# type: ignore[override]
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Literal
import torch
from torch import Tensor
from torchvision import tv_tensors
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, _default_anchorgen
from torchvision.models.detection.mask_rcnn import (
MaskRCNN_ResNet50_FPN_V2_Weights,
MaskRCNNPredictor,
)
from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.instance_segmentation.base import OTXInstanceSegModel
from otx.backend.native.models.instance_segmentation.heads import TVRoIHeads
from otx.backend.native.models.instance_segmentation.segmentors.maskrcnn_tv import (
FastRCNNConvFCHead,
MaskRCNN,
MaskRCNNBackbone,
MaskRCNNHeads,
RPNHead,
)
from otx.config.data import TileConfig
from otx.data.entity.base import OTXBatchLossEntity
from otx.data.entity.torch import OTXDataBatch, OTXPredBatch
from otx.data.entity.utils import stack_batch
from otx.metrics.mean_ap import MaskRLEMeanAPFMeasureCallable
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from otx.backend.native.models.base import DataInputParams
from otx.backend.native.schedulers import LRSchedulerListCallable
from otx.metrics import MetricCallable
from otx.types.label import LabelInfoTypes
[docs]
class MaskRCNNTV(OTXInstanceSegModel):
"""Implementation of torchvision MaskRCNN for instance segmentation.
Args:
label_info (LabelInfoTypes): Information about the labels used in the model.
data_input_params (DataInputParams): Parameters for the data input.
model_name (str, optional): Name of the model. Defaults to "maskrcnn_resnet_50".
optimizer (OptimizerCallable, optional): Optimizer for the model. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler for the model.
Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional): Metric for evaluating the model.
Defaults to MaskRLEMeanAPFMeasureCallable.
torch_compile (bool, optional): Whether to use torch compile. Defaults to False.
tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
explain_mode (bool, optional): Whether to enable explainable AI mode. Defaults to False.
"""
pretrained_weights: ClassVar[dict[str, Any]] = {
"maskrcnn_resnet_50": MaskRCNN_ResNet50_FPN_V2_Weights.verify("DEFAULT"),
}
def __init__(
self,
label_info: LabelInfoTypes,
data_input_params: DataInputParams,
model_name: Literal["maskrcnn_resnet_50"] = "maskrcnn_resnet_50",
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MaskRLEMeanAPFMeasureCallable,
torch_compile: bool = False,
tile_config: TileConfig = TileConfig(enable_tiler=False),
) -> None:
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
model_name=model_name,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
tile_config=tile_config,
)
def _create_model(self, num_classes: int | None = None) -> MaskRCNN:
num_classes = num_classes if num_classes is not None else self.num_classes
# NOTE: Add 1 to num_classes to account for background class.
num_classes = num_classes + 1
weights = self.pretrained_weights[self.model_name]
# init model components, model itself and load weights
rpn_anchor_generator = _default_anchorgen()
backbone = MaskRCNNBackbone(model_name=self.model_name)
rpn_head = RPNHead(model_name=self.model_name, anchorgen=rpn_anchor_generator)
box_head = FastRCNNConvFCHead(model_name=self.model_name)
mask_head = MaskRCNNHeads(model_name=self.model_name)
model = MaskRCNN(
backbone,
num_classes=91,
rpn_anchor_generator=rpn_anchor_generator,
rpn_head=rpn_head,
box_head=box_head,
mask_head=mask_head,
)
model.load_state_dict(weights.get_state_dict(progress=True, check_hash=True))
# Replace RoIHeads since torchvision does not allow customized roi_heads.
model.roi_heads = TVRoIHeads(
model.roi_heads.box_roi_pool,
model.roi_heads.box_head,
model.roi_heads.box_predictor,
fg_iou_thresh=0.5,
bg_iou_thresh=0.5,
batch_size_per_image=512,
positive_fraction=0.25,
bbox_reg_weights=None,
score_thresh=model.roi_heads.score_thresh,
nms_thresh=model.roi_heads.nms_thresh,
detections_per_img=model.roi_heads.detections_per_img,
mask_roi_pool=model.roi_heads.mask_roi_pool,
mask_head=model.roi_heads.mask_head,
mask_predictor=model.roi_heads.mask_predictor,
)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = model.roi_heads.mask_predictor.conv5_mask.out_channels
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask,
hidden_layer,
num_classes,
)
return model
def _customize_inputs(self, entity: OTXDataBatch) -> dict[str, Any]:
if isinstance(entity.images, list):
entity.images, entity.imgs_info = stack_batch(entity.images, entity.imgs_info, pad_size_divisor=32) # type: ignore[arg-type,assignment]
return {"entity": entity}
def _customize_outputs(
self,
outputs: dict | list[dict], # type: ignore[override]
inputs: OTXDataBatch,
) -> OTXPredBatch | OTXBatchLossEntity:
if self.training:
if not isinstance(outputs, dict):
raise TypeError(outputs)
losses = OTXBatchLossEntity()
for loss_name, loss_value in outputs.items():
if isinstance(loss_value, Tensor):
losses[loss_name] = loss_value
elif isinstance(loss_value, list):
losses[loss_name] = sum(_loss.mean() for _loss in loss_value)
# pop acc from losses
losses.pop("acc", None)
return losses
scores: list[Tensor] = []
bboxes: list[tv_tensors.BoundingBoxes] = []
labels: list[torch.LongTensor] = []
masks: list[tv_tensors.Mask] = []
# XAI wraps prediction under dictionary with key "predictions"
predictions = outputs["predictions"] if isinstance(outputs, dict) else outputs
for img_info, prediction in zip(inputs.imgs_info, predictions): # type: ignore[arg-type]
scores.append(prediction["scores"])
bboxes.append(
tv_tensors.BoundingBoxes(
prediction["boxes"],
format="XYXY",
canvas_size=img_info.ori_shape, # type: ignore[union-attr]
),
)
output_masks = tv_tensors.Mask(
prediction["masks"],
dtype=torch.bool,
)
masks.append(output_masks)
labels.append(prediction["labels"])
if self.explain_mode:
if not isinstance(outputs, dict):
msg = f"Model output should be a dict, but got {type(outputs)}."
raise ValueError(msg)
if "feature_vector" not in outputs:
msg = "No feature vector in the model output."
raise ValueError(msg)
if "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)
saliency_map = outputs["saliency_map"].detach().cpu().numpy()
feature_vector = outputs["feature_vector"].detach().cpu().numpy()
return OTXPredBatch(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks,
labels=labels,
saliency_map=list(saliency_map),
feature_vector=list(feature_vector),
)
return OTXPredBatch(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks,
labels=labels,
)
@property
def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
data_input_params=self.data_input_params,
resize_mode="fit_to_window",
pad_value=0,
swap_rgb=False,
via_onnx=True,
onnx_export_configuration={
"input_names": ["image"],
"output_names": ["boxes", "labels", "masks"],
"dynamic_axes": {
"image": {0: "batch"},
"boxes": {0: "batch", 1: "num_dets"},
"labels": {0: "batch", 1: "num_dets"},
"masks": {0: "batch", 1: "num_dets", 2: "height", 3: "width"},
},
"opset_version": 11,
"autograd_inlining": False,
},
output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None,
)
[docs]
def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]:
"""Forward function for export."""
shape = (int(inputs.shape[2]), int(inputs.shape[3]))
meta_info = {
"image_shape": shape,
}
meta_info_list = [meta_info] * len(inputs)
return self.model.export(inputs, meta_info_list, explain_mode=self.explain_mode)