# Copyright (C) 2025 Intel Corporation
#
# SPDX-License-Identifier: MIT
"""
Field implementations for various data types including tensors, images, and bounding boxes.
This module provides concrete field implementations that handle serialization
to/from Polars DataFrames for different data types commonly used in machine
learning and computer vision applications.
"""
from dataclasses import dataclass
from typing import Any, TypeVar
import numpy as np
import polars as pl
from .schema import Field, Semantic
from .type_registry import from_polars_data, to_numpy
T = TypeVar("T")
[docs]
@dataclass(frozen=True)
class TensorField(Field):
"""
Represents a tensor field with semantic tags and data type information.
This field handles n-dimensional tensor data by flattening it for storage
and preserving shape information separately for reconstruction.
Attributes:
semantic: Semantic tags describing the tensor's purpose
dtype: Polars data type for tensor elements
"""
semantic: Semantic
dtype: Any
[docs]
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
"""Generate Polars schema with separate columns for data and shape."""
return {name: pl.List(self.dtype), name + "_shape": pl.List(pl.Int32())}
[docs]
def to_polars(self, name: str, value: Any) -> dict[str, pl.Series]:
"""Convert tensor to flattened data and shape information."""
numpy_value = to_numpy(value, self.dtype)
return {
name: pl.Series(name, [numpy_value.reshape(-1)]),
name + "_shape": pl.Series(name + "_shape", [numpy_value.shape]),
}
[docs]
def from_polars(self, name: str, row_index: int, df: pl.DataFrame, target_type: type[T]) -> T:
"""Reconstruct tensor from flattened data using stored shape."""
flat_data = df[name][row_index]
shape = df[name + "_shape"][row_index]
numpy_data = np.array(flat_data).reshape(shape)
return from_polars_data(numpy_data, target_type) # type: ignore
[docs]
def tensor_field(dtype: Any, semantic: Semantic = Semantic.Default) -> Any:
"""
Create a TensorField instance with the specified semantic tags and data type.
Args:
dtype: Polars data type for tensor elements
semantic: Semantic tags describing the tensor's purpose (optional)
Returns:
TensorField instance configured with the given parameters
"""
return TensorField(semantic=semantic, dtype=dtype)
[docs]
@dataclass(frozen=True)
class ImageField(TensorField):
"""
Represents an image tensor field with format information.
Extends TensorField to include image-specific metadata such as
color format (RGB, BGR, etc.).
Attributes:
format: Image color format (e.g., "RGB", "BGR", "RGBA")
"""
format: str
[docs]
def image_field(dtype: Any, format: str = "RGB", semantic: Semantic = Semantic.Default) -> Any:
"""
Create an ImageField instance with the specified parameters.
Args:
dtype: Polars data type for pixel values
format: Image color format (defaults to "RGB")
semantic: Semantic tags describing the image's purpose (optional)
Returns:
ImageField instance configured with the given parameters
"""
return ImageField(semantic=semantic, dtype=dtype, format=format)
[docs]
@dataclass(frozen=True)
class BBoxField(Field):
"""
Represents a bounding box field with format and normalization options.
Handles bounding box data with support for different coordinate formats
and optional normalization to [0,1] range.
Attributes:
semantic: Semantic tags describing the bounding box purpose
dtype: Polars data type for coordinate values
format: Coordinate format (e.g., "x1y1x2y2", "xywh")
normalize: Whether coordinates are normalized to [0,1] range
"""
semantic: Semantic
dtype: Any
format: str
normalize: bool
[docs]
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
"""Generate schema for bounding box as list of 4-element arrays."""
return {name: pl.List(pl.Array(self.dtype, 4))}
[docs]
def to_polars(self, name: str, value: Any) -> dict[str, pl.Series]:
"""Convert bounding box tensor to Polars list format."""
numpy_value = to_numpy(value, self.dtype)
return {
name: pl.Series(
name,
numpy_value.reshape(1, -1, 4),
dtype=pl.List(pl.Array(self.dtype, 4)),
)
}
[docs]
def from_polars(self, name: str, row_index: int, df: pl.DataFrame, target_type: type[T]) -> T:
"""Reconstruct bounding box tensor from Polars data."""
polars_data = df[name][row_index]
return from_polars_data(polars_data, target_type) # type: ignore
[docs]
def bbox_field(
dtype: Any,
format: str = "x1y1x2y2",
normalize: bool = False,
semantic: Semantic = Semantic.Default,
) -> Any:
"""
Create a BBoxField instance with the specified parameters.
Args:
dtype: Polars data type for coordinate values
format: Coordinate format (defaults to "x1y1x2y2")
normalize: Whether coordinates are normalized (defaults to False)
semantic: Semantic tags describing the bounding box purpose (optional)
Returns:
BBoxField instance configured with the given parameters
"""
return BBoxField(semantic=semantic, dtype=dtype, format=format, normalize=normalize)
[docs]
@dataclass
class ImageInfo:
"""Container for image metadata (width and height)."""
width: int
height: int
[docs]
@dataclass(frozen=True)
class ImageInfoField(Field):
"""
Represents image metadata (width, height) as a Polars struct.
"""
semantic: Semantic
[docs]
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
return {name: pl.Struct([pl.Field("width", pl.Int32()), pl.Field("height", pl.Int32())])}
[docs]
def to_polars(self, name: str, value: ImageInfo) -> dict[str, pl.Series]:
return {name: pl.Series(name, [{"width": value.width, "height": value.height}])}
[docs]
def from_polars(
self, name: str, row_index: int, df: pl.DataFrame, target_type: type
) -> ImageInfo:
if not issubclass(target_type, ImageInfo):
raise TypeError(f"Expected target_type to be ImageInfo, got {target_type}")
struct_val = df[name][row_index]
return ImageInfo(width=struct_val["width"], height=struct_val["height"])
[docs]
def image_info_field(semantic: Semantic = Semantic.Default) -> Any:
"""
Create an ImageInfoField instance for storing image width and height.
Args:
semantic: Optional semantic tags for disambiguation (e.g., Semantic.Left)
Returns:
ImageInfoField instance configured with the given semantic tags
"""
return ImageInfoField(semantic=semantic)
[docs]
@dataclass(frozen=True)
class ImagePathField(Field):
"""
Represents a field containing the file path to an image on disk.
This field stores image file paths as strings and is typically used
as input for lazy loading operations where images are loaded on-demand.
Attributes:
semantic: Semantic tags describing the image path's purpose
"""
semantic: Semantic
[docs]
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
"""Generate schema for string path column."""
return {name: pl.String()}
[docs]
def to_polars(self, name: str, value: Any) -> dict[str, pl.Series]:
"""Convert path string to Polars series."""
return {name: pl.Series(name, [str(value)])}
[docs]
def from_polars(self, name: str, row_index: int, df: pl.DataFrame, target_type: type):
"""Extract path string from Polars data."""
return target_type(df[name][row_index])
[docs]
def image_path_field(semantic: Semantic = Semantic.Default) -> Any:
"""
Create an ImagePathField instance with the specified semantic tags.
Args:
semantic: Semantic tags describing the image path's purpose (optional)
Returns:
ImagePathField instance configured with the given semantic tags
"""
return ImagePathField(semantic=semantic)