import typing as t
import itertools
import logging
import os
import pathlib
import numpy as np
import pandas as pd
from scipy.interpolate import griddata
import optuna
from optuna.study import StudyDirection
from omegaconf import DictConfig
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from ..integrations import optuna as ml_helpers
[docs]
def find_pareto_frontier_fast(
df: pd.DataFrame,
x_col: str,
y_col: str,
x_dir: str,
y_dir: str,
) -> pd.DataFrame:
"""Find Pareto frontier points efficiently (O(n log n)) using sorting."""
if df.empty:
return pd.DataFrame(columns=df.columns)
x_ascending = x_dir == StudyDirection.MINIMIZE
y_ascending = y_dir == StudyDirection.MINIMIZE
sorted_df = df.sort_values(
by=[x_col, y_col],
ascending=[x_ascending, y_ascending]
)
pareto_indices: list[t.Union[int, str]] = []
last_pareto_point = sorted_df.iloc[0]
pareto_indices.append(last_pareto_point.name)
for index, row in sorted_df.iloc[1:].iterrows():
if y_dir == StudyDirection.MINIMIZE:
if row[y_col] < last_pareto_point[y_col]:
last_pareto_point = row
pareto_indices.append(index)
else: # StudyDirection.MAXIMIZE
if row[y_col] > last_pareto_point[y_col]:
last_pareto_point = row
pareto_indices.append(index)
return df.loc[pareto_indices]
[docs]
def apply_plot_filters(
metric_df: pd.DataFrame,
param_df: pd.DataFrame,
filter_cfg: t.List[t.Dict[str, t.Any]],
) -> None:
"""Apply a list of metric-based filters to metric_df and param_df in-place."""
if not filter_cfg:
return
missing = [
f["metric"]
for f in filter_cfg
if ml_helpers.sanitize_name(f["metric"]) not in metric_df.columns
]
if missing:
raise ValueError(f"Metrics specified in `plots.filter` not present: {missing}")
mask = pd.Series(True, index=metric_df.index)
for f in filter_cfg:
metric = ml_helpers.sanitize_name(f["metric"])
direction = f["direction"].lower()
threshold = f["threshold"]
if direction not in {"maximize", "minimize"}:
raise ValueError(f"Invalid direction `{direction}` for metric `{metric}`.")
if direction == "maximize":
mask &= metric_df[metric] >= threshold
else:
mask &= metric_df[metric] <= threshold
logging.info(f"Filter applied: metric=`{metric}`, direction=`{direction}`, threshold={threshold}. Remaining trials: {mask.sum()}")
metric_df.drop(index=metric_df.index[~mask], inplace=True)
param_df.drop(index=param_df.index[~mask], inplace=True)
if metric_df.empty:
raise RuntimeError("All trials were filtered out by `plots.filter`.")
[docs]
def plot_enhanced_scatter(
metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str],
output_path: t.Union[str, pathlib.Path], continuous_params_map: t.Dict[str, str], plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
"""Generates scatter plots of 1 continuous parameter vs. 2 objectives."""
logging.info("--- Starting Enhanced Scatter Plot Generation ---")
plot_output_path = os.path.join(output_path, "continuous_scatter")
template = plot_cfg.get("template", "plotly_dark")
for param_name, param_type in continuous_params_map.items():
for metric_y, metric_color in itertools.combinations(objective_metrics, 2):
fig = go.Figure(data=go.Scatter(
x=param_df[param_name], y=metric_df[metric_y], mode='markers',
marker=dict(color=metric_df[metric_color], colorscale='Viridis', showscale=True,
colorbar_title=metric_color.replace('/', '_')),
customdata=param_df.index,
hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{param_name}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>'
))
fig.update_layout(title=f'{param_name} vs. {metric_y}<br><sub>Colored by {metric_color}</sub>',
xaxis_title=f"{param_name} ({'log scale' if param_type == 'log' else 'linear scale'})",
yaxis_title=metric_y, xaxis_type='log' if param_type == 'log' else 'linear', template=template)
fname = f"{metric_y}.html"
param_output_path = os.path.join(plot_output_path, param_name)
os.makedirs(param_output_path, exist_ok=True)
fig.write_html(os.path.join(param_output_path, fname))
[docs]
def plot_pareto_continuous_color(
metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str],
study_directions: t.List[StudyDirection], output_path: t.Union[str, pathlib.Path],
continuous_params_map: t.Dict[str, str], plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
"""Generates Pareto fronts of 2 objectives, colored by a continuous parameter."""
if not continuous_params_map:
return
logging.info("--- Starting Continuous Pareto Plot Generation ---")
plot_output_path = os.path.join(output_path, "continuous_pareto")
template = plot_cfg.get("template", "plotly_dark")
combined_df = pd.concat([metric_df, param_df], axis=1)
directions_map = {name: d for name, d in zip(objective_metrics, study_directions)}
for metric_x, metric_y in itertools.combinations(objective_metrics, 2):
for param_name, param_type in continuous_params_map.items():
pareto_df = find_pareto_frontier_fast(combined_df, metric_x, metric_y, directions_map[metric_x], directions_map[metric_y])
fig = go.Figure()
fig.add_trace(go.Scatter(x=combined_df[metric_x], y=combined_df[metric_y], mode='markers',
marker=dict(color='grey', opacity=0.2, size=4), name='All Trials',
customdata=combined_df.index,
hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>'))
fig.add_trace(go.Scatter(x=pareto_df[metric_x], y=pareto_df[metric_y], mode='lines+markers',
marker=dict(color=pareto_df[param_name], colorscale='Cividis', showscale=True,
colorbar_title=param_name.replace('/', '_'), size=10),
line=dict(width=3), name='Pareto Frontier', customdata=pareto_df.index,
hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{param_name}</b>: %{{marker.color}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>'))
fig.update_layout(title=f'Pareto: {metric_x} vs. {metric_y}<br><sub>Colored by {param_name}</sub>',
xaxis_title=f"{metric_x} ({directions_map[metric_x].name})",
yaxis_title=f"{metric_y} ({directions_map[metric_y].name})", template=template)
fname = f"{metric_x}_vs_{metric_y}.html".replace('/', '_')
param_output_path = os.path.join(plot_output_path, param_name)
os.makedirs(param_output_path, exist_ok=True)
fig.write_html(os.path.join(param_output_path, fname))
[docs]
def plot_contour(
metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str],
output_path: t.Union[str, pathlib.Path], continuous_params_map: t.Dict[str, str], plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
"""Generates contour plots of 2 continuous parameters vs. 1 objective."""
logging.info("--- Starting Contour Plot Generation ---")
if len(metric_df) < 15:
logging.warning("Skipping contour plots: not enough trials (< 15) for reliable interpolation.")
return
if len(continuous_params_map) < 2:
logging.warning("Skipping contour plots: need at least 2 continuous parameters.")
return
plot_output_path = os.path.join(output_path, "continuous_contour")
template = plot_cfg.get("template", "plotly_dark")
for param_x, param_y in itertools.combinations(continuous_params_map.keys(), 2):
for metric_z in objective_metrics:
try:
points = param_df[[param_x, param_y]].values
values = metric_df[metric_z].values
is_x_log, is_y_log = (continuous_params_map[param_x] == 'log'), (continuous_params_map[param_y] == 'log')
points_for_grid = np.copy(points).astype(float)
if is_x_log: points_for_grid[:, 0] = np.log10(points_for_grid[:, 0])
if is_y_log: points_for_grid[:, 1] = np.log10(points_for_grid[:, 1])
grid_x, grid_y = np.mgrid[points_for_grid[:,0].min():points_for_grid[:,0].max():100j,
points_for_grid[:,1].min():points_for_grid[:,1].max():100j]
grid_z = griddata(points_for_grid, values, (grid_x, grid_y), method='cubic')
x_coords = 10**grid_x[:,0] if is_x_log else grid_x[:,0]
y_coords = 10**grid_y[0,:] if is_y_log else grid_y[0,:]
fig = go.Figure(data=[go.Contour(x=x_coords, y=y_coords, z=grid_z.T, colorscale='Turbo',
contours=dict(coloring='heatmap', showlabels=True, labelfont=dict(size=10, color='white')))])
fig.add_trace(go.Scatter(x=param_df[param_x], y=param_df[param_y], mode='markers',
marker=dict(color='rgba(255,255,255,0.5)', size=4), showlegend=False,
hovertemplate=f'<b>{param_x}</b>: %{{x}}<br><b>{param_y}</b>: %{{y}}<extra></extra>'))
fig.update_layout(title=f'Contour: {param_x} vs. {param_y}<br><sub>on {metric_z}</sub>',
xaxis_title=f"{param_x} ({'log scale' if is_x_log else 'linear'})",
yaxis_title=f"{param_y} ({'log scale' if is_y_log else 'linear'})",
xaxis_type='log' if is_x_log else 'linear', yaxis_type='log' if is_y_log else 'linear', template=template)
param_name = f"{param_x}_vs_{param_y}"
fname = f"{metric_z}.html".replace('/', '_')
param_output_path = os.path.join(plot_output_path, param_name)
os.makedirs(param_output_path, exist_ok=True)
fig.write_html(os.path.join(param_output_path, fname))
except Exception as e:
logging.warning(f"Could not generate contour for ({param_x}, {param_y}) vs {metric_z}: {e}")
[docs]
def plot_pareto_frontiers_by_category(
metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str],
study_directions: t.List[StudyDirection], output_path: t.Union[str, pathlib.Path],
categorical_params_list: list, plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
"""Generates Pareto frontier plots for each pair of objectives, grouped by each categorical hyperparameter."""
logging.info("--- Starting Pareto Frontier Plot Generation ---")
output_path = os.path.join(output_path, "pareto_front")
os.makedirs(output_path, exist_ok=True)
combined_df = pd.concat([metric_df, param_df], axis=1)
if not categorical_params_list:
logging.warning("No categorical parameters found to group Pareto frontiers by.")
return
directions_map = {name: direction for name, direction in zip(objective_metrics, study_directions)}
template = plot_cfg.get("template", "plotly_dark")
palette_name = plot_cfg.get("palette", "Plotly")
colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid)
for metric_x, metric_y in itertools.combinations(objective_metrics, 2):
for cat_param in categorical_params_list:
logging.info(f"Plotting Pareto for ({metric_x} vs {metric_y}) by '{cat_param}'")
fig = go.Figure()
categories = combined_df[cat_param].unique()
for i, category in enumerate(categories):
df_cat = combined_df[combined_df[cat_param] == category]
if df_cat.empty:
continue
color = colors[i % len(colors)]
fig.add_trace(go.Scatter(
x=df_cat[metric_x], y=df_cat[metric_y], mode='markers',
marker=dict(color=color, opacity=0.3, size=5), name=f'{category} (all trials)',
legendgroup=str(category), showlegend=False, customdata=df_cat.index,
hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>'
))
pareto_df = find_pareto_frontier_fast(
df=df_cat, x_col=metric_x, y_col=metric_y,
x_dir=directions_map[metric_x], y_dir=directions_map[metric_y]
)
fig.add_trace(go.Scatter(
x=pareto_df[metric_x], y=pareto_df[metric_y], mode='lines+markers',
line=dict(color=color, width=2), marker=dict(color=color, size=8, symbol='star'),
name=f'{category} (Pareto Frontier)', legendgroup=str(category), customdata=pareto_df.index,
hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>'
))
fig.update_layout(
title=f"Pareto Frontier: {metric_x} vs {metric_y}<br><sub>Grouped by {cat_param}</sub>",
xaxis_title=f"{metric_x} ({directions_map[metric_x].name})",
yaxis_title=f"{metric_y} ({directions_map[metric_y].name})",
template=template, legend_title=cat_param, hovermode='closest'
)
fname = f"pareto_{metric_x}_vs_{metric_y}_by_{cat_param}.html".replace('/', '_')
output_file = os.path.join(output_path, fname)
fig.write_html(output_file)
logging.info("--- Pareto Frontier Plot Generation Complete ---")
[docs]
def plot_categorical_distributions(
metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str],
output_path: t.Union[str, pathlib.Path], categorical_params_list: list, plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
"""Generates and saves distribution plots for categorical hyperparameters."""
logging.info("--- Starting Categorical Distribution Plot Generation ---")
plot_output_path = os.path.join(output_path, "distribution")
logging.info(f"Plots will be saved to: {os.path.abspath(plot_output_path)}")
combined_df = pd.concat([metric_df, param_df], axis=1)
if not categorical_params_list:
logging.warning("No categorical parameters found in 'optuna.parameters' config to plot.")
return
logging.info(f"Found categorical parameters to plot: {categorical_params_list}")
height_per_metric = plot_cfg.get("height_per_metric", 400)
template = plot_cfg.get("template", "plotly_dark")
vertical_spacing = plot_cfg.get("vertical_spacing", 0.05)
show_points = plot_cfg.get("show_points", "all")
point_opacity = plot_cfg.get("point_opacity", 0.5)
point_jitter = plot_cfg.get("point_jitter", 0.1)
color_by_category = plot_cfg.get("color_by_category", True)
palette_name = plot_cfg.get("palette", "Plotly")
colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid)
show_legend = plot_cfg.get("show_legend", True)
ref_lines_cfg = plot_cfg.get("reference_lines", {})
for param_name in categorical_params_list:
if param_name not in combined_df.columns:
logging.warning(f"Parameter '{param_name}' from config not found in trial data. Skipping.")
continue
logging.info(f"Generating plot for parameter: '{param_name}'...")
num_metrics = len(objective_metrics)
fig = make_subplots(
rows=num_metrics, cols=1,
subplot_titles=objective_metrics, vertical_spacing=vertical_spacing
)
updatemenus = []
total_subplot_height = 1.0 - (vertical_spacing * (num_metrics - 1))
subplot_height = total_subplot_height / num_metrics if num_metrics > 0 else 1.0
for i, metric_name in enumerate(objective_metrics):
if color_by_category:
categories = combined_df[param_name].unique()
for j, category in enumerate(categories):
df_cat = combined_df[combined_df[param_name] == category]
if df_cat.empty:
continue
fig.add_trace(go.Violin(
x=df_cat[param_name], y=df_cat[metric_name], name=str(category),
line_color=colors[j % len(colors)], points=show_points,
jitter=point_jitter, pointpos=0, box_visible=True,
meanline_visible=True, marker=dict(opacity=point_opacity)), row=i + 1, col=1)
else:
fig.add_trace(go.Violin(
x=combined_df[param_name], y=combined_df[metric_name], name=metric_name,
points=show_points, jitter=point_jitter, pointpos=0,
box_visible=True, meanline_visible=True,
marker=dict(opacity=point_opacity)), row=i + 1, col=1)
if metric_name in ref_lines_cfg:
line_cfg = ref_lines_cfg[metric_name]
if "value" in line_cfg:
fig.add_hline(y=line_cfg["value"],
line_dash=line_cfg.get("style", "dot"),
annotation_text=line_cfg.get("text", f'Target: {line_cfg["value"]}'),
annotation_position=line_cfg.get("position", "bottom right"),
row=i + 1, col=1)
y_data = combined_df[metric_name].dropna()
if not y_data.empty:
q1, q3 = y_data.quantile(0.25), y_data.quantile(0.75)
p05, p95 = y_data.quantile(0.05), y_data.quantile(0.95)
iqr = q3 - q1
lower_whisker = max(y_data.min(), q1 - 1.5 * iqr)
upper_whisker = min(y_data.max(), q3 + 1.5 * iqr)
buttons = [
dict(label="Full Range", method="relayout", args=[{f"yaxis{i+1}.autorange": True}]),
dict(label="Whisker Range", method="relayout", args=[{f"yaxis{i+1}.range": [lower_whisker, upper_whisker]}]),
dict(label="5th-95th Percentile", method="relayout", args=[{f"yaxis{i+1}.range": [p05, p95]}]),
dict(label="Interquartile (IQR)", method="relayout", args=[{f"yaxis{i+1}.range": [q1, q3]}]),
]
menu_y_position = 1.0 - (i * (subplot_height + vertical_spacing))
updatemenus.append(dict(buttons=buttons, direction="down", pad={"r": 10, "t": 10},
showactive=True, x=0.01, xanchor="left", y=menu_y_position, yanchor="top"))
fig.update_layout(
title_text=f"Distribution of Objective Metrics by '{param_name}'",
height=height_per_metric * num_metrics,
showlegend=(color_by_category and show_legend),
template=template,
updatemenus=updatemenus,
margin=dict(t=100, l=80))
output_file = os.path.join(plot_output_path, f"{param_name}.html")
try:
fig.write_html(output_file)
logging.info(f"Successfully saved plot to {output_file}")
except Exception as e:
logging.error(f"Failed to save plot for {param_name}: {e}")
[docs]
def generate_plots(
study: optuna.study.Study,
plot_cfg: DictConfig,
exp_params: DictConfig,
output_path: t.Union[str, pathlib.Path],
):
"""
Main function to generate all Optuna plots.
This function encapsulates the logic from the PLOT mode of the run function.
"""
logging.info("Starting PLOT mode.")
metric_names = study.metric_names
study_directions = study.directions
metric_df, param_df = ml_helpers.trials_to_dataframes(
study.get_trials(deepcopy=False),
metric_names,
)
continuous_params_map, categorical_params_list = ml_helpers.prepare_parameter_data(
exp_params, param_df
)
metric_df, metric_names = ml_helpers.sanitize_metric_dataframe(
metric_df,
metric_names,
)
filter_cfg = plot_cfg.get("filter", [])
if filter_cfg:
apply_plot_filters(metric_df, param_df, filter_cfg)
logging.info(f"After filtering, {len(metric_df)} trial(s) remain for plotting.")
else:
logging.info("No plot filters defined - using all trials.")
plot_categorical_distributions(
metric_df=metric_df,
param_df=param_df,
objective_metrics=metric_names,
output_path=output_path,
categorical_params_list=categorical_params_list,
plot_cfg=plot_cfg,
)
plot_pareto_frontiers_by_category(
metric_df=metric_df, param_df=param_df,
objective_metrics=metric_names,
study_directions=study_directions,
output_path=output_path,
categorical_params_list=categorical_params_list,
plot_cfg=plot_cfg,
)
plot_enhanced_scatter(
metric_df,
param_df,
metric_names,
output_path,
continuous_params_map,
plot_cfg
)
plot_pareto_continuous_color(
metric_df,
param_df,
metric_names,
study_directions,
output_path,
continuous_params_map,
plot_cfg
)
plot_contour(
metric_df,
param_df,
metric_names,
output_path,
continuous_params_map,
plot_cfg
)
logging.info("Plot generation complete.")