Source code for gunz_ml.integrations.optuna

"""
Provides helper functions for working with Optuna studies and configurations.
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.1.0"

import copy
import typing as t
import math
import logging
import warnings
import pandas as pd
import optuna
from optuna.trial import FrozenTrial, TrialState
from omegaconf import DictConfig, ListConfig, OmegaConf

from .hydra import resolve_cfg
from ..consts import OptunaType
from ..schemas.optuna_cfg import OptunaConfig

OPT_VAL_PREFIX = "#OPTVAL-"
OPT_DICT_PREFIX = "#OPTDICT-"

[docs] def trials_to_dataframes(trials, metric_names=None, target_state=TrialState.COMPLETE): metrics_data = [] params_data = [] if target_state is not None: trials = [trial for trial in trials if trial.state == target_state] for trial in trials: try: values = [trial.value] except RuntimeError: values = trial.values metrics_dict = dict(zip(metric_names, values)) if metric_names else {} metrics_dict["trial_number"] = trial.number metrics_data.append(metrics_dict) params_dict = trial.params.copy() params_dict["trial_number"] = trial.number params_data.append(params_dict) return pd.DataFrame(metrics_data).set_index("trial_number"), pd.DataFrame(params_data).set_index("trial_number")
[docs] def create_study(study_kwargs, objs_cfg, dry_run=False): try: study_kwargs = resolve_cfg(study_kwargs) objs_cfg = resolve_cfg(objs_cfg) except Exception as e: raise RuntimeError(f"Failed to resolve configs: {e}") from e study_args = study_kwargs.copy() if dry_run: study_args["storage"] = None directions = list(objs_cfg.values()) if len(directions) > 1: study_args["directions"] = directions elif len(directions) == 1: study_args["direction"] = directions[0] storage_url = study_args.get("storage") if storage_url and "sqlite" in str(storage_url): if "timeout" not in str(storage_url): sep = "&" if "?" in str(storage_url) else "?" study_args["storage"] = f"{storage_url}{sep}timeout=60" study = optuna.create_study(**study_args) if storage_url and "sqlite" in str(storage_url): try: from sqlalchemy import create_engine engine = create_engine(study_args["storage"]) with engine.connect() as conn: conn.exec_driver_sql("PRAGMA journal_mode=WAL;") except Exception as e: logging.warning(f"Could not set WAL mode: {e}") return study
[docs] def load_study(study_kwargs): study_kargs = study_kwargs.copy() assert study_kargs["storage"] is not None, "Study not found" if "load_if_exists" in study_kargs: del study_kargs["load_if_exists"] return optuna.load_study(**study_kargs)
[docs] def select_objective_values(evaluation_results, objectives_spec): objective_keys = list(objectives_spec.keys()) for key, direction in objectives_spec.items(): if key not in evaluation_results: raise KeyError(f"Metric {key} not found") selected_values = [evaluation_results[key] for key in objective_keys] return selected_values[0] if len(selected_values) == 1 else selected_values
[docs] def optimize_cfg(cfg, trial_cfg): resolved = resolve_cfg(cfg) optimized = copy.deepcopy(resolved) for key, val in optimized.items(): if isinstance(val, str): if val.startswith(OPT_VAL_PREFIX): optimized[key] = trial_cfg.suggest_param(val.removeprefix(OPT_VAL_PREFIX)) elif val.startswith(OPT_DICT_PREFIX): optimized[key] = trial_cfg.suggest_dict_params(val.removeprefix(OPT_DICT_PREFIX)) elif isinstance(val, dict): optimized[key] = optimize_cfg(val, trial_cfg) elif isinstance(val, list): optimized[key] = [optimize_cfg(i, trial_cfg) if isinstance(i, dict) else i for i in val] return optimized
[docs] def check_pruning_thresholds(metrics, pruning_cfg, key, trial_number=None, seed=None): level_cfg = pruning_cfg.get(key, {}) if not level_cfg.get("enabled", False): return threshold_cfg = level_cfg.get("metrics", {}) for name, config in threshold_cfg.items(): if name not in metrics: continue val = metrics[name] if config.get("check_nonfinite", True) and (not isinstance(val, (int, float)) or math.isnan(val) or math.isinf(val)): raise optuna.exceptions.TrialPruned(f"{name} is non-finite") if config.get("min") is not None and val < config["min"]: raise optuna.exceptions.TrialPruned(f"{name} < min") if config.get("max") is not None and val > config["max"]: raise optuna.exceptions.TrialPruned(f"{name} > max")
[docs] def sanitize_name(name): return name.replace("/", "-")
[docs] def prepare_parameter_data(optuna_params_cfg, param_df): continuous, categorical = {}, [] for raw_name, cfg in optuna_params_cfg.items(): if raw_name not in param_df.columns or not isinstance(cfg, (DictConfig, dict)): continue p_type = cfg.get("type") if not p_type: continue safe_name = sanitize_name(raw_name) if safe_name != raw_name: param_df.rename(columns={raw_name: safe_name}, inplace=True) if p_type in ["float", "logarithm"]: continuous[safe_name] = p_type elif p_type == "categorical": categorical.append(safe_name) for cat in categorical: param_df[cat] = param_df[cat].fillna("None") return continuous, categorical
[docs] def sanitize_metric_names(metric_names): cleaned, counter = [], {} for raw in metric_names: name = sanitize_name(str(raw)) if name in counter: counter[name] += 1 name = f"{name}_{counter[name]}" else: counter[name] = 0 cleaned.append(name) return cleaned
[docs] def sanitize_metric_dataframe(metric_df, metric_names): sanitized = sanitize_metric_names(metric_names) rename_map = {o: n for o, n in zip(metric_names, sanitized) if o in metric_df.columns} return metric_df.rename(columns=rename_map), sanitized