Source code for gunz_ml.analysis.runner

"""
Runner for the Deep Learning Analysis mode.
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import os
import json
import typing as t
import logging

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import torch
import numpy as np
import hydra
import matplotlib.pyplot as plt
import pandas as pd
from omegaconf import DictConfig, OmegaConf

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from .internals import count_parameters, layer_stats, estimate_flops
from .errors import compute_confusion_matrix, get_top_k_errors
from .dynamics import compute_gradient_norms

logger = logging.getLogger(__name__)

[docs] class AnalysisRunner: """ Orchestrates the analysis workflow. """ def __init__(self, cfg: DictConfig): self.cfg = cfg #? Analysis config should be part of the main config #? If not, we look for it or use defaults? #? For now, assume cfg.analysis exists as we added it to schema/config if "analysis" not in cfg: raise ValueError("Configuration must contain an 'analysis' section.") self.analysis_cfg = cfg.analysis self.output_dir = self.analysis_cfg.output_dir os.makedirs(self.output_dir, exist_ok=True)
[docs] def run(self): logger.info("Starting Analysis Mode...") #? 1. Instantiate Data & Model #? We try 'model' and 'data' keys which are standard in many Hydra setups if "model" not in self.cfg: raise ValueError("Config must have a 'model' section for instantiation.") if "data" not in self.cfg and "datamodule" not in self.cfg: raise ValueError("Config must have a 'data' or 'datamodule' section.") logger.info("Instantiating Model...") model = hydra.utils.instantiate(self.cfg.model) data_key = "data" if "data" in self.cfg else "datamodule" logger.info("Instantiating DataModule...") datamodule = hydra.utils.instantiate(self.cfg[data_key]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #? 2. Load Checkpoint ckpt_path = self.analysis_cfg.ckpt_path if not ckpt_path or not os.path.exists(ckpt_path): logger.warning(f"Checkpoint path '{ckpt_path}' not found. Using initialized model weights.") else: logger.info(f"Loading checkpoint from {ckpt_path}...") #? Handle Lightning checkpoint loading try: #? Try loading as Lightning Module if applicable #? But 'model' is already instantiated. We usually load state dict. checkpoint = torch.load(ckpt_path, map_location=device) if "state_dict" in checkpoint: model.load_state_dict(checkpoint["state_dict"]) else: model.load_state_dict(checkpoint) except Exception as e: logger.error(f"Failed to load checkpoint: {e}") raise model.eval() #? Move to GPU if available and configured? #? Usually handled by Lightning Trainer, but here we run manually. model.to(device) results = {} #? 3. Internals Analysis if self.analysis_cfg.enable_internals: logger.info("Running Internals Analysis...") results['internals'] = { 'params': count_parameters(model), 'stats': layer_stats(model) } #? Estimate FLOPs (requires input) datamodule.setup(stage="fit") # or validation loader = datamodule.train_dataloader() if loader: batch = next(iter(loader)) inputs = self._extract_inputs(batch, device) if inputs is not None: results['internals']['flops'] = estimate_flops(model, inputs) #? 4. Dynamics Analysis if self.analysis_cfg.enable_dynamics: logger.info("Running Dynamics Analysis...") #? Needs a backward pass datamodule.setup(stage="fit") loader = datamodule.train_dataloader() if loader: batch = next(iter(loader)) #? Need criterion if "loss" in self.cfg: criterion = hydra.utils.instantiate(self.cfg.loss) else: #? Try to find loss in model? criterion = getattr(model, "loss_fn", None) if not criterion: logger.warning("No 'loss' config found and model has no 'loss_fn'. Skipping dynamics.") criterion = None if criterion: batch_on_device = self._move_batch_to_device(batch, device) results['dynamics'] = compute_gradient_norms(model, batch_on_device, criterion, device) #? 5. Error Analysis if self.analysis_cfg.enable_errors: logger.info("Running Error Analysis...") datamodule.setup(stage="validate") loader = datamodule.val_dataloader() #? If no val loader, try test if not loader: datamodule.setup(stage="test") loader = datamodule.test_dataloader() if loader: preds, targets, probs = self._run_validation_loop(model, loader, device) #? Determine num_classes num_classes = getattr(model, "num_classes", None) if num_classes is None: num_classes = getattr(datamodule, "num_classes", None) if num_classes is None: num_classes = int(max(targets.max().item(), preds.max().item())) + 1 cm = compute_confusion_matrix(preds, targets, num_classes) top_k = get_top_k_errors(preds, targets, probs) results['errors'] = { 'confusion_matrix': cm.tolist(), 'top_k_misclassified': top_k } if self.analysis_cfg.save_plots: self._plot_confusion_matrix(cm, num_classes) #? Save Results if self.analysis_cfg.save_json: out_path = os.path.join(self.output_dir, "analysis_results.json") with open(out_path, "w") as f: json.dump(results, f, indent=2) logger.info(f"Results saved to {out_path}")
def _extract_inputs(self, batch, device): if isinstance(batch, (tuple, list)): return batch[0].to(device) elif isinstance(batch, dict): return batch.get("input", batch.get("x")).to(device) return batch.to(device) def _move_batch_to_device(self, batch, device): if isinstance(batch, (tuple, list)): return [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch] elif isinstance(batch, dict): return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} return batch.to(device) def _run_validation_loop(self, model, loader, device): preds_list = [] targets_list = [] probs_list = [] with torch.no_grad(): for batch in loader: inputs = self._extract_inputs(batch, device) #? Assume targets are second element or key 'target'/'y' if isinstance(batch, (tuple, list)) and len(batch) > 1: targets = batch[1] elif isinstance(batch, dict): targets = batch.get("target", batch.get("y")) else: continue targets = targets.to(device) logits = model(inputs) #? Handle if model returns dict if isinstance(logits, dict): logits = logits.get("logits", logits.get("out")) probs = torch.softmax(logits, dim=1) preds = torch.argmax(probs, dim=1) preds_list.append(preds.cpu()) targets_list.append(targets.cpu()) probs_list.append(probs.max(dim=1).values.cpu()) return torch.cat(preds_list), torch.cat(targets_list), torch.cat(probs_list) def _plot_confusion_matrix(self, cm, num_classes): plt.figure(figsize=(10, 8)) plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) plt.title('Confusion Matrix') plt.colorbar() tick_marks = np.arange(num_classes) plt.xticks(tick_marks, tick_marks) plt.yticks(tick_marks, tick_marks) #? Labeling plt.xlabel('Predicted Label') plt.ylabel('True Label') #? Save plt.savefig(os.path.join(self.output_dir, "confusion_matrix.png")) plt.close()