Source code for otx.backend.native.cli.utils
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""OTX APIs for User-friendliness."""
from __future__ import annotations
import fnmatch
import importlib
import inspect
import textwrap
from pathlib import Path
from otx.types.task import OTXTaskType
[docs]
def get_otx_root_path() -> Path:
"""Return the root path of the otx module.
Returns:
str: The root path of the otx module.
Raises:
ModuleNotFoundError: If the otx module is not found.
"""
otx_module = importlib.import_module("otx")
if otx_module:
file_path = inspect.getfile(otx_module)
return Path(file_path).parent
msg = "Cannot found otx."
raise ModuleNotFoundError(msg)
RECIPE_PATH = get_otx_root_path() / "recipe"
[docs]
def list_models(task: OTXTaskType | None = None, pattern: str | None = None, print_table: bool = False) -> list[str]:
"""Returns a list of available models for training.
Args:
task (OTXTaskType | None, optional): Recipe Filter by Task.
pattern (Optional[str], optional): A string pattern to filter the list of available models. Defaults to None.
print_table (bool, optional): Output the recipe information as a Rich.Table.
This is primarily used for `otx find` in the CLI.
Returns:
list[str]: A list of available models for pretraining.
Example:
# Return all available model list.
>>> models = list_models()
>>> models
['atss_mobilenetv2', 'atss_r50_fpn', ...]
# Return INSTANCE_SEGMENTATION model list.
>>> models = list_models(task="INSTANCE_SEGMENTATION")
>>> models
['maskrcnn_efficientnetb2b', 'maskrcnn_r50', 'maskrcnn_swint', 'openvino_model']
# Return all available model list that matches the pattern.
>>> models = list_models(task="MULTI_CLASS_CLS", pattern="*efficient")
>>> models
['efficientnet_v2', 'efficientnet_b0', ...]
# Print the recipe information as a Rich.Table (include task, model name, recipe path)
>>> models = list_models(task="MULTI_CLASS_CLS", pattern="*efficient", print_table=True)
"""
task_type = OTXTaskType(task).name.lower() if task is not None else "**"
recipe_list = [
str(recipe) for recipe in RECIPE_PATH.glob(f"**/{task_type}/**/*.yaml") if "_base_" not in recipe.parts
]
if pattern is not None:
# Always match keys with any postfix.
recipe_list = list(set(fnmatch.filter(recipe_list, f"*{pattern}*")))
if print_table:
from rich.console import Console
from rich.table import Table
console = Console()
table = Table(title="OTX Recipes", show_header=True, header_style="bold magenta")
table.add_column("Task")
table.add_column("Model Name")
table.add_column("Recipe Path")
for recipe in recipe_list:
recipe_path = (
textwrap.fill(recipe, width=int(console.width / 2)) if len(recipe) > console.width / 2 else recipe
)
table.add_row(
recipe.split("/")[-2].upper(),
Path(recipe).stem,
recipe_path,
)
console.print(table, width=console.width, justify="center")
return list({Path(recipe).stem for recipe in recipe_list})