gunz_ml.callbacks package

gunz_ml.callbacks.init_callbacks(cfg: omegaconf.DictConfig, callback_cls_dict: Dict, tmp_dir=None)[source]

Initialize callbacks based on the provided configuration.

Parameters:
  • cfg (DictConfig) – Configuration containing callback names and their corresponding parameters.

  • callback_cls_dict (Dict) – Dictionary mapping callback names to their respective classes.

  • tmp_dir (Optional[str], optional) – Temporary directory path to save model checkpoints. If None, model checkpoints are saved to the current directory. Defaults to None.

Returns:

List of initialized callback instances.

Return type:

List

Raises:

ValueError – If an invalid callback name is encountered in the configuration.

Notes

This function initializes callbacks based on the provided configuration. Each callback is instantiated with its specified parameters.

Submodules

gunz_ml.callbacks.dead_neuron_monitor module

Callback for monitoring dead neurons in activation layers.

class gunz_ml.callbacks.dead_neuron_monitor.DeadNeuronMonitor(*args: Any, **kwargs: Any)[source]

Bases: Callback

Monitors the fraction of ‘dead’ neurons in activation layers.

A neuron is considered dead if it outputs zero (or below a threshold) for all examples in a batch. High rates of dead neurons may indicate issues with initialization, learning rate, or optimizers.

Parameters:
  • log_every_n_steps (int) – Frequency of logging in steps. Defaults to 100.

  • threshold (float) – Activation threshold below which a neuron is considered inactive. Defaults to 0.0 (strict zero for ReLU).

on_train_end(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule)[source]

Clean up hooks.

on_train_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule)[source]

Attaches hooks to all ReLU/LeakyReLU/GELU layers.

gunz_ml.callbacks.early_stopping module

Early Stopping

Monitor a metric and stop training when it stops improving.

class gunz_ml.callbacks.early_stopping.EarlyStopping(monitor: str, min_delta: float = 0.0, patience: int = 3, verbose: bool = False, mode: str = 'min', strict: bool = True, check_finite: bool = True, stopping_threshold: float | None = None, divergence_threshold: float | None = None, check_on_train_epoch_end: bool | None = None, log_rank_zero_only: bool = False)[source]

Bases: Callback

Monitor a metric and stop training when it stops improving.

Parameters:
  • monitor (str) – Quantity to be monitored.

  • min_delta (float) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.

  • patience (int) – Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the Trainer, for example check_val_every_n_epoch and val_check_interval.

  • verbose (bool) – Verbosity mode.

  • mode (str) – One of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored has stopped decreasing and in 'max' mode it will stop when the quantity monitored has stopped increasing.

  • strict (bool) – Whether to crash the training if monitor is not found in the validation metrics.

  • check_finite (bool) – When set True, stops training when the monitor becomes NaN or infinite.

  • stopping_threshold (float, optional) – Stop training immediately once the monitored quantity reaches this threshold.

  • divergence_threshold (float, optional) – Stop training as soon as the monitored quantity becomes worse than this threshold.

  • check_on_train_epoch_end (bool, optional) – Whether to run early stopping at the end of the training epoch. If this is False, then the check runs at the end of the validation.

  • log_rank_zero_only (bool) – When set True, logs the status of the early stopping callback only for rank 0 process.

Raises:
  • MisconfigurationException – If mode is none of "min" or "max".

  • RuntimeError – If the metric monitor is not available.

Examples

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])

Notes

The patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped.

Tip

Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the following arguments: monitor, mode

load_state_dict(state_dict: Dict[str, Any]) None[source]

Load the callback state from a dictionary.

Parameters:

state_dict (Dict[str, Any]) – The state dictionary to restore from.

mode_dict = {'max': <built-in method gt of type object>, 'min': <built-in method lt of type object>}
property monitor_op: Callable

Get the comparison operator based on the mode.

Returns:

Comparison function (e.g., torch.lt or torch.gt).

Return type:

Callable

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train epoch ends.

Parameters:
  • trainer (pl.Trainer) – The Trainer instance.

  • pl_module (pl.LightningModule) – The LightningModule instance.

on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the validation loop ends.

Parameters:
  • trainer (pl.Trainer) – The Trainer instance.

  • pl_module (pl.LightningModule) – The LightningModule instance.

order_dict = {'max': '>', 'min': '<'}
setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]

Called when fit or test begins.

Parameters:
  • trainer (pl.Trainer) – The Trainer instance.

  • pl_module (pl.LightningModule) – The LightningModule instance.

  • stage (str) – The stage (e.g., ‘fit’, ‘validate’, ‘test’, ‘predict’).

state_dict() Dict[str, Any][source]

Get the callback state dictionary.

Returns:

The state dictionary containing wait count, stopped epoch, best score, and patience.

Return type:

Dict[str, Any]

property state_key: str

Generate a unique identifier for the callback state.

Returns:

The state key based on the monitored metric and mode.

Return type:

str

gunz_ml.callbacks.early_stopping_optuna module

A PyTorch Lightning callback for Optuna pruning.

class gunz_ml.callbacks.early_stopping_optuna.EarlyStoppingOptuna(*args: Any, **kwargs: Any)[source]

Bases: EarlyStopping

A PyTorch Lightning callback that prunes an Optuna trial if the early stopping condition is met.

This callback inherits from pytorch_lightning.callbacks.EarlyStopping and overrides the _run_early_stopping_check method. When the early stopping condition is fulfilled, it raises optuna.exceptions.TrialPruned to signal to Optuna that the trial should be pruned.

