Source code for otx.core.model.instance_segmentation

# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Class definition for instance segmentation model entity used in OTX."""

# type: ignore[override]

from __future__ import annotations

import copy
import logging as log
import types
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal

import torch
from model_api.tilers import InstanceSegmentationTiler
from torch import Tensor
from torchmetrics import Metric, MetricCollection
from torchvision import tv_tensors
from torchvision.models.detection.image_list import ImageList

from otx.algo.explain.explain_algo import InstSegExplainAlgo, feature_vector_fn
from otx.algo.instance_segmentation.segmentors.maskrcnn_tv import MaskRCNN
from otx.algo.instance_segmentation.segmentors.two_stage import TwoStageDetector
from otx.algo.utils.utils import InstanceData, load_checkpoint
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity
from otx.core.data.entity.utils import stack_batch
from otx.core.metrics import MetricInput
from otx.core.metrics.fmeasure import FMeasure
from otx.core.metrics.mean_ap import MaskRLEMeanAPFMeasureCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.export import TaskLevelExportParameters
from otx.core.types.label import LabelInfo, LabelInfoTypes
from otx.core.utils.mask_util import encode_rle, polygon_to_rle
from otx.core.utils.tile_merge import InstanceSegTileMerge
from otx.data import OTXDataBatch, OTXPredBatch

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
    from model_api.adapters import OpenvinoAdapter
    from model_api.models.utils import InstanceSegmentationResult
    from torch import nn

    from otx.core.metrics import MetricCallable
    from otx.core.model.base import DataInputParams


