Source code for otx.data.dataset.classification

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for OTXClassificationDatasets."""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from datumaro import Image, Label
from datumaro.components.annotation import AnnotationType
from torch.nn import functional
from torchvision.transforms.v2.functional import to_dtype, to_image

from otx.data.dataset.base import OTXDataset
from otx.data.entity.base import ImageInfo
from otx.data.entity.torch import OTXDataItem
from otx.types.label import HLabelInfo

if TYPE_CHECKING:
    from collections.abc import Callable


[docs] class OTXMulticlassClsDataset(OTXDataset): """OTXDataset class for multi-class classification task.""" def _get_item_impl(self, index: int) -> OTXDataItem | None: item = self.dm_subset[index] img = item.media_as(Image) roi = item.attributes.get("roi", None) img_data, img_shape, _ = self._get_img_data_and_shape(img, roi) image = to_dtype(to_image(img_data), dtype=torch.float32) if roi: # extract labels from ROI labels_ids = [ label["label"]["_id"] for label in roi["labels"] if label["label"]["domain"] == "CLASSIFICATION" ] if self.data_format == "arrow": label_anns = [self.label_info.label_ids.index(label_id) for label_id in labels_ids] else: label_anns = [self.label_info.label_names.index(label_id) for label_id in labels_ids] else: # extract labels from annotations label_anns = [ann.label for ann in item.annotations if isinstance(ann, Label)] if len(label_anns) > 1: msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}" raise ValueError(msg) entity = OTXDataItem( image=image, label=torch.as_tensor(label_anns, dtype=torch.long), img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, image_color_channel=self.image_color_channel, ), ) return self._apply_transforms(entity) @property def collate_fn(self) -> Callable: """Collection function to collect MulticlassClsDataEntity into MulticlassClsBatchDataEntity in data loader.""" return OTXDataItem.collate_fn
[docs] class OTXMultilabelClsDataset(OTXDataset): """OTXDataset class for multi-label classification task.""" def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.num_classes = len(self.dm_subset.categories()[AnnotationType.label]) def _get_item_impl(self, index: int) -> OTXDataItem | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item img_data, img_shape, _ = self._get_img_data_and_shape(img) img_data = to_dtype(to_image(img_data), dtype=torch.float32) label_ids = set() for ann in item.annotations: # multilabel information stored in 'multi_label_ids' attribute when the source format is arrow if "multi_label_ids" in ann.attributes: for lbl_idx in ann.attributes["multi_label_ids"]: label_ids.add(lbl_idx) if isinstance(ann, Label): label_ids.add(ann.label) else: # If the annotation is not Label, it should be converted to Label. # For Chained Task: Detection (Bbox) -> Classification (Label) label = Label(label=ann.label) label_ids.add(label.label) labels = torch.as_tensor(list(label_ids)) entity = OTXDataItem( image=img_data, label=self._convert_to_onehot(labels, ignored_labels), img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, image_color_channel=self.image_color_channel, ignored_labels=ignored_labels, ), ) return self._apply_transforms(entity) def _convert_to_onehot(self, labels: torch.tensor, ignored_labels: list[int]) -> torch.tensor: """Convert label to one-hot vector format.""" # Torch's one_hot() expects the input to be of type long # However, when labels are empty, they are of type float32 onehot = functional.one_hot(labels.long(), self.num_classes).sum(0).clamp_max_(1) if ignored_labels: for ignore_label in ignored_labels: onehot[ignore_label] = -1 return onehot
[docs] class OTXHlabelClsDataset(OTXDataset): """OTXDataset class for H-label classification task.""" def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.dm_categories = self.dm_subset.categories()[AnnotationType.label] # Hlabel classification used HLabelInfo to insert the HLabelData. if self.data_format == "arrow": # arrow format stores label IDs as names, have to deal with that here self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories) else: self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories) self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names)) self.id_to_name_mapping[""] = "" if self.label_info.num_multiclass_heads == 0: msg = "The number of multiclass heads should be larger than 0." raise ValueError(msg) if self.data_format != "arrow": for dm_item in self.dm_subset: self._add_ancestors(dm_item.annotations) def _add_ancestors(self, label_anns: list[Label]) -> None: """Add ancestors recursively if some label miss the ancestor information. If the label tree likes below, object - vehicle -- car |- bus |- truck And annotation = ['car'], it should be ['car', 'vehicle', 'object'], to include the ancestor. This function add the ancestors to the annotation if missing. """ def _label_idx_to_name(idx: int) -> str: return self.dm_categories[idx].name def _label_name_to_idx(name: str) -> int: indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name] return indices[0] def _get_label_group_idx(label_name: str) -> int: if isinstance(self.label_info, HLabelInfo): if self.data_format == "arrow": return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0] return self.label_info.class_to_group_idx[label_name][0] msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}" raise ValueError(msg) def _find_ancestor_recursively(label_name: str, ancestors: list) -> list[str]: _, dm_label_category = self.dm_categories.find(label_name) parent_name = dm_label_category.parent if dm_label_category else "" if parent_name != "": ancestors.append(parent_name) _find_ancestor_recursively(parent_name, ancestors) return ancestors def _get_all_label_names_in_anns(anns: list[Label]) -> list[str]: return [_label_idx_to_name(ann.label) for ann in anns] all_label_names = _get_all_label_names_in_anns(label_anns) ancestor_dm_labels = [] for ann in label_anns: label_idx = ann.label label_name = _label_idx_to_name(label_idx) ancestors = _find_ancestor_recursively(label_name, []) for i, ancestor in enumerate(ancestors): if ancestor not in all_label_names: ancestor_dm_labels.append( Label( label=_label_name_to_idx(ancestor), id=len(label_anns) + i, group=_get_label_group_idx(ancestor), ), ) label_anns.extend(ancestor_dm_labels) def _get_item_impl(self, index: int) -> OTXDataItem | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item img_data, img_shape, _ = self._get_img_data_and_shape(img) img_data = to_dtype(to_image(img_data), dtype=torch.float32) label_ids = set() for ann in item.annotations: # in h-cls scenario multilabel information stored in 'multi_label_ids' attribute if "multi_label_ids" in ann.attributes: for lbl_idx in ann.attributes["multi_label_ids"]: label_ids.add(lbl_idx) if isinstance(ann, Label): label_ids.add(ann.label) else: # If the annotation is not Label, it should be converted to Label. # For Chained Task: Detection (Bbox) -> Classification (Label) label = Label(label=ann.label) label_ids.add(label.label) hlabel_labels = self._convert_label_to_hlabel_format([Label(label=idx) for idx in label_ids], ignored_labels) entity = OTXDataItem( image=img_data, label=torch.as_tensor(hlabel_labels), img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, image_color_channel=self.image_color_channel, ignored_labels=ignored_labels, ), ) return self._apply_transforms(entity) def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_labels: list[int]) -> list[int]: """Convert format of the label to the h-label. It converts the label format to h-label format. Total length of result is sum of number of hierarchy and number of multilabel classes. i.e. Let's assume that we used the same dataset with example of the definition of HLabelData and the original labels are ["Rigid", "Triangle", "Lion"]. Then, h-label format will be [0, 1, 1, 0]. The first N-th indices represent the label index of multiclass heads (N=num_multiclass_heads), others represent the multilabel labels. [Multiclass Heads] 0-th index = 0 -> ["Rigid"(O), "Non-Rigid"(X)] <- First multiclass head 1-st index = 1 -> ["Rectangle"(O), "Triangle"(X), "Circle"(X)] <- Second multiclass head [Multilabel Head] 2, 3 indices = [1, 0] -> ["Lion"(O), "Panda"(X)] """ if not isinstance(self.label_info, HLabelInfo): msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}." raise TypeError(msg) num_multiclass_heads = self.label_info.num_multiclass_heads num_multilabel_classes = self.label_info.num_multilabel_classes class_indices = [0] * (num_multiclass_heads + num_multilabel_classes) for i in range(num_multiclass_heads): class_indices[i] = -1 for ann in label_anns: if self.data_format == "arrow": # skips unknown labels for instance, the empty one if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping: continue ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name] else: ann_name = self.dm_categories.items[ann.label].name group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name] if group_idx < num_multiclass_heads: class_indices[group_idx] = in_group_idx elif ann.label not in ignored_labels: class_indices[num_multiclass_heads + in_group_idx] = 1 else: class_indices[num_multiclass_heads + in_group_idx] = -1 return class_indices