Source code for otx.data.torch.torch

"""Torch-specific data item implementations."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from collections.abc import Iterator, Mapping
from dataclasses import asdict, dataclass, fields
from typing import TYPE_CHECKING, Any, Sequence

import torch
import torchvision.transforms.v2.functional as F  # noqa: N812
from torchvision import tv_tensors

from otx.core.data.entity.utils import register_pytree_node

from .validations import (
    ValidateBatchMixin,
    ValidateItemMixin,
)

if TYPE_CHECKING:
    import numpy as np
    from datumaro import Polygon
    from torchvision.tv_tensors import BoundingBoxes, Mask

    from otx.core.data.entity.base import ImageInfo


# NOTE: register_pytree_node and Mapping are required for torchvision.transforms.v2 to work with OTXDataEntity
# TODO(ashwinvaidya17): Remove this once custom transforms are removed
[docs] @register_pytree_node @dataclass class OTXDataItem(ValidateItemMixin, Mapping): """OTX data item implementation. Attributes: image (torch.Tensor | np.ndarray ): The image tensor label (torch.Tensor | None): The label tensor, optional. masks (Mask | None): The masks, optional. bboxes (BoundingBoxes | None): The bounding boxes, optional. keypoints (torch.Tensor | None): The keypoints, optional. polygons (list[Polygon] | None): The polygons, optional. img_info (ImageInfo | None): Additional image information, optional. """ image: torch.Tensor | np.ndarray label: torch.Tensor | None = None masks: Mask | None = None bboxes: BoundingBoxes | None = None keypoints: torch.Tensor | None = None polygons: list[Polygon] | None = None img_info: ImageInfo | None = None # TODO(ashwinvaidya17): revisit and try to remove this
[docs] @staticmethod def collate_fn(items: list[OTXDataItem]) -> OTXDataBatch: """Collate TorchDataItems into a batch. Args: items: List of TorchDataItems to batch Returns: Batched TorchDataItems with stacked tensors """ # Check if all images have the same size. TODO(kprokofi): remove this check once OV IR models are moved. if all(item.image.shape == items[0].image.shape for item in items): images = torch.stack([item.image for item in items]) else: # we need this only in case of OV inference, where no resize images = [item.image for item in items] return OTXDataBatch( batch_size=len(items), images=images, labels=[item.label for item in items], bboxes=[item.bboxes for item in items], keypoints=[item.keypoints for item in items], masks=[item.masks for item in items], polygons=[item.polygons for item in items], # type: ignore[misc] imgs_info=[item.img_info for item in items], )
def __iter__(self) -> Iterator[str]: for field_ in fields(self): yield field_.name def __getitem__(self, key: str) -> Any: # noqa: ANN401 return getattr(self, key) def __len__(self) -> int: return len(fields(self))
[docs] def to_tv_image(self) -> OTXDataItem: """Return a new instance with the `image` attribute converted to a TorchVision Image if it is a NumPy array. Returns: A new instance with the `image` attribute converted to a TorchVision Image, if applicable. Otherwise, return this instance as is. """ if isinstance(self.image, tv_tensors.Image): return self return self.wrap(image=F.to_image(self.image))
[docs] def wrap(self, **kwargs) -> OTXDataItem: """Wrap this dataclass with the given keyword arguments. Args: **kwargs: Keyword arguments to be overwritten on top of this dataclass Returns: Updated dataclass """ updated_kwargs = asdict(self) updated_kwargs.update(**kwargs) return self.__class__(**updated_kwargs)
[docs] @dataclass class OTXDataBatch(ValidateBatchMixin): """Torch data item batch implementation.""" batch_size: int # TODO(ashwinvaidya17): Remove this images: torch.Tensor | list[torch.Tensor] labels: list[torch.Tensor] | None = None masks: list[Mask] | None = None bboxes: list[BoundingBoxes] | None = None keypoints: list[torch.Tensor] | None = None polygons: list[list[Polygon]] | None = None imgs_info: Sequence[ImageInfo | None] | None = None # TODO(ashwinvaidya17): revisit
[docs] def pin_memory(self) -> OTXDataBatch: """Pin memory for member tensor variables.""" # https://github.com/pytorch/pytorch/issues/116403 kwargs = {} def maybe_pin(x: Any) -> Any: # noqa: ANN401 if isinstance(x, torch.Tensor): return x.pin_memory() return x def maybe_wrap_tv(x: Any) -> Any: # noqa: ANN401 if isinstance(x, tv_tensors.TVTensor): return tv_tensors.wrap(x.pin_memory(), like=x) return maybe_pin(x) # Handle images separately because of tv_tensors wrapping if self.images is not None: if isinstance(self.images, list): kwargs["images"] = [maybe_wrap_tv(img) for img in self.images] else: kwargs["images"] = maybe_wrap_tv(self.images) # Generic handler for all other fields for field in ["labels", "bboxes", "keypoints", "masks"]: value = getattr(self, field) if value is not None: kwargs[field] = [maybe_wrap_tv(v) if v is not None else None for v in value] return self.wrap(**kwargs)
[docs] def wrap(self, **kwargs) -> OTXDataBatch: """Wrap this dataclass with the given keyword arguments. Args: **kwargs: Keyword arguments to be overwritten on top of this dataclass Returns: Updated dataclass """ updated_kwargs = asdict(self) updated_kwargs.update(**kwargs) return self.__class__(**updated_kwargs)
[docs] @dataclass class OTXPredItem(OTXDataItem): """Torch prediction data item implementation.""" scores: torch.Tensor | None = None feature_vector: torch.Tensor | None = None saliency_map: torch.Tensor | None = None
[docs] @dataclass class OTXPredBatch(OTXDataBatch): """Torch prediction data item batch implementation.""" scores: list[torch.Tensor] | None = None feature_vector: list[torch.Tensor] | None = None saliency_map: list[torch.Tensor] | None = None @property def has_xai_outputs(self) -> bool: """Check if the batch has XAI outputs. Necessary for compatibility with tests. """ # TODO(ashwinvaidya17): the tests should directly refer to saliency map. return self.saliency_map is not None and len(self.saliency_map) > 0