Source code for otx.backend.native.tools.adaptive_bs.runner

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Algorithm to find a proper batch size which is fit to current GPU device."""

from __future__ import annotations

import logging
import os
from functools import partial
from math import sqrt
from typing import TYPE_CHECKING, Any

from lightning import Callback
from torch.cuda import is_available as is_cuda_available

from otx.backend.native.callbacks import BatchSizeFinder
from otx.utils.device import is_xpu_available

from .algorithm import BsSearchAlgo

if TYPE_CHECKING:
    from otx.backend.native.engine import OTXEngine

logger = logging.getLogger(__name__)


[docs] def adapt_batch_size( engine: OTXEngine, not_increase: bool = True, callbacks: list[Callback] | Callback | None = None, **train_args, ) -> None: """Change the actual batch size depending on the current GPU status. If not_increase is True, check current batch size is available to GPU and if not, decrease batch size. If not_increase is False, increase batch size to use most of GPU memory. Args: engine (OTXEngine): engine instnace. not_increase (bool) : Whether adapting batch size to larger value than default value or not. callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None. """ if not (is_cuda_available() or is_xpu_available()): msg = "Adaptive batch size supports CUDA or XPU." raise RuntimeError(msg) engine.model.patch_optimizer_and_scheduler_for_adaptive_bs() default_bs = engine.datamodule.train_subset.batch_size if "ADAPTIVE_BS_FOR_DIST" in os.environ: # main process of distributed training already executes adapt_batch_size new_batch_size = int(os.environ["ADAPTIVE_BS_FOR_DIST"]) if default_bs != new_batch_size: _apply_new_batch_size(engine, new_batch_size) return train_func = partial(_train_model, engine=engine, callbacks=callbacks, **_adjust_train_args(train_args)) bs_search_algo = BsSearchAlgo( train_func=train_func, default_bs=default_bs, max_bs=(len(engine.datamodule.subsets[engine.datamodule.train_subset.subset_name]) // engine.device.devices), ) if not_increase: new_batch_size = bs_search_algo.auto_decrease_batch_size() else: new_batch_size = bs_search_algo.find_big_enough_batch_size() if engine.device.devices != 1: os.environ["ADAPTIVE_BS_FOR_DIST"] = str(new_batch_size) if default_bs != new_batch_size: origin_lr = engine.model.optimizer_callable.optimizer_kwargs["lr"] # type: ignore[attr-defined] _apply_new_batch_size(engine, new_batch_size) msg = ( "Adapting batch size is done.\n" f"Batch size is adapted : {default_bs} -> {new_batch_size}\n" f"learning rate is adapted : {origin_lr} -> {engine.model.optimizer_callable.optimizer_kwargs['lr']}" # type: ignore[attr-defined] ) logger.info(msg) else: logger.info("Adapting batch size is done. Batch size isn't changed.")
def _adjust_train_args(train_args: dict[str, Any]) -> dict[str, Any]: train_args.update(train_args.pop("kwargs", {})) train_args.pop("self", None) train_args.pop("adaptive_bs") return train_args def _train_model(bs: int, engine: OTXEngine, callbacks: list[Callback] | Callback | None = None, **train_args) -> None: if bs <= 0: msg = f"Batch size should be greater than 0, but {bs} is given." raise ValueError(msg) if engine.device.devices != 1: # TODO(Eunwoo): Need to change after device api is updated engine._cache.update(devices=1) # noqa: SLF001 engine.datamodule.train_subset.batch_size = bs engine.train(callbacks=_register_callback(callbacks), **train_args) def _register_callback(callbacks: list[Callback] | Callback | None = None) -> list[Callback]: if isinstance(callbacks, Callback): callbacks = [callbacks] elif callbacks is None: callbacks = [] callbacks.append(BatchSizeFinder()) return callbacks def _apply_new_batch_size(engine: OTXEngine, new_batch_size: int) -> None: origin_bs = engine.datamodule.train_subset.batch_size if new_batch_size == origin_bs: return engine.datamodule.train_subset.batch_size = new_batch_size engine.model.optimizer_callable.optimizer_kwargs["lr"] *= sqrt(new_batch_size / origin_bs) # type: ignore[attr-defined]