Source code for gunz_ml.callbacks.gradient_monitor

"""
Callback for monitoring gradient norms during training.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch

[docs] class GradientMonitor(L.Callback): """ Logs the L2 norm of gradients for model parameters. This callback inspects the gradients after the backward pass and logs their norms to the configured logger(s). This is useful for detecting vanishing or exploding gradients. Parameters ---------- log_every_n_steps : int Frequency of logging in steps. Defaults to 50. log_per_layer : bool If True, logs the norm for every parameter. If False, logs only the total norm of the entire model. Defaults to True. """ def __init__( self, log_every_n_steps: int = 50, log_per_layer: bool = True ): super().__init__() self.log_every_n_steps = log_every_n_steps self.log_per_layer = log_per_layer
[docs] def on_after_backward( self, trainer: L.Trainer, pl_module: L.LightningModule ): """ Called after loss.backward() and before optimizers take a step. """ if (trainer.global_step + 1) % self.log_every_n_steps != 0: return total_norm = 0.0 norms = {} for name, p in pl_module.named_parameters(): if p.grad is not None: #? Compute L2 norm param_norm = p.grad.norm(2).item() if self.log_per_layer: norms[f"grad_norm/{name}"] = param_norm total_norm += param_norm ** 2 total_norm = total_norm ** 0.5 norms["grad_norm/total"] = total_norm #? Log using the module's logger #? on_step=True ensures it's logged at the current step pl_module.log_dict(norms, on_step=True, on_epoch=False)