"""
Callback for monitoring prediction dynamics (entropy, confidence).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch
[docs]
class PredictionDynamicsMonitor(L.Callback):
"""
Monitors the distribution of model predictions to detect mode collapse
or random guessing.
Notes
-----
This callback logs the following metrics:
- `dynamics/pred_entropy`: Average entropy of the prediction distribution.
Low entropy (~0) indicates mode collapse (always predicts same class),
while high entropy suggests random guessing.
- `dynamics/pred_confidence`: Average max probability.
Parameters
----------
log_every_n_steps : int
Frequency of logging in steps. Defaults to 50.
"""
def __init__(self, log_every_n_steps: int = 50):
super().__init__()
self.log_every_n_steps = log_every_n_steps
self.hook = None
self._pl_module_ref = None
[docs]
def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule):
self._pl_module_ref = pl_module
#? Hook the module to capture outputs
self.hook = pl_module.register_forward_hook(self.forward_hook)
[docs]
def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule):
if self.hook:
self.hook.remove()
[docs]
def forward_hook(self, module, input, output):
#? Check logging frequency
if not self._pl_module_ref.trainer:
return
if (self._pl_module_ref.global_step % self.log_every_n_steps) != 0:
return
logits = output
#? Attempt to unwrap standard output formats
if isinstance(output, dict):
#? Try common keys used in HF or custom models
logits = output.get("logits", output.get("out", output.get("preds")))
elif isinstance(output, (list, tuple)):
logits = output[0]
if not isinstance(logits, torch.Tensor):
return
#? Sanity check: Needs to be a batch of vectors
if logits.ndim < 2 or not logits.is_floating_point():
return
#? Compute Softmax & Entropy
#? Detach to avoid affecting gradients
with torch.no_grad():
probs = torch.softmax(logits, dim=1)
#? Entropy = - sum(p * log(p))
entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1).mean()
#? Confidence = max(p)
confidence = probs.max(dim=1).values.mean()
self._pl_module_ref.log("dynamics/pred_entropy", entropy, on_step=True, on_epoch=False)
self._pl_module_ref.log("dynamics/pred_confidence", confidence, on_step=True, on_epoch=False)