Source code for gunz_ml.plots.categorical_distributions

"""
Plotting function for visualizing distributions of metrics grouped by
categorical hyperparameters.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import logging
import pathlib
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from omegaconf import DictConfig
from plotly.subplots import make_subplots

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from .utils import DEFAULT_PLOT_CONFIG

[docs] def plot_categorical_distributions( #? --- Data Inputs --- metric_df: pd.DataFrame, param_df: pd.DataFrame, #? --- Plotting Configuration --- objective_metrics: t.List[str], output_path: pathlib.Path, categorical_params_list: list, plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """ Generates and saves distribution plots for categorical hyperparameters. """ logging.info("--- Starting Categorical Distribution Plot Generation ---") plot_output_path = output_path / "distribution" plot_output_path.mkdir(parents=True, exist_ok=True) cfg = {**DEFAULT_PLOT_CONFIG, **plot_cfg} combined_df = pd.concat([metric_df, param_df], axis=1) if not categorical_params_list: logging.warning("No categorical parameters found to plot.") return palette_name = cfg.get("palette", "Plotly") colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid) for param_name in categorical_params_list: if param_name not in combined_df.columns: logging.warning(f"Parameter '{param_name}' not in trial data. Skipping.") continue logging.info(f"Generating plot for parameter: '{param_name}'...") num_metrics = len(objective_metrics) fig = make_subplots( rows=num_metrics, cols=1, subplot_titles=objective_metrics, vertical_spacing=cfg['vertical_spacing'] ) for i, metric_name in enumerate(objective_metrics): if cfg['color_by_category']: categories = combined_df[param_name].unique() for j, category in enumerate(categories): df_cat = combined_df[combined_df[param_name] == category] if df_cat.empty: continue fig.add_trace(go.Violin( x=df_cat[param_name], y=df_cat[metric_name], name=str(category), line_color=colors[j % len(colors)], points=cfg['show_points'], jitter=cfg['point_jitter'], pointpos=0, box_visible=True, meanline_visible=True, marker=dict(opacity=cfg['point_opacity'])), row=i + 1, col=1 ) else: fig.add_trace(go.Violin( x=combined_df[param_name], y=combined_df[metric_name], name=metric_name, points=cfg['show_points'], jitter=cfg['point_jitter'], pointpos=0, box_visible=True, meanline_visible=True, marker=dict(opacity=cfg['point_opacity'])), row=i + 1, col=1 ) fig.update_layout( title_text=f"Distribution of Objective Metrics by '{param_name}'", height=cfg['height_per_metric'] * num_metrics, showlegend=(cfg['color_by_category'] and cfg['show_legend']), template=cfg['template'], margin=dict(t=100, l=80) ) fig.write_html(plot_output_path / f"{param_name}.html")