Source code for gunz_ml.callbacks.dead_neuron_monitor

"""
Callback for monitoring dead neurons in activation layers.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch
import torch.nn as nn

[docs] class DeadNeuronMonitor(L.Callback): """ Monitors the fraction of 'dead' neurons in activation layers. A neuron is considered dead if it outputs zero (or below a threshold) for all examples in a batch. High rates of dead neurons may indicate issues with initialization, learning rate, or optimizers. Parameters ---------- log_every_n_steps : int Frequency of logging in steps. Defaults to 100. threshold : float Activation threshold below which a neuron is considered inactive. Defaults to 0.0 (strict zero for ReLU). """ def __init__( self, log_every_n_steps: int = 100, threshold: float = 0.0 ): super().__init__() self.log_every_n_steps = log_every_n_steps self.threshold = threshold self.hooks = [] self._pl_module_ref = None
[docs] def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule): """ Attaches hooks to all ReLU/LeakyReLU/GELU layers. """ self._pl_module_ref = pl_module def hook_factory(name: str): def forward_hook(module, input, output): #? Access global step via module.trainer if available #? If trainer is not attached yet (rare in train_start), skip if not pl_module.trainer: return if (pl_module.global_step % self.log_every_n_steps) != 0: return #? Flatten batch dimension: (Batch, ...) -> (Batch, Features) #? We want to check if a feature is dead ACROSS the batch. if output.ndim < 2: return # Scalar output? flattened = output.flatten(1) #? Max activation across batch for each neuron max_activations = flattened.max(dim=0).values #? Dead if max activation <= threshold is_dead = (max_activations <= self.threshold).float() #? Percentage of dead neurons in this layer dead_pct = is_dead.mean().item() * 100.0 pl_module.log(f"dead_neurons/{name}", dead_pct, on_step=True, on_epoch=False) return forward_hook for name, module in pl_module.named_modules(): #? Target common activation functions if isinstance(module, (nn.ReLU, nn.LeakyReLU, nn.GELU, nn.Sigmoid, nn.Tanh)): #? Use a simplified name for logging clean_name = name.replace(".", "/") self.hooks.append(module.register_forward_hook(hook_factory(clean_name)))
[docs] def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule): """ Clean up hooks. """ for h in self.hooks: h.remove() self.hooks = []