Source code for otx.data.entity.base

# Copyright (C) 2023-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for OTX base data entities."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict

import torch
import torchvision.transforms.v2.functional as F  # noqa: N812
from torch import Tensor
from torch.utils._pytree import tree_flatten
from torchvision import tv_tensors
from torchvision.utils import _log_api_usage_once

from otx.types.image import ImageColorChannel, ImageType

if TYPE_CHECKING:
    from collections.abc import Mapping


def custom_wrap(wrappee: Tensor, *, like: tv_tensors.TVTensor, **kwargs) -> tv_tensors.TVTensor:
    """Add `Points` in tv_tensors.wrap.

    If `like` is
        - tv_tensors.BoundingBoxes : the `format` and `canvas_size` of `like` are assigned to `wrappee`
        - Points : the `canvas_size` of `like` is assigned to `wrappee`
    Unless, they are passed as `kwargs`.

    Args:
        wrappee (Tensor): The tensor to convert.
        like (tv_tensors.TVTensor): The reference. `wrappee` will be converted into the same subclass as `like`.
        kwargs: Can contain "format" and "canvas_size" if `like` is a tv_tensor.BoundingBoxes,
            or "canvas_size" if `like` is a `Points`. Ignored otherwise.
    """
    if isinstance(like, tv_tensors.BoundingBoxes):
        return tv_tensors.BoundingBoxes._wrap(  # noqa: SLF001
            wrappee,
            format=kwargs.get("format", like.format),
            canvas_size=kwargs.get("canvas_size", like.canvas_size),
        )
    elif isinstance(like, Points):  # noqa: RET505
        return Points._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))  # noqa: SLF001

    # # TODO(Vlad): remove this after torch upgrade. This workaround prevents a failure when like is also a Tensor
    # if type(like) == type(wrappee):
    #     return wrappee

    return wrappee.as_subclass(type(like))


tv_tensors.wrap = custom_wrap


