otx.models#

Reimport models from differnt backends for user frendly imports.

Classes

Padim(data_input_params[, label_info, ...])

OTX Padim model.

Stfpm(data_input_params[, label_info, ...])

OTX STFPM model.

Uflow(data_input_params[, label_info, ...])

OTX UFlow model.

EfficientNet(label_info, data_input_params)

Factory class for EfficientNet models.

TimmModel(label_info, data_input_params[, ...])

Factory class for TimmModel models.

MobileNetV3(label_info, data_input_params[, ...])

Factory class for MobileNetV3 models.

TVModel(label_info, data_input_params[, ...])

Factory class for Torch Vision models.

VisionTransformer(label_info, data_input_params)

Factory class for VisionTransformer models.

ATSS(label_info, data_input_params, ...[, ...])

OTX Detection model class for ATSS.

DFine(label_info, data_input_params, ...[, ...])

OTX Detection model class for DFine.

SSD(label_info, data_input_params, ...[, ...])

OTX Detection model class for SSD.

RTMDet(label_info, data_input_params, ...[, ...])

OTX Detection model class for RTMDet.

RTDETR(label_info, data_input_params, ...[, ...])

OTX Detection model class for RTDETR.

MaskRCNN(label_info, data_input_params, ...)

Implementation of MaskRCNN for instance segmentation.

MaskRCNNTV(label_info, data_input_params, ...)

Implementation of torchvision MaskRCNN for instance segmentation.

RTMDetInst(label_info, data_input_params, ...)

Implementation of RTMDetInst for instance segmentation.

RTMPose(label_info, data_input_params, ...)

RTMPose Model.

DinoV2Seg(label_info, data_input_params, ...)

DinoV2Seg for Semantic Segmentation model.

LiteHRNet(label_info, data_input_params, ...)

LiteHRNet Model.

SegNext(label_info, data_input_params, ...)

SegNext Model.

OVModel(model_path, model_type, ...)

Base class for the OpenVINO model.

OVDetectionModel(model_path, model_type, ...)

OVDetectionModel: Object detection model compatible for OpenVINO IR inference.

OVMulticlassClassificationModel(model_path, ...)

Classification model compatible for OpenVINO IR inference.

OVMultilabelClassificationModel(model_path, ...)

Multilabel classification model compatible for OpenVINO IR inference.

OVSegmentationModel(model_path, model_type, ...)

Semantic segmentation model compatible for OpenVINO IR inference.

OVHlabelClassificationModel(model_path, ...)

Hierarchical classification model compatible for OpenVINO IR inference.

OVInstanceSegmentationModel(model_path, ...)

Instance segmentation model compatible for OpenVINO IR inference.

OVKeypointDetectionModel(model_path, ...)

Keypoint detection model compatible for OpenVINO IR inference.

class otx.models.ATSS(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['atss_mobilenetv2', 'atss_resnext101'] = 'atss_mobilenetv2', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXDetectionModel

OTX Detection model class for ATSS.

pretrained_weights#

Dictionary containing URLs for pretrained weights.

Type:

ClassVar[dict[str, str]]

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (Literal, optional) – Name of the model to use. Defaults to “atss_mobilenetv2”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.models.DFine(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['dfine_hgnetv2_n', 'dfine_hgnetv2_s', 'dfine_hgnetv2_m', 'dfine_hgnetv2_l', 'dfine_hgnetv2_x'] = 'dfine_hgnetv2_x', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mean_ap_f_measure_callable>, multi_scale: bool = False, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXDetectionModel

OTX Detection model class for DFine.

pretrained_weights#

Dictionary containing URLs for pretrained weights.

Type:

ClassVar[dict[str, str]]

input_size_multiplier#

Multiplier for the input size.

Type:

int

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (literal, optional) – Name of the model to use. Defaults to “dfine_hgnetv2_x”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable.

  • multi_scale (bool, optional) – Whether to use multi-scale training. Defaults to False.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

configure_optimizers() tuple[list[Optimizer], list[dict[str, Any]]][source]#

Configure an optimizer and learning-rate schedulers.

Set up the optimizer and schedulers from the provided inputs. Typically, a warmup scheduler is used initially, followed by the main scheduler.

Returns:

Two list. The former is a list that contains an optimizer The latter is a list of lr scheduler configs which has a dictionary format.

class otx.models.DinoV2Seg(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['dinov2-small-seg'] = 'dinov2-small-seg', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _segm_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXSegmentationModel

DinoV2Seg for Semantic Segmentation model.

Parameters:
  • label_info (LabelInfoTypes) – Information about the hierarchical labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (Literal, optional) – Name of the model. Defaults to “dinov2-small-seg”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler.

  • DefaultSchedulerCallable. (Defaults to) –

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to SegmCallable.

  • torch_compile (bool, optional) – Flag to indicate whether to use torch.compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

class otx.models.EfficientNet(label_info: LabelInfoTypes, data_input_params: DataInputParams, task: Literal['multi_class', 'multi_label', 'h_label'] = 'multi_class', model_name: Literal['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8'] = 'efficientnet_b0', freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False)[source]#

Bases: object

Factory class for EfficientNet models.

Factory method to create EfficientNet models based on the task type.

Parameters:
  • label_info (LabelInfoTypes) – The label information.

  • data_input_params (DataInputParams | dict) – The data input parameters that consists of input size, mean and std.

  • freeze_backbone (bool, optional) – Whether to freeze the backbone during training. Defaults to False. Note: only multiclass classification supports this argument.

  • (Literal["efficientnet_b0" (model_name) – “efficientnet_b4”, “efficientnet_b5”, “efficientnet_b6”, “efficientnet_b7”, “efficientnet_b8”], optional): The model name. Defaults to “efficientnet_b0”.

  • "efficientnet_b1" – “efficientnet_b4”, “efficientnet_b5”, “efficientnet_b6”, “efficientnet_b7”, “efficientnet_b8”], optional): The model name. Defaults to “efficientnet_b0”.

  • "efficientnet_b2" – “efficientnet_b4”, “efficientnet_b5”, “efficientnet_b6”, “efficientnet_b7”, “efficientnet_b8”], optional): The model name. Defaults to “efficientnet_b0”.

  • "efficientnet_b3" – “efficientnet_b4”, “efficientnet_b5”, “efficientnet_b6”, “efficientnet_b7”, “efficientnet_b8”], optional): The model name. Defaults to “efficientnet_b0”.

