"""
Callback for monitoring weight update ratios.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch
[docs]
class WeightUpdateMonitor(L.Callback):
"""
Monitors the ratio of parameter updates to parameter magnitude.
Ratio = (Learning Rate * Gradient Norm) / Parameter Norm
- Extremely small values (< 1e-5): Model convergence or stuck.
- Extremely large values (> 1e-2): Instability.
Parameters
----------
log_every_n_steps : int
Frequency of logging in steps. Defaults to 100.
"""
def __init__(self, log_every_n_steps: int = 100):
super().__init__()
self.log_every_n_steps = log_every_n_steps
[docs]
def on_before_optimizer_step(
self,
trainer: L.Trainer,
pl_module: L.LightningModule,
optimizer
):
if (trainer.global_step % self.log_every_n_steps) != 0:
return
#? Get Learning Rate (using first param group as proxy)
if not optimizer.param_groups:
return
lr = optimizer.param_groups[0].get('lr', 0.0)
for name, p in pl_module.named_parameters():
if p.grad is not None:
#? Compute Norms
param_norm = p.norm(2).item()
grad_norm = p.grad.norm(2).item()
#? Compute Ratio
if param_norm > 1e-9:
ratio = (lr * grad_norm) / param_norm
#? Clean name
clean_name = name.replace(".", "/")
pl_module.log(f"update_ratio/{clean_name}", ratio, on_step=True, on_epoch=False)