Source code for gunz_ml.callbacks.prediction_dynamics

"""
Callback for monitoring prediction dynamics (entropy, confidence).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch

[docs] class PredictionDynamicsMonitor(L.Callback): """ Monitors the distribution of model predictions to detect mode collapse or random guessing. Notes ----- This callback logs the following metrics: - `dynamics/pred_entropy`: Average entropy of the prediction distribution. Low entropy (~0) indicates mode collapse (always predicts same class), while high entropy suggests random guessing. - `dynamics/pred_confidence`: Average max probability. Parameters ---------- log_every_n_steps : int Frequency of logging in steps. Defaults to 50. """ def __init__(self, log_every_n_steps: int = 50): super().__init__() self.log_every_n_steps = log_every_n_steps self.hook = None self._pl_module_ref = None
[docs] def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule): self._pl_module_ref = pl_module #? Hook the module to capture outputs self.hook = pl_module.register_forward_hook(self.forward_hook)
[docs] def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule): if self.hook: self.hook.remove()
[docs] def forward_hook(self, module, input, output): #? Check logging frequency if not self._pl_module_ref.trainer: return if (self._pl_module_ref.global_step % self.log_every_n_steps) != 0: return logits = output #? Attempt to unwrap standard output formats if isinstance(output, dict): #? Try common keys used in HF or custom models logits = output.get("logits", output.get("out", output.get("preds"))) elif isinstance(output, (list, tuple)): logits = output[0] if not isinstance(logits, torch.Tensor): return #? Sanity check: Needs to be a batch of vectors if logits.ndim < 2 or not logits.is_floating_point(): return #? Compute Softmax & Entropy #? Detach to avoid affecting gradients with torch.no_grad(): probs = torch.softmax(logits, dim=1) #? Entropy = - sum(p * log(p)) entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1).mean() #? Confidence = max(p) confidence = probs.max(dim=1).values.mean() self._pl_module_ref.log("dynamics/pred_entropy", entropy, on_step=True, on_epoch=False) self._pl_module_ref.log("dynamics/pred_confidence", confidence, on_step=True, on_epoch=False)