Source code for otx.data.dataset.detection
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Module for OTXDetectionDataset."""
from __future__ import annotations
import numpy as np
import torch
from datumaro import Bbox, Image
from torchvision import tv_tensors
from otx.data.entity.base import ImageInfo
from otx.data.entity.torch import OTXDataItem
from .base import OTXDataset
from .mixins import DataAugSwitchMixin
[docs]
class OTXDetectionDataset(OTXDataset, DataAugSwitchMixin): # type: ignore[misc]
"""OTXDataset class for detection task."""
def _get_item_impl(self, index: int) -> OTXDataItem | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape, _ = self._get_img_data_and_shape(img)
bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]
bboxes = (
np.stack([ann.points for ann in bbox_anns], axis=0).astype(np.float32)
if len(bbox_anns) > 0
else np.zeros((0, 4), dtype=np.float32)
)
entity = OTXDataItem(
image=img_data,
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
ignored_labels=ignored_labels,
),
bboxes=tv_tensors.BoundingBoxes(
bboxes,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=img_shape,
dtype=torch.float32,
),
label=torch.as_tensor([ann.label for ann in bbox_anns], dtype=torch.long),
)
# Apply augmentation switch if available
if self.has_dynamic_augmentation:
self._apply_augmentation_switch()
return self._apply_transforms(entity)