[docs] class ImageInfo(tv_tensors.TVTensor): """Meta info for image. Attributes: img_id: Image id img_shape: Image shape (heigth, width) after preprocessing ori_shape: Image shape (heigth, width) right after loading it padding: Number of pixels to pad all borders (left, top, right, bottom) scale_factor: Scale factor (height, width) if the image is resized during preprocessing. Default value is `(1.0, 1.0)` when there is no resizing. However, if the image is cropped, it will lose the scaling information and be `None`. normalized: If true, this image is normalized with `norm_mean` and `norm_std` norm_mean: Mean vector used to normalize this image norm_std: Standard deviation vector used to normalize this image image_color_channel: Color channel type of this image, RGB or BGR. ignored_labels: Label that should be ignored in this image. Default to None. """ img_idx: int img_shape: tuple[int, int] ori_shape: tuple[int, int] padding: tuple[int, int, int, int] = (0, 0, 0, 0) scale_factor: tuple[float, float] | None = (1.0, 1.0) normalized: bool = False norm_mean: tuple[float, float, float] = (0.0, 0.0, 0.0) norm_std: tuple[float, float, float] = (1.0, 1.0, 1.0) image_color_channel: ImageColorChannel = ImageColorChannel.RGB ignored_labels: list[int] @classmethod def _wrap( cls, dummy_tensor: Tensor, *, img_idx: int, img_shape: tuple[int, int], ori_shape: tuple[int, int], padding: tuple[int, int, int, int] = (0, 0, 0, 0), scale_factor: tuple[float, float] | None = (1.0, 1.0), normalized: bool = False, norm_mean: tuple[float, float, float] = (0.0, 0.0, 0.0), norm_std: tuple[float, float, float] = (1.0, 1.0, 1.0), image_color_channel: ImageColorChannel = ImageColorChannel.RGB, ignored_labels: list[int] | None = None, ) -> ImageInfo: image_info = dummy_tensor.as_subclass(cls) image_info.img_idx = img_idx image_info.img_shape = img_shape image_info.ori_shape = ori_shape image_info.padding = padding image_info.scale_factor = scale_factor image_info.normalized = normalized image_info.norm_mean = norm_mean image_info.norm_std = norm_std image_info.image_color_channel = image_color_channel image_info.ignored_labels = ignored_labels if ignored_labels else [] return image_info def __new__( # noqa: D102 cls, img_idx: int, img_shape: tuple[int, int], ori_shape: tuple[int, int], padding: tuple[int, int, int, int] = (0, 0, 0, 0), scale_factor: tuple[float, float] | None = (1.0, 1.0), normalized: bool = False, norm_mean: tuple[float, float, float] = (0.0, 0.0, 0.0), norm_std: tuple[float, float, float] = (1.0, 1.0, 1.0), image_color_channel: ImageColorChannel = ImageColorChannel.RGB, ignored_labels: list[int] | None = None, ) -> ImageInfo: return cls._wrap( dummy_tensor=Tensor(), img_idx=img_idx, img_shape=img_shape, ori_shape=ori_shape, padding=padding, scale_factor=scale_factor, normalized=normalized, norm_mean=norm_mean, norm_std=norm_std, image_color_channel=image_color_channel, ignored_labels=ignored_labels, ) @classmethod def _wrap_output( cls, output: Tensor, args: tuple[()] = (), kwargs: Mapping[str, Any] | None = None, ) -> ImageType: """Wrap an output (`torch.Tensor`) obtained from PyTorch function. For example, this function will be called when >>> img_info = ImageInfo(img_idx=0, img_shape=(10, 10), ori_shape=(10, 10)) >>> `_wrap_output()` will be called after the PyTorch function `to()` is called >>> img_info = img_info.to(device=torch.cuda) """ flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) if isinstance(output, Tensor) and not isinstance(output, ImageInfo): image_info = next(x for x in flat_params if isinstance(x, ImageInfo)) output = ImageInfo._wrap( dummy_tensor=output, img_idx=image_info.img_idx, img_shape=image_info.img_shape, ori_shape=image_info.ori_shape, padding=image_info.padding, scale_factor=image_info.scale_factor, normalized=image_info.normalized, norm_mean=image_info.norm_mean, norm_std=image_info.norm_std, image_color_channel=image_info.image_color_channel, ignored_labels=image_info.ignored_labels, ) elif isinstance(output, (tuple, list)): image_infos = [x for x in flat_params if isinstance(x, ImageInfo)] output = type(output)( ImageInfo._wrap( dummy_tensor=dummy_tensor, img_idx=image_info.img_idx, img_shape=image_info.img_shape, ori_shape=image_info.ori_shape, padding=image_info.padding, scale_factor=image_info.scale_factor, normalized=image_info.normalized, norm_mean=image_info.norm_mean, norm_std=image_info.norm_std, image_color_channel=image_info.image_color_channel, ignored_labels=image_info.ignored_labels, ) for dummy_tensor, image_info in zip(output, image_infos) ) return output def __repr__(self) -> str: return ( "ImageInfo(" f"img_idx={self.img_idx}, " f"img_shape={self.img_shape}, " f"ori_shape={self.ori_shape}, " f"padding={self.padding}, " f"scale_factor={self.scale_factor}, " f"normalized={self.normalized}, " f"norm_mean={self.norm_mean}, " f"norm_std={self.norm_std}, " f"image_color_channel={self.image_color_channel}, " f"ignored_labels={self.ignored_labels})" )
@F.register_kernel(functional=F.resize, tv_tensor_cls=ImageInfo) def _resize_image_info(image_info: ImageInfo, size: list[int], **kwargs) -> ImageInfo: # noqa: ARG001 """Register ImageInfo to TorchVision v2 resize kernel.""" if len(size) == 2: image_info.img_shape = (size[0], size[1]) elif len(size) == 1: image_info.img_shape = (size[0], size[0]) else: raise ValueError(size) ori_h, ori_w = image_info.ori_shape new_h, new_w = image_info.img_shape image_info.scale_factor = (new_h / ori_h, new_w / ori_w) return image_info @F.register_kernel(functional=F.crop, tv_tensor_cls=ImageInfo) def _crop_image_info( image_info: ImageInfo, height: int, width: int, **kwargs, # noqa: ARG001 ) -> ImageInfo: """Register ImageInfo to TorchVision v2 resize kernel.""" image_info.img_shape = (height, width) image_info.scale_factor = None return image_info @F.register_kernel(functional=F.resized_crop, tv_tensor_cls=ImageInfo) def _resized_crop_image_info( image_info: ImageInfo, size: list[int], **kwargs, # noqa: ARG001 ) -> ImageInfo: """Register ImageInfo to TorchVision v2 resize kernel.""" if len(size) == 2: image_info.img_shape = (size[0], size[1]) elif len(size) == 1: image_info.img_shape = (size[0], size[0]) else: raise ValueError(size) image_info.scale_factor = None return image_info @F.register_kernel(functional=F.center_crop, tv_tensor_cls=ImageInfo) def _center_crop_image_info( image_info: ImageInfo, output_size: list[int], **kwargs, # noqa: ARG001 ) -> ImageInfo: """Register ImageInfo to TorchVision v2 resize kernel.""" img_shape = F._geometry._center_crop_parse_output_size(output_size) # noqa: SLF001 image_info.img_shape = (img_shape[0], img_shape[1]) image_info.scale_factor = None return image_info @F.register_kernel(functional=F.pad, tv_tensor_cls=ImageInfo) def _pad_image_info( image_info: ImageInfo, padding: int | list[int], **kwargs, # noqa: ARG001 ) -> ImageInfo: """Register ImageInfo to TorchVision v2 resize kernel.""" left, right, top, bottom = F._geometry._parse_pad_padding(padding) # noqa: SLF001 height, width = image_info.img_shape image_info.padding = (left, top, right, bottom) image_info.img_shape = (height + top + bottom, width + left + right) return image_info @F.register_kernel(functional=F.normalize, tv_tensor_cls=ImageInfo) def _normalize_image_info( image_info: ImageInfo, mean: list[float], std: list[float], **kwargs, # noqa: ARG001 ) -> ImageInfo: image_info.normalized = True image_info.norm_mean = (mean[0], mean[1], mean[2]) image_info.norm_std = (std[0], std[1], std[2]) return image_info
[docs] class Points(tv_tensors.TVTensor): """`torch.Tensor` subclass for points. Attributes: data: Any data that can be turned into a tensor with `torch.as_tensor`. canvas_size (two-tuple of ints): Height and width of the corresponding image or video. dtype (torch.dtype, optional): Desired data type of the point. If omitted, will be inferred from `data`. device (torch.device, optional): Desired device of the point. If omitted and `data` is a `torch.Tensor`, the device is taken from it. Otherwise, the point is constructed on the CPU. requires_grad (bool, optional): Whether autograd should record operations on the point. If omitted and `data` is a `torch.Tensor`, the value is taken from it. Otherwise, defaults to `False`. """ canvas_size: tuple[int, int] @classmethod def _wrap(cls, tensor: Tensor, *, canvas_size: tuple[int, int]) -> Points: points = tensor.as_subclass(cls) points.canvas_size = canvas_size return points def __new__( # noqa: D102 cls, data: Any, # noqa: ANN401 *, canvas_size: tuple[int, int], dtype: torch.dtype | None = None, device: torch.device | str | int | None = None, requires_grad: bool | None = None, ) -> Points: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor, canvas_size=canvas_size) @classmethod def _wrap_output( cls, output: Tensor, args: tuple[()] = (), kwargs: Mapping[str, Any] | None = None, ) -> Points: flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) first_point_from_args = next(x for x in flat_params if isinstance(x, Points)) canvas_size = first_point_from_args.canvas_size if isinstance(output, Tensor) and not isinstance(output, Points): output = Points._wrap(output, canvas_size=canvas_size) elif isinstance(output, (tuple, list)): output = type(output)(Points._wrap(part, canvas_size=canvas_size) for part in output) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # noqa: ANN401 return self._make_repr(canvas_size=self.canvas_size)
def resize_points( points: torch.Tensor, canvas_size: tuple[int, int], size: tuple[int, int] | list[int], max_size: int | None = None, ) -> tuple[torch.Tensor, tuple[int, int]]: """Resize points.""" old_height, old_width = canvas_size new_height, new_width = F._geometry._compute_resized_output_size( # noqa: SLF001 canvas_size, size=size, max_size=max_size, ) if (new_height, new_width) == (old_height, old_width): return points, canvas_size w_ratio = new_width / old_width h_ratio = new_height / old_height ratios = torch.tensor([w_ratio, h_ratio], device=points.device) return ( points.mul(ratios).to(points.dtype), (new_height, new_width), ) @F.register_kernel(functional=F.resize, tv_tensor_cls=Points) def _resize_points_dispatch( inpt: Points, size: tuple[int, int] | list[int], max_size: int | None = None, **kwargs, # noqa: ARG001 ) -> Points: output, canvas_size = resize_points( inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) def pad_points( points: torch.Tensor, canvas_size: tuple[int, int], padding: tuple[int, ...] | list[int], padding_mode: str = "constant", ) -> tuple[torch.Tensor, tuple[int, int]]: """Pad points.""" if padding_mode not in ["constant"]: # TODO(sungchul): add support of other padding modes raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") # noqa: EM102, TRY003 left, right, top, bottom = F._geometry._parse_pad_padding(padding) # noqa: SLF001 pad = [left, top] points = points + torch.tensor(pad, dtype=points.dtype, device=points.device) height, width = canvas_size height += top + bottom width += left + right canvas_size = (height, width) return clamp_points(points, canvas_size=canvas_size), canvas_size @F.register_kernel(functional=F.pad, tv_tensor_cls=Points) def _pad_points_dispatch( inpt: Points, padding: tuple[int, ...] | list[int], padding_mode: str = "constant", **kwargs, # noqa: ARG001 ) -> Points: output, canvas_size = pad_points( inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, padding=padding, padding_mode=padding_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @F.register_kernel(functional=F.get_size, tv_tensor_cls=Points) def get_size_points(point: Points) -> list[int]: """Get size of points.""" return list(point.canvas_size) def _clamp_points(points: Tensor, canvas_size: tuple[int, int]) -> Tensor: in_dtype = points.dtype points = points.clone() if points.is_floating_point() else points.float() points[..., 0].clamp_(min=0, max=canvas_size[1]) points[..., 1].clamp_(min=0, max=canvas_size[0]) return points.to(in_dtype) def clamp_points(inpt: Tensor, canvas_size: tuple[int, int] | None = None) -> Tensor: """Clamp point range.""" if not torch.jit.is_scripting(): _log_api_usage_once(clamp_points) if torch.jit.is_scripting() or F._utils.is_pure_tensor(inpt): # noqa: SLF001 if canvas_size is None: raise ValueError("For pure tensor inputs, `canvas_size` has to be passed.") # noqa: EM101, TRY003 return _clamp_points(inpt, canvas_size=canvas_size) elif isinstance(inpt, Points): # noqa: RET505 if canvas_size is not None: raise ValueError("For point tv_tensor inputs, `canvas_size` must not be passed.") # noqa: EM101, TRY003 output = _clamp_points(inpt.as_subclass(Tensor), canvas_size=inpt.canvas_size) return tv_tensors.wrap(output, like=inpt) else: raise TypeError( # noqa: TRY003 f"Input can either be a plain tensor or a point tv_tensor, but got {type(inpt)} instead.", # noqa: EM102 ) class OTXBatchLossEntity(Dict[str, Tensor]): """Data entity to represent model output losses."""