Source code for otx.backend.native.models.anomaly.uflow

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OTX UFlow model."""

# mypy: ignore-errors

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from anomalib.models.image import Uflow as AnomalibUflow

from otx.backend.native.models.anomaly.base import AnomalyMixin, OTXAnomaly
from otx.backend.native.models.base import DataInputParams
from otx.types.label import AnomalyLabelInfo
from otx.types.task import OTXTaskType

if TYPE_CHECKING:
    from otx.types.label import LabelInfoTypes


[docs] class Uflow(AnomalyMixin, AnomalibUflow, OTXAnomaly): """OTX UFlow model. Args: label_info (LabelInfoTypes, optional): Label information. Defaults to AnomalyLabelInfo(). backbone (str, optional): Feature extractor backbone. Defaults to "resnet18". flow_steps (int, optional): Number of flow steps. Defaults to 4. affine_clamp (float, optional): Affine clamp. Defaults to 2.0. affine_subnet_channels_ratio (float, optional): Affine subnet channels ratio. Defaults to 1.0. permute_soft (bool, optional): Whether to use soft permutation. Defaults to False. task (Literal[ OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION ], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION. input_size (tuple[int, int], optional): Model input size in the order of height and width. Defaults to (256, 256) """ def __init__( self, data_input_params: DataInputParams, label_info: LabelInfoTypes = AnomalyLabelInfo(), backbone: str = "resnet18", flow_steps: int = 4, affine_clamp: float = 2.0, affine_subnet_channels_ratio: float = 1.0, permute_soft: bool = False, task: Literal[ OTXTaskType.ANOMALY, OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION, ] = OTXTaskType.ANOMALY_CLASSIFICATION, ) -> None: self.data_input_params = data_input_params self.input_size = data_input_params.input_size self.task = OTXTaskType(task) super().__init__( backbone=backbone, flow_steps=flow_steps, affine_clamp=affine_clamp, affine_subnet_channels_ratio=affine_subnet_channels_ratio, permute_soft=permute_soft, )