# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition for keypoint detection model entity used in OTX."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import torch
from otx.backend.openvino.models.base import OVModel
from otx.data.entity.torch import OTXDataBatch, OTXPredBatch
from otx.metrics import MetricCallable, MetricInput
from otx.metrics.pck import PCKMeasureCallable
from otx.types.task import OTXTaskType
if TYPE_CHECKING:
from model_api.models.result import DetectedKeypoints
from torchmetrics import Metric
from otx.types import PathLike
[docs]
class OVKeypointDetectionModel(OVModel):
"""Keypoint detection model compatible for OpenVINO IR inference.
It can consume OpenVINO IR model path or model name from Intel OMZ repository
and create the OTX keypoint detection model compatible for OTX testing pipeline.
"""
def __init__(
self,
model_path: PathLike,
model_type: str = "keypoint_detection",
async_inference: bool = True,
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = PCKMeasureCallable,
) -> None:
"""Initialize the keypoint detection model.
Args:
model_path (PathLike): Path to the OpenVINO IR model.
model_type (str): Type of the model. Defaults to "keypoint_detection".
async_inference (bool): Whether to enable asynchronous inference. Defaults to True.
max_num_requests (int | None): Maximum number of inference requests. Defaults to None.
use_throughput_mode (bool): Whether to enable throughput mode. Defaults to True.
model_api_configuration (dict[str, Any] | None): Configuration for the model API. Defaults to None.
metric (MetricCallable): Metric callable for evaluation. Defaults to PCKMeasureCallable.
"""
super().__init__(
model_path=model_path,
model_type=model_type,
async_inference=async_inference,
max_num_requests=max_num_requests,
use_throughput_mode=use_throughput_mode,
model_api_configuration=model_api_configuration,
metric=metric,
)
self._task = OTXTaskType.KEYPOINT_DETECTION
def _customize_outputs(
self,
outputs: list[DetectedKeypoints],
inputs: OTXDataBatch,
) -> OTXPredBatch:
"""Customize the outputs of the model for keypoint detection.
Args:
outputs (list[DetectedKeypoints]): List of detected keypoints from the model.
inputs (OTXDataBatch): Input batch containing images and metadata.
Returns:
OTXPredBatch: A batch containing processed keypoints, scores, and other metadata.
"""
keypoints = []
scores = []
# default visibility threshold
visibility_threshold = 0.5
for output in outputs:
kps = torch.as_tensor(output.keypoints)
score = torch.as_tensor(output.scores)
visible_keypoints = torch.cat([kps, score.unsqueeze(1) > visibility_threshold], dim=1)
keypoints.append(visible_keypoints)
scores.append(score)
return OTXPredBatch(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
keypoints=keypoints,
scores=scores,
bboxes=[],
labels=[],
)
[docs]
def compute_metrics(self, metric: Metric) -> dict:
"""Compute evaluation metrics for the keypoint detection model.
Args:
metric (Metric): Metric object used for evaluation.
Returns:
dict: A dictionary containing computed metric values.
"""
metric.input_size = (self.model.h, self.model.w)
return super()._compute_metrics(metric)