otx.core.model.keypoint_detection#
Class definition for keypoint detection model entity used in OTX.
Classes
|
Base class for the detection models used in OTX. |
- class otx.core.model.keypoint_detection.OTXKeypointDetectionModel(label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _pck_measure_callable>, torch_compile: bool = False)[source]#
Bases:
OTXModel[KeypointDetBatchDataEntity,KeypointDetBatchPredEntity]Base class for the detection models used in OTX.
- forward_for_tracing(image: Tensor) Tensor | tuple[Tensor][source]#
Model forward function used for the model tracing during model exportation.
- get_classification_layers(prefix: str = 'model.') dict[str, dict[str, int]][source]#
Get final classification layer information for incremental learning case.
- get_dummy_input(batch_size: int = 1) KeypointDetBatchDataEntity[source]#
Generates a dummy input, suitable for launching forward() on it.
- Parameters:
batch_size (int, optional) – number of elements in a dummy input sequence. Defaults to 1.
- Returns:
An entity containing randomly generated inference data.
- Return type:
KeypointDetBatchDataEntity