# Copyright (C) 2023-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Class definition for instance segmentation model entity used in OTX."""
from __future__ import annotations
import logging as log
from typing import TYPE_CHECKING, Any
import torch
from model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors
from otx.backend.native.models.instance_segmentation.utils.structures.mask.mask_util import encode_rle, polygon_to_rle
from otx.backend.openvino.models.base import OVModel
from otx.data.entity.torch import OTXDataBatch, OTXPredBatch
from otx.metrics import MetricInput
from otx.metrics.mean_ap import MaskRLEMeanAPFMeasureCallable
from otx.types.label import LabelInfo
from otx.types.task import OTXTaskType
if TYPE_CHECKING:
from model_api.adapters import OpenvinoAdapter
from model_api.models.utils import InstanceSegmentationResult
from torchmetrics import Metric
from otx.metrics import MetricCallable
from otx.types import PathLike
[docs]
class OVInstanceSegmentationModel(OVModel):
"""Instance segmentation model compatible for OpenVINO IR inference.
It can consume OpenVINO IR model path or model name from Intel OMZ repository
and create the OTX detection model compatible for OTX testing pipeline.
"""
def __init__(
self,
model_path: PathLike,
model_type: str = "MaskRCNN",
async_inference: bool = True,
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = MaskRLEMeanAPFMeasureCallable,
**kwargs,
) -> None:
"""Initialize the instance segmentation model.
Args:
model_path (PathLike): Path to the OpenVINO IR model.
model_type (str): Type of the model (default: "MaskRCNN").
async_inference (bool): Whether to use asynchronous inference (default: True).
max_num_requests (int | None): Maximum number of inference requests (default: None).
use_throughput_mode (bool): Whether to use throughput mode (default: True).
model_api_configuration (dict[str, Any] | None): Model API configuration (default: None).
metric (MetricCallable): Metric callable for evaluation (default: MaskRLEMeanAPFMeasureCallable).
**kwargs: Additional keyword arguments.
"""
super().__init__(
model_path=model_path,
model_type=model_type,
async_inference=async_inference,
max_num_requests=max_num_requests,
use_throughput_mode=use_throughput_mode,
model_api_configuration=model_api_configuration,
metric=metric,
)
self._task = OTXTaskType.INSTANCE_SEGMENTATION
def _setup_tiler(self) -> None:
"""Set up the tiler for tiled inference.
This method configures the tiler for the instance segmentation task,
enabling tiled inference with specified tile size and overlap.
"""
execution_mode = "async" if self.async_inference else "sync"
self.async_inference = False # Disable async_inference as tiling has its own implementation
self.model = InstanceSegmentationTiler(self.model, execution_mode=execution_mode)
log.info(
f"Enable tiler with tile size: {self.model.tile_size} \
and overlap: {self.model.tiles_overlap}",
)
def _get_hparams_from_adapter(self, model_adapter: OpenvinoAdapter) -> None:
"""Retrieve hyperparameters from the OpenVINO adapter.
Args:
model_adapter (OpenvinoAdapter): The OpenVINO adapter to read the configuration from.
This method reads the confidence threshold from the model's runtime information
and updates the hyperparameters accordingly.
"""
if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]):
best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value
self.hparams["best_confidence_threshold"] = float(best_confidence_threshold)
else:
msg = (
"Cannot get best_confidence_threshold from OpenVINO IR's rt_info. "
"Please check whether this model is trained by OTX or not. "
"Without this information, it can produce a wrong F1 metric score. "
"At this time, it will be set as the default value = None."
)
log.warning(msg)
self.hparams["best_confidence_threshold"] = None
def _customize_outputs(
self,
outputs: list[InstanceSegmentationResult],
inputs: OTXDataBatch,
) -> OTXPredBatch:
"""Customize the model outputs for OTX compatibility.
Args:
outputs (list[InstanceSegmentationResult]): Model outputs.
inputs (OTXDataBatch): Input data batch.
Returns:
OTXPredBatch: Customized predictions batch.
"""
bboxes = []
scores = []
labels = []
masks = []
for output in outputs:
bboxes.append(
tv_tensors.BoundingBoxes(
data=output.bboxes,
format="XYXY",
canvas_size=inputs.imgs_info[-1].img_shape, # type: ignore[union-attr,index]
dtype=torch.float32,
),
)
scores.append(torch.tensor(output.scores.reshape(-1)))
masks.append(torch.tensor(output.masks))
labels.append(torch.tensor(output.labels.reshape(-1) - 1, dtype=torch.long))
if outputs and outputs[0].saliency_map:
predicted_s_maps = []
for out in outputs:
image_map = torch.tensor(
[s_map for s_map in out.saliency_map if s_map.size > 0],
dtype=torch.float32,
)
predicted_s_maps.append(image_map)
predicted_f_vectors = [out.feature_vector[0] for out in outputs]
return OTXPredBatch(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks if any(mask.numel() > 0 for mask in masks) else None,
labels=labels,
saliency_map=predicted_s_maps,
feature_vector=predicted_f_vectors,
)
return OTXPredBatch(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
bboxes=bboxes,
masks=masks if any(mask.numel() > 0 for mask in masks) else None,
labels=labels,
)
[docs]
def compute_metrics(self, metric: Metric) -> dict:
"""Compute evaluation metrics for the model.
Args:
metric (Metric): Metric object to compute the evaluation metrics.
Returns:
dict: Computed metrics.
"""
best_confidence_threshold = self.hparams.get("best_confidence_threshold", None)
compute_kwargs = {"best_confidence_threshold": best_confidence_threshold}
return super()._compute_metrics(metric, **compute_kwargs)
def _create_label_info_from_ov_ir(self) -> LabelInfo:
"""Create label information from the OpenVINO IR model.
Reads label information from the OpenVINO IR model's runtime information
and constructs a LabelInfo object.
Returns:
LabelInfo: Label information extracted from the OpenVINO IR model.
"""
ov_model = self.model.get_model()
if ov_model.has_rt_info(["model_info", "label_info"]):
serialized = ov_model.get_rt_info(["model_info", "label_info"]).value
ir_label_info = LabelInfo.from_json(serialized)
if ir_label_info.label_names[0] == "otx_empty_lbl":
ir_label_info.label_names.pop(0)
ir_label_info.label_ids.pop(0)
ir_label_info.label_groups[0].pop(0)
return ir_label_info
return super()._create_label_info_from_ov_ir()