# Copyright (C) 2023-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Dataclasses for label information."""
from __future__ import annotations
import copy
import json
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any
from datumaro.components.annotation import GroupType
if TYPE_CHECKING:
from datumaro import Label, LabelCategories
__all__ = [
"LabelInfo",
"HLabelInfo",
"SegLabelInfo",
"NullLabelInfo",
"AnomalyLabelInfo",
"LabelInfoTypes",
]
[docs]
@dataclass
class LabelInfo:
"""Object to represent label information."""
label_names: list[str]
label_ids: list[str]
label_groups: list[list[str]]
@property
def num_classes(self) -> int:
"""Return number of labels."""
return len(self.label_names)
[docs]
@classmethod
def from_num_classes(cls, num_classes: int) -> LabelInfo:
"""Create this object from the number of classes.
Args:
num_classes: Number of classes
Returns:
LabelInfo(
label_names=["label_0", ...],
label_groups=[["label_0", ...]]
)
"""
if num_classes <= 0:
return NullLabelInfo()
label_names = [f"label_{idx}" for idx in range(num_classes)]
label_ids = [str(i) for i in range(num_classes)]
return cls(
label_names=label_names,
label_groups=[label_names],
label_ids=label_ids,
)
[docs]
@classmethod
def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> LabelInfo:
"""Create this object from the datumaro label groups.
Args:
dm_label_categories (LabelCategories): The label category information from Datumaro.
Returns:
LabelInfo(
label_names=["Heart_King", "Heart_Queen", "Spade_King", "Spade_Jack"]
label_groups=[["Heart_King", "Heart_Queen"], ["Spade_King", "Spade_Jack"]]
)
"""
label_names = [item.name for item in dm_label_categories.items]
label_groups = [label_group.labels for label_group in dm_label_categories.label_groups]
if len(label_groups) == 0: # Single-label classification
label_groups = [label_names]
return LabelInfo(
label_names=label_names,
label_groups=label_groups,
label_ids=[str(i) for i in range(len(label_names))],
)
[docs]
@classmethod
def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> LabelInfo:
"""Overload to support datumaro's arrow format."""
label_names = []
for item in dm_label_categories.items:
for attr in item.attributes:
if attr.startswith("__name__"):
label_names.append(attr[len("__name__") :])
break
if len(label_names) != len(dm_label_categories.items):
msg = "Wrong arrow format: can not extract label names from attributes"
raise ValueError(msg)
id_to_name_mapping = {item.name: label_names[i] for i, item in enumerate(dm_label_categories.items)}
for label_group in dm_label_categories.label_groups:
label_group.labels = [id_to_name_mapping.get(label, label) for label in label_group.labels]
label_groups = [label_group.labels for label_group in dm_label_categories.label_groups]
if len(label_groups) == 0: # Single-label classification
label_groups = [label_names]
label_ids = [item.name for item in dm_label_categories.items]
return LabelInfo(
label_names=label_names,
label_groups=label_groups,
label_ids=label_ids,
)
[docs]
def as_dict(self, normalize_label_names: bool = False) -> dict[str, Any]:
"""Return a dictionary including all params."""
result = asdict(self)
if normalize_label_names:
def normalize_fn(node: str | list | tuple | dict | int) -> str | list | tuple | dict | int:
"""Normalizes the label names stored in various nested structures."""
if isinstance(node, str):
return node.replace(" ", "_")
if isinstance(node, list):
return [normalize_fn(item) for item in node]
if isinstance(node, tuple):
return tuple(normalize_fn(item) for item in node)
if isinstance(node, dict):
return {normalize_fn(key): normalize_fn(value) for key, value in node.items()}
return node
for k in result:
result[k] = normalize_fn(result[k])
return result
[docs]
def to_json(self) -> str:
"""Return JSON serialized string."""
return json.dumps(self.as_dict())
[docs]
@classmethod
def from_json(cls, serialized: str) -> LabelInfo:
"""Reconstruct it from the JSON serialized string."""
return cls(**json.loads(serialized))
[docs]
@dataclass
class HLabelInfo(LabelInfo):
"""The label information represents the hierarchy.
All params should be kept since they're also used at the Model API side.
:param num_multiclass_heads: the number of the multiclass heads
:param num_multilabel_classes: the number of multilabel classes
:param head_to_logits_range: the logit range of each heads
:param num_single_label_classes: the number of single label classes
:param class_to_group_idx: represents the head index and label index
:param all_groups: represents information of all groups
:param label_to_idx: index of each label
:param empty_multiclass_head_indices: the index of head that doesn't include any label
due to the label removing
i.e.
Single-selection group information (Multiclass, Exclusive)
{
"Shape": ["Rigid", "Non-Rigid"],
"Rigid": ["Rectangle", "Triangle"],
"Non-Rigid": ["Circle"]
}
Multi-selection group information (Multilabel)
{
"Animal": ["Lion", "Panda"]
}
In the case above, HlabelInfo will be generated as below.
NOTE, If there was only one label in the multiclass group, it will be handeled as multilabel(Circle).
num_multiclass_heads: 2 (Shape, Rigid)
num_multilabel_classes: 3 (Circle, Lion, Panda)
head_to_logits_range: {'0': (0, 2), '1': (2, 4)} (Each multiclass head have 2 labels)
num_single_label_classes: 4 (Rigid, Non-Rigid, Rectangle, Triangle)
class_to_group_idx: {
'Non-Rigid': (0, 0), 'Rigid': (0, 1),
'Rectangle': (1, 0), 'Triangle': (1, 1),
'Circle': (2, 0), 'Lion': (2,1), 'Panda': (2,2)
} (head index, label index for each head)
all_groups: [['Non-Rigid', 'Rigid'], ['Rectangle', 'Triangle'], ['Circle'], ['Lion'], ['Panda']]
label_to_idx: {
'Rigid': 0, 'Rectangle': 1,
'Triangle': 2, 'Non-Rigid': 3, 'Circle': 4
'Lion': 5, 'Panda': 6
}
label_tree_edges: [
["Rectangle", "Rigid"], ["Triangle", "Rigid"], ["Circle", "Non-Rigid"],
] # NOTE, label_tree_edges format could be changed.
empty_multiclass_head_indices: []
All of the member variables should be considered for the Model API.
https://github.com/open-edge-platform/training_extensions/blob/develop/src/otx/algorithms/classification/utils/cls_utils.py#L97
"""
num_multiclass_heads: int
num_multilabel_classes: int
head_idx_to_logits_range: dict[str, tuple[int, int]]
num_single_label_classes: int
class_to_group_idx: dict[str, tuple[int, int]]
all_groups: list[list[str]]
label_to_idx: dict[str, int]
label_tree_edges: list[list[str]]
empty_multiclass_head_indices: list[int]
[docs]
@classmethod
def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelInfo:
"""Generate HLabelData from the Datumaro LabelCategories.
Args:
dm_label_categories (LabelCategories): the label categories of datumaro.
"""
def get_exclusive_group_info(exclusive_groups: list[Label | list[Label]]) -> dict[str, Any]:
"""Get exclusive group information."""
last_logits_pos = 0
num_single_label_classes = 0
head_idx_to_logits_range = {}
class_to_idx = {}
for i, group in enumerate(exclusive_groups):
head_idx_to_logits_range[str(i)] = (last_logits_pos, last_logits_pos + len(group))
last_logits_pos += len(group)
for j, c in enumerate(group):
class_to_idx[c] = (i, j)
num_single_label_classes += 1
return {
"num_multiclass_heads": len(exclusive_groups),
"head_idx_to_logits_range": head_idx_to_logits_range,
"class_to_idx": class_to_idx,
"num_single_label_classes": num_single_label_classes,
}
def get_single_label_group_info(
single_label_groups: list[Label | list[Label]],
num_exclusive_groups: int,
) -> dict[str, Any]:
"""Get single label group information."""
class_to_idx = {}
for i, group in enumerate(single_label_groups):
class_to_idx[group[0]] = (num_exclusive_groups, i)
return {
"num_multilabel_classes": len(single_label_groups),
"class_to_idx": class_to_idx,
}
def merge_class_to_idx(
exclusive_ctoi: dict[str, tuple[int, int]],
single_label_ctoi: dict[str, tuple[int, int]],
) -> dict[str, tuple[int, int]]:
"""Merge the class_to_idx information from exclusive and single_label groups."""
def put_key_values(src: dict, dst: dict) -> None:
"""Put key and values from src to dst."""
for k, v in src.items():
dst[k] = v
class_to_idx: dict[str, tuple[int, int]] = {}
put_key_values(exclusive_ctoi, class_to_idx)
put_key_values(single_label_ctoi, class_to_idx)
return class_to_idx
def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str]]:
"""Get label tree edges information. Each edges represent [child, parent]."""
return [[item.name, item.parent] for item in dm_label_items if item.parent != ""]
def convert_labels_if_needed(
dm_label_categories: LabelCategories,
label_names: list[str],
) -> list[list[str]]:
# Check if the labels need conversion and create name to ID mapping if required
name_to_id_mapping = None
for label_group in dm_label_categories.label_groups:
if label_group.labels and label_group.labels[0] not in label_names:
name_to_id_mapping = {
attr[len("__name__") :]: category.name
for category in dm_label_categories.items
for attr in category.attributes
if attr.startswith("__name__")
}
break
# If mapping exists, update the labels
if name_to_id_mapping:
for label_group in dm_label_categories.label_groups:
label_group.labels = [name_to_id_mapping.get(label, label) for label in label_group.labels]
# Retrieve all label groups after conversion
return [group.labels for group in dm_label_categories.label_groups]
label_names = [item.name for item in dm_label_categories.items]
all_groups = convert_labels_if_needed(dm_label_categories, label_names)
exclusive_groups = [g for g in all_groups if len(g) > 1]
exclusive_group_info = get_exclusive_group_info(exclusive_groups)
single_label_groups = [g for g in all_groups if len(g) == 1]
single_label_group_info = get_single_label_group_info(
single_label_groups,
exclusive_group_info["num_multiclass_heads"],
)
merged_class_to_idx = merge_class_to_idx(
exclusive_group_info["class_to_idx"],
single_label_group_info["class_to_idx"],
)
label_to_idx = {lbl: i for i, lbl in enumerate(merged_class_to_idx.keys())}
return HLabelInfo(
label_names=label_names,
label_groups=exclusive_groups + single_label_groups,
num_multiclass_heads=exclusive_group_info["num_multiclass_heads"],
num_multilabel_classes=single_label_group_info["num_multilabel_classes"],
head_idx_to_logits_range=exclusive_group_info["head_idx_to_logits_range"],
num_single_label_classes=exclusive_group_info["num_single_label_classes"],
class_to_group_idx=merged_class_to_idx,
all_groups=exclusive_groups + single_label_groups,
label_to_idx=label_to_idx,
label_tree_edges=get_label_tree_edges(dm_label_categories.items),
empty_multiclass_head_indices=[], # consider the label removing case
label_ids=[str(i) for i in range(len(label_names))],
)
[docs]
@classmethod
def from_dm_label_groups_arrow(cls, dm_label_categories: LabelCategories) -> HLabelInfo:
"""Generate HLabelData from the Datumaro LabelCategories. Arrow-specific implementation.
Args:
dm_label_categories (LabelCategories): the label categories of datumaro.
"""
dm_label_categories = copy.deepcopy(dm_label_categories)
empty_label_name = None
for label_group in dm_label_categories.label_groups:
if label_group.group_type == GroupType.RESTRICTED:
empty_label_name = label_group.labels[0]
dm_label_categories.label_groups = [
group for group in dm_label_categories.label_groups if group.group_type != GroupType.RESTRICTED
]
empty_label_id = None
label_names = []
for item in dm_label_categories.items:
for attr in item.attributes:
if attr.startswith("__name__"):
name = attr[len("__name__") :]
if name == empty_label_name:
empty_label_id = item.name
label_names.append(name)
break
if len(label_names) != len(dm_label_categories.items):
msg = "Wrong arrow file: can not extract label names from attributes"
raise ValueError(msg)
if empty_label_name is not None:
label_names.remove(empty_label_name)
dm_label_categories.items = [item for item in dm_label_categories.items if item.name != empty_label_id]
label_ids = [item.name for item in dm_label_categories.items]
id_to_name_mapping = {item.name: label_names[i] for i, item in enumerate(dm_label_categories.items)}
for i, item in enumerate(dm_label_categories.items):
item.name = label_names[i]
item.parent = id_to_name_mapping.get(item.parent, item.parent)
for label_group in dm_label_categories.label_groups:
label_group.labels = [id_to_name_mapping.get(label, label) for label in label_group.labels]
obj = cls.from_dm_label_groups(dm_label_categories)
obj.label_ids = label_ids
return obj
[docs]
def as_head_config_dict(self) -> dict[str, Any]:
"""Return a dictionary including params needed to configure the HLabel MMPretrained head network."""
return {
"num_classes": self.num_classes,
"num_multiclass_heads": self.num_multiclass_heads,
"num_multilabel_classes": self.num_multilabel_classes,
"head_idx_to_logits_range": self.head_idx_to_logits_range,
"num_single_label_classes": self.num_single_label_classes,
"empty_multiclass_head_indices": self.empty_multiclass_head_indices,
}
[docs]
@classmethod
def from_json(cls, serialized: str) -> HLabelInfo:
"""Reconstruct it from the JSON serialized string."""
loaded = json.loads(serialized)
# List to tuple
loaded["head_idx_to_logits_range"] = {
key: tuple(value) for key, value in loaded["head_idx_to_logits_range"].items()
}
loaded["class_to_group_idx"] = {key: tuple(value) for key, value in loaded["class_to_group_idx"].items()}
return cls(**loaded)
[docs]
@dataclass
class SegLabelInfo(LabelInfo):
"""Meta information of Semantic Segmentation."""
ignore_index: int = 255
[docs]
@classmethod
def from_num_classes(cls, num_classes: int) -> LabelInfo:
"""Create this object from the number of classes.
Args:
num_classes: Number of classes
Returns:
LabelInfo(
label_names=["Background", "label_0", ..., "label_{num_classes - 1}"]
label_groups=[["Background", "label_0", ..., "label_{num_classes - 1}"]]
)
"""
if num_classes == 1:
# binary segmentation
label_names = ["background", "label_0"]
return SegLabelInfo(label_names=label_names, label_groups=[label_names], label_ids=["0", "1"])
return super().from_num_classes(num_classes)
[docs]
@dataclass
class NullLabelInfo(LabelInfo):
"""Represent no label information."""
def __init__(self) -> None:
super().__init__(label_names=[], label_groups=[[]], label_ids=[])
[docs]
@classmethod
def from_json(cls, _: str) -> LabelInfo:
"""Reconstruct it from the JSON serialized string."""
return cls()
@dataclass
class AnomalyLabelInfo(LabelInfo):
"""Represent no label information. It is used for Anomaly tasks."""
def __init__(self) -> None:
super().__init__(label_names=["Normal", "Anomaly"], label_groups=[["Normal", "Anomaly"]], label_ids=["0", "1"])
# Dispatching rules:
# 1. label_info: int => LabelInfo.from_num_classes(label_info)
# 2. label_info: list[str] => LabelInfo(label_names=label_info, label_groups=[label_info])
# 3. label_info: LabelInfo => label_info
# See OTXModel._dispatch_label_info() for more details
LabelInfoTypes = LabelInfo | int | list[str]