Source code for otx.backend.native.models.detection.ssd

# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Copyright (c) OpenMMLab. All rights reserved.

"""SSD object detector for the OTX detection.

Implementation modified from mmdet.models.detectors.single_stage.
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/mmdet/models/detectors/single_stage.py
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import numpy as np
from datumaro.components.annotation import Bbox

from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.common.utils.assigners import MaxIoUAssigner
from otx.backend.native.models.common.utils.coders import DeltaXYWHBBoxCoder
from otx.backend.native.models.detection.base import OTXDetectionModel
from otx.backend.native.models.detection.detectors import SingleStageDetector
from otx.backend.native.models.detection.heads import SSDHead
from otx.backend.native.models.detection.losses import SSDCriterion
from otx.backend.native.models.detection.utils.prior_generators import SSDAnchorGeneratorClustered
from otx.backend.native.models.utils.support_otx_v1 import OTXv1Helper
from otx.backend.native.models.utils.utils import load_checkpoint
from otx.config.data import TileConfig
from otx.metrics.fmeasure import MeanAveragePrecisionFMeasureCallable

if TYPE_CHECKING:
    import torch
    from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
    from torch import nn

    from otx.backend.native.schedulers import LRSchedulerListCallable
    from otx.data.dataset.base import OTXDataset
    from otx.metrics import MetricCallable
    from otx.types.label import LabelInfoTypes


logger = logging.getLogger()


[docs] class SSD(OTXDetectionModel): """OTX Detection model class for SSD. Attributes: pretrained_weights (ClassVar[dict[str, str]]): Dictionary containing URLs for pretrained weights. Args: label_info (LabelInfoTypes): Information about the labels. data_input_params (DataInputParams): Parameters for data input. model_name (str, optional): Name of the model to use. Defaults to "ssd_mobilenetv2". optimizer (OptimizerCallable, optional): Callable for the optimizer. Defaults to DefaultOptimizerCallable. scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Callable for the metric. Defaults to MeanAveragePrecisionFMeasureCallable. torch_compile (bool, optional): Whether to use torch compile. Defaults to False. tile_config (TileConfig, optional): Configuration for tiling. Defaults to TileConfig(enable_tiler=False). """ pretrained_weights: ClassVar[dict[str, str]] = { "ssd_mobilenetv2": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/" "object_detection/v2/mobilenet_v2-2s_ssd-992x736.pth", } def __init__( self, label_info: LabelInfoTypes, data_input_params: DataInputParams, model_name: Literal["ssd_mobilenetv2"] = "ssd_mobilenetv2", optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = MeanAveragePrecisionFMeasureCallable, torch_compile: bool = False, tile_config: TileConfig = TileConfig(enable_tiler=False), ) -> None: super().__init__( label_info=label_info, data_input_params=data_input_params, model_name=model_name, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, tile_config=tile_config, ) def _create_model(self, num_classes: int | None = None) -> SingleStageDetector: num_classes = num_classes if num_classes is not None else self.num_classes train_cfg = { "assigner": MaxIoUAssigner( min_pos_iou=0.0, ignore_iof_thr=-1, gt_max_assign_all=False, pos_iou_thr=0.4, neg_iou_thr=0.4, ), "allowed_border": -1, "pos_weight": -1, "debug": False, "use_giou": False, "use_focal": False, } test_cfg = { "nms": {"type": "nms", "iou_threshold": 0.45}, "min_bbox_size": 0, "score_thr": 0.02, "max_per_img": 200, } backbone = self._build_backbone(model_name=self.model_name) bbox_head = SSDHead( model_name=self.model_name, num_classes=num_classes, anchor_generator=SSDAnchorGeneratorClustered( strides=[16, 32], widths=[ [38.641007923271076, 92.49516032784699, 271.4234764938237, 141.53469410876247], [206.04136086566515, 386.6542727907841, 716.9892752215089, 453.75609561761405, 788.4629155558277], ], heights=[ [48.9243877087132, 147.73088476194903, 158.23569788707474, 324.14510379107367], [587.6216059488938, 381.60024152086544, 323.5988913027747, 702.7486097568518, 741.4865860938451], ], ), bbox_coder=DeltaXYWHBBoxCoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(0.1, 0.1, 0.2, 0.2), ), init_cfg={ "type": "Xavier", "layer": "Conv2d", "distribution": "uniform", }, # TODO (sungchul, kirill): remove train_cfg=train_cfg, # TODO (sungchul, kirill): remove test_cfg=test_cfg, # TODO (sungchul, kirill): remove ) criterion = SSDCriterion( num_classes=num_classes, bbox_coder=DeltaXYWHBBoxCoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(0.1, 0.1, 0.2, 0.2), ), ) model = SingleStageDetector( backbone=backbone, bbox_head=bbox_head, criterion=criterion, train_cfg=train_cfg, # TODO (sungchul, kirill): remove test_cfg=test_cfg, # TODO (sungchul, kirill): remove ) model.init_weights() load_checkpoint(model, self.pretrained_weights[self.model_name], map_location="cpu") return model def _build_backbone(self, model_name: str) -> nn.Module: if "mobilenetv2" in model_name: from otx.backend.native.models.common.backbones import build_model_including_pytorchcv return build_model_including_pytorchcv( cfg={ "type": "mobilenetv2_w1", "out_indices": [4, 5], "frozen_stages": -1, "norm_eval": False, "pretrained": True, }, ) msg = f"Unknown backbone name: {model_name}" raise ValueError(msg)
[docs] def setup(self, stage: str) -> None: """Callback for setup OTX SSD Model. OTXSSD requires auto anchor generating w.r.t. training dataset for better accuracy. This callback will provide training dataset to model's anchor generator. Args: trainer(Trainer): Lightning trainer contains OTXLitModule and OTXDatamodule. """ super().setup(stage=stage) if stage == "fit": anchor_generator = self.model.bbox_head.anchor_generator dataset = self.trainer.datamodule.train_dataloader().dataset new_anchors = self._get_new_anchors(dataset, anchor_generator) if new_anchors is not None: logger.warning("Anchor will be updated by Dataset's statistics") logger.warning(f"{anchor_generator.widths} -> {new_anchors[0]}") logger.warning(f"{anchor_generator.heights} -> {new_anchors[1]}") anchor_generator.widths = new_anchors[0] anchor_generator.heights = new_anchors[1] anchor_generator.gen_base_anchors() self.hparams["ssd_anchors"] = { "heights": anchor_generator.heights, "widths": anchor_generator.widths, }
def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: SSDAnchorGeneratorClustered) -> tuple | None: """Get new anchors for SSD from OTXDataset.""" from torchvision.transforms.v2._container import Compose from otx.data.transform_libs.torchvision import Resize target_wh = None if isinstance(dataset.transforms, Compose): for transform in dataset.transforms.transforms: if isinstance(transform, Resize): target_wh = transform.scale if target_wh is None: target_wh = list(reversed(self.data_input_params.input_size)) # type: ignore[assignment] msg = f"Cannot get target_wh from the dataset. Assign it with the default value: {target_wh}" logger.warning(msg) group_as = [len(width) for width in anchor_generator.widths] wh_stats = self._get_sizes_from_dataset_entity(dataset, list(target_wh)) # type: ignore[arg-type] if len(wh_stats) < sum(group_as): logger.warning( f"There are not enough objects to cluster: {len(wh_stats)} were detected, while it should be " f"at least {sum(group_as)}. Anchor box clustering was skipped.", ) return None return self._get_anchor_boxes(wh_stats, group_as) @staticmethod def _get_sizes_from_dataset_entity(dataset: OTXDataset, target_wh: list[int]) -> list[tuple[int, int]]: """Function to get width and height size of items in OTXDataset. Args: dataset(OTXDataset): OTXDataset in which to get statistics target_wh(list[int]): target width and height of the dataset Return list[tuple[int, int]]: tuples with width and height of each instance """ wh_stats: list[tuple[int, int]] = [] for item in dataset.dm_subset: for ann in item.annotations: if isinstance(ann, Bbox): x1, y1, x2, y2 = ann.points x1 = x1 / item.media.size[1] * target_wh[0] y1 = y1 / item.media.size[0] * target_wh[1] x2 = x2 / item.media.size[1] * target_wh[0] y2 = y2 / item.media.size[0] * target_wh[1] wh_stats.append((x2 - x1, y2 - y1)) return wh_stats @staticmethod def _get_anchor_boxes(wh_stats: list[tuple[int, int]], group_as: list[int]) -> tuple: """Get new anchor box widths & heights using KMeans.""" from sklearn.cluster import KMeans kmeans = KMeans(init="k-means++", n_clusters=sum(group_as), random_state=0).fit(wh_stats) centers = kmeans.cluster_centers_ areas = np.sqrt(np.prod(centers, axis=1)) idx = np.argsort(areas) widths = centers[idx, 0] heights = centers[idx, 1] group_as = np.cumsum(group_as[:-1]) widths, heights = np.split(widths, group_as), np.split(heights, group_as) widths = [width.tolist() for width in widths] heights = [height.tolist() for height in heights] return widths, heights def _identify_classification_layers( # type: ignore[override] self, prefix: str = "model.", ) -> dict[str, dict[str, bool | int]]: """Return classification layer names by comparing two different number of classes models. Args: prefix (str): Prefix of model param name. Normally it is "model." since OTXModel set it's nn.Module model as self.model Return: dict[str, dict[str, int]] A dictionary contain classification layer's name and information. `use_bg` means whether SSD use background class. It if True if SSD use softmax loss, and it is False if SSD use cross entropy loss. `num_anchors` means number of anchors of layer. SSD have classification per each anchor, so we have to update every anchors. """ sample_model_dict = self._create_model(num_classes=3).state_dict() incremental_model_dict = self._create_model(num_classes=4).state_dict() classification_layers = {} for key in sample_model_dict: if sample_model_dict[key].shape != incremental_model_dict[key].shape: sample_model_dim = sample_model_dict[key].shape[0] if sample_model_dim % 3 != 0: use_bg = True num_anchors = int(sample_model_dim / 4) classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors} else: use_bg = False num_anchors = int(sample_model_dim / 3) classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors} return classification_layers
[docs] def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None: """Modify input state_dict according to class name matching. It is used for incremental learning.""" model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes) for param_name, info in self._identify_classification_layers().items(): model_param = self.state_dict()[param_name].clone() ckpt_param = state_dict[prefix + param_name] use_bg = info["use_bg"] num_anchors = info["num_anchors"] if use_bg: num_ckpt_classes = len(self.ckpt_classes) + 1 num_model_classes = len(self.model_classes) + 1 else: num_ckpt_classes = len(self.ckpt_classes) num_model_classes = len(self.model_classes) for anchor_idx in range(num_anchors): for model_dst, ckpt_dst in enumerate(model2ckpt): if ckpt_dst >= 0: # Copying only matched weight rows model_param[anchor_idx * num_model_classes + model_dst].copy_( ckpt_param[anchor_idx * num_ckpt_classes + ckpt_dst], ) if use_bg: model_param[anchor_idx * num_model_classes + num_model_classes - 1].copy_( ckpt_param[anchor_idx * num_ckpt_classes + num_ckpt_classes - 1], ) # Replace checkpoint weight by mixed weights state_dict[prefix + param_name] = model_param
@property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" return OTXNativeModelExporter( task_level_export_parameters=self._export_parameters, data_input_params=self.data_input_params, resize_mode="standard", pad_value=0, swap_rgb=False, via_onnx=True, # Currently SSD should be exported through ONNX onnx_export_configuration={ "input_names": ["image"], "output_names": ["boxes", "labels"], "dynamic_axes": { "image": {0: "batch"}, "boxes": {0: "batch", 1: "num_dets"}, "labels": {0: "batch", 1: "num_dets"}, }, "autograd_inlining": False, }, output_names=["bboxes", "labels", "feature_vector", "saliency_map"] if self.explain_mode else None, )
[docs] def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: """Callback on load checkpoint.""" if (hparams := checkpoint.get("hyper_parameters")) and (anchors := hparams.get("ssd_anchors", None)): anchor_generator = self.model.bbox_head.anchor_generator anchor_generator.widths = anchors["widths"] anchor_generator.heights = anchors["heights"] anchor_generator.gen_base_anchors() return super().on_load_checkpoint(checkpoint)
[docs] def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_ssd_ckpt(state_dict, add_prefix)