Source code for gunz_ml.analysis.dynamics

"""
Analysis modules for training dynamics (gradients, logs).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import torch
import torch.nn as nn

# =============================================================================
# FUNCTIONS
# =============================================================================
[docs] def compute_gradient_norms( model: nn.Module, batch: t.Any, criterion: nn.Module, device: torch.device, ) -> dict[str, float]: """ Computes the L2 norm of gradients for a single batch. Parameters ---------- model : nn.Module The PyTorch model. batch : Any A batch of data (inputs, targets) compatible with the model. criterion : nn.Module Loss function. device : torch.device Device to run the computation on. Returns ------- dict[str, float] A dictionary mapping parameter names to their gradient L2 norms. """ #? Ensure model is in training mode to enable gradients training = model.training model.train() #? Unpack batch (assuming standard tuple/list format) #? This might need adaptation depending on DataModule specifics if isinstance(batch, (tuple, list)): inputs, targets = batch elif isinstance(batch, dict): inputs = batch["input"] # specific to some implementations targets = batch["target"] else: # Fallback inputs, targets = batch, None if isinstance(inputs, torch.Tensor): inputs = inputs.to(device) if isinstance(targets, torch.Tensor): targets = targets.to(device) #? Forward Pass try: outputs = model(inputs) if targets is not None: loss = criterion(outputs, targets) else: #? If no targets, maybe use a dummy loss (e.g. sum) just to get gradients? #? But that's meaningless for "Dynamics". #? Assume targets exist for supervised learning. loss = outputs.sum() #? Backward Pass model.zero_grad() loss.backward() #? Collect Norms norms = {} for name, p in model.named_parameters(): if p.grad is not None: norms[name] = float(p.grad.norm(2).item()) finally: model.zero_grad() model.train(training) return norms