"""
Plotting function for visualizing a continuous hyperparameter against two
objective metrics.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import itertools
import logging
import pathlib
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import pandas as pd
import plotly.graph_objects as go
from omegaconf import DictConfig
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from .utils import DEFAULT_PLOT_CONFIG
[docs]
def plot_enhanced_scatter(
#? --- Data Inputs ---
metric_df: pd.DataFrame,
param_df: pd.DataFrame,
#? --- Plotting Configuration ---
objective_metrics: t.List[str],
output_path: pathlib.Path,
continuous_params_map: t.Dict[str, str],
plot_cfg: t.Union[DictConfig, t.Dict],
) -> None:
""" Generates scatter plots of 1 continuous parameter vs. 2 objectives. """
logging.info("--- Starting Enhanced Scatter Plot Generation ---")
plot_output_path = output_path / "continuous_scatter"
cfg = {**DEFAULT_PLOT_CONFIG, **plot_cfg}
for param_name, param_type in continuous_params_map.items():
for metric_y, metric_color in itertools.combinations(objective_metrics, 2):
fig = go.Figure(data=go.Scatter(
x=param_df[param_name],
y=metric_df[metric_y],
mode='markers',
marker=dict(
color=metric_df[metric_color],
colorscale='Viridis',
showscale=True,
colorbar_title=metric_color.replace('/', '_')
),
customdata=param_df.index,
hovertemplate=(f'<b>Trial</b>: %{{customdata}}<br>'
f'<b>{param_name}</b>: %{{x}}<br>'
f'<b>{metric_y}</b>: %{{y}}<extra></extra>')
))
fig.update_layout(
title=f'{param_name} vs. {metric_y}<br><sub>Colored by {metric_color}</sub>',
xaxis_title=f"{param_name} ({'log' if param_type == 'log' else 'linear'} scale)",
yaxis_title=metric_y,
xaxis_type='log' if param_type == 'log' else 'linear',
template=cfg["template"]
)
param_output_path = plot_output_path / param_name
param_output_path.mkdir(parents=True, exist_ok=True)
fig.write_html(param_output_path / f"{metric_y.replace('/', '_')}.html")