"""
Plotting function for visualizing distributions of metrics grouped by
categorical hyperparameters.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import logging
import pathlib
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from omegaconf import DictConfig
from plotly.subplots import make_subplots
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from .utils import DEFAULT_PLOT_CONFIG
[docs]
def plot_categorical_distributions(
#? --- Data Inputs ---
metric_df: pd.DataFrame,
param_df: pd.DataFrame,
#? --- Plotting Configuration ---
objective_metrics: t.List[str],
output_path: pathlib.Path,
categorical_params_list: list,
plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
"""
Generates and saves distribution plots for categorical hyperparameters.
"""
logging.info("--- Starting Categorical Distribution Plot Generation ---")
plot_output_path = output_path / "distribution"
plot_output_path.mkdir(parents=True, exist_ok=True)
cfg = {**DEFAULT_PLOT_CONFIG, **plot_cfg}
combined_df = pd.concat([metric_df, param_df], axis=1)
if not categorical_params_list:
logging.warning("No categorical parameters found to plot.")
return
palette_name = cfg.get("palette", "Plotly")
colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid)
for param_name in categorical_params_list:
if param_name not in combined_df.columns:
logging.warning(f"Parameter '{param_name}' not in trial data. Skipping.")
continue
logging.info(f"Generating plot for parameter: '{param_name}'...")
num_metrics = len(objective_metrics)
fig = make_subplots(
rows=num_metrics, cols=1,
subplot_titles=objective_metrics, vertical_spacing=cfg['vertical_spacing']
)
for i, metric_name in enumerate(objective_metrics):
if cfg['color_by_category']:
categories = combined_df[param_name].unique()
for j, category in enumerate(categories):
df_cat = combined_df[combined_df[param_name] == category]
if df_cat.empty:
continue
fig.add_trace(go.Violin(
x=df_cat[param_name], y=df_cat[metric_name], name=str(category),
line_color=colors[j % len(colors)], points=cfg['show_points'],
jitter=cfg['point_jitter'], pointpos=0, box_visible=True,
meanline_visible=True, marker=dict(opacity=cfg['point_opacity'])),
row=i + 1, col=1
)
else:
fig.add_trace(go.Violin(
x=combined_df[param_name], y=combined_df[metric_name], name=metric_name,
points=cfg['show_points'], jitter=cfg['point_jitter'], pointpos=0,
box_visible=True, meanline_visible=True,
marker=dict(opacity=cfg['point_opacity'])),
row=i + 1, col=1
)
fig.update_layout(
title_text=f"Distribution of Objective Metrics by '{param_name}'",
height=cfg['height_per_metric'] * num_metrics,
showlegend=(cfg['color_by_category'] and cfg['show_legend']),
template=cfg['template'],
margin=dict(t=100, l=80)
)
fig.write_html(plot_output_path / f"{param_name}.html")