:param“efficientnet_b4”, “efficientnet_b5”, “efficientnet_b6”, “efficientnet_b7”,

“efficientnet_b8”], optional): The model name. Defaults to “efficientnet_b0”.

Parameters:
  • task (Literal["multi_class", "multi_label", "h_label"], optional) – The task type. Can be “multi_class”, “multi_label”, or “h_label”. Defaults to “multi_class”.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Examples

>>> # Basic usage
>>> model = EfficientNet(
...     task="multi_class",
...     label_info=10,
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     model_name="efficientnet_b0",
... )
class otx.models.LiteHRNet(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['lite_hrnet_s', 'lite_hrnet_18', 'lite_hrnet_x'] = 'lite_hrnet_18', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _segm_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXSegmentationModel

LiteHRNet Model.

Parameters:
  • label_info (LabelInfoTypes) – Information about the hierarchical labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (Literal, optional) – Name of the model. Defaults to “lite_hrnet_18”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler.

  • DefaultSchedulerCallable. (Defaults to) –

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to SegmCallable.

  • torch_compile (bool, optional) – Flag to indicate whether to use torch.compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

property ignore_scope: dict[str, Any]#

Get the ignored scope for LiteHRNet.

class otx.models.MaskRCNN(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['maskrcnn_resnet_50', 'maskrcnn_efficientnet_b2b', 'maskrcnn_swin_tiny'] = 'maskrcnn_resnet_50', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _rle_mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXInstanceSegModel

Implementation of MaskRCNN for instance segmentation.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels used in the model.

  • data_input_params (DataInputParams) – Parameters for the data input.

  • model_name (str, optional) – Name of the model. Defaults to “maskrcnn_resnet_50”.

  • optimizer (OptimizerCallable, optional) – Optimizer for the model. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Scheduler for the model. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Metric for evaluating the model. Defaults to MaskRLEMeanAPFMeasureCallable.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

  • explain_mode (bool, optional) – Whether to enable explainable AI mode. Defaults to False.

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.models.MaskRCNNTV(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['maskrcnn_resnet_50'] = 'maskrcnn_resnet_50', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _rle_mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXInstanceSegModel

Implementation of torchvision MaskRCNN for instance segmentation.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels used in the model.

  • data_input_params (DataInputParams) – Parameters for the data input.

  • model_name (str, optional) – Name of the model. Defaults to “maskrcnn_resnet_50”.

  • optimizer (OptimizerCallable, optional) – Optimizer for the model. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Scheduler for the model. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Metric for evaluating the model. Defaults to MaskRLEMeanAPFMeasureCallable.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

  • explain_mode (bool, optional) – Whether to enable explainable AI mode. Defaults to False.

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

forward_for_tracing(inputs: Tensor) tuple[Tensor, ...][source]#

Forward function for export.

class otx.models.MobileNetV3(label_info: LabelInfoTypes, data_input_params: DataInputParams | dict, task: Literal['multi_class', 'multi_label', 'h_label'] = 'multi_class', freeze_backbone: bool = False, model_name: Literal['mobilenetv3_large', 'mobilenetv3_small'] = 'mobilenetv3_large', optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False)[source]#

Bases: object

Factory class for MobileNetV3 models.

Factory method to create MobileNetV3 models based on the task type.

Parameters:
  • label_info (LabelInfoTypes) – The label information.

  • data_input_params (DataInputParams | dict) – The data input parameters that consists of input size, mean and std.

  • freeze_backbone (bool, optional) – Whether to freeze the backbone during training. Defaults to False. Note: only multiclass classification supports this argument.

  • model_name (str, optional) – The model name. Defaults to “mobilenetv3_large”.

  • task (Literal["multi_class", "multi_label", "h_label"], optional) – The task type. Can be “multi_class”, “multi_label”, or “h_label”. Defaults to “multi_class”.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Examples: >>> # Basic usage >>> model = MobileNetV3( … task=”multi_class”, … label_info=10, … data_input_params={“input_size”: (224, 224), … “mean”: [123.675, 116.28, 103.53], … “std”: [58.395, 57.12, 57.375]}, … model_name=”mobilenetv3_small”, … )

>>> # Multi-label classification
>>> model = MobileNetV3(
...     task="multi_label",
...     model_name="mobilenetv3_large",
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     label_info=[1, 5, 10]  # Multi-label setup
... )
class otx.models.OVDetectionModel(model_path: PathLike, model_type: str = 'SSD', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _mean_ap_f_measure_callable>, **kwargs)[source]#

Bases: OVModel

OVDetectionModel: Object detection model compatible for OpenVINO IR inference.

This class is designed to work with OpenVINO IR models or models from the Intel OMZ repository. It provides compatibility with the OTX testing pipeline for object detection tasks.

Initialize the OVDetectionModel.

model_path (PathLike): Path to the OpenVINO IR model. model_type (str): Type of the model (default: “SSD”). async_inference (bool): Whether to use asynchronous inference (default: True). max_num_requests (int | None): Maximum number of inference requests (default: None). use_throughput_mode (bool): Whether to use throughput mode (default: True). model_api_configuration (dict[str, Any] | None): Configuration for the model API (default: None). metric (MetricCallable): Metric callable for evaluation (default: MeanAveragePrecisionFMeasureCallable). **kwargs: Additional keyword arguments.

Setup the tiler for handling tiled inference tasks.

This method configures the tiler with the appropriate execution mode and disables asynchronous inference as tiling has its own sync/async implementation. …

Extract hyperparameters from the OpenVINO model adapter.

model_adapter (OpenvinoAdapter): The adapter to extract model configuration from.

This method reads the confidence threshold from the model’s runtime information (rt_info). If unavailable, it logs a warning and sets the confidence threshold to None. …

Customize the outputs of the model to match the expected format.

outputs (list[DetectionResult]): List of detection results from the model. inputs (OTXDataBatch): Input batch containing image and metadata.

OTXPredBatch: A batch of predictions including bounding boxes, scores, labels, and optionally saliency maps and feature vectors.

Prepare inputs for metric computation.

preds (OTXPredBatch): Predicted batch containing bounding boxes, scores, and labels. inputs (OTXDataBatch): Input batch containing ground truth bounding boxes and labels.

MetricInput: A dictionary with ‘preds’ and ‘target’ keys containing the predicted and ground truth bounding boxes and labels.

Compute evaluation metrics for the model.

metric (Metric): Metric object used for evaluation.

dict: A dictionary containing computed metric values.

Initialize the OVModel instance.

Parameters:
  • model_path (PathLike) – Path to the model file.

  • model_type (str) – Type of the model.

  • async_inference (bool) – Whether to enable asynchronous inference.

  • force_cpu (bool) – Whether to force the use of CPU.

  • max_num_requests (int | None) – Maximum number of inference requests.

  • use_throughput_mode (bool) – Whether to use throughput mode.

  • model_api_configuration (dict[str, Any] | None) – Configuration for the Model API.

  • metric (MetricCallable) – Metric callable for evaluation.

compute_metrics(metric: Metric) dict[source]#

Compute metrics for the model.

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) dict[str, Any] | list[dict[str, Any]] | dict[str, list[dict[str, Any]]][source]#

Convert prediction and input entities to a format suitable for metric computation.

Parameters:
  • preds (OTXPredBatch) – The predicted batch entity containing predicted bboxes.

  • inputs (OTXDataBatch) – The input batch entity containing ground truth bboxes.

Returns:

A dictionary contains ‘preds’ and ‘target’ keys corresponding to the predicted and target bboxes for metric evaluation.

Return type:

MetricInput

class otx.models.OVHlabelClassificationModel(model_path: PathLike, model_type: str = 'Classification', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _mixed_hlabel_accuracy>, **kwargs)[source]#

Bases: OVModel

Hierarchical classification model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX classification model compatible for OTX testing pipeline.

Initialize the hierarchical classification model.

Parameters:
  • model_path (PathLike) – Path to the OpenVINO IR model.

  • model_type (str) – Type of the model (default: “Classification”).

  • async_inference (bool) – Whether to enable asynchronous inference (default: True).

  • max_num_requests (int | None) – Maximum number of inference requests (default: None).

  • use_throughput_mode (bool) – Whether to use throughput mode (default: True).

  • model_api_configuration (dict[str, Any] | None) – Configuration for the model API (default: None).

  • metric (MetricCallable) – Metric callable for evaluation (default: HLabelClsMetricCallable).

  • **kwargs – Additional keyword arguments.

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) dict[str, Any] | list[dict[str, Any]] | dict[str, list[dict[str, Any]]][source]#

Prepare inputs for metric computation.

Converts predictions and ground truth inputs into a format suitable for metric evaluation.

Parameters:
  • preds (OTXPredBatch) – Predicted batch containing labels and scores.

  • inputs (OTXDataBatch) – Input batch containing ground truth labels.

Returns:

A dictionary with ‘preds’ and ‘target’ keys for metric evaluation.

Return type:

MetricInput

class otx.models.OVInstanceSegmentationModel(model_path: PathLike, model_type: str = 'MaskRCNN', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _rle_mean_ap_f_measure_callable>, **kwargs)[source]#

Bases: OVModel

Instance segmentation model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX detection model compatible for OTX testing pipeline.

Initialize the instance segmentation model.

Parameters:
  • model_path (PathLike) – Path to the OpenVINO IR model.

  • model_type (str) – Type of the model (default: “MaskRCNN”).

  • async_inference (bool) – Whether to use asynchronous inference (default: True).

  • max_num_requests (int | None) – Maximum number of inference requests (default: None).

  • use_throughput_mode (bool) – Whether to use throughput mode (default: True).

  • model_api_configuration (dict[str, Any] | None) – Model API configuration (default: None).

  • metric (MetricCallable) – Metric callable for evaluation (default: MaskRLEMeanAPFMeasureCallable).

  • **kwargs – Additional keyword arguments.

compute_metrics(metric: Metric) dict[source]#

Compute evaluation metrics for the model.

Parameters:

metric (Metric) – Metric object to compute the evaluation metrics.

Returns:

Computed metrics.

Return type:

dict

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) dict[str, Any] | list[dict[str, Any]] | dict[str, list[dict[str, Any]]][source]#

