# import os
import typing as t
from gunz_ml import integrations as ml_helpers
from lightning.pytorch.callbacks import (
StochasticWeightAveraging,
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
DeviceStatsMonitor
)
# from .early_stopping import EarlyStopping
from .optuna import PyTorchLightningPruningCallback, OptunaPruningCallback
from .early_stopping_optuna import EarlyStoppingOptuna
from .gradient_monitor import GradientMonitor
from .dead_neuron_monitor import DeadNeuronMonitor
from .prediction_dynamics import PredictionDynamicsMonitor
from .weight_update_monitor import WeightUpdateMonitor
from omegaconf import DictConfig, OmegaConf
DEF_CALLBACKS = {
#? Lightning specific
"EarlyStopping": EarlyStopping,
"LearningRateMonitor": LearningRateMonitor,
"ModelCheckpoint": ModelCheckpoint,
"StochasticWeightAveraging": StochasticWeightAveraging,
"DeviceStatsMonitor": DeviceStatsMonitor,
#? Optuna specific
"PyTorchLightningPruningCallback": PyTorchLightningPruningCallback,
"OptunaPruningCallback": OptunaPruningCallback,
#? Analysis specific
"GradientMonitor": GradientMonitor,
"DeadNeuronMonitor": DeadNeuronMonitor,
"PredictionDynamicsMonitor": PredictionDynamicsMonitor,
"WeightUpdateMonitor": WeightUpdateMonitor,
}
[docs]
def init_callbacks(
cfg:DictConfig,
callback_cls_dict:t.Dict,
tmp_dir=None
):
"""Initialize callbacks based on the provided configuration.
Parameters
----------
cfg : DictConfig
Configuration containing callback names and their corresponding parameters.
callback_cls_dict : Dict
Dictionary mapping callback names to their respective classes.
tmp_dir : Optional[str], optional
Temporary directory path to save model checkpoints. If None, model checkpoints
are saved to the current directory. Defaults to None.
Returns
-------
List
List of initialized callback instances.
Raises
------
ValueError
If an invalid callback name is encountered in the configuration.
Notes
-----
This function initializes callbacks based on the provided configuration.
Each callback is instantiated with its specified parameters.
"""
callbacks = []
if cfg is None:
return callbacks
for callback_name, callback_kwargs in cfg.items():
if callback_name not in callback_cls_dict:
raise ValueError(f"Invalid callback:{callback_name}")
#? By default (if None) give a dictionary
if callback_kwargs is None:
callback_kwargs = dict()
#? Convert Hydra config to dict and list
elif isinstance(callback_kwargs, DictConfig):
callback_kwargs = ml_helpers.resolve_cfg(callback_kwargs)
if callback_name == 'ModelCheckpoint':
callback_kwargs['dirpath'] = tmp_dir
callback = callback_cls_dict[callback_name](**callback_kwargs)
callbacks.append(callback)
return callbacks