Source code for gunz_ml.plots.pareto_by_category

"""
Plotting function for Pareto frontiers grouped by a categorical hyperparameter.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import itertools
import logging
import pathlib
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from omegaconf import DictConfig
from optuna.study import StudyDirection

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ..optimization.pareto import find_pareto_frontier_fast
from .utils import DEFAULT_PLOT_CONFIG

[docs] def plot_pareto_frontiers_by_category( #? --- Data Inputs --- metric_df: pd.DataFrame, param_df: pd.DataFrame, #? --- Plotting Configuration --- objective_metrics: t.List[str], study_directions: t.List[StudyDirection], output_path: pathlib.Path, categorical_params_list : list, plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """ Generates Pareto plots for objectives, grouped by a categorical parameter. """ logging.info("--- Starting Categorical Pareto Frontier Plot Generation ---") plot_output_path = output_path / "pareto_front_categorical" plot_output_path.mkdir(parents=True, exist_ok=True) cfg = {**DEFAULT_PLOT_CONFIG, **plot_cfg} if not categorical_params_list: logging.warning("No categorical parameters found for Pareto plotting.") return combined_df = pd.concat([metric_df, param_df], axis=1) directions_map = {name: d for name, d in zip(objective_metrics, study_directions)} palette_name = cfg.get("palette", "Plotly") colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid) for metric_x, metric_y in itertools.combinations(objective_metrics, 2): for cat_param in categorical_params_list: fig = go.Figure() categories = combined_df[cat_param].unique() for i, category in enumerate(categories): df_cat = combined_df[combined_df[cat_param] == category] if df_cat.empty: continue color = colors[i % len(colors)] pareto_df = find_pareto_frontier_fast(df_cat, metric_x, metric_y, directions_map[metric_x], directions_map[metric_y]) fig.add_trace(go.Scatter( x=df_cat[metric_x], y=df_cat[metric_y], mode='markers', marker=dict(color=color, opacity=0.3, size=5), name=f'{category} (all)', legendgroup=str(category), showlegend=False, customdata=df_cat.index, hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>' )) fig.add_trace(go.Scatter( x=pareto_df[metric_x], y=pareto_df[metric_y], mode='lines+markers', line=dict(color=color, width=2), marker=dict(color=color, size=8, symbol='star'), name=f'{category} (Pareto)', legendgroup=str(category), customdata=pareto_df.index, hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>' )) fig.update_layout( title=f"Pareto: {metric_x} vs {metric_y}<br><sub>Grouped by {cat_param}</sub>", xaxis_title=f"{metric_x} ({directions_map[metric_x].name})", yaxis_title=f"{metric_y} ({directions_map[metric_y].name})", template=cfg['template'], legend_title=cat_param, hovermode='closest' ) fname = f"pareto_{metric_x}_vs_{metric_y}_by_{cat_param}.html".replace('/', '_') fig.write_html(plot_output_path / fname)