Source code for otx.backend.native.models.segmentation.segnext

# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""SegNext model implementations."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.segmentation.backbones import MSCAN
from otx.backend.native.models.segmentation.base import OTXSegmentationModel
from otx.backend.native.models.segmentation.heads import LightHamHead
from otx.backend.native.models.segmentation.losses import CrossEntropyLossWithIgnore
from otx.backend.native.models.segmentation.segmentors import BaseSegmentationModel
from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper
from otx.config.data import TileConfig
from otx.metrics.dice import SegmCallable

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
    from torch import nn

    from otx.backend.native.schedulers import LRSchedulerListCallable
    from otx.metrics import MetricCallable
    from otx.types.label import LabelInfoTypes


[docs] class SegNext(OTXSegmentationModel): """SegNext Model. Args: label_info (LabelInfoTypes): Information about the hierarchical labels. data_input_params (DataInputParams): Parameters for data input. model_name (Literal, optional): Name of the model. Defaults to "segnext_small". optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Callable for the metric. Defaults to SegmCallable. torch_compile (bool, optional): Flag to indicate whether to use torch.compile. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). """ def __init__( self, label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal["segnext_tiny", "segnext_small", "segnext_base"] = "segnext_small", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = SegmCallable, # type: ignore[assignment] torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False), ): super().__init__( label_info=label_info, data_input_params=data_input_params, model_name=model_name, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, tile_config=tile_config, ) def _create_model(self, num_classes: int | None = None) -> nn.Module: # initialize backbones num_classes = num_classes if num_classes is not None else self.num_classes backbone = MSCAN(model_name=self.model_name) decode_head = LightHamHead(model_name=self.model_name, num_classes=num_classes) criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined] return BaseSegmentationModel( backbone=backbone, decode_head=decode_head, criterion=criterion, )
[docs] def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_seg_segnext_ckpt(state_dict, add_prefix)
@property def _optimization_config(self) -> dict[str, Any]: """PTQ config for SegNext.""" # TODO(Kirill): check PTQ removing hamburger from ignored_scope return { "ignored_scope": { "patterns": ["__module.model.decode_head.hamburger*"], "types": [ "Add", "MVN", "Divide", "Multiply", ], }, }