"""
Analysis modules for model internals (complexity, parameters, weights).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import numpy as np
import torch
import torch.nn as nn
# =============================================================================
# FUNCTIONS
# =============================================================================
[docs]
def count_parameters(
model: nn.Module,
) -> dict[str, int]:
"""
Counts total, trainable, and non-trainable parameters.
Parameters
----------
model : nn.Module
The PyTorch model to inspect.
Returns
-------
dict[str, int]
A dictionary containing parameter counts.
"""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
"total": total,
"trainable": trainable,
"non_trainable": total - trainable,
}
[docs]
def layer_stats(
model: nn.Module,
) -> dict[str, dict[str, float]]:
"""
Computes statistics (mean, std, min, max) for weight parameters.
Parameters
----------
model : nn.Module
The PyTorch model to inspect.
Returns
-------
dict[str, dict[str, float]]
A dictionary mapping parameter names to their statistics.
"""
stats = {}
for name, param in model.named_parameters():
#? Only analyze weights, skip biases for brevity
if "weight" in name:
data = param.detach().cpu().float().numpy()
stats[name] = {
"mean": float(np.mean(data)),
"std": float(np.std(data)),
"min": float(np.min(data)),
"max": float(np.max(data)),
}
return stats
[docs]
def estimate_flops(
model: nn.Module,
input_sample: torch.Tensor,
) -> float:
"""
Estimates floating point operations (FLOPs) for a single forward pass.
This is a simplified estimation supporting Linear and Conv2d layers.
Parameters
----------
model : nn.Module
The PyTorch model.
input_sample : torch.Tensor
A sample input tensor (batch size should ideally be 1 for normalized FLOPs,
but the function calculates total FLOPs for the given batch).
Returns
-------
float
Estimated number of floating point operations.
"""
flops = 0.0
def hook(
module: nn.Module,
input: tuple[torch.Tensor, ...],
output: torch.Tensor
):
nonlocal flops
if isinstance(module, nn.Linear):
#? FLOPs = inputs * outputs (ignoring bias addition)
#? input[0] shape: (Batch, In_Features)
#? module.out_features: Out_Features
#? Mul-Adds count as 1 FLOP here for simplicity, or 2?
#? Usually 2 * In * Out per instance. But let's stick to MACs (Multiply-Accumulates)
#? or just standard "ops". Let's count MACs.
batch_size = input[0].shape[0]
flops += batch_size * input[0].shape[1] * module.out_features
elif isinstance(module, nn.Conv2d):
#? FLOPs = Batch * Out_H * Out_W * Out_C * (In_C / Groups) * Kernel_H * Kernel_W
batch_size = input[0].shape[0]
#? Output shape: (Batch, Out_C, Out_H, Out_W)
output_elements = output.numel()
kernel_ops = (
module.in_channels // module.groups
) * module.kernel_size[0] * module.kernel_size[1]
flops += output_elements * kernel_ops
handles = []
#? Register hooks on leaf modules
for m in model.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
handles.append(m.register_forward_hook(hook))
#? Run forward pass
training = model.training
model.eval()
with torch.no_grad():
try:
model(input_sample)
except Exception:
#? Fallback or ignore if forward fails
pass
#? Remove hooks
for h in handles:
h.remove()
#? Restore mode
model.train(training)
return flops