Parameters:
  • monitor (str) – The quantity to be monitored (e.g., ‘val_loss’).

  • min_delta (float, optional) – Minimum change in the monitored quantity to qualify as an improvement, by default 0.0.

  • patience (int, optional) – Number of checks with no improvement after which training will be stopped, by default 3.

  • verbose (bool, optional) – If True, prints a message for each pruning, by default False.

  • mode (str, optional) – One of {‘min’, ‘max’}. In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing; in ‘max’ mode it will stop when the quantity monitored has stopped increasing, by default “min”.

  • strict (bool, optional) – Whether to crash the training if monitor is not found in the validation metrics, by default True.

  • check_finite (bool, optional) – When set to True, stops training when the monitor becomes NaN or infinite, by default True.

  • stopping_threshold (float, optional) – Stop training immediately once the monitored quantity reaches this threshold, by default None.

  • divergence_threshold (float, optional) – Stop training as soon as the monitored quantity becomes worse than this threshold, by default None.

  • check_on_train_epoch_end (bool, optional) – Whether to run early stopping at the end of the training epoch. If this is False, then the check runs at the end of the validation epoch, by default None.

gunz_ml.callbacks.gradient_monitor module

Callback for monitoring gradient norms during training.

class gunz_ml.callbacks.gradient_monitor.GradientMonitor(*args: Any, **kwargs: Any)[source]

Bases: Callback

Logs the L2 norm of gradients for model parameters.

This callback inspects the gradients after the backward pass and logs their norms to the configured logger(s). This is useful for detecting vanishing or exploding gradients.

Parameters:
  • log_every_n_steps (int) – Frequency of logging in steps. Defaults to 50.

  • log_per_layer (bool) – If True, logs the norm for every parameter. If False, logs only the total norm of the entire model. Defaults to True.

on_after_backward(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule)[source]

Called after loss.backward() and before optimizers take a step.

gunz_ml.callbacks.optuna module

class gunz_ml.callbacks.optuna.OptunaPruningCallback(*args: Any, **kwargs: Any)[source]

Bases: Callback

An Optuna pruning callback for PyTorch Lightning running on validation end.

Parameters


trial : optuna.trial.Trial

The Optuna trial object.

metric : str

The metric to monitor for pruning.

sub_metric : str, optional

The sub-metric to monitor for pruning, if metric is a dictionary.

Defaults to None.

start_epoch : int, optional

The epoch to start pruning. Defaults to 1.

Attributes


_trial : optuna.trial.Trial

The Optuna trial object.

metric : str

The metric to monitor for pruning.

sub_metric : str, optional

The sub-metric to monitor for pruning.

start_epoch : int

The epoch to start pruning.

on_validation_end(trainer: lightning.Trainer, pl_module: lightning.LightningModule) None[source]

Called when the validation loop ends.

Parameters


trainer : L.Trainer

The PyTorch Lightning trainer instance.

pl_module : L.LightningModule

The PyTorch Lightning module.

Raises


optuna.exceptions.TrialPruned

If the trial should be pruned.

class gunz_ml.callbacks.optuna.PyTorchLightningPruningCallback(*args: Any, **kwargs: Any)[source]

Bases: Callback

A PyTorch Lightning callback for Optuna pruning.

Parameters


trial : optuna.trial.Trial

The Optuna trial object.

monitor : str

The metric to monitor for pruning.

Attributes


_trial : optuna.trial.Trial

The Optuna trial object.

monitor : str

The metric to monitor for pruning.

is_ddp_backend : bool

Whether to use DDP backend. Defaults to False.

on_train_epoch_end(trainer: lightning.Trainer, pl_module: lightning.LightningModule) None[source]

Called when the training epoch ends.

Parameters


trainer : L.Trainer

The PyTorch Lightning trainer instance.

pl_module : L.LightningModule

The PyTorch Lightning module.

Raises


optuna.exceptions.TrialPruned

If the trial should be pruned based on the intermediate value.

gunz_ml.callbacks.prediction_dynamics module

Callback for monitoring prediction dynamics (entropy, confidence).

class gunz_ml.callbacks.prediction_dynamics.PredictionDynamicsMonitor(*args: Any, **kwargs: Any)[source]

Bases: Callback

Monitors the distribution of model predictions to detect mode collapse or random guessing.

Notes

This callback logs the following metrics: - dynamics/pred_entropy: Average entropy of the prediction distribution.

Low entropy (~0) indicates mode collapse (always predicts same class), while high entropy suggests random guessing.

  • dynamics/pred_confidence: Average max probability.

Parameters:

log_every_n_steps (int) – Frequency of logging in steps. Defaults to 50.

forward_hook(module, input, output)[source]
on_train_end(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule)[source]
on_train_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule)[source]

gunz_ml.callbacks.weight_update_monitor module

Callback for monitoring weight update ratios.

class gunz_ml.callbacks.weight_update_monitor.WeightUpdateMonitor(*args: Any, **kwargs: Any)[source]

Bases: Callback

Monitors the ratio of parameter updates to parameter magnitude.

Ratio = (Learning Rate * Gradient Norm) / Parameter Norm

  • Extremely small values (< 1e-5): Model convergence or stuck.

  • Extremely large values (> 1e-2): Instability.

Parameters:

log_every_n_steps (int) – Frequency of logging in steps. Defaults to 100.

on_before_optimizer_step(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, optimizer)[source]