Prepare inputs for metric computation.

Converts predictions and ground truth to the format required by the metric and caches the ground truth for the current batch.

Parameters:
  • preds (OTXPredBatch) – Current batch predictions.

  • inputs (OTXDataBatch) – Current batch ground-truth inputs.

Returns:

Dictionary containing predictions and ground truth.

Return type:

MetricInput

class otx.models.OVKeypointDetectionModel(model_path: PathLike, model_type: str = 'keypoint_detection', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _pck_measure_callable>)[source]#

Bases: OVModel

Keypoint detection model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX keypoint detection model compatible for OTX testing pipeline.

Initialize the keypoint detection model.

Parameters:
  • model_path (PathLike) – Path to the OpenVINO IR model.

  • model_type (str) – Type of the model. Defaults to “keypoint_detection”.

  • async_inference (bool) – Whether to enable asynchronous inference. Defaults to True.

  • max_num_requests (int | None) – Maximum number of inference requests. Defaults to None.

  • use_throughput_mode (bool) – Whether to enable throughput mode. Defaults to True.

  • model_api_configuration (dict[str, Any] | None) – Configuration for the model API. Defaults to None.

  • metric (MetricCallable) – Metric callable for evaluation. Defaults to PCKMeasureCallable.

compute_metrics(metric: Metric) dict[source]#

