Source code for otx.data.dataset.tile

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OTX tile dataset."""

from __future__ import annotations

import logging as log
import operator
import warnings
from collections import defaultdict
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Callable

import numpy as np
import shapely.geometry as sg
import torch
from datumaro import Dataset as DmDataset
from datumaro import DatasetItem, Image
from datumaro.components.annotation import AnnotationType, Bbox, Ellipse, ExtractedMask, Polygon
from datumaro.plugins.tiling import Tile
from datumaro.plugins.tiling.tile import _apply_offset
from datumaro.plugins.tiling.util import (
    clip_x1y1x2y2,
    cxcywh_to_x1y1x2y2,
    x1y1x2y2_to_cxcywh,
    x1y1x2y2_to_xywh,
)
from torchvision import tv_tensors

from otx.backend.native.models.instance_segmentation.utils.structures.mask.mask_util import polygon_to_bitmap
from otx.data.dataset.segmentation import _extract_class_mask
from otx.data.entity.base import ImageInfo
from otx.data.entity.tile import (
    TileBatchDetDataEntity,
    TileBatchInstSegDataEntity,
    TileBatchSegDataEntity,
    TileDetDataEntity,
    TileInstSegDataEntity,
    TileSegDataEntity,
)
from otx.data.entity.torch import OTXDataItem
from otx.types.task import OTXTaskType

from .base import OTXDataset

if TYPE_CHECKING:
    from datumaro.components.media import BboxIntCoords

    from otx.config.data import TileConfig
    from otx.data.dataset.detection import OTXDetectionDataset
    from otx.data.dataset.instance_segmentation import OTXInstanceSegDataset
    from otx.data.dataset.segmentation import OTXSegmentationDataset

# ruff: noqa: SLF001
# NOTE: Disable private-member-access (SLF001).
# This is a workaround so we could apply the same transforms to tiles as the original dataset.

# NOTE: Datumaro subset name should be standardized.
TRAIN_SUBSET_NAMES = ("train", "TRAINING")
VAL_SUBSET_NAMES = ("val", "VALIDATION")


class OTXTileTransform(Tile):
    """OTX tile transform.

    Different from the original Datumaro Tile transform,
    OTXTileTransform takes tile_size and overlap as input instead of grid size

    Args:
        extractor (DmDataset): Dataset subset to extract tiles from.
        tile_size (tuple[int, int]): Tile size.
        overlap (tuple[float, float]): Overlap ratio.
            Overlap values are clipped between 0 and 0.9 to ensure the stride is not too small.
        threshold_drop_ann (float): Threshold to drop annotations.
        with_full_img (bool): Include full image in the tiles.
    """

    def __init__(
        self,
        extractor: DmDataset,
        tile_size: tuple[int, int],
        overlap: tuple[float, float],
        threshold_drop_ann: float,
        with_full_img: bool,
    ) -> None:
        # NOTE: clip overlap to [0, 0.9]
        overlap = max(0, min(overlap[0], 0.9)), max(0, min(overlap[1], 0.9))
        super().__init__(
            extractor,
            (0, 0),
            overlap=overlap,
            threshold_drop_ann=threshold_drop_ann,
        )
        self._tile_size = tile_size
        self._tile_ann_func_map[AnnotationType.polygon] = OTXTileTransform._tile_polygon
        self._tile_ann_func_map[AnnotationType.mask] = OTXTileTransform._tile_masks
        self._tile_ann_func_map[AnnotationType.ellipse] = OTXTileTransform._tile_ellipse
        self.with_full_img = with_full_img

    @staticmethod
    def _tile_polygon(
        ann: Polygon,
        roi_box: sg.Polygon,
        threshold_drop_ann: float = 0.8,
        *args,  # noqa: ARG004
        **kwargs,  # noqa: ARG004
    ) -> Polygon | None:
        polygon = sg.Polygon(ann.get_points())

        # NOTE: polygon may be invalid, e.g. self-intersecting
        if not roi_box.intersects(polygon) or not polygon.is_valid:
            return None

        # NOTE: intersection may return a GeometryCollection or MultiPolygon
        inter = polygon.intersection(roi_box)
        if isinstance(inter, (sg.GeometryCollection, sg.MultiPolygon)):
            shapes = [(geom, geom.area) for geom in list(inter.geoms) if geom.is_valid]
            if not shapes:
                return None

            inter, _ = max(shapes, key=operator.itemgetter(1))

            if not isinstance(inter, sg.Polygon) and not inter.is_valid:
                return None

        prop_area = inter.area / polygon.area

        if prop_area < threshold_drop_ann:
            return None

        inter = _apply_offset(inter, roi_box)

        return ann.wrap(
            points=[p for xy in inter.exterior.coords for p in xy],
            attributes=deepcopy(ann.attributes),
        )

    @staticmethod
    def _tile_masks(
        ann: ExtractedMask,
        roi_int: BboxIntCoords,
        *args,  # noqa: ARG004
        **kwargs,  # noqa: ARG004
    ) -> ExtractedMask:
        """Extracts a tile mask from the given annotation.

        Note: Original Datumaro _tile_masks does not work with ExtractedMask.

        Args:
            ann (ExtractedMask): datumaro ExtractedMask annotation.
            roi_int (BboxIntCoords): ROI coordinates.

        Returns:
            ExtractedMask: ExtractedMask annotation.
        """
        x, y, w, h = roi_int
        return ann.wrap(
            index_mask=ann.index_mask()[y : y + h, x : x + w],
            attributes=deepcopy(ann.attributes),
        )

    @staticmethod
    def _tile_ellipse(
        ann: Ellipse,
        roi_box: sg.Polygon,
        threshold_drop_ann: float = 0.8,
        *args,  # noqa: ARG004
        **kwargs,  # noqa: ARG004
    ) -> Polygon | None:
        polygon = sg.Polygon(ann.get_points(num_points=10))

        # NOTE: polygon may be invalid, e.g. self-intersecting
        if not roi_box.intersects(polygon) or not polygon.is_valid:
            return None

        # NOTE: intersection may return a GeometryCollection or MultiPolygon
        inter = polygon.intersection(roi_box)
        if isinstance(inter, (sg.GeometryCollection, sg.MultiPolygon)):
            shapes = [(geom, geom.area) for geom in list(inter.geoms) if geom.is_valid]
            if not shapes:
                return None

            inter, _ = max(shapes, key=operator.itemgetter(1))

            if not isinstance(inter, sg.Polygon) and not inter.is_valid:
                return None

        prop_area = inter.area / polygon.area

        if prop_area < threshold_drop_ann:
            return None

        inter = _apply_offset(inter, roi_box)

        return Polygon(
            points=[p for xy in inter.exterior.coords for p in xy],
            attributes=deepcopy(ann.attributes),
            label=ann.label,
        )

    def _extract_rois(self, image: Image) -> list[BboxIntCoords]:
        """Extracts Tile ROIs from the given image.

        Args:
            image (Image): Full image.

        Returns:
            list[BboxIntCoords]: list of ROIs.
        """
        if image.size is None:
            msg = "Image size is None"
            raise ValueError(msg)

        img_h, img_w = image.size
        tile_h, tile_w = self._tile_size
        h_ovl, w_ovl = self._overlap

        rois: set[BboxIntCoords] = set()
        cols = range(0, img_w, int(tile_w * (1 - w_ovl)))
        rows = range(0, img_h, int(tile_h * (1 - h_ovl)))

        if self.with_full_img:
            rois.add(x1y1x2y2_to_xywh(0, 0, img_w, img_h))
        for offset_x, offset_y in product(cols, rows):
            x2 = min(offset_x + tile_w, img_w)
            y2 = min(offset_y + tile_h, img_h)
            c_x, c_y, w, h = x1y1x2y2_to_cxcywh(offset_x, offset_y, x2, y2)
            x1, y1, x2, y2 = cxcywh_to_x1y1x2y2(c_x, c_y, w, h)
            x1, y1, x2, y2 = clip_x1y1x2y2(x1, y1, x2, y2, img_w, img_h)
            x1, y1, x2, y2 = (int(v) for v in [x1, y1, x2, y2])
            rois.add(x1y1x2y2_to_xywh(x1, y1, x2, y2))

        log.info(f"image: {img_h}x{img_w} ~ tile_size: {self._tile_size}")
        log.info(f"{len(rows)}x{len(cols)} tiles -> {len(rois)} tiles")
        return list(rois)


[docs] class OTXTileDatasetFactory: """OTX tile dataset factory."""
[docs] @classmethod def create( cls, task: OTXTaskType, dataset: OTXDataset, tile_config: TileConfig, ) -> OTXTileDataset: """Create a tile dataset based on the task type and subset type. NOte: All task utilize the same OTXTileTrainDataset for training. In testing, we use different tile dataset for different task type due to different annotation format and data entity. Args: task (OTXTaskType): OTX task type. dataset (OTXDataset): OTX dataset. tile_config (TilerConfig): Tile configuration. Returns: OTXTileDataset: Tile dataset. """ if dataset.dm_subset[0].subset in TRAIN_SUBSET_NAMES: return OTXTileTrainDataset(dataset, tile_config) if task == OTXTaskType.DETECTION: return OTXTileDetTestDataset(dataset, tile_config) if task in [OTXTaskType.ROTATED_DETECTION, OTXTaskType.INSTANCE_SEGMENTATION]: return OTXTileInstSegTestDataset(dataset, tile_config) if task == OTXTaskType.SEMANTIC_SEGMENTATION: return OTXTileSemanticSegTestDataset(dataset, tile_config) msg = f"Unsupported task type: {task} for tiling" raise NotImplementedError(msg)
class OTXTileDataset(OTXDataset): """OTX tile dataset base class. Args: dataset (OTXDataset): OTX dataset. tile_config (TilerConfig): Tile configuration. """ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None: super().__init__( dataset.dm_subset, dataset.transforms, dataset.max_refetch, dataset.image_color_channel, dataset.stack_images, dataset.to_tv_image, ) self.tile_config = tile_config self._dataset = dataset # LabelInfo differs from SegLabelInfo, thus we need to update it for semantic segmentation. if self.label_info != dataset.label_info: msg = ( "Replace the label info to match the dataset's label info", "as there is a mismatch between the dataset and the tile dataset.", ) log.warning(msg) self.label_info = dataset.label_info def __len__(self) -> int: return len(self._dataset) @property def collate_fn(self) -> Callable: """Collate function from the original dataset.""" return self._dataset.collate_fn def _get_item_impl(self, index: int) -> OTXDataItem | None: """Get item implementation from the original dataset.""" return self._dataset._get_item_impl(index) def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: """Convert a tile dataset item to OTXDataItem.""" msg = "Method _convert_entity is not implemented." raise NotImplementedError(msg) def transform_item( self, item: DatasetItem, tile_size: tuple[int, int], overlap: tuple[float, float], with_full_img: bool, ) -> DmDataset: """Transform a dataset item to tile dataset which contains multiple tiles.""" tile_ds = DmDataset.from_iterable([item]) return tile_ds.transform( OTXTileTransform, tile_size=tile_size, overlap=overlap, threshold_drop_ann=0.5, with_full_img=with_full_img, ) def get_tiles( self, image: np.ndarray, item: DatasetItem, parent_idx: int, ) -> tuple[list[OTXDataItem], list[dict]]: """Retrieves tiles from the given image and dataset item. Args: image (np.ndarray): The input image. item (DatasetItem): The dataset item. parent_idx (int): The parent index. This is to keep track of the original dataset item index for merging. Returns: A tuple containing two lists: - tile_entities (list[OTXDataItem]): List of tile entities. - tile_attrs (list[dict]): List of tile attributes. """ tile_ds = self.transform_item( item, tile_size=self.tile_config.tile_size, overlap=(self.tile_config.overlap, self.tile_config.overlap), with_full_img=self.tile_config.with_full_img, ) if item.subset in VAL_SUBSET_NAMES: # NOTE: filter validation tiles with annotations only to avoid evaluation on empty tiles. tile_ds = tile_ds.filter("/item/annotation", filter_annotations=True, remove_empty=True) # if tile dataset is empty it means objects are too big to fit in any tile, in this case include full image if len(tile_ds) == 0: tile_ds = self.transform_item( item, tile_size=self.tile_config.tile_size, overlap=(self.tile_config.overlap, self.tile_config.overlap), with_full_img=True, ) tile_entities: list[OTXDataItem] = [] tile_attrs: list[dict] = [] for tile in tile_ds: tile_entity = self._convert_entity(image, tile, parent_idx) # apply the same transforms as the original dataset transformed_tile = self._apply_transforms(tile_entity) if transformed_tile is None: msg = "Transformed tile is None" raise RuntimeError(msg) tile.attributes.update({"tile_size": self.tile_config.tile_size}) tile_entities.append(transformed_tile) tile_attrs.append(tile.attributes) return tile_entities, tile_attrs class OTXTileTrainDataset(OTXTileDataset): """OTX tile train dataset. Args: dataset (OTXDataset): OTX dataset. tile_config (TilerConfig): Tile configuration. """ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None: dm_dataset = dataset.dm_subset dm_dataset = dm_dataset.transform( OTXTileTransform, tile_size=tile_config.tile_size, overlap=(tile_config.overlap, tile_config.overlap), threshold_drop_ann=0.5, with_full_img=tile_config.with_full_img, ) dm_dataset = dm_dataset.filter("/item/annotation", filter_annotations=True, remove_empty=True) # Include original dataset for training dm_dataset.update(dataset.dm_subset) dataset.dm_subset = dm_dataset super().__init__(dataset, tile_config) class OTXTileDetTestDataset(OTXTileDataset): """OTX tile detection test dataset. OTXTileDetTestDataset wraps a list of tiles (DetDataEntity) into a single TileDetDataEntity for testing/predicting. Args: dataset (OTXDetDataset): OTX detection dataset. tile_config (TilerConfig): Tile configuration. """ def __init__(self, dataset: OTXDetectionDataset, tile_config: TileConfig) -> None: super().__init__(dataset, tile_config) @property def collate_fn(self) -> Callable: """Collate function for tile detection test dataset.""" return TileBatchDetDataEntity.collate_fn def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[override] """Get item implementation. Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and wrap tiles into a single TileDetDataEntity. Args: index (int): Index of the dataset item. Returns: TileDetDataEntity: tile detection data entity that wraps a list of detection data entities. Note: Ignoring [override] check is necessary here since OTXDataset._get_item_impl exclusively permits the return of OTXDataItem. Nevertheless, in instances involving tiling, it becomes imperative to encapsulate tiles within a unified entity, namely TileDetDataEntity. """ item = self.dm_subset[index] img = item.media_as(Image) img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] bboxes = ( np.stack([ann.points for ann in bbox_anns], axis=0).astype(np.float32) if len(bbox_anns) > 0 else np.zeros((0, 4), dtype=np.float32) ) labels = torch.as_tensor([ann.label for ann in bbox_anns]) tile_entities, tile_attrs = self.get_tiles(img_data, item, index) return TileDetDataEntity( num_tiles=len(tile_entities), entity_list=tile_entities, tile_attr_list=tile_attrs, ori_img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, ), ori_bboxes=tv_tensors.BoundingBoxes( bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=img_shape, ), ori_labels=labels, ) def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: # type: ignore[override] """Convert a tile datumaro dataset item to TorchDataItem.""" x1, y1, w, h = dataset_item.attributes["roi"] tile_img = image[y1 : y1 + h, x1 : x1 + w] tile_shape = tile_img.shape[:2] img_info = ImageInfo( img_idx=parent_idx, img_shape=tile_shape, ori_shape=tile_shape, ) return OTXDataItem( image=tile_img, img_info=img_info, ) class OTXTileInstSegTestDataset(OTXTileDataset): """OTX tile inst-seg test dataset. OTXTileDetTestDataset wraps a list of tiles (TorchDataItem) into a single TileDetDataEntity for testing/predicting. Args: dataset (OTXInstanceSegDataset): OTX inst-seg dataset. tile_config (TilerConfig): Tile configuration. """ def __init__(self, dataset: OTXInstanceSegDataset, tile_config: TileConfig) -> None: super().__init__(dataset, tile_config) @property def collate_fn(self) -> Callable: """Collate function for tile inst-seg test dataset.""" return TileBatchInstSegDataEntity.collate_fn def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[override] """Get item implementation. Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and wrap tiles into a single TileInstSegDataEntity. Args: index (int): Index of the dataset item. Returns: TileInstSegDataEntity: tile inst-seg data entity that wraps a list of inst-seg data entities. Note: Ignoring [override] check is necessary here since OTXDataset._get_item_impl exclusively permits the return of OTXDataItem. Nevertheless, in instances involving tiling, it becomes imperative to encapsulate tiles within a unified entity, namely TileInstSegDataEntity. """ item = self.dm_subset[index] img = item.media_as(Image) img_data, img_shape, _ = self._get_img_data_and_shape(img) anno_collection: dict[str, list] = defaultdict(list) for anno in item.annotations: anno_collection[anno.__class__.__name__].append(anno) gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], [] # TODO(Eugene): https://jira.devtools.intel.com/browse/CVS-159363 # Temporary solution to handle multiple annotation types. # Ideally, we should pre-filter annotations during initialization of the dataset. if Polygon.__name__ in anno_collection: # Polygon for InstSeg has higher priority for poly in anno_collection[Polygon.__name__]: bbox = Bbox(*poly.get_bbox()).points gt_bboxes.append(bbox) gt_labels.append(poly.label) if self._dataset.include_polygons: gt_polygons.append(poly) else: gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0]) elif Bbox.__name__ in anno_collection: boxes = anno_collection[Bbox.__name__] gt_bboxes = [ann.points for ann in boxes] gt_labels = [ann.label for ann in boxes] for box in boxes: poly = Polygon(box.as_polygon()) if self._dataset.include_polygons: gt_polygons.append(poly) else: gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0]) elif Ellipse.__name__ in anno_collection: for ellipse in anno_collection[Ellipse.__name__]: bbox = Bbox(*ellipse.get_bbox()).points gt_bboxes.append(bbox) gt_labels.append(ellipse.label) poly = Polygon(ellipse.as_polygon(num_points=10)) if self._dataset.include_polygons: gt_polygons.append(poly) else: gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0]) else: warnings.warn(f"No valid annotations found for image {item.id}!", stacklevel=2) bboxes = np.stack(gt_bboxes, dtype=np.float32) if gt_bboxes else np.empty((0, 4), dtype=np.float32) masks = np.stack(gt_masks, axis=0) if gt_masks else np.empty((0, *img_shape), dtype=bool) labels = np.array(gt_labels, dtype=np.int64) tile_entities, tile_attrs = self.get_tiles(img_data, item, index) return TileInstSegDataEntity( num_tiles=len(tile_entities), entity_list=tile_entities, tile_attr_list=tile_attrs, ori_img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, ), ori_bboxes=tv_tensors.BoundingBoxes( bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=img_shape, ), ori_labels=torch.as_tensor(labels), ori_masks=tv_tensors.Mask(masks, dtype=torch.uint8), ori_polygons=gt_polygons, ) def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: # type: ignore[override] """Convert a tile dataset item to TorchDataItem.""" x1, y1, w, h = dataset_item.attributes["roi"] tile_img = image[y1 : y1 + h, x1 : x1 + w] tile_shape = tile_img.shape[:2] img_info = ImageInfo( img_idx=parent_idx, img_shape=tile_shape, ori_shape=tile_shape, ) return OTXDataItem( image=tile_img, img_info=img_info, masks=tv_tensors.Mask(np.zeros((0, *tile_shape), dtype=bool)), ) class OTXTileSemanticSegTestDataset(OTXTileDataset): """OTX tile semantic-seg test dataset. OTXTileSemanticSegTestDataset wraps a list of tiles (SegDataEntity) into a single TileSegDataEntity for testing/predicting. Args: dataset (OTXSegmentationDataset): OTX semantic-seg dataset. tile_config (TilerConfig): Tile configuration. """ def __init__(self, dataset: OTXSegmentationDataset, tile_config: TileConfig) -> None: super().__init__(dataset, tile_config) self.ignore_index = self._dataset.ignore_index @property def collate_fn(self) -> Callable: """Collate function for tile detection test dataset.""" return TileBatchSegDataEntity.collate_fn def _get_item_impl(self, index: int) -> TileSegDataEntity: # type: ignore[override] """Get item implementation. Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and wrap tiles into a single TileSegDataEntity. Args: index (int): Index of the dataset item. Returns: TileSegDataEntity: tile semantic-seg data entity that wraps a list of semantic-seg data entities. """ item = self.dm_subset[index] img = item.media_as(Image) img_data, img_shape, _ = self._get_img_data_and_shape(img) extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index) masks = tv_tensors.Mask(extracted_mask[None]) tile_entities, tile_attrs = self.get_tiles(img_data, item, index) return TileSegDataEntity( num_tiles=len(tile_entities), entity_list=tile_entities, tile_attr_list=tile_attrs, ori_img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, ), ori_masks=masks, ) def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataItem: # type: ignore[override] """Convert a tile datumaro dataset item to SegDataEntity.""" x1, y1, w, h = dataset_item.attributes["roi"] tile_img = image[y1 : y1 + h, x1 : x1 + w] tile_shape = tile_img.shape[:2] img_info = ImageInfo( img_idx=parent_idx, img_shape=tile_shape, ori_shape=tile_shape, ) return OTXDataItem( image=tile_img, img_info=img_info, masks=tv_tensors.Mask(np.zeros((0, *tile_shape), dtype=bool)), )