# Copyright (C) 2025 Intel Corporation
#
# SPDX-License-Identifier: MIT
"""
Converter system for transforming data between different field representations.
This module provides the foundation for data transformation pipelines,
including converter registration, schema mapping, and automatic conversion
path discovery using graph algorithms.
"""
from __future__ import annotations
import heapq
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from functools import cache
from typing import (
Any,
Callable,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
get_type_hints,
overload,
)
import polars as pl
from typing_extensions import cast, dataclass_transform
from .categories import Categories
from .schema import Field, Schema, Semantic
TField = TypeVar("TField", bound=Field)
[docs]
class ConversionPaths(NamedTuple):
"""
Container for separated batch and lazy conversion paths.
The batch converters can be applied immediately to the entire DataFrame,
while lazy converters must be deferred and applied at sample access time.
"""
batch_converters: List["Converter"]
lazy_converters: List["Converter"]
[docs]
@dataclass(frozen=True)
class AttributeSpec(Generic[TField]):
"""
Specification for an attribute used in converters.
Links an attribute name with its corresponding field type definition,
providing the complete specification needed for converter operations.
Args:
TField: The specific Field type, defaults to Field
Attributes:
name: The attribute name
field: The field type specification
categories: Optional categories information (e.g., LabelCategories, MaskCategories)
"""
name: str
field: TField
categories: Optional[Categories] = None
[docs]
@dataclass_transform()
class Converter(ABC):
"""
Base class for data converters with input/output specifications.
Converters transform data between different field representations by
implementing the convert() method and optionally filtering their
applicability through filter_output_spec().
"""
def __init__(self, **kwargs: Any):
"""
Initialize converter with input and output AttributeSpec instances.
Args:
**kwargs: AttributeSpec instances for converter inputs/outputs
based on input_*/output_* class attributes
"""
# Set all provided kwargs as instance attributes
for key, value in kwargs.items():
setattr(self, key, value)
lazy: bool = False
"""
Whether this converter performs lazy operations.
Lazy converters defer expensive operations (like loading images from disk)
until data is actually accessed. When a lazy converter is in the conversion
path, all dependent converters must also be executed lazily.
"""
[docs]
@classmethod
@cache
def get_from_types(cls) -> dict[str, Type[Field]]:
"""
Extract input field types from input_* class attributes.
Returns:
Dictionary mapping input attribute names to their Field types
"""
from_types: dict[str, Type[Field]] = {}
# Get type hints for the class
hints = get_type_hints(cls)
for attr_name, attr_type in hints.items():
if attr_name.startswith("input_"):
# Extract the Field type from AttributeSpec[FieldType] annotation
if hasattr(attr_type, "__args__") and len(attr_type.__args__) > 0:
# Handle generic types like AttributeSpec[SomeField]
field_type = attr_type.__args__[0]
else:
raise RuntimeError("Attributes must be annotated with AttributeSpec[FieldType]")
from_types[attr_name] = field_type
return from_types
[docs]
@classmethod
@cache
def get_to_types(cls) -> dict[str, Type[Field]]:
"""
Extract output field types from output_* class attributes.
Returns:
Dictionary mapping output attribute names to their Field types
"""
to_types: dict[str, Type[Field]] = {}
# Get type hints for the class
hints = get_type_hints(cls)
for attr_name, attr_type in hints.items():
if attr_name.startswith("output_"):
# Extract the Field type from AttributeSpec[FieldType] annotation
if hasattr(attr_type, "__args__") and len(attr_type.__args__) > 0:
# Handle generic types like AttributeSpec[SomeField]
field_type = attr_type.__args__[0]
else:
raise RuntimeError("Attributes must be annotated with AttributeSpec[FieldType]")
to_types[attr_name] = field_type
return to_types
[docs]
@abstractmethod
def convert(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Convert a DataFrame using the stored AttributeSpec instances.
Args:
df: Input DataFrame
Returns:
Converted DataFrame
"""
pass
[docs]
def filter_output_spec(self) -> bool:
"""
Filter and modify the converter's output specification in-place.
This method allows converters to inspect and modify their output
specifications based on input characteristics. It should return
True if the converter can handle the given input/output combination.
Returns:
True if the converter is applicable, False otherwise
"""
# Default implementation accepts all conversions
# Subclasses should override for sophisticated filtering
return True
[docs]
def get_output_attr_specs(self) -> List[AttributeSpec[Field]]:
"""
Get the current output AttributeSpec instances from output_* attributes.
Returns:
List of output AttributeSpec instances currently configured on the converter
"""
output_attr_specs: List[AttributeSpec[Field]] = []
# Get the output attribute names from class type hints
to_types = self.get_to_types()
for attr_name in to_types.keys():
attr_spec = cast(AttributeSpec[Field], getattr(self, attr_name))
output_attr_specs.append(attr_spec)
return output_attr_specs
[docs]
class ConverterRegistry:
"""
Registry for managing and discovering data converters.
This class maintains a global registry of converter classes and provides
functionality for finding and instantiating appropriate converters for
schema transformations.
"""
_converter_registry: List[Type[Converter]] = []
[docs]
@classmethod
def add_converter(cls, converter: Type[Converter]):
"""Add a converter class to the registry."""
cls._converter_registry.append(converter)
[docs]
@classmethod
def remove_converter(cls, converter: Type[Converter]) -> None:
"""Remove a converter class from the registry.
Args:
converter: The converter class to remove
Raises:
ValueError: If the converter is not found in the registry
"""
cls._converter_registry.remove(converter)
[docs]
@classmethod
def list_converters(cls) -> Sequence[Type[Converter]]:
"""List all registered converter classes as an immutable sequence."""
return cls._converter_registry
@overload
def converter(cls: Type[Converter], /) -> Type[Converter]:
"""Overload for @converter (no parentheses)."""
...
@overload
def converter(*, lazy: bool = False) -> Callable[[Type[Converter]], Type[Converter]]:
"""Overload for @converter() or @converter(lazy=True)."""
...
[docs]
def converter(
cls: Optional[Type[Converter]] = None, /, *, lazy: bool = False
) -> Type[Converter] | Callable[[Type[Converter]], Type[Converter]]:
"""Register a converter class and configure its lazy loading behavior.
This decorator automatically registers converter classes with the global
converter registry and sets their lazy evaluation mode. The converter
class must define at least one output_* attribute with type hints.
Args:
lazy: If True, this converter will only be applied during lazy
evaluation in Dataset.__getitem__. If False, it will be
applied during batch conversion operations. Lazy converters
automatically make all dependent converters lazy as well.
Usage:
@converter
class ImageToTensorConverter(Converter):
input_image: AttributeSpec
output_tensor: AttributeSpec
def convert(self, df: pl.DataFrame) -> pl.DataFrame:
# conversion logic
return df
@converter(lazy=True)
class ImagePathToImageConverter(Converter):
input_path: AttributeSpec
output_image: AttributeSpec
def convert(self, df: pl.DataFrame) -> pl.DataFrame:
# lazy conversion logic
return df
"""
def decorator(cls: Type[Converter]) -> Type[Converter]:
# Validate converter class by checking for required attributes
hints = get_type_hints(cls)
# Ensure at least one output attribute is defined
output_attrs = [name for name in hints if name.startswith("output_")]
if not output_attrs:
raise TypeError(f"{cls.__name__} must define at least one 'output_*' attribute")
# Set the lazy attribute directly on the class
cls.lazy = lazy
# Register with the global converter registry for discovery
ConverterRegistry.add_converter(cls)
return cls
# Handle both @converter and @converter() syntax patterns
if cls is None:
# Called with parentheses: @converter() or @converter(lazy=True)
return decorator
# Called without parentheses: @converter
return decorator(cls)
[docs]
class ConversionError(Exception):
"""Exception raised when conversion fails."""
pass
[docs]
class AttributeRemapperConverter(Converter):
"""
Special converter for renaming/selecting attributes and dropping others.
This converter is not registered with the converter registry but is used
internally by find_conversion_path when attributes need to be renamed or deleted.
It uses .select() to only keep the specified attributes with their new names,
effectively handling both renaming and deletion in a single operation.
"""
def __init__(self, attr_mappings: list[tuple[AttributeSpec, AttributeSpec]]):
"""
Initialize the converter with a list of attribute mappings.
Args:
attr_mappings: List of tuples (from_attr_spec, to_attr_spec) defining
the attribute transformations. Only attributes in this
list will be kept in the output.
"""
self.attr_mappings = attr_mappings
# Calculate column mapping from attribute mappings
self.column_map = {}
for from_attr, to_attr in attr_mappings:
# Get all column names for this field using to_polars_schema
from_columns = list(from_attr.field.to_polars_schema(from_attr.name).keys())
to_columns = list(to_attr.field.to_polars_schema(to_attr.name).keys())
# Map each column from source to target
for from_col, to_col in zip(from_columns, to_columns):
self.column_map[from_col] = to_col
# Dynamically set input_* and output_* attributes for get_from_types/get_to_types
for i, (from_attr, to_attr) in enumerate(attr_mappings):
setattr(self, f"input_{i}", from_attr)
setattr(self, f"output_{i}", to_attr)
super().__init__()
[docs]
@cache
def get_from_types(self) -> dict[str, Type[Field]]:
"""
Extract input field types from input_* class attributes.
Returns:
Dictionary mapping input attribute names to their Field types
"""
from_types: dict[str, Type[Field]] = {}
for i, (input_spec, _) in enumerate(self.attr_mappings):
from_types[f"input_{i}"] = type(input_spec.field)
return from_types
[docs]
@cache
def get_to_types(self) -> dict[str, Type[Field]]:
"""
Extract output field types from output_* class attributes.
Returns:
Dictionary mapping output attribute names to their Field types
"""
to_types: dict[str, Type[Field]] = {}
for i, (_, output_spec) in enumerate(self.attr_mappings):
to_types[f"output_{i}"] = type(output_spec.field)
return to_types
[docs]
def convert(self, df: pl.DataFrame) -> pl.DataFrame:
"""
Select and rename columns according to column_map.
Columns not in the mapping are dropped.
Args:
df: Input DataFrame
Returns:
DataFrame with only the selected/renamed columns
"""
# Build selection expressions for all columns we want to keep
select_exprs = {}
for old_name, new_name in self.column_map.items():
select_exprs[new_name] = pl.col(old_name)
return df.select(**select_exprs)
[docs]
def filter_output_spec(self) -> bool:
"""Always return True as renaming is always applicable."""
return True
@dataclass(frozen=True)
class _SchemaState:
"""Represents a schema state during A* search."""
field_to_attr_spec: dict[
Type[Field], AttributeSpec[Field]
] # Map field types to their AttributeSpec
def __hash__(self):
# Hash only field types and their properties, not names
field_items = []
for field_type, attr_spec in self.field_to_attr_spec.items():
# Hash field type and field properties, but not the attribute name
field_items.append((field_type, attr_spec.field))
return hash(tuple(field_items))
def __eq__(self, other: object) -> bool:
if not isinstance(other, _SchemaState):
return False
# Compare only field types and their properties, not names
if set(self.field_to_attr_spec.keys()) != set(other.field_to_attr_spec.keys()):
return False
for field_type in self.field_to_attr_spec.keys():
if (
self.field_to_attr_spec[field_type].field
!= other.field_to_attr_spec[field_type].field
):
return False
return True
def get_attr_spec_for_field_type(
self, field_type: Type[Field]
) -> Optional[AttributeSpec[Field]]:
"""Get AttributeSpec for a specific field type."""
return self.field_to_attr_spec.get(field_type)
@dataclass
class _SearchNode:
"""Node in the A* search tree."""
state: _SchemaState
path: List[Converter] # Now stores Converter instances directly
g_cost: int # Actual cost from start
h_cost: int # Heuristic cost to goal
@property
def f_cost(self) -> int:
"""Total cost (g + h)."""
return self.g_cost + self.h_cost
def __lt__(self, other: "_SearchNode") -> bool:
return self.f_cost < other.f_cost
def _heuristic_cost(current_state: _SchemaState, target_state: _SchemaState) -> int:
"""
Heuristic function for A* search.
Returns the number of missing target fields plus field differences as a heuristic.
This counts both:
1. Missing field types that need to be created
2. Field differences where the type exists but properties differ (dtype, format, semantic, etc.)
Note: Attribute names are ignored in the heuristic as they can be fixed in post-processing.
"""
cost = 0
current_field_types = set(current_state.field_to_attr_spec.keys())
target_field_types = set(target_state.field_to_attr_spec.keys())
# Count missing field types
missing_field_types = target_field_types - current_field_types
cost += len(missing_field_types)
# Count field differences for types that exist in both states
common_field_types = current_field_types & target_field_types
for field_type in common_field_types:
current_attr_spec = current_state.field_to_attr_spec[field_type]
target_attr_spec = target_state.field_to_attr_spec[field_type]
# Compare field properties (ignoring names) - if they differ, we need conversion
if current_attr_spec.field != target_attr_spec.field:
cost += 1
return cost
def _get_applicable_converters(
semantic: Semantic, state: _SchemaState, target_state: _SchemaState, iteration: int = 0
) -> List[Tuple[Converter, _SchemaState]]:
"""Get all converters that can be applied to the current state along with their resulting states."""
applicable: List[Tuple[Converter, _SchemaState]] = []
# Get available field types
available_field_types = set(state.field_to_attr_spec.keys())
for converter_class in ConverterRegistry.list_converters():
# Check if all required input types are available
from_types = converter_class.get_from_types()
# Collect available input AttributeSpec instances
converter_kwargs = {}
all_inputs_available = True
for attr_name, field_type in from_types.items():
if field_type not in available_field_types:
all_inputs_available = False
break
# Add the input attribute to kwargs for the converter constructor
attr_spec = state.field_to_attr_spec[field_type]
converter_kwargs[attr_name] = attr_spec
# Check if we have the required input types
if not all_inputs_available:
continue
# Collect desired output AttributeSpec instances
to_types = converter_class.get_to_types()
to_attr_specs: List[AttributeSpec[Field]] = []
for field_type in to_types.values():
if field_type in target_state.field_to_attr_spec:
attr_spec = target_state.field_to_attr_spec[field_type]
to_attr_specs.append(attr_spec)
for attr_name, field_type in to_types.items():
# First, check if target state has a matching field type and use its name/field
if field_type in target_state.field_to_attr_spec:
target_attr_spec = target_state.field_to_attr_spec[field_type]
output_name = target_attr_spec.name
output_field = target_attr_spec.field
output_categories = target_attr_spec.categories
else:
# The field does not exist, use a temporary name
output_name = field_type.__name__.lower()
# and create a new instance of the field
output_field = field_type(semantic=semantic)
output_categories = None
# Add the iteration count at the end to ensure uniqueness
# and avoid any conflict with existing attribute names
output_name = f"{output_name}_temp_{iteration}"
output_attr_spec = AttributeSpec(
name=output_name, field=output_field, categories=output_categories
)
converter_kwargs[attr_name] = output_attr_spec
# Create converter instance with all AttributeSpec instances as kwargs
converter_instance = converter_class(**converter_kwargs)
if not converter_instance.filter_output_spec():
continue
# Fetch the updated output AttributeSpec instances after filter_output_spec()
updated_output_attr_specs = converter_instance.get_output_attr_specs()
# Apply converter to get new state
new_field_to_attr_spec = dict(state.field_to_attr_spec)
# Keep old attributes instead of removing them
# (Later we can work on cleanup/rename logic separately)
# Add produced output types using the updated output_attr_specs
for attr_spec in updated_output_attr_specs:
field_type = type(attr_spec.field)
new_field_to_attr_spec[field_type] = attr_spec
new_state = _SchemaState(new_field_to_attr_spec)
applicable.append((converter_instance, new_state))
return applicable
def _group_fields_by_semantic(schema: Schema) -> dict[Semantic, _SchemaState]:
"""
Group schema attributes by their semantic tags and return as SchemaState objects.
Args:
schema: Schema to group
Returns:
Dictionary mapping semantic tags to SchemaState objects
"""
groups: dict[Semantic, dict[Type[Field], AttributeSpec[Field]]] = defaultdict(dict)
for attr_name, attr_info in schema.attributes.items():
semantic = attr_info.annotation.semantic
field_type = type(attr_info.annotation)
attr_spec = AttributeSpec(
name=attr_name, field=attr_info.annotation, categories=attr_info.categories
)
groups[semantic][field_type] = attr_spec
# Convert to SchemaState objects
return {
semantic: _SchemaState(field_to_attr_spec)
for semantic, field_to_attr_spec in groups.items()
}
def _create_post_processing_for_semantic(
final_state: _SchemaState, target_state: _SchemaState
) -> Tuple[List[Converter], _SchemaState]:
"""
Create post-processing converters for a single semantic group.
Args:
final_state: Final state reached after conversions
target_state: Target state for this semantic group
Returns:
Tuple of (list of post-processing converters, updated_state_after_processing)
where updated_state_after_processing reflects the state after renaming/deletion
"""
# Build attribute mappings: include only attributes that should be kept
attr_mappings = []
# Determine if a converter is needed (i.e. at least one attribute has been renamed or deleted)
converter_needed = False
for field_type, target_attr_spec in target_state.field_to_attr_spec.items():
if field_type in final_state.field_to_attr_spec:
final_attr_spec = final_state.field_to_attr_spec[field_type]
if final_attr_spec.name != target_attr_spec.name:
converter_needed = True
# Add the mapping from final to target attribute spec
attr_mappings.append((final_attr_spec, target_attr_spec))
# Check if any fields need to be deleted (exist in final but not in target)
for field_type, final_attr_spec in final_state.field_to_attr_spec.items():
if field_type not in target_state.field_to_attr_spec:
# If the field is not in the target state, it should be deleted
converter_needed = True
# Create a single remapper converter that handles both renaming and deletion
if converter_needed:
# Create the updated state after processing by applying the attr_mappings to final_state
# This preserves the categories from final_state but with the target names/structure
updated_field_to_attr_spec = {}
for final_attr_spec, target_attr_spec in attr_mappings:
# Use the target name and field type, but preserve categories from final state
updated_field_to_attr_spec[type(target_attr_spec.field)] = AttributeSpec(
name=target_attr_spec.name,
field=target_attr_spec.field,
categories=final_attr_spec.categories, # Preserve inferred categories
)
updated_state_after_processing = _SchemaState(updated_field_to_attr_spec)
return [
AttributeRemapperConverter(attr_mappings=attr_mappings)
], updated_state_after_processing
# No converter needed, return the final state as-is
return [], final_state
def _find_conversion_path_for_semantic(
start_state: _SchemaState, target_state: _SchemaState, semantic: Semantic
) -> Tuple[List[Converter], _SchemaState]:
"""
Find conversion path for fields with a specific semantic tag.
Args:
start_state: Source state for this semantic
target_state: Target state for this semantic
semantic: The semantic tag being processed
Returns:
Tuple of (list of converters needed for this semantic group, updated target state)
Raises:
ConversionError: If no conversion path is found for this semantic
"""
# If we already have all required fields, check if we need renaming/deletion
if start_state == target_state:
return _create_post_processing_for_semantic(start_state, target_state)
# Initialize A* search
open_set: List[_SearchNode] = []
closed_set: Set[_SchemaState] = set()
start_node = _SearchNode(
state=start_state,
path=[],
g_cost=0,
h_cost=_heuristic_cost(start_state, target_state),
)
heapq.heappush(open_set, start_node)
while open_set:
current_node = heapq.heappop(open_set)
if current_node.state in closed_set:
continue
closed_set.add(current_node.state)
# Check if we've reached the goal - all target fields must match exactly
if _heuristic_cost(current_node.state, target_state) == 0:
# Add post-processing converters for final renaming and deletion
post_processing, final_state = _create_post_processing_for_semantic(
current_node.state, target_state
)
return current_node.path + post_processing, final_state
# Explore neighbors
for converter, new_state in _get_applicable_converters(
semantic,
current_node.state,
target_state,
current_node.g_cost,
):
if new_state in closed_set:
continue
new_path = current_node.path + [converter]
new_g_cost = current_node.g_cost + 1 # Each converter has cost 1
new_h_cost = _heuristic_cost(new_state, target_state)
new_node = _SearchNode(
state=new_state, path=new_path, g_cost=new_g_cost, h_cost=new_h_cost
)
heapq.heappush(open_set, new_node)
# No path found
missing_fields = set(target_state.field_to_attr_spec.keys()) - set(
start_state.field_to_attr_spec.keys()
)
raise ConversionError(
f"No conversion path found for semantic {semantic}. " f"Missing fields: {missing_fields}"
)
[docs]
def find_conversion_path(
from_schema: Schema, to_schema: Schema
) -> Tuple[ConversionPaths, Dict[str, Categories]]:
"""
Find an optimal sequence of converters using A* search, grouped by semantic.
Fields with the same semantic can be converted between each other, but
conversion across semantic boundaries is not allowed.
Args:
from_schema: Source schema
to_schema: Target schema
Returns:
Tuple of (ConversionPaths with separated batch and lazy converter lists,
dictionary of attribute names to inferred categories)
Raises:
ConversionError: If no conversion path is found
"""
# Group fields by semantic in both schemas
start_groups = _group_fields_by_semantic(from_schema)
target_groups = _group_fields_by_semantic(to_schema)
# Collect all converters needed across all semantic groups
all_converters: List[Converter] = []
attr_mappings = []
# Process each semantic group in the target schema
for semantic, target_state in target_groups.items():
# Get corresponding source state for this semantic (if any)
start_state = start_groups.get(semantic, _SchemaState({}))
# Find conversion path for this semantic group
semantic_converters, updated_target_state = _find_conversion_path_for_semantic(
start_state, target_state, semantic
)
# Update the target state with any inferred categories
target_groups[semantic] = updated_target_state
# Merge all the attribute remappers into a single one. If any, the remapper is always the last step.
if semantic_converters and isinstance(semantic_converters[-1], AttributeRemapperConverter):
attr_mappings += semantic_converters[-1].attr_mappings
semantic_converters = semantic_converters[:-1]
all_converters.extend(semantic_converters)
if attr_mappings:
all_converters.append(AttributeRemapperConverter(attr_mappings=attr_mappings))
# Reconstruct the updated schema with inferred categories
inferred_categories: dict[str, Categories] = {}
for semantic, updated_target_state in target_groups.items():
for attr_spec in updated_target_state.field_to_attr_spec.values():
if attr_spec.categories is not None:
inferred_categories[attr_spec.name] = attr_spec.categories
# Separate batch and lazy converters
conversion_paths = _separate_batch_and_lazy_converters(all_converters)
return conversion_paths, inferred_categories
def _separate_batch_and_lazy_converters(
conversion_path: List[Converter],
) -> ConversionPaths:
"""
Separate converters into batch and lazy lists based on dependencies.
If a converter is lazy, all converters that depend on its output must also be lazy.
Args:
conversion_path: The complete conversion path from A* search
Returns:
ConversionPaths with separated batch and lazy converter lists
"""
if not conversion_path:
return ConversionPaths(batch_converters=[], lazy_converters=[])
# Track which converters must be lazy
lazy_indices: Set[int] = set()
lazy_fields: dict[str, bool] = defaultdict(
bool
) # Maps fields whether they were produced lazily
for i, converter in enumerate(conversion_path):
lazy = False
if converter.lazy:
# Mark all intrinsically lazy converters as lazy
lazy = True
else:
# Check whether the converter depends on a lazy converter
input_specs = converter.get_input_attr_specs()
for attr_spec in input_specs:
if attr_spec.name in lazy_fields:
lazy = True
break
if lazy:
lazy_indices.add(i)
# Mark all output fields as lazy
output_specs = converter.get_output_attr_specs()
for attr_spec in output_specs:
lazy_fields[attr_spec.name] = True
# Separate into batch and lazy lists
batch_converters: List[Converter] = []
lazy_converters: List[Converter] = []
for i, converter in enumerate(conversion_path):
if i in lazy_indices:
lazy_converters.append(converter)
else:
batch_converters.append(converter)
return ConversionPaths(batch_converters=batch_converters, lazy_converters=lazy_converters)