Compute evaluation metrics for the keypoint detection model.

Parameters:

metric (Metric) – Metric object used for evaluation.

Returns:

A dictionary containing computed metric values.

Return type:

dict

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) dict[str, Any] | list[dict[str, Any]] | dict[str, list[dict[str, Any]]][source]#

Prepare inputs for metric computation.

Converts prediction and input entities to a format suitable for metric evaluation.

Parameters:
  • preds (OTXPredBatch) – The predicted batch entity containing predicted keypoints.

  • inputs (OTXDataBatch) – The input batch entity containing ground truth keypoints.

Returns:

A dictionary containing ‘preds’ and ‘target’ keys corresponding to the predicted and target keypoints for metric evaluation.

Return type:

MetricInput

Raises:

ValueError – If ground truth keypoints, predicted keypoints, or scores are missing, or if the number of predicted and ground truth keypoints does not match.

class otx.models.OVModel(model_path: PathLike, model_type: str, async_inference: bool = True, force_cpu: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _null_metric_callable>)[source]#

Bases: object

Base class for the OpenVINO model.

This is a base class representing interface for interacting with OpenVINO Intermediate Representation (IR) models. OVModel can create and validate OpenVINO IR model directly from provided path locally or from OpenVINO OMZ repository. (Only PyTorch models are supported). OVModel supports synchronous as well as asynchronous inference type.

Parameters:

num_classes – Number of classes this model can predict.

Initialize the OVModel instance.

Parameters:
  • model_path (PathLike) – Path to the model file.

  • model_type (str) – Type of the model.

  • async_inference (bool) – Whether to enable asynchronous inference.

  • force_cpu (bool) – Whether to force the use of CPU.

  • max_num_requests (int | None) – Maximum number of inference requests.

  • use_throughput_mode (bool) – Whether to use throughput mode.

  • model_api_configuration (dict[str, Any] | None) – Configuration for the Model API.

  • metric (MetricCallable) – Metric callable for evaluation.

__call__(*args, **kwds)[source]#

Call the model for inference.

Parameters:
  • *args – Positional arguments.

  • **kwds – Keyword arguments.

Returns:

Model output.

Return type:

Any

compute_metrics(metric: Metric) dict[source]#

Compute metrics using the provided metric object.

Parameters:

metric (Metric) – Metric object.

Returns:

Computed metrics.

Return type:

dict

forward(inputs: OTXDataBatch, async_inference: bool = True) OTXPredBatch[source]#

Perform forward pass of the model.

Parameters:
  • inputs (OTXDataBatch) – Input data batch.

  • async_inference (bool) – Whether to use asynchronous inference.

