Source code for otx.data.entity.tile

# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for OTX tile data entities."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence

import torch
from torchvision import tv_tensors

from otx.data.entity.torch import OTXDataBatch, OTXDataItem
from otx.data.entity.utils import stack_batch
from otx.types.task import OTXTaskType

from .base import ImageInfo

if TYPE_CHECKING:
    from datumaro import Polygon
    from torch import LongTensor


@dataclass
class TileDataEntity:
    """Base data entity for tile task.

    Attributes:
        num_tiles (int): The number of tiles.
        entity_list (Sequence[OTXDataEntity]): A list of OTXDataEntity.
        tile_attr_list (list[dict[str, int | str]]): The tile attributes including tile index and tile RoI information.
        ori_img_info (ImageInfo): The image information about the original image.
    """

    num_tiles: int
    entity_list: Sequence[OTXDataItem]
    tile_attr_list: list[dict[str, int | str]]
    ori_img_info: ImageInfo

    @property
    def task(self) -> OTXTaskType:
        """OTX Task type definition."""
        raise NotImplementedError


[docs] @dataclass class TileDetDataEntity(TileDataEntity): """Data entity for detection tile task. Attributes: ori_bboxes (tv_tensors.BoundingBoxes): The bounding boxes of the original image. ori_labels (LongTensor): The labels of the original image. """ ori_bboxes: tv_tensors.BoundingBoxes ori_labels: LongTensor @property def task(self) -> OTXTaskType: """OTX Task type definition.""" return OTXTaskType.DETECTION
TileAttrDictList = list[dict[str, int | str]] @dataclass class OTXTileBatchDataEntity: """Base batch data entity for tile task. Attributes: batch_size (int): The size of the batch. batch_tiles (list[list[tv_tensors.Image]]): The batch of tile images. batch_tile_img_infos (list[list[ImageInfo]]): The batch of tiles image information. batch_tile_attr_list (list[list[dict[str, int | str]]]): The batch of tile attributes including tile index and tile RoI information. imgs_info (list[ImageInfo]): The image information about the original image. """ batch_size: int batch_tiles: list[list[tv_tensors.Image]] batch_tile_img_infos: list[list[ImageInfo]] batch_tile_attr_list: list[TileAttrDictList] imgs_info: list[ImageInfo] def unbind(self) -> list[tuple[TileAttrDictList, OTXDataBatch]]: """Unbind batch data entity.""" raise NotImplementedError
[docs] @dataclass class TileBatchDetDataEntity(OTXTileBatchDataEntity): """Batch data entity for detection tile task. Attributes: bboxes (list[tv_tensors.BoundingBoxes]): The bounding boxes of the original image. labels (list[LongTensor]): The labels of the original image. """ bboxes: list[tv_tensors.BoundingBoxes] labels: list[LongTensor]
[docs] def unbind(self) -> list[tuple[TileAttrDictList, OTXDataBatch]]: """Unbind batch data entity for detection task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] tile_attr_list = [tile_attr for tile_attrs in self.batch_tile_attr_list for tile_attr in tile_attrs] batch_tile_attr_list = [ tile_attr_list[i : i + self.batch_size] for i in range(0, len(tile_attr_list), self.batch_size) ] batch_data_entities = [] for i in range(0, len(tiles), self.batch_size): stacked_images, updated_img_info = stack_batch( tiles[i : i + self.batch_size], tile_infos[i : i + self.batch_size], ) batch_data_entities.append( OTXDataBatch( batch_size=self.batch_size, images=stacked_images, imgs_info=updated_img_info, ), ) return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))
[docs] @classmethod def collate_fn(cls, batch_entities: list[TileDetDataEntity]) -> TileBatchDetDataEntity: """Collate function to collect TileDetDataEntity into TileBatchDetDataEntity in data loader.""" if (batch_size := len(batch_entities)) == 0: msg = "collate_fn() input should have > 0 entities" raise RuntimeError(msg) for tile_entity in batch_entities: for entity in tile_entity.entity_list: if not isinstance(entity, OTXDataItem): msg = "All entities should be OTXDataItem before collate_fn()" raise TypeError(msg) if entity.img_info is None: msg = "All entities should have img_info, but found None" raise ValueError(msg) return TileBatchDetDataEntity( batch_size=batch_size, batch_tiles=[[entity.image for entity in tile_entity.entity_list] for tile_entity in batch_entities], batch_tile_img_infos=[ [entity.img_info for entity in tile_entity.entity_list if isinstance(entity.img_info, ImageInfo)] for tile_entity in batch_entities ], batch_tile_attr_list=[tile_entity.tile_attr_list for tile_entity in batch_entities], imgs_info=[tile_entity.ori_img_info for tile_entity in batch_entities], bboxes=[tile_entity.ori_bboxes for tile_entity in batch_entities], labels=[tile_entity.ori_labels for tile_entity in batch_entities], )
@dataclass class TileInstSegDataEntity(TileDataEntity): """Data entity for instance segmentation tile task. Attributes: ori_bboxes (tv_tensors.BoundingBoxes): The bounding boxes of the original image. ori_labels (LongTensor): The labels of the original image. ori_masks (tv_tensors.Mask): The masks of the original image. ori_polygons (list[Polygon]): The polygons of the original image. """ ori_bboxes: tv_tensors.BoundingBoxes ori_labels: LongTensor ori_masks: tv_tensors.Mask ori_polygons: list[Polygon] @property def task(self) -> OTXTaskType: """OTX Task type definition.""" return OTXTaskType.INSTANCE_SEGMENTATION
[docs] @dataclass class TileBatchInstSegDataEntity(OTXTileBatchDataEntity): """Batch data entity for instance segmentation tile task. Attributes: bboxes (list[tv_tensors.BoundingBoxes]): The bounding boxes of the original image. labels (list[LongTensor]): The labels of the original image. masks (list[tv_tensors.Mask]): The masks of the original image. polygons (list[list[Polygon]]): The polygons of the original image. """ bboxes: list[tv_tensors.BoundingBoxes] labels: list[LongTensor] masks: list[tv_tensors.Mask] polygons: list[list[Polygon]]
[docs] def unbind(self) -> list[tuple[TileAttrDictList, OTXDataBatch]]: """Unbind batch data entity for instance segmentation task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] tile_attr_list = [tile_attr for tile_attrs in self.batch_tile_attr_list for tile_attr in tile_attrs] batch_tile_attr_list = [ tile_attr_list[i : i + self.batch_size] for i in range(0, len(tile_attr_list), self.batch_size) ] batch_data_entities = [ OTXDataBatch( batch_size=self.batch_size, images=tiles[i : i + self.batch_size], imgs_info=tile_infos[i : i + self.batch_size], ) for i in range(0, len(tiles), self.batch_size) ] return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))
[docs] @classmethod def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchInstSegDataEntity: """Collate function to collect TileInstSegDataEntity into TileBatchInstSegDataEntity in data loader.""" if (batch_size := len(batch_entities)) == 0: msg = "collate_fn() input should have > 0 entities" raise RuntimeError(msg) for tile_entity in batch_entities: for entity in tile_entity.entity_list: if not isinstance(entity, OTXDataItem): msg = "All entities should be OTXDataItem before collate_fn()" raise TypeError(msg) if entity.img_info is None: msg = "All entities should have img_info, but found None" raise ValueError(msg) return TileBatchInstSegDataEntity( batch_size=batch_size, batch_tiles=[[entity.image for entity in tile_entity.entity_list] for tile_entity in batch_entities], batch_tile_img_infos=[ [entity.img_info for entity in tile_entity.entity_list if isinstance(entity.img_info, ImageInfo)] for tile_entity in batch_entities ], batch_tile_attr_list=[tile_entity.tile_attr_list for tile_entity in batch_entities], imgs_info=[tile_entity.ori_img_info for tile_entity in batch_entities], bboxes=[tile_entity.ori_bboxes for tile_entity in batch_entities], labels=[tile_entity.ori_labels for tile_entity in batch_entities], masks=[tile_entity.ori_masks for tile_entity in batch_entities], polygons=[tile_entity.ori_polygons for tile_entity in batch_entities], )
[docs] @dataclass class TileSegDataEntity(TileDataEntity): """Data entity for segmentation tile task. Attributes: ori_masks (tv_tensors.Mask): The masks of the original image. """ ori_masks: tv_tensors.Mask @property def task(self) -> OTXTaskType: """OTX Task type definition.""" return OTXTaskType.SEMANTIC_SEGMENTATION
[docs] @dataclass class TileBatchSegDataEntity(OTXTileBatchDataEntity): """Batch data entity for semantic segmentation tile task. Attributes: masks (list[tv_tensors.Mask]): The masks of the original image. """ masks: list[tv_tensors.Mask]
[docs] def unbind(self) -> list[tuple[list[dict[str, int | str]], OTXDataBatch]]: """Unbind batch data entity for semantic segmentation task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] tile_attr_list = [tile_attr for tile_attrs in self.batch_tile_attr_list for tile_attr in tile_attrs] batch_tile_attr_list = [ tile_attr_list[i : i + self.batch_size] for i in range(0, len(tile_attr_list), self.batch_size) ] batch_data_entities = [ OTXDataBatch( batch_size=self.batch_size, images=tv_tensors.wrap(torch.stack(tiles[i : i + self.batch_size]), like=tiles[0]), imgs_info=tile_infos[i : i + self.batch_size], masks=[torch.empty((1, 1, 1)) for _ in range(self.batch_size)], ) for i in range(0, len(tiles), self.batch_size) ] return list(zip(batch_tile_attr_list, batch_data_entities))
[docs] @classmethod def collate_fn(cls, batch_entities: list[TileSegDataEntity]) -> TileBatchSegDataEntity: """Collate function to collect TileSegDataEntity into TileBatchSegDataEntity in data loader.""" if (batch_size := len(batch_entities)) == 0: msg = "collate_fn() input should have > 0 entities" raise RuntimeError(msg) for tile_entity in batch_entities: for entity in tile_entity.entity_list: if not isinstance(entity, OTXDataItem): msg = "All entities should be OTXDataItem before collate_fn()" raise TypeError(msg) if entity.img_info is None: msg = "All entities should have img_info, but found None" raise ValueError(msg) return TileBatchSegDataEntity( batch_size=batch_size, batch_tiles=[[entity.image for entity in tile_entity.entity_list] for tile_entity in batch_entities], batch_tile_img_infos=[ [entity.img_info for entity in tile_entity.entity_list if isinstance(entity.img_info, ImageInfo)] for tile_entity in batch_entities ], batch_tile_attr_list=[tile_entity.tile_attr_list for tile_entity in batch_entities], imgs_info=[tile_entity.ori_img_info for tile_entity in batch_entities], masks=[tile_entity.ori_masks for tile_entity in batch_entities], )