Source code for datumaro.experimental.schema

# Copyright (C) 2025 Intel Corporation
#
# SPDX-License-Identifier: MIT
"""
Schema definitions for the dataset system.
"""

import copy
from dataclasses import dataclass, field
from enum import Flag, auto
from typing import TYPE_CHECKING, Any, Dict, Optional

import polars as pl

if TYPE_CHECKING:
    from .categories import Categories


[docs] class Semantic(Flag): """ Used for disambiguation when multiple fields of the same type exist. Default is used for fields that don't need disambiguation. Left/Right are used for stereo vision scenarios. """ Default = auto() Left = auto() Right = auto()
[docs] class Field: """ Base class for fields with semantic tags and Polars type mapping. This abstract base class defines the interface for all field types, providing methods for converting between Python objects and Polars DataFrame representations. Attributes: semantic: Semantic tags for disambiguation (Default, Left, Right) """ semantic: Semantic
[docs] def to_polars_schema(self, name: str) -> dict[str, pl.DataType]: """ Generate Polars schema definition for this field. Args: name: The column name for this field Returns: Dictionary mapping column names to Polars data types Raises: NotImplementedError: Must be implemented by subclasses """ raise NotImplementedError("Subclasses must implement the to_polars_type method.")
[docs] def to_polars(self, name: str, value: Any) -> dict[str, pl.Series]: """ Convert the field value to Polars-compatible format. Args: name: The column name for this field value: The value to convert Returns: Dictionary mapping column names to Polars Series """ return {name: pl.Series(name, [value])}
[docs] def from_polars(self, name: str, row_index: int, df: pl.DataFrame, target_type: type) -> Any: """ Convert from Polars-compatible format back to the field's value. Args: name: The column name for this field row_index: The row index to extract df: The source DataFrame target_type: The target type to convert to Returns: The converted value in the target type """ return target_type(df[name][row_index])
[docs] @dataclass class AttributeInfo: """ Container for attribute type and field annotation information. """ type: type annotation: Field categories: Optional["Categories"] = None
[docs] @dataclass class Schema: """ Represents the schema of a dataset with attribute definitions. Enforces that only one field of each type exists per semantic context. """ attributes: dict[str, AttributeInfo] = field(default_factory=dict[str, AttributeInfo]) def __post_init__(self): """Validate that only one field of each type exists per semantic context.""" seen: dict[tuple[type[Field], Semantic], str] = {} for name, attr in self.attributes.items(): key = type(attr.annotation), attr.annotation.semantic if key in seen: raise ValueError( f"Duplicate field type {key[0]} for semantic {key[1]} in schema. " f"Fields '{seen[key]}' and '{name}' conflict." ) seen[key] = name
[docs] def with_categories(self, categories: Dict[str, "Categories"]) -> "Schema": """ Create a new schema with categories applied to specific attributes. Args: categories: Dictionary mapping attribute names to categories Returns: A new Schema instance with categories applied Raises: ValueError: If an attribute name is not found in the schema """ # Make a shallow copy of this schema new_schema = copy.copy(self) # Also copy the attributes dict to avoid modifying the original AttributeInfo objects new_schema.attributes = { name: copy.copy(attr_info) for name, attr_info in self.attributes.items() } # Add categories to specific attributes for attr_name, category in categories.items(): if attr_name in new_schema.attributes: new_schema.attributes[attr_name].categories = category else: raise ValueError(f"Attribute '{attr_name}' not found in schema") return new_schema