"""
Plotting function for Pareto frontiers grouped by a categorical hyperparameter.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import itertools
import logging
import pathlib
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from omegaconf import DictConfig
from optuna.study import StudyDirection
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ..optimization.pareto import find_pareto_frontier_fast
from .utils import DEFAULT_PLOT_CONFIG
[docs]
def plot_pareto_frontiers_by_category(
#? --- Data Inputs ---
metric_df: pd.DataFrame,
param_df: pd.DataFrame,
#? --- Plotting Configuration ---
objective_metrics: t.List[str],
study_directions: t.List[StudyDirection],
output_path: pathlib.Path,
categorical_params_list : list,
plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
"""
Generates Pareto plots for objectives, grouped by a categorical parameter.
"""
logging.info("--- Starting Categorical Pareto Frontier Plot Generation ---")
plot_output_path = output_path / "pareto_front_categorical"
plot_output_path.mkdir(parents=True, exist_ok=True)
cfg = {**DEFAULT_PLOT_CONFIG, **plot_cfg}
if not categorical_params_list:
logging.warning("No categorical parameters found for Pareto plotting.")
return
combined_df = pd.concat([metric_df, param_df], axis=1)
directions_map = {name: d for name, d in zip(objective_metrics, study_directions)}
palette_name = cfg.get("palette", "Plotly")
colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid)
for metric_x, metric_y in itertools.combinations(objective_metrics, 2):
for cat_param in categorical_params_list:
fig = go.Figure()
categories = combined_df[cat_param].unique()
for i, category in enumerate(categories):
df_cat = combined_df[combined_df[cat_param] == category]
if df_cat.empty:
continue
color = colors[i % len(colors)]
pareto_df = find_pareto_frontier_fast(df_cat, metric_x, metric_y, directions_map[metric_x], directions_map[metric_y])
fig.add_trace(go.Scatter(
x=df_cat[metric_x], y=df_cat[metric_y], mode='markers',
marker=dict(color=color, opacity=0.3, size=5), name=f'{category} (all)',
legendgroup=str(category), showlegend=False, customdata=df_cat.index,
hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>'
))
fig.add_trace(go.Scatter(
x=pareto_df[metric_x], y=pareto_df[metric_y], mode='lines+markers',
line=dict(color=color, width=2), marker=dict(color=color, size=8, symbol='star'),
name=f'{category} (Pareto)', legendgroup=str(category), customdata=pareto_df.index,
hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>'
))
fig.update_layout(
title=f"Pareto: {metric_x} vs {metric_y}<br><sub>Grouped by {cat_param}</sub>",
xaxis_title=f"{metric_x} ({directions_map[metric_x].name})",
yaxis_title=f"{metric_y} ({directions_map[metric_y].name})",
template=cfg['template'], legend_title=cat_param, hovermode='closest'
)
fname = f"pareto_{metric_x}_vs_{metric_y}_by_{cat_param}.html".replace('/', '_')
fig.write_html(plot_output_path / fname)