"""
Analysis modules for error analysis (confusion matrix, top-k errors).
"""
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import numpy as np
import torch
# =============================================================================
# FUNCTIONS
# =============================================================================
[docs]
def compute_confusion_matrix(
preds: t.Union[torch.Tensor, np.ndarray],
targets: t.Union[torch.Tensor, np.ndarray],
num_classes: int,
) -> np.ndarray:
"""
Computes the confusion matrix.
Parameters
----------
preds : torch.Tensor or numpy.ndarray
Predicted class indices (1D).
targets : torch.Tensor or numpy.ndarray
True class indices (1D).
num_classes : int
Total number of classes.
Returns
-------
numpy.ndarray
Confusion matrix of shape (num_classes, num_classes).
Rows represent true classes, columns represent predicted classes.
"""
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
#? Ensure flat arrays
preds = preds.flatten()
targets = targets.flatten()
#? Compute confusion matrix using bincount for speed
#? Only consider valid targets
mask = (targets >= 0) & (targets < num_classes)
cm = np.bincount(
num_classes * targets[mask] + preds[mask],
minlength=num_classes**2
).reshape(num_classes, num_classes)
return cm
[docs]
def get_top_k_errors(
preds: torch.Tensor,
targets: torch.Tensor,
probs: torch.Tensor,
k: int = 5,
) -> t.List[t.Dict[str, t.Any]]:
"""
Identifies the top-k misclassified examples based on confidence.
Parameters
----------
preds : torch.Tensor
Predicted class indices.
targets : torch.Tensor
True class indices.
probs : torch.Tensor
Probabilities of the predicted classes (confidence).
k : int, optional
Number of examples to retrieve. Defaults to 5.
Returns
-------
list of dict
List of dictionaries containing index, target, prediction, and confidence.
"""
#? Identify misclassifications
incorrect_mask = preds != targets
incorrect_indices = torch.nonzero(incorrect_mask, as_tuple=True)[0]
if len(incorrect_indices) == 0:
return []
#? Get confidence of the incorrect predictions
incorrect_probs = probs[incorrect_indices]
#? Sort by confidence (descending) - high confidence errors are "worst"
sorted_idx = torch.argsort(incorrect_probs, descending=True)
top_k_idx = incorrect_indices[sorted_idx[:k]]
results = []
for idx in top_k_idx:
results.append({
"index": int(idx.item()),
"true_class": int(targets[idx].item()),
"pred_class": int(preds[idx].item()),
"confidence": float(probs[idx].item()),
})
return results