Source code for otx.data.dataset.keypoint_detection

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for OTXKeypointDetectionDataset."""

from __future__ import annotations

from collections import defaultdict
from typing import Callable, List, Union

import numpy as np
import torch
from datumaro import AnnotationType, Bbox, Dataset, DatasetSubset, Image, Points
from torchvision import tv_tensors
from torchvision.transforms.v2.functional import to_dtype, to_image

from otx.data.entity.base import ImageInfo
from otx.data.entity.torch import OTXDataItem
from otx.data.transform_libs.torchvision import Compose
from otx.types.image import ImageColorChannel
from otx.types.label import LabelInfo

from .base import OTXDataset

Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]]


[docs] class OTXKeypointDetectionDataset(OTXDataset): """OTXDataset class for keypoint detection task.""" def __init__( self, dm_subset: DatasetSubset, transforms: Transforms, max_refetch: int = 1000, image_color_channel: ImageColorChannel = ImageColorChannel.RGB, stack_images: bool = True, to_tv_image: bool = True, data_format: str = "", ) -> None: super().__init__( dm_subset, transforms, max_refetch, image_color_channel, stack_images, to_tv_image, data_format, ) self.dm_subset = self._get_single_bbox_dataset(dm_subset) # arrow doesn't follow common coco convention, no need to fetch kp-specific labels if self.dm_subset.categories() and data_format != "arrow": kp_labels = self.dm_subset.categories()[AnnotationType.points][0].labels self.label_info = LabelInfo( label_names=kp_labels, label_groups=[], label_ids=[str(i) for i in range(len(kp_labels))], ) def _get_single_bbox_dataset(self, dm_subset: DatasetSubset) -> Dataset: """Method for splitting dataset items into multiple items for each bbox/keypoint.""" dm_items = [] for item in dm_subset: new_items = defaultdict(list) for ann in item.annotations: if isinstance(ann, (Bbox, Points)): new_items[ann.id].append(ann) for ann_id, anns in new_items.items(): available_types = [] for ann in anns: if isinstance(ann, Bbox) and (ann.w <= 0 or ann.h <= 0): continue if isinstance(ann, Points) and max(ann.points) <= 0: continue available_types.append(ann.type) if AnnotationType.points not in available_types: continue dm_items.append(item.wrap(id=item.id + "_" + str(ann_id), annotations=anns)) if len(dm_items) == 0: msg = "No keypoints found in the dataset. Please, check dataset annotations." raise ValueError(msg) return Dataset.from_iterable(dm_items, categories=self.dm_subset.categories()) def _get_item_impl(self, index: int) -> OTXDataItem | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] bboxes = ( np.stack([ann.points for ann in bbox_anns], axis=0).astype(np.float32) if len(bbox_anns) > 0 else np.zeros((0, 4), dtype=np.float32) ) # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] keypoint_anns = [ann for ann in item.annotations if isinstance(ann, Points)] keypoints = ( np.stack([ann.points for ann in keypoint_anns], axis=0).astype(np.float32) if len(keypoint_anns) > 0 else np.zeros((0, len(self.label_info.label_names) * 2), dtype=np.float32) ).reshape(-1, 2) keypoints_visible = ( (np.array([ann.visibility for ann in keypoint_anns]) > 1).reshape(-1).astype(np.int8) if len(keypoint_anns) > 0 and hasattr(keypoint_anns[0], "visibility") else np.minimum(1, keypoints)[..., 0] ) keypoints = np.hstack((keypoints, keypoints_visible.reshape(-1, 1))) entity = OTXDataItem( image=to_dtype(to_image(img_data), torch.float32), img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, image_color_channel=self.image_color_channel, ignored_labels=ignored_labels, ), bboxes=tv_tensors.BoundingBoxes( bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=img_shape, ), label=torch.as_tensor([ann.label for ann in bbox_anns], dtype=torch.long), keypoints=torch.as_tensor(keypoints, dtype=torch.float32), ) return self._apply_transforms(entity) # type: ignore[return-value]