Source code for otx.backend.native.models.keypoint_detection.rtmpose

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""RTMPose model implementations."""

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Literal

from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.detection.backbones import CSPNeXt
from otx.backend.native.models.keypoint_detection.base import OTXKeypointDetectionModel
from otx.backend.native.models.keypoint_detection.detectors.topdown import TopdownPoseEstimator
from otx.backend.native.models.keypoint_detection.heads.rtmcc_head import RTMCCHead
from otx.backend.native.models.keypoint_detection.losses.kl_discret_loss import KLDiscretLoss
from otx.backend.native.models.utils.utils import load_checkpoint
from otx.metrics.pck import PCKMeasureCallable

if TYPE_CHECKING:
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
    from torch import nn

    from otx.backend.native.exporter.base import OTXModelExporter
    from otx.backend.native.schedulers import LRSchedulerListCallable
    from otx.metrics import MetricCallable
    from otx.types.export import TaskLevelExportParameters
    from otx.types.label import LabelInfoTypes


[docs] class RTMPose(OTXKeypointDetectionModel): """RTMPose Model.""" pretrained_weights: ClassVar[dict[str, str]] = { "rtmpose_tiny": "https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth", } def __init__( self, label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal["rtmpose_tiny"] = "rtmpose_tiny", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = PCKMeasureCallable, torch_compile: bool = False, ) -> None: super().__init__( label_info=label_info, data_input_params=data_input_params, model_name=model_name, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, ) def _create_model(self, num_classes: int | None = None) -> nn.Module: num_classes = num_classes if num_classes is not None else self.num_classes backbone = CSPNeXt(model_name=self.model_name) head = RTMCCHead( out_channels=num_classes, in_channels=384, input_size=self.data_input_params.input_size, in_featuremap_size=(self.data_input_params.input_size[0] // 32, self.data_input_params.input_size[1] // 32), simcc_split_ratio=2.0, final_layer_kernel_size=7, loss=KLDiscretLoss(use_target_weight=True, beta=10.0, label_softmax=True), decoder_cfg={ "input_size": self.data_input_params.input_size, "simcc_split_ratio": 2.0, "sigma": (4.9, 5.66), "normalize": False, "use_dark": False, }, gau_cfg={ "num_token": num_classes, "in_token_dims": 256, "out_token_dims": 256, "s": 128, "expansion_factor": 2, "dropout_rate": 0.0, "drop_path": 0.0, "act_fn": "SiLU", "use_rel_bias": False, "pos_enc": False, }, ) model = TopdownPoseEstimator( backbone=backbone, head=head, ) model.init_weights() load_checkpoint(model, self.pretrained_weights[self.model_name], map_location="cpu") return model @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" if self.explain_mode: msg = "Export with explain is not supported for RTMPose model." raise ValueError(msg) return OTXNativeModelExporter( task_level_export_parameters=self._export_parameters, data_input_params=self.data_input_params, resize_mode="fit_to_window", pad_value=0, swap_rgb=False, via_onnx=True, onnx_export_configuration={ "input_names": ["image"], "dynamic_axes": { "image": {0: "batch"}, "pred_x": {0: "batch"}, "pred_y": {0: "batch"}, }, "autograd_inlining": False, }, output_names=["pred_x", "pred_y"], ) @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" return super()._export_parameters.wrap(optimization_config={"preset": "mixed"})