"""
Callback for monitoring gradient norms during training.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch
[docs]
class GradientMonitor(L.Callback):
"""
Logs the L2 norm of gradients for model parameters.
This callback inspects the gradients after the backward pass and logs
their norms to the configured logger(s). This is useful for detecting
vanishing or exploding gradients.
Parameters
----------
log_every_n_steps : int
Frequency of logging in steps. Defaults to 50.
log_per_layer : bool
If True, logs the norm for every parameter. If False, logs only the
total norm of the entire model. Defaults to True.
"""
def __init__(
self,
log_every_n_steps: int = 50,
log_per_layer: bool = True
):
super().__init__()
self.log_every_n_steps = log_every_n_steps
self.log_per_layer = log_per_layer
[docs]
def on_after_backward(
self,
trainer: L.Trainer,
pl_module: L.LightningModule
):
"""
Called after loss.backward() and before optimizers take a step.
"""
if (trainer.global_step + 1) % self.log_every_n_steps != 0:
return
total_norm = 0.0
norms = {}
for name, p in pl_module.named_parameters():
if p.grad is not None:
#? Compute L2 norm
param_norm = p.grad.norm(2).item()
if self.log_per_layer:
norms[f"grad_norm/{name}"] = param_norm
total_norm += param_norm ** 2
total_norm = total_norm ** 0.5
norms["grad_norm/total"] = total_norm
#? Log using the module's logger
#? on_step=True ensures it's logged at the current step
pl_module.log_dict(norms, on_step=True, on_epoch=False)