Source code for gunz_ml.analysis.internals

"""
Analysis modules for model internals (complexity, parameters, weights).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import numpy as np
import torch
import torch.nn as nn

# =============================================================================
# FUNCTIONS
# =============================================================================
[docs] def count_parameters( model: nn.Module, ) -> dict[str, int]: """ Counts total, trainable, and non-trainable parameters. Parameters ---------- model : nn.Module The PyTorch model to inspect. Returns ------- dict[str, int] A dictionary containing parameter counts. """ total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return { "total": total, "trainable": trainable, "non_trainable": total - trainable, }
[docs] def layer_stats( model: nn.Module, ) -> dict[str, dict[str, float]]: """ Computes statistics (mean, std, min, max) for weight parameters. Parameters ---------- model : nn.Module The PyTorch model to inspect. Returns ------- dict[str, dict[str, float]] A dictionary mapping parameter names to their statistics. """ stats = {} for name, param in model.named_parameters(): #? Only analyze weights, skip biases for brevity if "weight" in name: data = param.detach().cpu().float().numpy() stats[name] = { "mean": float(np.mean(data)), "std": float(np.std(data)), "min": float(np.min(data)), "max": float(np.max(data)), } return stats
[docs] def estimate_flops( model: nn.Module, input_sample: torch.Tensor, ) -> float: """ Estimates floating point operations (FLOPs) for a single forward pass. This is a simplified estimation supporting Linear and Conv2d layers. Parameters ---------- model : nn.Module The PyTorch model. input_sample : torch.Tensor A sample input tensor (batch size should ideally be 1 for normalized FLOPs, but the function calculates total FLOPs for the given batch). Returns ------- float Estimated number of floating point operations. """ flops = 0.0 def hook( module: nn.Module, input: tuple[torch.Tensor, ...], output: torch.Tensor ): nonlocal flops if isinstance(module, nn.Linear): #? FLOPs = inputs * outputs (ignoring bias addition) #? input[0] shape: (Batch, In_Features) #? module.out_features: Out_Features #? Mul-Adds count as 1 FLOP here for simplicity, or 2? #? Usually 2 * In * Out per instance. But let's stick to MACs (Multiply-Accumulates) #? or just standard "ops". Let's count MACs. batch_size = input[0].shape[0] flops += batch_size * input[0].shape[1] * module.out_features elif isinstance(module, nn.Conv2d): #? FLOPs = Batch * Out_H * Out_W * Out_C * (In_C / Groups) * Kernel_H * Kernel_W batch_size = input[0].shape[0] #? Output shape: (Batch, Out_C, Out_H, Out_W) output_elements = output.numel() kernel_ops = ( module.in_channels // module.groups ) * module.kernel_size[0] * module.kernel_size[1] flops += output_elements * kernel_ops handles = [] #? Register hooks on leaf modules for m in model.modules(): if isinstance(m, (nn.Linear, nn.Conv2d)): handles.append(m.register_forward_hook(hook)) #? Run forward pass training = model.training model.eval() with torch.no_grad(): try: model(input_sample) except Exception: #? Fallback or ignore if forward fails pass #? Remove hooks for h in handles: h.remove() #? Restore mode model.train(training) return flops