# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""LightningDataModule extension for OTX."""
from __future__ import annotations
import logging as log
from typing import TYPE_CHECKING
from datumaro import Dataset as DmDataset
from lightning import LightningDataModule
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, RandomSampler
from torchvision.transforms.v2 import Normalize
from otx.backend.native.utils.instantiators import instantiate_sampler
from otx.backend.native.utils.utils import get_adaptive_num_workers
from otx.config.data import TileConfig
from otx.data.dataset.tile import OTXTileDatasetFactory
from otx.data.factory import OTXDatasetFactory
from otx.data.utils import adapt_input_size_to_dataset, adapt_tile_config
from otx.data.utils.pre_filtering import pre_filtering
from otx.types.device import DeviceType
from otx.types.image import ImageColorChannel
from otx.types.label import LabelInfo
from otx.types.task import OTXTaskType
if TYPE_CHECKING:
from lightning.pytorch.utilities.parsing import AttributeDict
from otx.config.data import SubsetConfig
from otx.data.dataset.base import OTXDataset
[docs]
class OTXDataModule(LightningDataModule):
"""This class extends the LightningDataModule to provide data handling capabilities for the OTX pipeline.
Args:
task (OTXTaskType): The type of task (e.g., classification, detection).
data_format (str): The format of the data (e.g., 'coco', 'voc').
data_root (str): The root directory where the data is stored.
train_subset (SubsetConfig): Configuration for the training subset.
val_subset (SubsetConfig): Configuration for the validation subset.
test_subset (SubsetConfig): Configuration for the test subset.
tile_config (TileConfig, optional): Configuration for tiling.
Defaults to TileConfig(enable_tiler=False).
image_color_channel (ImageColorChannel, optional): Color channel configuration for images.
Defaults to ImageColorChannel.RGB.
include_polygons (bool, optional): Whether to include polygons in the data. Defaults to False.
ignore_index (int, optional): Index to ignore in segmentation tasks. Defaults to 255.
unannotated_items_ratio (float, optional): Ratio of unannotated items to include. Defaults to 0.0.
auto_num_workers (bool, optional): Whether to automatically determine the number of workers. Defaults to False.
device (DeviceType, optional): Device type to use (e.g., 'cpu', 'gpu'). Defaults to DeviceType.auto.
input_size (tuple[int, int] | str, optional): Final image or video shape after transformation.
Can be "auto" to determine size automatically. Defaults to "auto".
input_size_multiplier (int, optional): Multiplier for adaptive input size.
Useful for models requiring specific input size multiples. Defaults to 1.
"""
def __init__(
self,
task: OTXTaskType,
data_format: str,
data_root: str,
train_subset: SubsetConfig,
val_subset: SubsetConfig,
test_subset: SubsetConfig,
tile_config: TileConfig = TileConfig(enable_tiler=False),
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
include_polygons: bool = False,
ignore_index: int = 255,
unannotated_items_ratio: float = 0.0,
auto_num_workers: bool = False,
device: DeviceType = DeviceType.auto,
input_size: tuple[int, int] | str = "auto",
input_size_multiplier: int = 1,
) -> None:
"""Constructor."""
super().__init__()
self.task = task
self.data_format = data_format
self.data_root = data_root
self.train_subset = train_subset
self.val_subset = val_subset
self.test_subset = test_subset
self.tile_config = tile_config
self.image_color_channel = image_color_channel
self.include_polygons = include_polygons
self.ignore_index = ignore_index
self.unannotated_items_ratio = unannotated_items_ratio
self.auto_num_workers = auto_num_workers
self.device = device
self.subsets: dict[str, OTXDataset] = {}
self.save_hyperparameters(ignore=["input_size"])
dataset = DmDataset.import_from(self.data_root, format=self.data_format)
if self.task != OTXTaskType.H_LABEL_CLS and not (
self.task == OTXTaskType.KEYPOINT_DETECTION and self.data_format == "arrow"
):
dataset = pre_filtering(
dataset,
self.data_format,
self.unannotated_items_ratio,
ignore_index=self.ignore_index if self.task == "SEMANTIC_SEGMENTATION" else None,
)
if isinstance(input_size, str) and input_size == "auto":
input_size = adapt_input_size_to_dataset(
dataset,
self.task,
input_size_multiplier,
)
elif not isinstance(input_size, (tuple, list)):
msg = f"input_size should be tuple/list of ints or 'auto', but got {input_size}"
raise ValueError(msg)
for subset_cfg in [train_subset, val_subset, test_subset]:
if subset_cfg.input_size is None:
subset_cfg.input_size = input_size # type: ignore[assignment]
# get mean and std from Normalize transform
mean = (0.0, 0.0, 0.0)
std = (1.0, 1.0, 1.0)
if train_subset.transforms is not None:
for transform in train_subset.transforms:
if isinstance(transform, dict) and "Normalize" in transform.get("class_path", ""):
# CLI case with jsonargparse
mean = transform["init_args"].get("mean", (0.0, 0.0, 0.0))
std = transform["init_args"].get("std", (1.0, 1.0, 1.0))
break
if isinstance(transform, Normalize):
# torchvision.transforms case
mean = transform.mean
std = transform.std
break
self.input_mean = mean
self.input_std = std
self.input_size = input_size
if self.tile_config.enable_tiler and self.tile_config.enable_adaptive_tiling:
adapt_tile_config(self.tile_config, dataset=dataset, task=self.task)
config_mapping = {
self.train_subset.subset_name: self.train_subset,
self.val_subset.subset_name: self.val_subset,
self.test_subset.subset_name: self.test_subset,
}
if self.auto_num_workers:
if self.device not in [DeviceType.gpu, DeviceType.auto]:
log.warning(
"Only GPU device type support auto_num_workers. "
f"Current deveice type is {self.device!s}. auto_num_workers is skipped.",
)
elif (num_workers := get_adaptive_num_workers()) is not None:
for subset_name, subset_config in config_mapping.items():
log.info(
f"num_workers of {subset_name} subset is changed : "
f"{subset_config.num_workers} -> {num_workers}",
)
subset_config.num_workers = num_workers
label_infos: list[LabelInfo] = []
for name, dm_subset in dataset.subsets().items():
if name not in config_mapping:
log.warning(f"{name} is not available. Skip it")
continue
dataset = OTXDatasetFactory.create(
task=self.task,
dm_subset=dm_subset.as_dataset(),
cfg_subset=config_mapping[name],
data_format=self.data_format,
image_color_channel=image_color_channel,
include_polygons=include_polygons,
ignore_index=ignore_index,
)
if self.tile_config.enable_tiler:
dataset = OTXTileDatasetFactory.create(
task=self.task,
dataset=dataset,
tile_config=self.tile_config,
)
self.subsets[name] = dataset
label_infos += [self.subsets[name].label_info]
log.info(f"Add name: {name}, self.subsets: {self.subsets}")
if self._is_meta_info_valid(label_infos) is False:
msg = "All data meta infos of subsets should be the same."
raise ValueError(msg)
self.label_info = next(iter(label_infos))
def _is_meta_info_valid(self, label_infos: list[LabelInfo]) -> bool:
"""Check whether there are mismatches in the metainfo for the all subsets."""
return bool(all(label_info == label_infos[0] for label_info in label_infos))
def _get_dataset(self, subset: str) -> OTXDataset:
if (dataset := self.subsets.get(subset)) is None:
msg = f"Dataset has no '{subset}'. Available subsets = {list(self.subsets.keys())}"
raise KeyError(msg)
return dataset
[docs]
def train_dataloader(self) -> DataLoader:
"""Get train dataloader."""
config = self.train_subset
dataset = self._get_dataset(config.subset_name)
sampler = instantiate_sampler(config.sampler, dataset=dataset, batch_size=config.batch_size)
common_args = {
"dataset": dataset,
"batch_size": config.batch_size,
"num_workers": config.num_workers,
"pin_memory": True,
"collate_fn": dataset.collate_fn,
"persistent_workers": config.num_workers > 0,
"sampler": sampler,
"shuffle": sampler is None,
}
tile_config = self.tile_config
if tile_config.enable_tiler and tile_config.sampling_ratio < 1:
num_samples = max(1, int(len(dataset) * tile_config.sampling_ratio))
log.info(f"Using tiled sampling with {num_samples} samples")
common_args.update(
{
"shuffle": False,
"sampler": RandomSampler(dataset, num_samples=num_samples),
},
)
return DataLoader(**common_args)
[docs]
def val_dataloader(self) -> DataLoader:
"""Get val dataloader."""
config = self.val_subset
dataset = self._get_dataset(config.subset_name)
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
collate_fn=dataset.collate_fn,
persistent_workers=config.num_workers > 0,
)
[docs]
def test_dataloader(self) -> DataLoader:
"""Get test dataloader."""
config = self.test_subset
dataset = self._get_dataset(config.subset_name)
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
collate_fn=dataset.collate_fn,
persistent_workers=config.num_workers > 0,
)
[docs]
def predict_dataloader(self) -> DataLoader:
"""Get test dataloader."""
config = self.test_subset
dataset = self._get_dataset(config.subset_name)
return DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
collate_fn=dataset.collate_fn,
persistent_workers=config.num_workers > 0,
)
[docs]
def setup(self, stage: str) -> None:
"""Setup for each stage."""
[docs]
def teardown(self, stage: str) -> None:
"""Teardown for each stage."""
# clean up after fit or test
# called on every process in DDP
@property
def hparams_initial(self) -> AttributeDict:
"""The collection of hyperparameters saved with `save_hyperparameters()`. It is read-only.
The reason why we override is that we have some custom resolvers for `DictConfig`.
Some resolved Python objects has not a primitive type, so that is not loggable without errors.
Therefore, we need to unresolve it this time.
"""
hp = super().hparams_initial
for key, value in hp.items():
if isinstance(value, DictConfig):
# It should be unresolved to make it loggable
hp[key] = OmegaConf.to_container(value, resolve=False)
return hp
def __reduce__(self):
"""Re-initialize object when unpickled."""
return (
self.__class__,
(
self.task,
self.data_format,
self.data_root,
self.train_subset,
self.val_subset,
self.test_subset,
self.tile_config,
self.image_color_channel,
self.include_polygons,
self.ignore_index,
self.unannotated_items_ratio,
self.auto_num_workers,
self.device,
self.input_size,
),
)