Source code for gunz_ml.analysis.errors

"""
Analysis modules for error analysis (confusion matrix, top-k errors).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import numpy as np
import torch

# =============================================================================
# FUNCTIONS
# =============================================================================
[docs] def compute_confusion_matrix( preds: t.Union[torch.Tensor, np.ndarray], targets: t.Union[torch.Tensor, np.ndarray], num_classes: int, ) -> np.ndarray: """ Computes the confusion matrix. Parameters ---------- preds : torch.Tensor or numpy.ndarray Predicted class indices (1D). targets : torch.Tensor or numpy.ndarray True class indices (1D). num_classes : int Total number of classes. Returns ------- numpy.ndarray Confusion matrix of shape (num_classes, num_classes). Rows represent true classes, columns represent predicted classes. """ if isinstance(preds, torch.Tensor): preds = preds.cpu().numpy() if isinstance(targets, torch.Tensor): targets = targets.cpu().numpy() #? Ensure flat arrays preds = preds.flatten() targets = targets.flatten() #? Compute confusion matrix using bincount for speed #? Only consider valid targets mask = (targets >= 0) & (targets < num_classes) cm = np.bincount( num_classes * targets[mask] + preds[mask], minlength=num_classes**2 ).reshape(num_classes, num_classes) return cm
[docs] def get_top_k_errors( preds: torch.Tensor, targets: torch.Tensor, probs: torch.Tensor, k: int = 5, ) -> t.List[t.Dict[str, t.Any]]: """ Identifies the top-k misclassified examples based on confidence. Parameters ---------- preds : torch.Tensor Predicted class indices. targets : torch.Tensor True class indices. probs : torch.Tensor Probabilities of the predicted classes (confidence). k : int, optional Number of examples to retrieve. Defaults to 5. Returns ------- list of dict List of dictionaries containing index, target, prediction, and confidence. """ #? Identify misclassifications incorrect_mask = preds != targets incorrect_indices = torch.nonzero(incorrect_mask, as_tuple=True)[0] if len(incorrect_indices) == 0: return [] #? Get confidence of the incorrect predictions incorrect_probs = probs[incorrect_indices] #? Sort by confidence (descending) - high confidence errors are "worst" sorted_idx = torch.argsort(incorrect_probs, descending=True) top_k_idx = incorrect_indices[sorted_idx[:k]] results = [] for idx in top_k_idx: results.append({ "index": int(idx.item()), "true_class": int(targets[idx].item()), "pred_class": int(preds[idx].item()), "confidence": float(probs[idx].item()), }) return results