Returns:

Model predictions.

Return type:

OTXPredBatch

get_dummy_input(batch_size: int = 1) OTXDataBatch[source]#

Generate a dummy input for the model.

Parameters:

batch_size (int) – Batch size for the dummy input.

Returns:

Dummy input data.

Return type:

OTXDataBatch

optimize(output_dir: Path, data_module: OTXDataModule, ptq_config: dict[str, Any] | None = None, optimized_model_name: str = 'optimized_model') Path[source]#

Optimize the model using NNCF quantization.

Parameters:
  • output_dir (Path) – Directory to save the optimized model.

  • data_module (OTXDataModule) – Data module for training data.

  • ptq_config (dict[str, Any] | None) – PTQ configuration.

  • optimized_model_name (str) – Name of the optimized model.

Returns:

Path to the optimized model.

Return type:

Path

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) MetricInput[source]#

Prepare inputs for metric computation.

Parameters:
  • preds (OTXPredBatch) – Predicted batch entity.

  • inputs (OTXDataBatch) – Input batch entity.

Returns:

Dictionary containing predictions and targets.

Return type:

MetricInput

transform_fn(data_batch: OTXDataBatch) array[source]#

Transform data for PTQ.

Parameters:

data_batch (OTXDataBatch) – Input data batch.

Returns:

Transformed data.

Return type:

np.array

property label_info: LabelInfo#

Get label information of the model.

Returns:

Label information.

Return type:

LabelInfo

property model_adapter_parameters: dict#

Get model parameters for export.

Returns:

Model parameters.

Return type:

dict

property task: OTXTaskType | None#

Get the task type of the model.

Returns:

Task type.

Return type:

OTXTaskType | None

class otx.models.OVMulticlassClassificationModel(model_path: PathLike, model_type: str = 'Classification', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = False, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _multi_class_cls_metric_callable>)[source]#

Bases: OVModel

Classification model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX classification model compatible for OTX testing pipeline.

Initialize the OVMulticlassClassificationModel.

Parameters:
  • model_path (PathLike) – Path to the OpenVINO IR model or model name from Intel OMZ.

  • model_type (str) – Type of the model. Defaults to “Classification”.

  • async_inference (bool) – Whether to enable asynchronous inference. Defaults to True.

  • max_num_requests (int | None) – Maximum number of inference requests. Defaults to None.

  • use_throughput_mode (bool) – Whether to use throughput mode. Defaults to False.

  • model_api_configuration (dict[str, Any] | None) – Configuration for the model API. Defaults to None.

  • metric (MetricCallable) – Metric callable for evaluation. Defaults to MultiClassClsMetricCallable.

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) dict[str, Any] | list[dict[str, Any]] | dict[str, list[dict[str, Any]]][source]#

Prepare inputs for metric computation.

Converts prediction and input entities into a format suitable for metric evaluation.

Parameters:
  • preds (OTXPredBatch) – Predicted batch containing predicted labels and other metadata.

  • inputs (OTXDataBatch) – Input batch containing ground truth labels and other metadata.

Returns:

A dictionary containing ‘preds’ and ‘target’ keys corresponding to predicted and target labels.

Return type:

MetricInput

class otx.models.OVMultilabelClassificationModel(model_path: PathLike, model_type: str = 'Classification', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _multi_label_cls_metric_callable>, **kwargs)[source]#

Bases: OVModel

Multilabel classification model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX classification model compatible for OTX testing pipeline.

Initialize the multilabel classification model.

Parameters:
  • model_path (PathLike) – Path to the OpenVINO IR model or model name from Intel OMZ.

  • model_type (str) – Type of the model. Defaults to “Classification”.

  • async_inference (bool) – Whether to use asynchronous inference. Defaults to True.

  • max_num_requests (int | None) – Maximum number of inference requests. Defaults to None.

  • use_throughput_mode (bool) – Whether to use throughput mode. Defaults to True.

  • model_api_configuration (dict[str, Any] | None) – Configuration for the model API. Defaults to None.

  • metric (MetricCallable) – Metric callable for evaluation. Defaults to MultiLabelClsMetricCallable.

  • **kwargs – Additional keyword arguments.

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) dict[str, Any] | list[dict[str, Any]] | dict[str, list[dict[str, Any]]][source]#

Prepare inputs for metric computation.

Converts prediction and input entities to a format suitable for metric evaluation.

Parameters:
  • preds (OTXPredBatch) – The predicted batch entity containing predicted labels and scores.

  • inputs (OTXDataBatch) – The input batch entity containing ground truth labels.

Returns:

A dictionary containing ‘preds’ and ‘target’ keys corresponding to the predicted and target labels for metric evaluation.

Return type:

MetricInput

class otx.models.OVSegmentationModel(model_path: PathLike, model_type: str = 'Segmentation', async_inference: bool = True, max_num_requests: int | None = None, use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = <function _segm_callable>, **kwargs)[source]#

Bases: OVModel

Semantic segmentation model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository and create the OTX segmentation model compatible for OTX testing pipeline.

Initialize the OVSegmentationModel.

