gunz_ml.callbacks package
- gunz_ml.callbacks.init_callbacks(cfg: omegaconf.DictConfig, callback_cls_dict: Dict, tmp_dir=None)[source]
Initialize callbacks based on the provided configuration.
- Parameters:
cfg (DictConfig) – Configuration containing callback names and their corresponding parameters.
callback_cls_dict (Dict) – Dictionary mapping callback names to their respective classes.
tmp_dir (Optional[str], optional) – Temporary directory path to save model checkpoints. If None, model checkpoints are saved to the current directory. Defaults to None.
- Returns:
List of initialized callback instances.
- Return type:
List
- Raises:
ValueError – If an invalid callback name is encountered in the configuration.
Notes
This function initializes callbacks based on the provided configuration. Each callback is instantiated with its specified parameters.
Submodules
gunz_ml.callbacks.dead_neuron_monitor module
Callback for monitoring dead neurons in activation layers.
- class gunz_ml.callbacks.dead_neuron_monitor.DeadNeuronMonitor(*args: Any, **kwargs: Any)[source]
Bases:
CallbackMonitors the fraction of ‘dead’ neurons in activation layers.
A neuron is considered dead if it outputs zero (or below a threshold) for all examples in a batch. High rates of dead neurons may indicate issues with initialization, learning rate, or optimizers.
- Parameters:
gunz_ml.callbacks.early_stopping module
Early Stopping
Monitor a metric and stop training when it stops improving.
- class gunz_ml.callbacks.early_stopping.EarlyStopping(monitor: str, min_delta: float = 0.0, patience: int = 3, verbose: bool = False, mode: str = 'min', strict: bool = True, check_finite: bool = True, stopping_threshold: float | None = None, divergence_threshold: float | None = None, check_on_train_epoch_end: bool | None = None, log_rank_zero_only: bool = False)[source]
Bases:
CallbackMonitor a metric and stop training when it stops improving.
- Parameters:
monitor (str) – Quantity to be monitored.
min_delta (float) – Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.
patience (int) – Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the
Trainer, for examplecheck_val_every_n_epochandval_check_interval.verbose (bool) – Verbosity mode.
mode (str) – One of
'min','max'. In'min'mode, training will stop when the quantity monitored has stopped decreasing and in'max'mode it will stop when the quantity monitored has stopped increasing.strict (bool) – Whether to crash the training if monitor is not found in the validation metrics.
check_finite (bool) – When set
True, stops training when the monitor becomes NaN or infinite.stopping_threshold (float, optional) – Stop training immediately once the monitored quantity reaches this threshold.
divergence_threshold (float, optional) – Stop training as soon as the monitored quantity becomes worse than this threshold.
check_on_train_epoch_end (bool, optional) – Whether to run early stopping at the end of the training epoch. If this is
False, then the check runs at the end of the validation.log_rank_zero_only (bool) – When set
True, logs the status of the early stopping callback only for rank 0 process.
- Raises:
MisconfigurationException – If
modeis none of"min"or"max".RuntimeError – If the metric
monitoris not available.
Examples
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping])
Notes
The patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters
check_val_every_n_epoch=10andpatience=3, the trainer will perform at least 40 training epochs before being stopped.Tip
Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the following arguments: monitor, mode
- load_state_dict(state_dict: Dict[str, Any]) None[source]
Load the callback state from a dictionary.
- Parameters:
state_dict (Dict[str, Any]) – The state dictionary to restore from.
- mode_dict = {'max': <built-in method gt of type object>, 'min': <built-in method lt of type object>}
- property monitor_op: Callable
Get the comparison operator based on the mode.
- Returns:
Comparison function (e.g., torch.lt or torch.gt).
- Return type:
Callable
- on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]
Called when the train epoch ends.
- Parameters:
trainer (pl.Trainer) – The Trainer instance.
pl_module (pl.LightningModule) – The LightningModule instance.
- on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]
Called when the validation loop ends.
- Parameters:
trainer (pl.Trainer) – The Trainer instance.
pl_module (pl.LightningModule) – The LightningModule instance.
- order_dict = {'max': '>', 'min': '<'}
- setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]
Called when fit or test begins.
- Parameters:
trainer (pl.Trainer) – The Trainer instance.
pl_module (pl.LightningModule) – The LightningModule instance.
stage (str) – The stage (e.g., ‘fit’, ‘validate’, ‘test’, ‘predict’).
gunz_ml.callbacks.early_stopping_optuna module
A PyTorch Lightning callback for Optuna pruning.
- class gunz_ml.callbacks.early_stopping_optuna.EarlyStoppingOptuna(*args: Any, **kwargs: Any)[source]
Bases:
EarlyStoppingA PyTorch Lightning callback that prunes an Optuna trial if the early stopping condition is met.
This callback inherits from pytorch_lightning.callbacks.EarlyStopping and overrides the _run_early_stopping_check method. When the early stopping condition is fulfilled, it raises optuna.exceptions.TrialPruned to signal to Optuna that the trial should be pruned.
- Parameters:
monitor (str) – The quantity to be monitored (e.g., ‘val_loss’).
min_delta (float, optional) – Minimum change in the monitored quantity to qualify as an improvement, by default 0.0.
patience (int, optional) – Number of checks with no improvement after which training will be stopped, by default 3.
verbose (bool, optional) – If True, prints a message for each pruning, by default False.
mode (str, optional) – One of {‘min’, ‘max’}. In ‘min’ mode, training will stop when the quantity monitored has stopped decreasing; in ‘max’ mode it will stop when the quantity monitored has stopped increasing, by default “min”.
strict (bool, optional) – Whether to crash the training if monitor is not found in the validation metrics, by default True.
check_finite (bool, optional) – When set to True, stops training when the monitor becomes NaN or infinite, by default True.
stopping_threshold (float, optional) – Stop training immediately once the monitored quantity reaches this threshold, by default None.
divergence_threshold (float, optional) – Stop training as soon as the monitored quantity becomes worse than this threshold, by default None.
check_on_train_epoch_end (bool, optional) – Whether to run early stopping at the end of the training epoch. If this is False, then the check runs at the end of the validation epoch, by default None.
gunz_ml.callbacks.gradient_monitor module
Callback for monitoring gradient norms during training.
- class gunz_ml.callbacks.gradient_monitor.GradientMonitor(*args: Any, **kwargs: Any)[source]
Bases:
CallbackLogs the L2 norm of gradients for model parameters.
This callback inspects the gradients after the backward pass and logs their norms to the configured logger(s). This is useful for detecting vanishing or exploding gradients.
- Parameters:
gunz_ml.callbacks.optuna module
- class gunz_ml.callbacks.optuna.OptunaPruningCallback(*args: Any, **kwargs: Any)[source]
Bases:
CallbackAn Optuna pruning callback for PyTorch Lightning running on validation end.
Parameters
trial : optuna.trial.Trial
The Optuna trial object.
metric : str
The metric to monitor for pruning.
sub_metric : str, optional
The sub-metric to monitor for pruning, if metric is a dictionary.
Defaults to None.
start_epoch : int, optional
The epoch to start pruning. Defaults to 1.
Attributes
_trial : optuna.trial.Trial
The Optuna trial object.
metric : str
The metric to monitor for pruning.
sub_metric : str, optional
The sub-metric to monitor for pruning.
start_epoch : int
The epoch to start pruning.
- on_validation_end(trainer: lightning.Trainer, pl_module: lightning.LightningModule) None[source]
Called when the validation loop ends.
Parameters
trainer : L.Trainer
The PyTorch Lightning trainer instance.
pl_module : L.LightningModule
The PyTorch Lightning module.
Raises
optuna.exceptions.TrialPruned
If the trial should be pruned.
- class gunz_ml.callbacks.optuna.PyTorchLightningPruningCallback(*args: Any, **kwargs: Any)[source]
Bases:
CallbackA PyTorch Lightning callback for Optuna pruning.
Parameters
trial : optuna.trial.Trial
The Optuna trial object.
monitor : str
The metric to monitor for pruning.
Attributes
_trial : optuna.trial.Trial
The Optuna trial object.
monitor : str
The metric to monitor for pruning.
is_ddp_backend : bool
Whether to use DDP backend. Defaults to False.
- on_train_epoch_end(trainer: lightning.Trainer, pl_module: lightning.LightningModule) None[source]
Called when the training epoch ends.
Parameters
trainer : L.Trainer
The PyTorch Lightning trainer instance.
pl_module : L.LightningModule
The PyTorch Lightning module.
Raises
optuna.exceptions.TrialPruned
If the trial should be pruned based on the intermediate value.
gunz_ml.callbacks.prediction_dynamics module
Callback for monitoring prediction dynamics (entropy, confidence).
- class gunz_ml.callbacks.prediction_dynamics.PredictionDynamicsMonitor(*args: Any, **kwargs: Any)[source]
Bases:
CallbackMonitors the distribution of model predictions to detect mode collapse or random guessing.
Notes
This callback logs the following metrics: - dynamics/pred_entropy: Average entropy of the prediction distribution.
Low entropy (~0) indicates mode collapse (always predicts same class), while high entropy suggests random guessing.
dynamics/pred_confidence: Average max probability.
- Parameters:
log_every_n_steps (int) – Frequency of logging in steps. Defaults to 50.
gunz_ml.callbacks.weight_update_monitor module
Callback for monitoring weight update ratios.
- class gunz_ml.callbacks.weight_update_monitor.WeightUpdateMonitor(*args: Any, **kwargs: Any)[source]
Bases:
CallbackMonitors the ratio of parameter updates to parameter magnitude.
Ratio = (Learning Rate * Gradient Norm) / Parameter Norm
Extremely small values (< 1e-5): Model convergence or stuck.
Extremely large values (> 1e-2): Instability.
- Parameters:
log_every_n_steps (int) – Frequency of logging in steps. Defaults to 100.