Source code for otx.data.dataset.anomaly

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Anomaly Classification Dataset."""

from __future__ import annotations

from enum import Enum
from pathlib import Path

import cv2
import numpy as np
import torch
from datumaro import Dataset as DmDataset
from datumaro import DatasetItem, Image
from datumaro.components.annotation import AnnotationType, Bbox, Ellipse, Polygon
from datumaro.components.media import ImageFromBytes, ImageFromFile
from torchvision import io
from torchvision.transforms.v2.functional import to_dtype, to_image
from torchvision.tv_tensors import Mask

from otx.data.dataset.base import OTXDataset, Transforms
from otx.data.entity.base import ImageInfo
from otx.data.entity.torch import OTXDataItem
from otx.types.image import ImageColorChannel
from otx.types.label import AnomalyLabelInfo
from otx.types.task import OTXTaskType


class AnomalyLabel(Enum):
    """Anomaly label to tensor mapping."""

    NORMAL = torch.tensor(0.0)
    ANOMALOUS = torch.tensor(1.0)


[docs] class OTXAnomalyDataset(OTXDataset): """OTXDataset class for anomaly classification task.""" def __init__( self, task_type: OTXTaskType, dm_subset: DmDataset, transforms: Transforms, max_refetch: int = 1000, image_color_channel: ImageColorChannel = ImageColorChannel.RGB, stack_images: bool = True, to_tv_image: bool = True, data_format: str = "", ) -> None: self.task_type = task_type super().__init__( dm_subset, transforms, max_refetch, image_color_channel, stack_images, to_tv_image, ) self.label_info = AnomalyLabelInfo() self._label_mapping = self._map_id_to_label() def _get_item_impl( self, index: int, ) -> OTXDataItem: datumaro_item = self.dm_subset[index] img = datumaro_item.media_as(Image) # returns image in RGB format if self.image_color_channel is RGB img_data, img_shape, _ = self._get_img_data_and_shape(img) image = to_dtype(to_image(img_data), dtype=torch.float32) label = self._get_label(datumaro_item) item = OTXDataItem( image=image, img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, image_color_channel=self.image_color_channel, ), label=torch.tensor(label, dtype=torch.long), masks=Mask(self._get_mask(datumaro_item, label, img_shape)), ) return self._apply_transforms(item) # type: ignore[return-value] def _get_mask(self, datumaro_item: DatasetItem, label: torch.Tensor, img_shape: tuple[int, int]) -> torch.Tensor: """Get mask from datumaro_item. Converts bounding boxes to mask if mask is not available. """ if isinstance(datumaro_item.media, ImageFromFile): if label == AnomalyLabel.ANOMALOUS.value: mask = self._mask_image_from_file(datumaro_item, img_shape) else: mask = torch.zeros(1, *img_shape).to(torch.uint8) elif isinstance(datumaro_item.media, ImageFromBytes): mask = torch.zeros(1, *img_shape).to(torch.uint8) if label == AnomalyLabel.ANOMALOUS.value: for annotation in datumaro_item.annotations: # There is only one mask if isinstance(annotation, (Ellipse, Polygon)): polygons = np.asarray(annotation.as_polygon(), dtype=np.int32).reshape((-1, 1, 2)) mask = np.zeros(img_shape, dtype=np.uint8) mask = cv2.drawContours( mask, [polygons], 0, (1, 1, 1), thickness=cv2.FILLED, ) mask = torch.from_numpy(mask).to(torch.uint8).unsqueeze(0) break # If there is no mask, create a mask from bbox if isinstance(annotation, Bbox): bbox = annotation mask = self._bbox_to_mask(bbox, img_shape) break return mask def _bbox_to_mask(self, bbox: Bbox, img_shape: tuple[int, int]) -> torch.Tensor: mask = torch.zeros(1, *img_shape).to(torch.uint8) x1, y1, x2, y2 = bbox.get_bbox() x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) mask[:, y1:y2, x1:x2] = 1 return mask def _get_label(self, datumaro_item: DatasetItem) -> torch.LongTensor: """Get label from datumaro item.""" if isinstance(datumaro_item.media, ImageFromFile): # Note: This assumes that the dataset is in MVTec format. # We can't use datumaro label id as it returns some number like 3 for good from which it is hard to infer # whether the image is Anomalous or Normal. Because it leads to other questions like what do numbers 0,1,2 # mean? label: torch.LongTensor = AnomalyLabel.NORMAL if "good" in datumaro_item.id else AnomalyLabel.ANOMALOUS elif isinstance(datumaro_item.media, ImageFromBytes): label = self._label_mapping[datumaro_item.annotations[0].label] else: msg = f"Media type {type(datumaro_item.media)} is not supported." raise NotImplementedError(msg) return label.value def _map_id_to_label(self) -> dict[int, torch.Tensor]: """Map label id to label tensor.""" id_label_mapping = {} categories = self.dm_subset.categories()[AnnotationType.label] for label_item in categories.items: if any("normal" in attribute.lower() for attribute in label_item.attributes): label = AnomalyLabel.NORMAL else: label = AnomalyLabel.ANOMALOUS id_label_mapping[categories.find(label_item.name)[0]] = label return id_label_mapping def _mask_image_from_file(self, datumaro_item: DatasetItem, img_shape: tuple[int, int]) -> torch.Tensor: """Assumes MVTec format and returns mask from disk.""" mask_file_path = ( Path("/".join(datumaro_item.media.path.split("/")[:-3])) / "ground_truth" / f"{('/'.join(datumaro_item.media.path.split('/')[-2:])).replace('.png','_mask.png')}" ) if mask_file_path.exists(): return (io.read_image(str(mask_file_path), mode=io.ImageReadMode.GRAY) / 255).to(torch.uint8) # Note: This is a workaround to handle the case where mask is not available otherwise the tests fail. # This is problematic because it assigns empty masks to an Anomalous image. return torch.zeros(1, *img_shape).to(torch.uint8)