# Copyright (C) 2023-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Module for OTXSegmentationDataset."""
from __future__ import annotations
from typing import TYPE_CHECKING
import cv2
import numpy as np
import torch
from datumaro.components.annotation import Bbox, Ellipse, Image, Mask, Polygon, RotatedBbox
from torchvision import tv_tensors
from torchvision.transforms.v2.functional import to_dtype, to_image
from otx.data.entity.base import ImageInfo
from otx.data.entity.torch import OTXDataItem
from otx.types.image import ImageColorChannel
from otx.types.label import SegLabelInfo
from .base import OTXDataset
if TYPE_CHECKING:
from datumaro import Dataset as DmDataset
from datumaro import DatasetItem
from otx.data.dataset.base import Transforms
# NOTE: It is copied from https://github.com/open-edge-platform/datumaro/pull/1409
# It will be replaced in the future.
def _make_index_mask(
binary_mask: np.ndarray,
index: int,
ignore_index: int = 0,
dtype: np.dtype | None = None,
) -> np.ndarray:
"""Create an index mask from a binary mask by filling a given index value.
Args:
binary_mask: Binary mask to create an index mask.
index: Scalar value to fill the ones in the binary mask.
ignore_index: Scalar value to fill in the zeros in the binary mask.
Defaults to 0.
dtype: Data type for the resulting mask. If not specified,
it will be inferred from the provided index. Defaults to None.
Returns:
np.ndarray: Index mask created from the binary mask.
Raises:
ValueError: If dtype is not specified and incompatible scalar types are used for index
and ignore_index.
Examples:
>>> binary_mask = np.eye(2, dtype=np.bool_)
>>> index_mask = make_index_mask(binary_mask, index=10, ignore_index=255, dtype=np.uint8)
>>> print(index_mask)
array([[ 10, 255],
[255, 10]], dtype=uint8)
"""
if dtype is None:
dtype = np.min_scalar_type(index)
if dtype != np.min_scalar_type(ignore_index):
raise ValueError
flipped_zero_np_scalar = ~np.full((), fill_value=0, dtype=dtype)
# NOTE: This dispatching rule is required for a performance boost
if ignore_index == flipped_zero_np_scalar:
flipped_index = ~np.full((), fill_value=index, dtype=dtype)
return ~(binary_mask * flipped_index)
mask = binary_mask * np.full((), fill_value=index, dtype=dtype)
if ignore_index == 0:
return mask
return np.where(binary_mask, mask, ignore_index)
def _extract_class_mask(item: DatasetItem, img_shape: tuple[int, int], ignore_index: int) -> np.ndarray:
"""Extract class mask from Datumaro masks.
This is a temporary workaround and will be replaced with the native Datumaro interfaces
after some works, e.g., https://github.com/open-edge-platform/datumaro/pull/1409 are done.
Args:
item: Datumaro dataset item having mask annotations.
img_shape: Image shape (H, W).
ignore_index: Scalar value to fill in the zeros in the binary mask.
Returns:
2D numpy array
"""
if ignore_index > 255:
msg = "It is not currently support an ignore index which is more than 255."
raise ValueError(msg, ignore_index)
# fill mask with background label if we have Polygon/Ellipse/Bbox annotations
fill_value = 0 if isinstance(item.annotations[0], (Ellipse, Polygon, Bbox, RotatedBbox)) else ignore_index
class_mask = np.full(shape=img_shape[:2], fill_value=fill_value, dtype=np.uint8)
for mask in sorted(
[ann for ann in item.annotations if isinstance(ann, (Mask, Ellipse, Polygon, Bbox, RotatedBbox))],
key=lambda ann: ann.z_order,
):
index = mask.label
if index is None:
msg = "Mask's label index should not be None."
raise ValueError(msg)
if isinstance(mask, (Ellipse, Polygon, Bbox, RotatedBbox)):
polygons = np.asarray(mask.as_polygon(), dtype=np.int32).reshape((-1, 1, 2))
class_index = index + 1 # NOTE: disregard the background index. Objects start from index=1
this_class_mask = cv2.drawContours(
class_mask,
[polygons],
0,
(class_index, class_index, class_index),
thickness=cv2.FILLED,
)
elif isinstance(mask, Mask):
binary_mask = mask.image
if index is None:
msg = "Mask's label index should not be None."
raise ValueError(msg)
if index > 255:
msg = "Mask's label index should not be more than 255."
raise ValueError(msg, index)
this_class_mask = _make_index_mask(
binary_mask=binary_mask,
index=index,
ignore_index=ignore_index,
dtype=np.uint8,
)
if this_class_mask.shape != img_shape:
this_class_mask = cv2.resize(
this_class_mask,
dsize=(img_shape[1], img_shape[0]), # NOTE: cv2.resize() uses (width, height) format
interpolation=cv2.INTER_NEAREST,
)
class_mask = np.where(this_class_mask != ignore_index, this_class_mask, class_mask)
return class_mask
[docs]
class OTXSegmentationDataset(OTXDataset):
"""OTXDataset class for segmentation task."""
def __init__(
self,
dm_subset: DmDataset,
transforms: Transforms,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
to_tv_image: bool = True,
ignore_index: int = 255,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
transforms,
max_refetch,
image_color_channel,
to_tv_image,
data_format=data_format,
)
if self.has_polygons:
# insert background class at index 0 since polygons represent only objects
self.label_info.label_names.insert(0, "otx_background_lbl")
self.label_info.label_ids.insert(0, "None")
self.label_info = SegLabelInfo(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
label_ids=self.label_info.label_ids,
)
self.ignore_index = ignore_index
@property
def has_polygons(self) -> bool:
"""Check if the dataset has polygons in annotations."""
# all polygon-like format should be considered as polygons
if {ann_type.name for ann_type in self.dm_subset.ann_types()} & {"polygon", "ellipse", "bbox", "rotated_bbox"}:
return True
return False
def _get_item_impl(self, index: int) -> OTXDataItem | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = []
roi = item.attributes.get("roi", None)
img_data, img_shape, roi_meta = self._get_img_data_and_shape(img, roi)
ori_shape = roi_meta["orig_image_shape"] if roi_meta else img_shape
extracted_mask = _extract_class_mask(item=item, img_shape=ori_shape, ignore_index=self.ignore_index)
if roi_meta:
extracted_mask = extracted_mask[roi_meta["y1"] : roi_meta["y2"], roi_meta["x1"] : roi_meta["x2"]]
masks = tv_tensors.Mask(extracted_mask[None], dtype=torch.long)
entity = OTXDataItem(
image=to_dtype(to_image(img_data), dtype=torch.float32),
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
),
masks=masks,
)
return self._apply_transforms(entity)