Parameters:
  • model_path (PathLike) – Path to the OpenVINO IR model.

  • model_type (str) – Type of the model (default: “Segmentation”).

  • async_inference (bool) – Whether to enable asynchronous inference (default: True).

  • max_num_requests (int | None) – Maximum number of inference requests (default: None).

  • use_throughput_mode (bool) – Whether to use throughput mode (default: True).

  • model_api_configuration (dict[str, Any] | None) – Configuration for the model API (default: None).

  • metric (MetricCallable) – Metric callable for evaluation (default: SegmCallable).

  • **kwargs – Additional keyword arguments.

prepare_metric_inputs(preds: OTXPredBatch, inputs: OTXDataBatch) dict[str, Any] | list[dict[str, Any]] | dict[str, list[dict[str, Any]]][source]#

Prepare inputs for metric computation.

Converts predictions and ground truth inputs into a format suitable for metric evaluation.

Parameters:
  • preds (OTXPredBatch) – Predicted segmentation batch containing masks.

  • inputs (OTXDataBatch) – Input batch containing ground truth masks.

Returns:

A list of dictionaries with ‘preds’ and ‘target’ keys for metric evaluation.

Return type:

MetricInput

Raises:

ValueError – If predicted or ground truth masks are not provided.

class otx.models.Padim(data_input_params: DataInputParams, label_info: LabelInfoTypes = AnomalyLabelInfo(label_names=['Normal', 'Anomaly'], label_ids=['0', '1'], label_groups=[['Normal', 'Anomaly']]), backbone: str = 'resnet18', layers: list[str] = ['layer1', 'layer2', 'layer3'], pre_trained: bool = True, n_features: int | None = None, task: Literal[OTXTaskType.ANOMALY, OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION] = OTXTaskType.ANOMALY_CLASSIFICATION)[source]#

Bases: AnomalyMixin, Padim, OTXAnomaly

OTX Padim model.

