"""
Callback for monitoring dead neurons in activation layers.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import lightning.pytorch as L
import torch
import torch.nn as nn
[docs]
class DeadNeuronMonitor(L.Callback):
"""
Monitors the fraction of 'dead' neurons in activation layers.
A neuron is considered dead if it outputs zero (or below a threshold)
for all examples in a batch. High rates of dead neurons may indicate
issues with initialization, learning rate, or optimizers.
Parameters
----------
log_every_n_steps : int
Frequency of logging in steps. Defaults to 100.
threshold : float
Activation threshold below which a neuron is considered inactive.
Defaults to 0.0 (strict zero for ReLU).
"""
def __init__(
self,
log_every_n_steps: int = 100,
threshold: float = 0.0
):
super().__init__()
self.log_every_n_steps = log_every_n_steps
self.threshold = threshold
self.hooks = []
self._pl_module_ref = None
[docs]
def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule):
"""
Attaches hooks to all ReLU/LeakyReLU/GELU layers.
"""
self._pl_module_ref = pl_module
def hook_factory(name: str):
def forward_hook(module, input, output):
#? Access global step via module.trainer if available
#? If trainer is not attached yet (rare in train_start), skip
if not pl_module.trainer:
return
if (pl_module.global_step % self.log_every_n_steps) != 0:
return
#? Flatten batch dimension: (Batch, ...) -> (Batch, Features)
#? We want to check if a feature is dead ACROSS the batch.
if output.ndim < 2:
return # Scalar output?
flattened = output.flatten(1)
#? Max activation across batch for each neuron
max_activations = flattened.max(dim=0).values
#? Dead if max activation <= threshold
is_dead = (max_activations <= self.threshold).float()
#? Percentage of dead neurons in this layer
dead_pct = is_dead.mean().item() * 100.0
pl_module.log(f"dead_neurons/{name}", dead_pct, on_step=True, on_epoch=False)
return forward_hook
for name, module in pl_module.named_modules():
#? Target common activation functions
if isinstance(module, (nn.ReLU, nn.LeakyReLU, nn.GELU, nn.Sigmoid, nn.Tanh)):
#? Use a simplified name for logging
clean_name = name.replace(".", "/")
self.hooks.append(module.register_forward_hook(hook_factory(clean_name)))
[docs]
def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule):
"""
Clean up hooks.
"""
for h in self.hooks:
h.remove()
self.hooks = []