[docs] class OTXInstanceSegModel(OTXModel): """Base class for the Instance Segmentation models used in OTX. Args: label_info (LabelInfoTypes): Information about the labels used in the model. data_input_params (DataInputParams): Parameters for the data input. model_name (str, optional): Name of the model. Defaults to "inst_segm_model". optimizer (OptimizerCallable, optional): Optimizer for the model. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Scheduler for the model. Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Metric for evaluating the model. Defaults to MaskRLEMeanAPFMeasureCallable. torch_compile (bool, optional): Whether to use torch compile. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). explain_mode (bool, optional): Whether to enable explainable AI mode. Defaults to False. """ def __init__( self, label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: str = "inst_segm_model", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MaskRLEMeanAPFMeasureCallable, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False), ) -> None: super().__init__( label_info=label_info, data_input_params=data_input_params, model_name=model_name, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, tile_config=tile_config, ) self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn() self.model.get_results_from_head = self.get_results_from_head def _create_model(self, num_classes: int | None = None) -> nn.Module: num_classes = num_classes if num_classes is not None else self.num_classes detector = self._build_model(num_classes) if hasattr(detector, "init_weights"): detector.init_weights() if isinstance(self.load_from, dict): load_checkpoint(detector, self.load_from[self.model_name], map_location="cpu") elif self.load_from is not None: load_checkpoint(detector, self.load_from, map_location="cpu") return detector def _customize_inputs(self, entity: OTXDataBatch) -> dict[str, Any]: if isinstance(entity.images, list): entity.images, entity.imgs_info = stack_batch(entity.images, entity.imgs_info, pad_size_divisor=32) # type: ignore[assignment,arg-type] inputs: dict[str, Any] = {} inputs["entity"] = entity inputs["mode"] = "loss" if self.training else "predict" return inputs def _customize_outputs( self, outputs: list[InstanceData] | dict, inputs: OTXDataBatch, ) -> OTXPredBatch | OTXBatchLossEntity: if self.training: if not isinstance(outputs, dict): raise TypeError(outputs) losses = OTXBatchLossEntity() for loss_name, loss_value in outputs.items(): if isinstance(loss_value, Tensor): losses[loss_name] = loss_value elif isinstance(loss_value, list): losses[loss_name] = sum(_loss.mean() for _loss in loss_value) losses.pop("acc", None) return losses scores: list[Tensor] = [] bboxes: list[tv_tensors.BoundingBoxes] = [] labels: list[torch.LongTensor] = [] masks: list[tv_tensors.Mask] = [] predictions = outputs["predictions"] if isinstance(outputs, dict) else outputs for img_info, prediction in zip(inputs.imgs_info, predictions): # type: ignore[arg-type] scores.append(prediction.scores) bboxes.append( tv_tensors.BoundingBoxes( prediction.bboxes, format="XYXY", canvas_size=img_info.ori_shape, # type: ignore[union-attr] ), ) output_masks = tv_tensors.Mask( prediction.masks, dtype=torch.bool, ) masks.append(output_masks) labels.append(prediction.labels) if self.explain_mode: if not isinstance(outputs, dict): msg = f"Model output should be a dict, but got {type(outputs)}." raise ValueError(msg) if "feature_vector" not in outputs: msg = "No feature vector in the model output." raise ValueError(msg) if "saliency_map" not in outputs: msg = "No saliency maps in the model output." raise ValueError(msg) saliency_map = outputs["saliency_map"].detach().cpu().numpy() feature_vector = outputs["feature_vector"].detach().cpu().numpy() return OTXPredBatch( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, masks=masks, labels=labels, saliency_map=list(saliency_map), feature_vector=list(feature_vector), ) return OTXPredBatch( batch_size=len(predictions), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, masks=masks, labels=labels, )
[docs] def forward_tiles(self, inputs: OTXTileBatchDataEntity) -> OTXPredBatch: """Unpack instance segmentation tiles. Args: inputs (OTXTileBatchDataEntity): Tile batch data entity. Returns: TorchPredBatch: Merged instance segmentation prediction. """ tile_preds: list[OTXPredBatch] = [] tile_attrs: list[list[dict[str, int | str]]] = [] merger = InstanceSegTileMerge( inputs.imgs_info, self.num_classes, self.tile_config, self.explain_mode, ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) if isinstance(output, OTXBatchLossEntity): msg = "Loss output is not supported for tile merging" raise TypeError(msg) tile_preds.append(output) tile_attrs.append(batch_tile_attrs) pred_entities = merger.merge(tile_preds, tile_attrs) pred_entity = OTXPredBatch( batch_size=inputs.batch_size, images=[pred_entity.image for pred_entity in pred_entities], imgs_info=[pred_entity.img_info for pred_entity in pred_entities], scores=[pred_entity.scores for pred_entity in pred_entities], bboxes=[pred_entity.bboxes for pred_entity in pred_entities], labels=[pred_entity.label for pred_entity in pred_entities], masks=[pred_entity.masks for pred_entity in pred_entities], polygons=[pred_entity.polygons for pred_entity in pred_entities], # type: ignore[misc] ) if self.explain_mode: pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities] pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities] return pred_entity
[docs] def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]: """Forward function for export.""" shape = (int(inputs.shape[2]), int(inputs.shape[3])) meta_info = { "pad_shape": shape, "batch_input_shape": shape, "img_shape": shape, "scale_factor": (1.0, 1.0), } meta_info_list = [meta_info] * len(inputs) return self.model.export(inputs, meta_info_list, explain_mode=self.explain_mode)
@property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" modified_label_info = copy.deepcopy(self.label_info) # Instance segmentation needs to add empty label to satisfy MAPI wrapper requirements modified_label_info.label_names.insert(0, "otx_empty_lbl") modified_label_info.label_ids.insert(0, "None") return super()._export_parameters.wrap( model_type="MaskRCNN", task_type="instance_segmentation", confidence_threshold=self.hparams.get("best_confidence_threshold", 0.05), iou_threshold=0.5, tile_config=self.tile_config if self.tile_config.enable_tiler else None, label_info=modified_label_info, )
[docs] def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None: """Load state_dict from checkpoint. For detection, it is need to update confidence threshold information when the metric is FMeasure. """ if best_confidence_threshold := ckpt.get("confidence_threshold", None) or ( (hyper_parameters := ckpt.get("hyper_parameters", None)) and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None)) ): self.hparams["best_confidence_threshold"] = best_confidence_threshold super().on_load_checkpoint(ckpt)
def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: if key == "val": retval = super()._log_metrics(meter, key) # NOTE: Validation metric logging can update `best_confidence_threshold` if ( isinstance(meter, MetricCollection) and (fmeasure := getattr(meter, "FMeasure", None)) and (best_confidence_threshold := getattr(fmeasure, "best_confidence_threshold", None)) ) or ( isinstance(meter, FMeasure) and (best_confidence_threshold := getattr(meter, "best_confidence_threshold", None)) ): self.hparams["best_confidence_threshold"] = best_confidence_threshold return retval if key == "test": # NOTE: Test metric logging should use `best_confidence_threshold` found previously. best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) compute_kwargs = ( {"best_confidence_threshold": best_confidence_threshold} if best_confidence_threshold else {} ) return super()._log_metrics(meter, key, **compute_kwargs) raise ValueError(key) def _convert_pred_entity_to_compute_metric( self, preds: OTXPredBatch, # type: ignore[override] inputs: OTXDataBatch, # type: ignore[override] ) -> MetricInput: """Convert the prediction entity to the format that the metric can compute and cache the ground truth. This function will convert mask to RLE format and cache the ground truth for the current batch. Args: preds (TorchPredBatch): Current batch predictions. inputs (TorchDataBatch): Current batch ground-truth inputs. Returns: dict[str, list[dict[str, Tensor]]]: The converted predictions and ground truth. """ pred_info = [] target_info = [] for i in range(len(preds.imgs_info)): # type: ignore[arg-type] bboxes = preds.bboxes[i] if preds.bboxes is not None else None masks = preds.masks[i] if preds.masks is not None else None scores = preds.scores[i] if preds.scores is not None else None labels = preds.labels[i] if preds.labels is not None else None pred_info.append( { "boxes": bboxes.data, "masks": [encode_rle(mask) for mask in masks.data], "scores": scores, "labels": labels, }, ) for i in range(len(inputs.imgs_info)): # type: ignore[arg-type] imgs_info = inputs.imgs_info[i] if inputs.imgs_info is not None else None bboxes = inputs.bboxes[i] if inputs.bboxes is not None else None masks = inputs.masks[i] if inputs.masks is not None else None polygons = inputs.polygons[i] if inputs.polygons is not None else None labels = inputs.labels[i] if inputs.labels is not None else None rles = ( [encode_rle(mask) for mask in masks.data] if len(masks) else polygon_to_rle(polygons, *imgs_info.ori_shape) # type: ignore[union-attr,arg-type] ) target_info.append( { "boxes": bboxes.data, "masks": rles, "labels": labels, }, ) return {"preds": pred_info, "target": target_info}
[docs] def get_dummy_input(self, batch_size: int = 1) -> OTXDataBatch: # type: ignore[override] """Returns a dummy input for instance segmentation model.""" images = [torch.rand(3, *self.data_input_params.input_size) for _ in range(batch_size)] infos = [] for i, img in enumerate(images): infos.append( ImageInfo( img_idx=i, img_shape=img.shape, ori_shape=img.shape, ), ) return OTXDataBatch(batch_size, images, imgs_info=infos)
[docs] def forward_explain(self, inputs: OTXDataBatch) -> OTXPredBatch: """Model forward function.""" if isinstance(inputs, OTXTileBatchDataEntity): return self.forward_tiles(inputs) self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn() outputs = ( self._forward_explain_inst_seg(self.model, **self._customize_inputs(inputs)) if self._customize_inputs != OTXInstanceSegModel._customize_inputs else self._forward_explain_inst_seg(self.model, inputs) ) return ( self._customize_outputs(outputs, inputs) if self._customize_outputs != OTXInstanceSegModel._customize_outputs else outputs["predictions"] )
@staticmethod @torch.no_grad() def _forward_explain_inst_seg( self: TwoStageDetector, entity: OTXDataBatch, mode: str = "tensor", # noqa: ARG004 ) -> dict[str, Tensor]: """Forward func of the BaseDetector instance, which located in is in ExplainableOTXInstanceSegModel().model.""" x = self.backbone(entity.images) if isinstance(self, MaskRCNN) else self.extract_feat(entity.images) feature_vector = self.feature_vector_fn(x) predictions = self.get_results_from_head(x, entity) if isinstance(predictions, tuple) and isinstance(predictions[0], Tensor): # Export case, consists of tensors # For OV task saliency map are generated on MAPI side saliency_map = torch.empty(1, dtype=torch.uint8) elif isinstance(predictions, list) and isinstance(predictions[0], (InstanceData, dict)): # Predict case, consists of InstanceData or dict saliency_map = self.explain_fn(predictions) else: msg = f"Unexpected predictions type: {type(predictions)}" raise TypeError(msg) return { "predictions": predictions, "feature_vector": feature_vector, "saliency_map": saliency_map, }
[docs] def get_results_from_head( self, x: tuple[Tensor], entity: OTXDataBatch, ) -> tuple[Tensor, Tensor, Tensor] | list[InstanceData] | list[dict[str, Tensor]]: """Get the results from the head of the instance segmentation model. Args: x (tuple[Tensor]): The features from backbone and neck. data_samples (OptSampleList | None): A list of data samples. Returns: tuple[Tensor, Tensor, Tensor] | list[InstanceData]: The predicted results from the head of the model. Tuple for the Export case, list for the Predict case. """ from otx.algo.instance_segmentation.maskrcnn_tv import MaskRCNNTV from otx.algo.instance_segmentation.rtmdet_inst import RTMDetInst if isinstance(self, MaskRCNNTV): ori_shapes = [img_info.ori_shape for img_info in entity.imgs_info] # type: ignore[union-attr] img_shapes = [img_info.img_shape for img_info in entity.imgs_info] # type: ignore[union-attr] image_list = ImageList(entity.images, img_shapes) proposals, _ = self.model.rpn(image_list, x) detections, _ = self.model.roi_heads( x, proposals, image_list.image_sizes, ) scale_factors = [ img_meta.scale_factor if img_meta.scale_factor else (1.0, 1.0) # type: ignore[union-attr] for img_meta in entity.imgs_info # type: ignore[union-attr] ] return self.model.postprocess(detections, ori_shapes, scale_factors) if isinstance(self, RTMDetInst): return self.model.bbox_head.predict(x, entity, rescale=False) rpn_results_list = self.model.rpn_head.predict(x, entity, rescale=False) return self.model.roi_head.predict(x, rpn_results_list, entity, rescale=True)
[docs] def get_explain_fn(self) -> Callable: """Returns explain function.""" explainer = InstSegExplainAlgo(num_classes=self.num_classes) return explainer.func
[docs] @contextmanager def export_model_forward_context(self) -> Iterator[None]: """A context manager for managing the model's forward function during model exportation. It temporarily modifies the model's forward function to generate output sinks for explain results during the model graph tracing. """ try: self._reset_model_forward() yield finally: self._restore_model_forward()
def _reset_model_forward(self) -> None: if not self.explain_mode: return self.model.explain_fn = self.get_explain_fn() forward_with_explain = self._forward_explain_inst_seg self.original_model_forward = self.model.forward func_type = types.MethodType # Patch class method model_class = type(self.model) model_class.forward = func_type(forward_with_explain, self.model) def _restore_model_forward(self) -> None: if not self.explain_mode: return if not self.original_model_forward: msg = "Original model forward was not saved." raise RuntimeError(msg) func_type = types.MethodType self.model.forward = func_type(self.original_model_forward, self.model) self.original_model_forward = None
[docs] class OVInstanceSegmentationModel( OVModel, ): """Instance segmentation model compatible for OpenVINO IR inference. It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX detection model compatible for OTX testing pipeline. """ def __init__( self, model_name: str, model_type: str = "MaskRCNN", async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = MaskRLEMeanAPFMeasureCallable, **kwargs, ) -> None: super().__init__( model_name=model_name, model_type=model_type, async_inference=async_inference, max_num_requests=max_num_requests, use_throughput_mode=use_throughput_mode, model_api_configuration=model_api_configuration, metric=metric, ) def _setup_tiler(self) -> None: """Setup tiler for tile task.""" execution_mode = "async" if self.async_inference else "sync" # Note: Disable async_inference as tiling has its own sync/async implementation self.async_inference = False self.model = InstanceSegmentationTiler(self.model, execution_mode=execution_mode) log.info( f"Enable tiler with tile size: {self.model.tile_size} \ and overlap: {self.model.tiles_overlap}", ) def _get_hparams_from_adapter(self, model_adapter: OpenvinoAdapter) -> None: """Reads model configuration from ModelAPI OpenVINO adapter. Args: model_adapter (OpenvinoAdapter): target adapter to read the config """ if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]): best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value self.hparams["best_confidence_threshold"] = float(best_confidence_threshold) else: msg = ( "Cannot get best_confidence_threshold from OpenVINO IR's rt_info. " "Please check whether this model is trained by OTX or not. " "Without this information, it can produce a wrong F1 metric score. " "At this time, it will be set as the default value = None." ) log.warning(msg) self.hparams["best_confidence_threshold"] = None def _customize_outputs( self, outputs: list[InstanceSegmentationResult], inputs: OTXDataBatch, ) -> OTXPredBatch | OTXBatchLossEntity: # add label index bboxes = [] scores = [] labels = [] masks = [] for output in outputs: bboxes.append( tv_tensors.BoundingBoxes( data=output.bboxes, format="XYXY", canvas_size=inputs.imgs_info[-1].img_shape, # type: ignore[union-attr,index] device=self.device, dtype=torch.float32, ), ) # NOTE: OTX 1.5 filter predictions with result_based_confidence_threshold, # but OTX 2.0 doesn't have it in configuration. scores.append(torch.tensor(output.scores.reshape(-1), device=self.device)) masks.append(torch.tensor(output.masks, device=self.device)) labels.append(torch.tensor(output.labels.reshape(-1) - 1, device=self.device, dtype=torch.long)) if outputs and outputs[0].saliency_map: predicted_s_maps = [] for out in outputs: image_map = torch.tensor( [s_map for s_map in out.saliency_map if s_map.size > 0], dtype=torch.float32, device=self.device, ) predicted_s_maps.append(image_map) # Squeeze dim 2D => 1D, (1, internal_dim) => (internal_dim) predicted_f_vectors = [out.feature_vector[0] for out in outputs] return OTXPredBatch( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, masks=masks if any(mask.numel() > 0 for mask in masks) else None, labels=labels, saliency_map=predicted_s_maps, feature_vector=predicted_f_vectors, ) return OTXPredBatch( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, scores=scores, bboxes=bboxes, masks=masks if any(mask.numel() > 0 for mask in masks) else None, labels=labels, ) def _convert_pred_entity_to_compute_metric( self, preds: OTXPredBatch, # type: ignore[override] inputs: OTXDataBatch, # type: ignore[override] ) -> MetricInput: """Convert the prediction entity to the format that the metric can compute and cache the ground truth. This function will convert mask to RLE format and cache the ground truth for the current batch. Args: preds (TorchPredBatch): Current batch predictions. inputs (TorchDataBatch): Current batch ground-truth inputs. Returns: dict[str, list[dict[str, Tensor]]]: The converted predictions and ground truth. """ target_info = [] _bboxes = preds.bboxes if preds.bboxes is not None else None _masks = preds.masks if preds.masks is not None else None pred_info = [ { "boxes": _bboxes[idx].data if _bboxes is not None else torch.empty((0, 4)), "masks": [encode_rle(mask) for mask in _masks[idx].data] if _masks is not None and len(_masks) else torch.empty((0,)), "scores": preds.scores[idx], # type: ignore[index] "labels": preds.labels[idx], # type: ignore[index] } for idx in range(len(preds.labels)) # type: ignore[arg-type] ] _bboxes = inputs.bboxes if inputs.bboxes is not None else None _masks = inputs.masks if inputs.masks is not None else None for idx in range(len(inputs.labels)): # type: ignore[arg-type] rles = ( [encode_rle(mask) for mask in _masks[idx].data] if _masks is not None and len(_masks[idx]) > 0 else polygon_to_rle(inputs.polygons[idx], *inputs.imgs_info[idx].ori_shape) # type: ignore[index,union-attr] ) target_info.append( { "boxes": _bboxes[idx].data if _bboxes is not None else torch.empty((0, 4)), "masks": rles, "labels": inputs.labels[idx], # type: ignore[index] }, ) return {"preds": pred_info, "target": target_info} def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) compute_kwargs = {"best_confidence_threshold": best_confidence_threshold} return super()._log_metrics(meter, key, **compute_kwargs) def _create_label_info_from_ov_ir(self) -> LabelInfo: ov_model = self.model.get_model() if ov_model.has_rt_info(["model_info", "label_info"]): serialized = ov_model.get_rt_info(["model_info", "label_info"]).value ir_label_info = LabelInfo.from_json(serialized) # workaround to hide extra otx_empty_lbl if ir_label_info.label_names[0] == "otx_empty_lbl": ir_label_info.label_names.pop(0) ir_label_info.label_ids.pop(0) ir_label_info.label_groups[0].pop(0) return ir_label_info return super()._create_label_info_from_ov_ir()