Source code for gunz_ml.plots.enhanced_scatter

"""
Plotting function for visualizing a continuous hyperparameter against two
objective metrics.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import itertools
import logging
import pathlib
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import pandas as pd
import plotly.graph_objects as go
from omegaconf import DictConfig

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from .utils import DEFAULT_PLOT_CONFIG

[docs] def plot_enhanced_scatter( #? --- Data Inputs --- metric_df: pd.DataFrame, param_df: pd.DataFrame, #? --- Plotting Configuration --- objective_metrics: t.List[str], output_path: pathlib.Path, continuous_params_map: t.Dict[str, str], plot_cfg: t.Union[DictConfig, t.Dict], ) -> None: """ Generates scatter plots of 1 continuous parameter vs. 2 objectives. """ logging.info("--- Starting Enhanced Scatter Plot Generation ---") plot_output_path = output_path / "continuous_scatter" cfg = {**DEFAULT_PLOT_CONFIG, **plot_cfg} for param_name, param_type in continuous_params_map.items(): for metric_y, metric_color in itertools.combinations(objective_metrics, 2): fig = go.Figure(data=go.Scatter( x=param_df[param_name], y=metric_df[metric_y], mode='markers', marker=dict( color=metric_df[metric_color], colorscale='Viridis', showscale=True, colorbar_title=metric_color.replace('/', '_') ), customdata=param_df.index, hovertemplate=(f'<b>Trial</b>: %{{customdata}}<br>' f'<b>{param_name}</b>: %{{x}}<br>' f'<b>{metric_y}</b>: %{{y}}<extra></extra>') )) fig.update_layout( title=f'{param_name} vs. {metric_y}<br><sub>Colored by {metric_color}</sub>', xaxis_title=f"{param_name} ({'log' if param_type == 'log' else 'linear'} scale)", yaxis_title=metric_y, xaxis_type='log' if param_type == 'log' else 'linear', template=cfg["template"] ) param_output_path = plot_output_path / param_name param_output_path.mkdir(parents=True, exist_ok=True) fig.write_html(param_output_path / f"{metric_y.replace('/', '_')}.html")