Source code for gunz_ml.callbacks.weight_update_monitor

"""
Callback for monitoring weight update ratios.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch

[docs] class WeightUpdateMonitor(L.Callback): """ Monitors the ratio of parameter updates to parameter magnitude. Ratio = (Learning Rate * Gradient Norm) / Parameter Norm - Extremely small values (< 1e-5): Model convergence or stuck. - Extremely large values (> 1e-2): Instability. Parameters ---------- log_every_n_steps : int Frequency of logging in steps. Defaults to 100. """ def __init__(self, log_every_n_steps: int = 100): super().__init__() self.log_every_n_steps = log_every_n_steps
[docs] def on_before_optimizer_step( self, trainer: L.Trainer, pl_module: L.LightningModule, optimizer ): if (trainer.global_step % self.log_every_n_steps) != 0: return #? Get Learning Rate (using first param group as proxy) if not optimizer.param_groups: return lr = optimizer.param_groups[0].get('lr', 0.0) for name, p in pl_module.named_parameters(): if p.grad is not None: #? Compute Norms param_norm = p.norm(2).item() grad_norm = p.grad.norm(2).item() #? Compute Ratio if param_norm > 1e-9: ratio = (lr * grad_norm) / param_norm #? Clean name clean_name = name.replace(".", "/") pl_module.log(f"update_ratio/{clean_name}", ratio, on_step=True, on_epoch=False)