Source code for gunz_ml.plots.optuna

import typing as t
import itertools
import logging
import os
import pathlib

import numpy as np
import pandas as pd
from scipy.interpolate import griddata

import optuna
from optuna.study import StudyDirection
from omegaconf import DictConfig

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

from ..integrations import optuna as ml_helpers

[docs] def find_pareto_frontier_fast( df: pd.DataFrame, x_col: str, y_col: str, x_dir: str, y_dir: str, ) -> pd.DataFrame: """Find Pareto frontier points efficiently (O(n log n)) using sorting.""" if df.empty: return pd.DataFrame(columns=df.columns) x_ascending = x_dir == StudyDirection.MINIMIZE y_ascending = y_dir == StudyDirection.MINIMIZE sorted_df = df.sort_values( by=[x_col, y_col], ascending=[x_ascending, y_ascending] ) pareto_indices: list[t.Union[int, str]] = [] last_pareto_point = sorted_df.iloc[0] pareto_indices.append(last_pareto_point.name) for index, row in sorted_df.iloc[1:].iterrows(): if y_dir == StudyDirection.MINIMIZE: if row[y_col] < last_pareto_point[y_col]: last_pareto_point = row pareto_indices.append(index) else: # StudyDirection.MAXIMIZE if row[y_col] > last_pareto_point[y_col]: last_pareto_point = row pareto_indices.append(index) return df.loc[pareto_indices]
[docs] def apply_plot_filters( metric_df: pd.DataFrame, param_df: pd.DataFrame, filter_cfg: t.List[t.Dict[str, t.Any]], ) -> None: """Apply a list of metric-based filters to metric_df and param_df in-place.""" if not filter_cfg: return missing = [ f["metric"] for f in filter_cfg if ml_helpers.sanitize_name(f["metric"]) not in metric_df.columns ] if missing: raise ValueError(f"Metrics specified in `plots.filter` not present: {missing}") mask = pd.Series(True, index=metric_df.index) for f in filter_cfg: metric = ml_helpers.sanitize_name(f["metric"]) direction = f["direction"].lower() threshold = f["threshold"] if direction not in {"maximize", "minimize"}: raise ValueError(f"Invalid direction `{direction}` for metric `{metric}`.") if direction == "maximize": mask &= metric_df[metric] >= threshold else: mask &= metric_df[metric] <= threshold logging.info(f"Filter applied: metric=`{metric}`, direction=`{direction}`, threshold={threshold}. Remaining trials: {mask.sum()}") metric_df.drop(index=metric_df.index[~mask], inplace=True) param_df.drop(index=param_df.index[~mask], inplace=True) if metric_df.empty: raise RuntimeError("All trials were filtered out by `plots.filter`.")
[docs] def plot_enhanced_scatter( metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str], output_path: t.Union[str, pathlib.Path], continuous_params_map: t.Dict[str, str], plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """Generates scatter plots of 1 continuous parameter vs. 2 objectives.""" logging.info("--- Starting Enhanced Scatter Plot Generation ---") plot_output_path = os.path.join(output_path, "continuous_scatter") template = plot_cfg.get("template", "plotly_dark") for param_name, param_type in continuous_params_map.items(): for metric_y, metric_color in itertools.combinations(objective_metrics, 2): fig = go.Figure(data=go.Scatter( x=param_df[param_name], y=metric_df[metric_y], mode='markers', marker=dict(color=metric_df[metric_color], colorscale='Viridis', showscale=True, colorbar_title=metric_color.replace('/', '_')), customdata=param_df.index, hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{param_name}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>' )) fig.update_layout(title=f'{param_name} vs. {metric_y}<br><sub>Colored by {metric_color}</sub>', xaxis_title=f"{param_name} ({'log scale' if param_type == 'log' else 'linear scale'})", yaxis_title=metric_y, xaxis_type='log' if param_type == 'log' else 'linear', template=template) fname = f"{metric_y}.html" param_output_path = os.path.join(plot_output_path, param_name) os.makedirs(param_output_path, exist_ok=True) fig.write_html(os.path.join(param_output_path, fname))
[docs] def plot_pareto_continuous_color( metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str], study_directions: t.List[StudyDirection], output_path: t.Union[str, pathlib.Path], continuous_params_map: t.Dict[str, str], plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """Generates Pareto fronts of 2 objectives, colored by a continuous parameter.""" if not continuous_params_map: return logging.info("--- Starting Continuous Pareto Plot Generation ---") plot_output_path = os.path.join(output_path, "continuous_pareto") template = plot_cfg.get("template", "plotly_dark") combined_df = pd.concat([metric_df, param_df], axis=1) directions_map = {name: d for name, d in zip(objective_metrics, study_directions)} for metric_x, metric_y in itertools.combinations(objective_metrics, 2): for param_name, param_type in continuous_params_map.items(): pareto_df = find_pareto_frontier_fast(combined_df, metric_x, metric_y, directions_map[metric_x], directions_map[metric_y]) fig = go.Figure() fig.add_trace(go.Scatter(x=combined_df[metric_x], y=combined_df[metric_y], mode='markers', marker=dict(color='grey', opacity=0.2, size=4), name='All Trials', customdata=combined_df.index, hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>')) fig.add_trace(go.Scatter(x=pareto_df[metric_x], y=pareto_df[metric_y], mode='lines+markers', marker=dict(color=pareto_df[param_name], colorscale='Cividis', showscale=True, colorbar_title=param_name.replace('/', '_'), size=10), line=dict(width=3), name='Pareto Frontier', customdata=pareto_df.index, hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{param_name}</b>: %{{marker.color}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>')) fig.update_layout(title=f'Pareto: {metric_x} vs. {metric_y}<br><sub>Colored by {param_name}</sub>', xaxis_title=f"{metric_x} ({directions_map[metric_x].name})", yaxis_title=f"{metric_y} ({directions_map[metric_y].name})", template=template) fname = f"{metric_x}_vs_{metric_y}.html".replace('/', '_') param_output_path = os.path.join(plot_output_path, param_name) os.makedirs(param_output_path, exist_ok=True) fig.write_html(os.path.join(param_output_path, fname))
[docs] def plot_contour( metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str], output_path: t.Union[str, pathlib.Path], continuous_params_map: t.Dict[str, str], plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """Generates contour plots of 2 continuous parameters vs. 1 objective.""" logging.info("--- Starting Contour Plot Generation ---") if len(metric_df) < 15: logging.warning("Skipping contour plots: not enough trials (< 15) for reliable interpolation.") return if len(continuous_params_map) < 2: logging.warning("Skipping contour plots: need at least 2 continuous parameters.") return plot_output_path = os.path.join(output_path, "continuous_contour") template = plot_cfg.get("template", "plotly_dark") for param_x, param_y in itertools.combinations(continuous_params_map.keys(), 2): for metric_z in objective_metrics: try: points = param_df[[param_x, param_y]].values values = metric_df[metric_z].values is_x_log, is_y_log = (continuous_params_map[param_x] == 'log'), (continuous_params_map[param_y] == 'log') points_for_grid = np.copy(points).astype(float) if is_x_log: points_for_grid[:, 0] = np.log10(points_for_grid[:, 0]) if is_y_log: points_for_grid[:, 1] = np.log10(points_for_grid[:, 1]) grid_x, grid_y = np.mgrid[points_for_grid[:,0].min():points_for_grid[:,0].max():100j, points_for_grid[:,1].min():points_for_grid[:,1].max():100j] grid_z = griddata(points_for_grid, values, (grid_x, grid_y), method='cubic') x_coords = 10**grid_x[:,0] if is_x_log else grid_x[:,0] y_coords = 10**grid_y[0,:] if is_y_log else grid_y[0,:] fig = go.Figure(data=[go.Contour(x=x_coords, y=y_coords, z=grid_z.T, colorscale='Turbo', contours=dict(coloring='heatmap', showlabels=True, labelfont=dict(size=10, color='white')))]) fig.add_trace(go.Scatter(x=param_df[param_x], y=param_df[param_y], mode='markers', marker=dict(color='rgba(255,255,255,0.5)', size=4), showlegend=False, hovertemplate=f'<b>{param_x}</b>: %{{x}}<br><b>{param_y}</b>: %{{y}}<extra></extra>')) fig.update_layout(title=f'Contour: {param_x} vs. {param_y}<br><sub>on {metric_z}</sub>', xaxis_title=f"{param_x} ({'log scale' if is_x_log else 'linear'})", yaxis_title=f"{param_y} ({'log scale' if is_y_log else 'linear'})", xaxis_type='log' if is_x_log else 'linear', yaxis_type='log' if is_y_log else 'linear', template=template) param_name = f"{param_x}_vs_{param_y}" fname = f"{metric_z}.html".replace('/', '_') param_output_path = os.path.join(plot_output_path, param_name) os.makedirs(param_output_path, exist_ok=True) fig.write_html(os.path.join(param_output_path, fname)) except Exception as e: logging.warning(f"Could not generate contour for ({param_x}, {param_y}) vs {metric_z}: {e}")
[docs] def plot_pareto_frontiers_by_category( metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str], study_directions: t.List[StudyDirection], output_path: t.Union[str, pathlib.Path], categorical_params_list: list, plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """Generates Pareto frontier plots for each pair of objectives, grouped by each categorical hyperparameter.""" logging.info("--- Starting Pareto Frontier Plot Generation ---") output_path = os.path.join(output_path, "pareto_front") os.makedirs(output_path, exist_ok=True) combined_df = pd.concat([metric_df, param_df], axis=1) if not categorical_params_list: logging.warning("No categorical parameters found to group Pareto frontiers by.") return directions_map = {name: direction for name, direction in zip(objective_metrics, study_directions)} template = plot_cfg.get("template", "plotly_dark") palette_name = plot_cfg.get("palette", "Plotly") colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid) for metric_x, metric_y in itertools.combinations(objective_metrics, 2): for cat_param in categorical_params_list: logging.info(f"Plotting Pareto for ({metric_x} vs {metric_y}) by '{cat_param}'") fig = go.Figure() categories = combined_df[cat_param].unique() for i, category in enumerate(categories): df_cat = combined_df[combined_df[cat_param] == category] if df_cat.empty: continue color = colors[i % len(colors)] fig.add_trace(go.Scatter( x=df_cat[metric_x], y=df_cat[metric_y], mode='markers', marker=dict(color=color, opacity=0.3, size=5), name=f'{category} (all trials)', legendgroup=str(category), showlegend=False, customdata=df_cat.index, hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>' )) pareto_df = find_pareto_frontier_fast( df=df_cat, x_col=metric_x, y_col=metric_y, x_dir=directions_map[metric_x], y_dir=directions_map[metric_y] ) fig.add_trace(go.Scatter( x=pareto_df[metric_x], y=pareto_df[metric_y], mode='lines+markers', line=dict(color=color, width=2), marker=dict(color=color, size=8, symbol='star'), name=f'{category} (Pareto Frontier)', legendgroup=str(category), customdata=pareto_df.index, hovertemplate=f'<b>Trial</b>: %{{customdata}}<br><b>{metric_x}</b>: %{{x}}<br><b>{metric_y}</b>: %{{y}}<extra></extra>' )) fig.update_layout( title=f"Pareto Frontier: {metric_x} vs {metric_y}<br><sub>Grouped by {cat_param}</sub>", xaxis_title=f"{metric_x} ({directions_map[metric_x].name})", yaxis_title=f"{metric_y} ({directions_map[metric_y].name})", template=template, legend_title=cat_param, hovermode='closest' ) fname = f"pareto_{metric_x}_vs_{metric_y}_by_{cat_param}.html".replace('/', '_') output_file = os.path.join(output_path, fname) fig.write_html(output_file) logging.info("--- Pareto Frontier Plot Generation Complete ---")
[docs] def plot_categorical_distributions( metric_df: pd.DataFrame, param_df: pd.DataFrame, objective_metrics: t.List[str], output_path: t.Union[str, pathlib.Path], categorical_params_list: list, plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """Generates and saves distribution plots for categorical hyperparameters.""" logging.info("--- Starting Categorical Distribution Plot Generation ---") plot_output_path = os.path.join(output_path, "distribution") logging.info(f"Plots will be saved to: {os.path.abspath(plot_output_path)}") combined_df = pd.concat([metric_df, param_df], axis=1) if not categorical_params_list: logging.warning("No categorical parameters found in 'optuna.parameters' config to plot.") return logging.info(f"Found categorical parameters to plot: {categorical_params_list}") height_per_metric = plot_cfg.get("height_per_metric", 400) template = plot_cfg.get("template", "plotly_dark") vertical_spacing = plot_cfg.get("vertical_spacing", 0.05) show_points = plot_cfg.get("show_points", "all") point_opacity = plot_cfg.get("point_opacity", 0.5) point_jitter = plot_cfg.get("point_jitter", 0.1) color_by_category = plot_cfg.get("color_by_category", True) palette_name = plot_cfg.get("palette", "Plotly") colors = getattr(px.colors.qualitative, palette_name, px.colors.qualitative.Vivid) show_legend = plot_cfg.get("show_legend", True) ref_lines_cfg = plot_cfg.get("reference_lines", {}) for param_name in categorical_params_list: if param_name not in combined_df.columns: logging.warning(f"Parameter '{param_name}' from config not found in trial data. Skipping.") continue logging.info(f"Generating plot for parameter: '{param_name}'...") num_metrics = len(objective_metrics) fig = make_subplots( rows=num_metrics, cols=1, subplot_titles=objective_metrics, vertical_spacing=vertical_spacing ) updatemenus = [] total_subplot_height = 1.0 - (vertical_spacing * (num_metrics - 1)) subplot_height = total_subplot_height / num_metrics if num_metrics > 0 else 1.0 for i, metric_name in enumerate(objective_metrics): if color_by_category: categories = combined_df[param_name].unique() for j, category in enumerate(categories): df_cat = combined_df[combined_df[param_name] == category] if df_cat.empty: continue fig.add_trace(go.Violin( x=df_cat[param_name], y=df_cat[metric_name], name=str(category), line_color=colors[j % len(colors)], points=show_points, jitter=point_jitter, pointpos=0, box_visible=True, meanline_visible=True, marker=dict(opacity=point_opacity)), row=i + 1, col=1) else: fig.add_trace(go.Violin( x=combined_df[param_name], y=combined_df[metric_name], name=metric_name, points=show_points, jitter=point_jitter, pointpos=0, box_visible=True, meanline_visible=True, marker=dict(opacity=point_opacity)), row=i + 1, col=1) if metric_name in ref_lines_cfg: line_cfg = ref_lines_cfg[metric_name] if "value" in line_cfg: fig.add_hline(y=line_cfg["value"], line_dash=line_cfg.get("style", "dot"), annotation_text=line_cfg.get("text", f'Target: {line_cfg["value"]}'), annotation_position=line_cfg.get("position", "bottom right"), row=i + 1, col=1) y_data = combined_df[metric_name].dropna() if not y_data.empty: q1, q3 = y_data.quantile(0.25), y_data.quantile(0.75) p05, p95 = y_data.quantile(0.05), y_data.quantile(0.95) iqr = q3 - q1 lower_whisker = max(y_data.min(), q1 - 1.5 * iqr) upper_whisker = min(y_data.max(), q3 + 1.5 * iqr) buttons = [ dict(label="Full Range", method="relayout", args=[{f"yaxis{i+1}.autorange": True}]), dict(label="Whisker Range", method="relayout", args=[{f"yaxis{i+1}.range": [lower_whisker, upper_whisker]}]), dict(label="5th-95th Percentile", method="relayout", args=[{f"yaxis{i+1}.range": [p05, p95]}]), dict(label="Interquartile (IQR)", method="relayout", args=[{f"yaxis{i+1}.range": [q1, q3]}]), ] menu_y_position = 1.0 - (i * (subplot_height + vertical_spacing)) updatemenus.append(dict(buttons=buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.01, xanchor="left", y=menu_y_position, yanchor="top")) fig.update_layout( title_text=f"Distribution of Objective Metrics by '{param_name}'", height=height_per_metric * num_metrics, showlegend=(color_by_category and show_legend), template=template, updatemenus=updatemenus, margin=dict(t=100, l=80)) output_file = os.path.join(plot_output_path, f"{param_name}.html") try: fig.write_html(output_file) logging.info(f"Successfully saved plot to {output_file}") except Exception as e: logging.error(f"Failed to save plot for {param_name}: {e}")
[docs] def generate_plots( study: optuna.study.Study, plot_cfg: DictConfig, exp_params: DictConfig, output_path: t.Union[str, pathlib.Path], ): """ Main function to generate all Optuna plots. This function encapsulates the logic from the PLOT mode of the run function. """ logging.info("Starting PLOT mode.") metric_names = study.metric_names study_directions = study.directions metric_df, param_df = ml_helpers.trials_to_dataframes( study.get_trials(deepcopy=False), metric_names, ) continuous_params_map, categorical_params_list = ml_helpers.prepare_parameter_data( exp_params, param_df ) metric_df, metric_names = ml_helpers.sanitize_metric_dataframe( metric_df, metric_names, ) filter_cfg = plot_cfg.get("filter", []) if filter_cfg: apply_plot_filters(metric_df, param_df, filter_cfg) logging.info(f"After filtering, {len(metric_df)} trial(s) remain for plotting.") else: logging.info("No plot filters defined - using all trials.") plot_categorical_distributions( metric_df=metric_df, param_df=param_df, objective_metrics=metric_names, output_path=output_path, categorical_params_list=categorical_params_list, plot_cfg=plot_cfg, ) plot_pareto_frontiers_by_category( metric_df=metric_df, param_df=param_df, objective_metrics=metric_names, study_directions=study_directions, output_path=output_path, categorical_params_list=categorical_params_list, plot_cfg=plot_cfg, ) plot_enhanced_scatter( metric_df, param_df, metric_names, output_path, continuous_params_map, plot_cfg ) plot_pareto_continuous_color( metric_df, param_df, metric_names, study_directions, output_path, continuous_params_map, plot_cfg ) plot_contour( metric_df, param_df, metric_names, output_path, continuous_params_map, plot_cfg ) logging.info("Plot generation complete.")