Parameters:
  • backbone (str, optional) – Feature extractor backbone. Defaults to “resnet18”.

  • layers (list[str], optional) – Feature extractor layers. Defaults to [“layer1”, “layer2”, “layer3”].

  • pre_trained (bool, optional) – Pretrained backbone. Defaults to True.

  • n_features (int | None, optional) – Number of features. Defaults to None.

  • (Literal[ (task) – OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION ], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (256, 256)

class otx.models.RTDETR(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['rtdetr_18', 'rtdetr_50', 'rtdetr_101'] = 'rtdetr_50', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXDetectionModel

OTX Detection model class for RTDETR.

pretrained_weights#

Dictionary containing URLs for pretrained weights.

Type:

ClassVar[dict[str, str]]

input_size_multiplier#

Multiplier for the input size.

Type:

int

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (literal, optional) – Name of the model to use. Defaults to “rtdetr_50”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable.

  • multi_scale (bool, optional) – Whether to use multi-scale training. Defaults to False.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

configure_optimizers() tuple[list[Optimizer], list[dict[str, Any]]][source]#

Configure an optimizer and learning-rate schedulers.

Configure an optimizer and learning-rate schedulers from the given optimizer and scheduler or scheduler list callable in the constructor. Generally, there is two lr schedulers. One is for a linear warmup scheduler and the other is the main scheduler working after the warmup period.

Returns:

Two list. The former is a list that contains an optimizer The latter is a list of lr scheduler configs which has a dictionary format.

class otx.models.RTMDet(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['rtmdet_tiny'] = 'rtmdet_tiny', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXDetectionModel

OTX Detection model class for RTMDet.

pretrained_weights#

Dictionary containing URLs for pretrained weights.

Type:

ClassVar[dict[str, str]]

input_size_multiplier#

Multiplier for the input size.

Type:

int

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (str, optional) – Name of the model to use. Defaults to “rtmdet_tiny”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

class otx.models.RTMDetInst(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['rtmdet_inst_tiny'] = 'rtmdet_inst_tiny', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _rle_mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXInstanceSegModel

Implementation of RTMDetInst for instance segmentation.

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels used in the model.

  • data_input_params (DataInputParams) – Parameters for the data input.

  • model_name (str, optional) – Name of the model. Defaults to “rtmdet_inst_tiny”.

  • optimizer (OptimizerCallable, optional) – Optimizer for the model. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Scheduler for the model. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Metric for evaluating the model. Defaults to MaskRLEMeanAPFMeasureCallable.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

  • explain_mode (bool, optional) – Whether to enable explainable AI mode. Defaults to False.

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

forward_for_tracing(inputs: Tensor) tuple[Tensor, ...][source]#

Forward function for export.

NOTE : RTMDetInst uses explain_mode unlike other models.

class otx.models.RTMPose(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['rtmpose_tiny'] = 'rtmpose_tiny', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _pck_measure_callable>, torch_compile: bool = False)[source]#

Bases: OTXKeypointDetectionModel

RTMPose Model.

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

class otx.models.SSD(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['ssd_mobilenetv2'] = 'ssd_mobilenetv2', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mean_ap_f_measure_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXDetectionModel

OTX Detection model class for SSD.

pretrained_weights#

Dictionary containing URLs for pretrained weights.

Type:

ClassVar[dict[str, str]]

Parameters:
  • label_info (LabelInfoTypes) – Information about the labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (str, optional) – Name of the model to use. Defaults to “ssd_mobilenetv2”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable.

  • torch_compile (bool, optional) – Whether to use torch compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

load_state_dict_pre_hook(state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) None[source]#

Modify input state_dict according to class name matching. It is used for incremental learning.

on_load_checkpoint(checkpoint: dict[str, Any]) None[source]#

Callback on load checkpoint.

setup(stage: str) None[source]#

Callback for setup OTX SSD Model.

OTXSSD requires auto anchor generating w.r.t. training dataset for better accuracy. This callback will provide training dataset to model’s anchor generator.

Parameters:

trainer (Trainer) – Lightning trainer contains OTXLitModule and OTXDatamodule.

class otx.models.SegNext(label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal['segnext_tiny', 'segnext_small', 'segnext_base'] = 'segnext_small', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _segm_callable>, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False, enable_adaptive_tiling=True, tile_size=(400, 400), overlap=0.2, iou_threshold=0.45, max_num_instances=1500, object_tile_ratio=0.03, sampling_ratio=1.0, with_full_img=False))[source]#

Bases: OTXSegmentationModel

SegNext Model.

Parameters:
  • label_info (LabelInfoTypes) – Information about the hierarchical labels.

  • data_input_params (DataInputParams) – Parameters for data input.

  • model_name (Literal, optional) – Name of the model. Defaults to “segnext_small”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – Callable for the learning rate scheduler.

  • DefaultSchedulerCallable. (Defaults to) –

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to SegmCallable.

  • torch_compile (bool, optional) – Flag to indicate whether to use torch.compile. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Initialize the base model with the given parameters.

Parameters:
  • label_info (LabelInfoTypes | int | Sequence) – Information about the labels used in the model. If int is given, label info will be constructed from number of classes, if Sequence is given, label info will be constructed from the sequence of label names.

  • data_input_params (DataInputParams | dict) – Parameters of the input data such as input size, mean, and std.

  • model_name (str, optional) – Name of the model. Defaults to “OTXModel”.

  • optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.

  • torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.

  • tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).

Returns:

None

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.models.Stfpm(data_input_params: DataInputParams, label_info: LabelInfoTypes = AnomalyLabelInfo(label_names=['Normal', 'Anomaly'], label_ids=['0', '1'], label_groups=[['Normal', 'Anomaly']]), layers: Sequence[str] = ['layer1', 'layer2', 'layer3'], backbone: str = 'resnet18', task: Literal[OTXTaskType.ANOMALY, OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION] = OTXTaskType.ANOMALY_CLASSIFICATION, **kwargs)[source]#

Bases: AnomalyMixin, Stfpm, OTXAnomaly

OTX STFPM model.

Parameters:
  • layers (Sequence[str]) – Feature extractor layers.

  • backbone (str, optional) – Feature extractor backbone. Defaults to “resnet18”.

  • (Literal[ (task) – OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION ], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (256, 256)

class otx.models.TVModel(label_info: LabelInfoTypes, data_input_params: DataInputParams, task: Literal['multi_class', 'multi_label', 'h_label'] = 'multi_class', model_name: str = 'efficientnet_v2_s', freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False)[source]#

Bases: object

Factory class for Torch Vision models.

Factory to create TV models based on the task type.

This class allows users to create models for multi-class, multi-label, or hierarchical label classification by specifying the task parameter. You can select any model available in the TorchVision library (over 40 models as of 2025) by providing its name to the model_name parameter. To explore all available models, use torchvision.models.list_models() or TVModel.list_models().

Parameters:
  • label_info (LabelInfoTypes) – The label information.

  • data_input_params (DataInputParams | dict) – The data input parameters that consists of input size, mean and std.

  • freeze_backbone (bool, optional) – Whether to freeze the backbone during training. Note: only multiclass classification supports this argument. Defaults to False.

  • model_name (str, optional) – The model name. Defaults to “efficientnet_v2_s”.

  • task (Literal["multi_class", "multi_label", "h_label"], optional) – The task type. Can be “multi_class”, “multi_label”, or “h_label”. Defaults to “multi_class”.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Examples

>>> # Basic usage
>>> model = TVModel(
...     task="multi_class",
...     label_info=10,
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     model_name="efficientnet_v2_s",
... )
... # Multi-label classification
>>> model = TVModel(
...     task="multi_label",
...     model_name="mobilenet_v3_small",
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     label_info=[1, 5, 10]  # Multi-label setup
... )
static list_models() list[str][source]#

List available Torch Vision models.

class otx.models.TimmModel(label_info: LabelInfoTypes, data_input_params: DataInputParams, task: Literal['multi_class', 'multi_label', 'h_label'] = 'multi_class', model_name: str = 'tf_efficientnetv2_s.in21k', freeze_backbone: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False)[source]#

Bases: object

Factory class for TimmModel models.

Factory method to create Timm models based on the task type.

This class allows users to create models for multi-class, multi-label, or hierarchical label classification by specifying the task parameter. Users can select any model available in the Timm library (over 900 models as of 2025) by providing its name to the model_name parameter. To explore all available models, use timm.list_models() or TimmModel.list_model().

Note: - If you wish to use Vision Transformer (ViT) models, it is recommended to use the VisionTransformer

implementation provided by OTX for better integration and support.

Parameters:
  • label_info (LabelInfoTypes) – The label information.

  • data_input_params (DataInputParams | dict) – The data input parameters that consists of input size, mean and std.

  • freeze_backbone (bool, optional) – Whether to freeze the backbone during training. Note: only multiclass classification supports this argument. Defaults to False.

  • model_name (str, optional) – The model name. Defaults to “tf_efficientnetv2_s.in21k”. You can find all available models at timm.list_models() or using TimmModel.list_model().

  • task (Literal["multi_class", "multi_label", "h_label"], optional) – The task type. Can be “multi_class”, “multi_label”, or “h_label”. Defaults to “multi_class”.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Examples

>>> # Basic usage
>>> model = TimmModel(
...     task="multi_class",
...     label_info=10,
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     model_name="tf_efficientnetv2_s.in21k",
... )
>>> # Multi-label classification
>>> model = TimmModel(
...     task="multi_label",
...     model_name="tf_efficientnetv2_s.in21k",
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     label_info=[1, 5, 10]  # Multi-label setup
... )
static list_models() list[str][source]#

List available Timm models.

class otx.models.Uflow(data_input_params: DataInputParams, label_info: LabelInfoTypes = AnomalyLabelInfo(label_names=['Normal', 'Anomaly'], label_ids=['0', '1'], label_groups=[['Normal', 'Anomaly']]), backbone: str = 'resnet18', flow_steps: int = 4, affine_clamp: float = 2.0, affine_subnet_channels_ratio: float = 1.0, permute_soft: bool = False, task: Literal[OTXTaskType.ANOMALY, OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION] = OTXTaskType.ANOMALY_CLASSIFICATION)[source]#

Bases: AnomalyMixin, Uflow, OTXAnomaly

OTX UFlow model.

Parameters:
  • label_info (LabelInfoTypes, optional) – Label information. Defaults to AnomalyLabelInfo().

  • backbone (str, optional) – Feature extractor backbone. Defaults to “resnet18”.

  • flow_steps (int, optional) – Number of flow steps. Defaults to 4.

  • affine_clamp (float, optional) – Affine clamp. Defaults to 2.0.

  • affine_subnet_channels_ratio (float, optional) – Affine subnet channels ratio. Defaults to 1.0.

  • permute_soft (bool, optional) – Whether to use soft permutation. Defaults to False.

  • (Literal[ (task) – OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION ], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (256, 256)

class otx.models.VisionTransformer(label_info: LabelInfoTypes, data_input_params: DataInputParams, task: Literal['multi_class', 'multi_label', 'h_label'] = 'multi_class', model_name: Literal['vit-tiny', 'vit-small', 'vit-base', 'vit-large', 'dinov2-small', 'dinov2-base', 'dinov2-large', 'dinov2-giant'] = 'vit-tiny', freeze_backbone: bool = False, lora: bool = False, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False)[source]#

Bases: object

Factory class for VisionTransformer models.

Factory to create VisionTransformer models based on the task type.

This class supports multi-class, multi-label, and hierarchical label classification tasks. It provides VIT backbones (tiny to large) and DINOv2 backbones (small to giant).

Parameters:
  • label_info (LabelInfoTypes) – The label information.

  • data_input_params (DataInputParams | dict) – The data input parameters that consists of input size, mean and std.

  • freeze_backbone (bool, optional) – Whether to freeze the backbone during training. Note: only multiclass classification supports this argument. Defaults to False.

  • (Literal["vit-tiny" (model_name) – “dinov2-small”, “dinov2-base”, “dinov2-large”, “dinov2-giant”], optional): The model name. Defaults to “vit-tiny”.

  • "vit-small" – “dinov2-small”, “dinov2-base”, “dinov2-large”, “dinov2-giant”], optional): The model name. Defaults to “vit-tiny”.

  • "vit-base" – “dinov2-small”, “dinov2-base”, “dinov2-large”, “dinov2-giant”], optional): The model name. Defaults to “vit-tiny”.

  • "vit-large" – “dinov2-small”, “dinov2-base”, “dinov2-large”, “dinov2-giant”], optional): The model name. Defaults to “vit-tiny”.

:param“dinov2-small”, “dinov2-base”, “dinov2-large”, “dinov2-giant”], optional):

The model name. Defaults to “vit-tiny”.

Parameters:
  • task (Literal["multi_class", "multi_label", "h_label"], optional) – The task type. Can be “multi_class”, “multi_label”, or “h_label”. Defaults to “multi_class”.

  • optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.

  • metric (MetricCallable, optional) – The metric callable. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Examples

>>> # Basic usage
>>> model = VisionTransformer(
...     task="multi_class",
...     label_info=10,
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     model_name="vit-tiny",
... )
>>> # Multi-label classification
>>> model = VisionTransformer(
...     task="multi_label",
...     model_name="vit-small",
...     data_input_params={"input_size": (224, 224),
...                        "mean": [123.675, 116.28, 103.53],
...                        "std": [58.395, 57.12, 57.375]},
...     label_info=[1, 5, 10]  # Multi-label setup
... )