Source code for gunz_ml.callbacks

# import os
import typing as t
from gunz_ml import integrations as ml_helpers
from lightning.pytorch.callbacks import (
    StochasticWeightAveraging,
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    DeviceStatsMonitor
)
# from .early_stopping import EarlyStopping
from .optuna import PyTorchLightningPruningCallback, OptunaPruningCallback
from .early_stopping_optuna import EarlyStoppingOptuna
from .gradient_monitor import GradientMonitor
from .dead_neuron_monitor import DeadNeuronMonitor
from .prediction_dynamics import PredictionDynamicsMonitor
from .weight_update_monitor import WeightUpdateMonitor
from omegaconf import DictConfig, OmegaConf

DEF_CALLBACKS = {
    #? Lightning specific
    "EarlyStopping": EarlyStopping,
    "LearningRateMonitor": LearningRateMonitor,
    "ModelCheckpoint": ModelCheckpoint,
    "StochasticWeightAveraging": StochasticWeightAveraging,
    "DeviceStatsMonitor": DeviceStatsMonitor,
    #? Optuna specific
    "PyTorchLightningPruningCallback": PyTorchLightningPruningCallback,
    "OptunaPruningCallback": OptunaPruningCallback,
    #? Analysis specific
    "GradientMonitor": GradientMonitor,
    "DeadNeuronMonitor": DeadNeuronMonitor,
    "PredictionDynamicsMonitor": PredictionDynamicsMonitor,
    "WeightUpdateMonitor": WeightUpdateMonitor,
}

[docs] def init_callbacks( cfg:DictConfig, callback_cls_dict:t.Dict, tmp_dir=None ): """Initialize callbacks based on the provided configuration. Parameters ---------- cfg : DictConfig Configuration containing callback names and their corresponding parameters. callback_cls_dict : Dict Dictionary mapping callback names to their respective classes. tmp_dir : Optional[str], optional Temporary directory path to save model checkpoints. If None, model checkpoints are saved to the current directory. Defaults to None. Returns ------- List List of initialized callback instances. Raises ------ ValueError If an invalid callback name is encountered in the configuration. Notes ----- This function initializes callbacks based on the provided configuration. Each callback is instantiated with its specified parameters. """ callbacks = [] if cfg is None: return callbacks for callback_name, callback_kwargs in cfg.items(): if callback_name not in callback_cls_dict: raise ValueError(f"Invalid callback:{callback_name}") #? By default (if None) give a dictionary if callback_kwargs is None: callback_kwargs = dict() #? Convert Hydra config to dict and list elif isinstance(callback_kwargs, DictConfig): callback_kwargs = ml_helpers.resolve_cfg(callback_kwargs) if callback_name == 'ModelCheckpoint': callback_kwargs['dirpath'] = tmp_dir callback = callback_cls_dict[callback_name](**callback_kwargs) callbacks.append(callback) return callbacks