"""
Analysis modules for training dynamics (gradients, logs).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import torch
import torch.nn as nn
# =============================================================================
# FUNCTIONS
# =============================================================================
[docs]
def compute_gradient_norms(
model: nn.Module,
batch: t.Any,
criterion: nn.Module,
device: torch.device,
) -> dict[str, float]:
"""
Computes the L2 norm of gradients for a single batch.
Parameters
----------
model : nn.Module
The PyTorch model.
batch : Any
A batch of data (inputs, targets) compatible with the model.
criterion : nn.Module
Loss function.
device : torch.device
Device to run the computation on.
Returns
-------
dict[str, float]
A dictionary mapping parameter names to their gradient L2 norms.
"""
#? Ensure model is in training mode to enable gradients
training = model.training
model.train()
#? Unpack batch (assuming standard tuple/list format)
#? This might need adaptation depending on DataModule specifics
if isinstance(batch, (tuple, list)):
inputs, targets = batch
elif isinstance(batch, dict):
inputs = batch["input"] # specific to some implementations
targets = batch["target"]
else:
# Fallback
inputs, targets = batch, None
if isinstance(inputs, torch.Tensor):
inputs = inputs.to(device)
if isinstance(targets, torch.Tensor):
targets = targets.to(device)
#? Forward Pass
try:
outputs = model(inputs)
if targets is not None:
loss = criterion(outputs, targets)
else:
#? If no targets, maybe use a dummy loss (e.g. sum) just to get gradients?
#? But that's meaningless for "Dynamics".
#? Assume targets exist for supervised learning.
loss = outputs.sum()
#? Backward Pass
model.zero_grad()
loss.backward()
#? Collect Norms
norms = {}
for name, p in model.named_parameters():
if p.grad is not None:
norms[name] = float(p.grad.norm(2).item())
finally:
model.zero_grad()
model.train(training)
return norms