"""
Runner for the Deep Learning Analysis mode.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import os
import json
import typing as t
import logging
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import torch
import numpy as np
import hydra
import matplotlib.pyplot as plt
import pandas as pd
from omegaconf import DictConfig, OmegaConf
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from .internals import count_parameters, layer_stats, estimate_flops
from .errors import compute_confusion_matrix, get_top_k_errors
from .dynamics import compute_gradient_norms
logger = logging.getLogger(__name__)
[docs]
class AnalysisRunner:
"""
Orchestrates the analysis workflow.
"""
def __init__(self, cfg: DictConfig):
self.cfg = cfg
#? Analysis config should be part of the main config
#? If not, we look for it or use defaults?
#? For now, assume cfg.analysis exists as we added it to schema/config
if "analysis" not in cfg:
raise ValueError("Configuration must contain an 'analysis' section.")
self.analysis_cfg = cfg.analysis
self.output_dir = self.analysis_cfg.output_dir
os.makedirs(self.output_dir, exist_ok=True)
[docs]
def run(self):
logger.info("Starting Analysis Mode...")
#? 1. Instantiate Data & Model
#? We try 'model' and 'data' keys which are standard in many Hydra setups
if "model" not in self.cfg:
raise ValueError("Config must have a 'model' section for instantiation.")
if "data" not in self.cfg and "datamodule" not in self.cfg:
raise ValueError("Config must have a 'data' or 'datamodule' section.")
logger.info("Instantiating Model...")
model = hydra.utils.instantiate(self.cfg.model)
data_key = "data" if "data" in self.cfg else "datamodule"
logger.info("Instantiating DataModule...")
datamodule = hydra.utils.instantiate(self.cfg[data_key])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#? 2. Load Checkpoint
ckpt_path = self.analysis_cfg.ckpt_path
if not ckpt_path or not os.path.exists(ckpt_path):
logger.warning(f"Checkpoint path '{ckpt_path}' not found. Using initialized model weights.")
else:
logger.info(f"Loading checkpoint from {ckpt_path}...")
#? Handle Lightning checkpoint loading
try:
#? Try loading as Lightning Module if applicable
#? But 'model' is already instantiated. We usually load state dict.
checkpoint = torch.load(ckpt_path, map_location=device)
if "state_dict" in checkpoint:
model.load_state_dict(checkpoint["state_dict"])
else:
model.load_state_dict(checkpoint)
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise
model.eval()
#? Move to GPU if available and configured?
#? Usually handled by Lightning Trainer, but here we run manually.
model.to(device)
results = {}
#? 3. Internals Analysis
if self.analysis_cfg.enable_internals:
logger.info("Running Internals Analysis...")
results['internals'] = {
'params': count_parameters(model),
'stats': layer_stats(model)
}
#? Estimate FLOPs (requires input)
datamodule.setup(stage="fit") # or validation
loader = datamodule.train_dataloader()
if loader:
batch = next(iter(loader))
inputs = self._extract_inputs(batch, device)
if inputs is not None:
results['internals']['flops'] = estimate_flops(model, inputs)
#? 4. Dynamics Analysis
if self.analysis_cfg.enable_dynamics:
logger.info("Running Dynamics Analysis...")
#? Needs a backward pass
datamodule.setup(stage="fit")
loader = datamodule.train_dataloader()
if loader:
batch = next(iter(loader))
#? Need criterion
if "loss" in self.cfg:
criterion = hydra.utils.instantiate(self.cfg.loss)
else:
#? Try to find loss in model?
criterion = getattr(model, "loss_fn", None)
if not criterion:
logger.warning("No 'loss' config found and model has no 'loss_fn'. Skipping dynamics.")
criterion = None
if criterion:
batch_on_device = self._move_batch_to_device(batch, device)
results['dynamics'] = compute_gradient_norms(model, batch_on_device, criterion, device)
#? 5. Error Analysis
if self.analysis_cfg.enable_errors:
logger.info("Running Error Analysis...")
datamodule.setup(stage="validate")
loader = datamodule.val_dataloader()
#? If no val loader, try test
if not loader:
datamodule.setup(stage="test")
loader = datamodule.test_dataloader()
if loader:
preds, targets, probs = self._run_validation_loop(model, loader, device)
#? Determine num_classes
num_classes = getattr(model, "num_classes", None)
if num_classes is None:
num_classes = getattr(datamodule, "num_classes", None)
if num_classes is None:
num_classes = int(max(targets.max().item(), preds.max().item())) + 1
cm = compute_confusion_matrix(preds, targets, num_classes)
top_k = get_top_k_errors(preds, targets, probs)
results['errors'] = {
'confusion_matrix': cm.tolist(),
'top_k_misclassified': top_k
}
if self.analysis_cfg.save_plots:
self._plot_confusion_matrix(cm, num_classes)
#? Save Results
if self.analysis_cfg.save_json:
out_path = os.path.join(self.output_dir, "analysis_results.json")
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {out_path}")
def _extract_inputs(self, batch, device):
if isinstance(batch, (tuple, list)):
return batch[0].to(device)
elif isinstance(batch, dict):
return batch.get("input", batch.get("x")).to(device)
return batch.to(device)
def _move_batch_to_device(self, batch, device):
if isinstance(batch, (tuple, list)):
return [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
elif isinstance(batch, dict):
return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
return batch.to(device)
def _run_validation_loop(self, model, loader, device):
preds_list = []
targets_list = []
probs_list = []
with torch.no_grad():
for batch in loader:
inputs = self._extract_inputs(batch, device)
#? Assume targets are second element or key 'target'/'y'
if isinstance(batch, (tuple, list)) and len(batch) > 1:
targets = batch[1]
elif isinstance(batch, dict):
targets = batch.get("target", batch.get("y"))
else:
continue
targets = targets.to(device)
logits = model(inputs)
#? Handle if model returns dict
if isinstance(logits, dict):
logits = logits.get("logits", logits.get("out"))
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
preds_list.append(preds.cpu())
targets_list.append(targets.cpu())
probs_list.append(probs.max(dim=1).values.cpu())
return torch.cat(preds_list), torch.cat(targets_list), torch.cat(probs_list)
def _plot_confusion_matrix(self, cm, num_classes):
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, tick_marks)
plt.yticks(tick_marks, tick_marks)
#? Labeling
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
#? Save
plt.savefig(os.path.join(self.output_dir, "confusion_matrix.png"))
plt.close()