Source code for otx.backend.native.callbacks.batchsize_finder
"""Callback that finds the optimal batch size."""
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import TYPE_CHECKING
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers.logger import DummyLogger
if TYPE_CHECKING:
from lightning import LightningModule
from lightning.pytorch.trainer import Trainer
[docs]
class BatchSizeFinder(Callback):
"""This callback makes trainer run specified iteration and exit.
Args:
steps_per_trial: number of steps to run with a given batch size.
Ideally 1 should be enough to test if an OOM error occurs, however in practice a few are needed.
"""
def __init__(
self,
steps_per_trial: int = 3,
) -> None:
self._steps_per_trial = steps_per_trial
[docs]
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str | None = None) -> None:
"""Check current stage is fit."""
if stage != "fit":
msg = "Adaptive batch size supports only training."
raise RuntimeError(msg)
[docs]
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Run steps_per_trial iterations and exit."""
_scale_batch_reset_params(trainer, self._steps_per_trial)
_try_loop_run(trainer)
def _try_loop_run(trainer: Trainer) -> None:
loop = trainer._active_loop # noqa: SLF001
if loop is None:
msg = "There is no active loop."
raise RuntimeError(msg)
loop.restarting = False
loop.run()
def _scale_batch_reset_params(trainer: Trainer, steps_per_trial: int) -> None:
trainer.logger = DummyLogger() if trainer.logger is not None else None
trainer.callbacks = []
loop = trainer._active_loop # noqa: SLF001
if loop is None:
msg = "There is no active loop."
raise RuntimeError(msg)
if trainer.fit_loop.epoch_loop.max_steps == -1: # epoch based loop
trainer.fit_loop.max_epochs = 1
trainer.limit_train_batches = steps_per_trial
else: # iter based loop
trainer.fit_loop.epoch_loop.max_steps = steps_per_trial
trainer.limit_train_batches = 1.0
if trainer.limit_val_batches != 0:
trainer.limit_val_batches = steps_per_trial