"""
Provides helper functions for working with Optuna studies and configurations.
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.1.0"
import copy
import typing as t
import math
import logging
import warnings
import pandas as pd
import optuna
from optuna.trial import FrozenTrial, TrialState
from omegaconf import DictConfig, ListConfig, OmegaConf
from .hydra import resolve_cfg
from ..consts import OptunaType
from ..schemas.optuna_cfg import OptunaConfig
OPT_VAL_PREFIX = "#OPTVAL-"
OPT_DICT_PREFIX = "#OPTDICT-"
[docs]
def trials_to_dataframes(trials, metric_names=None, target_state=TrialState.COMPLETE):
metrics_data = []
params_data = []
if target_state is not None:
trials = [trial for trial in trials if trial.state == target_state]
for trial in trials:
try:
values = [trial.value]
except RuntimeError:
values = trial.values
metrics_dict = dict(zip(metric_names, values)) if metric_names else {}
metrics_dict["trial_number"] = trial.number
metrics_data.append(metrics_dict)
params_dict = trial.params.copy()
params_dict["trial_number"] = trial.number
params_data.append(params_dict)
return pd.DataFrame(metrics_data).set_index("trial_number"), pd.DataFrame(params_data).set_index("trial_number")
[docs]
def create_study(study_kwargs, objs_cfg, dry_run=False):
try:
study_kwargs = resolve_cfg(study_kwargs)
objs_cfg = resolve_cfg(objs_cfg)
except Exception as e:
raise RuntimeError(f"Failed to resolve configs: {e}") from e
study_args = study_kwargs.copy()
if dry_run: study_args["storage"] = None
directions = list(objs_cfg.values())
if len(directions) > 1: study_args["directions"] = directions
elif len(directions) == 1: study_args["direction"] = directions[0]
storage_url = study_args.get("storage")
if storage_url and "sqlite" in str(storage_url):
if "timeout" not in str(storage_url):
sep = "&" if "?" in str(storage_url) else "?"
study_args["storage"] = f"{storage_url}{sep}timeout=60"
study = optuna.create_study(**study_args)
if storage_url and "sqlite" in str(storage_url):
try:
from sqlalchemy import create_engine
engine = create_engine(study_args["storage"])
with engine.connect() as conn:
conn.exec_driver_sql("PRAGMA journal_mode=WAL;")
except Exception as e:
logging.warning(f"Could not set WAL mode: {e}")
return study
[docs]
def load_study(study_kwargs):
study_kargs = study_kwargs.copy()
assert study_kargs["storage"] is not None, "Study not found"
if "load_if_exists" in study_kargs: del study_kargs["load_if_exists"]
return optuna.load_study(**study_kargs)
[docs]
def select_objective_values(evaluation_results, objectives_spec):
objective_keys = list(objectives_spec.keys())
for key, direction in objectives_spec.items():
if key not in evaluation_results: raise KeyError(f"Metric {key} not found")
selected_values = [evaluation_results[key] for key in objective_keys]
return selected_values[0] if len(selected_values) == 1 else selected_values
[docs]
def optimize_cfg(cfg, trial_cfg):
resolved = resolve_cfg(cfg)
optimized = copy.deepcopy(resolved)
for key, val in optimized.items():
if isinstance(val, str):
if val.startswith(OPT_VAL_PREFIX): optimized[key] = trial_cfg.suggest_param(val.removeprefix(OPT_VAL_PREFIX))
elif val.startswith(OPT_DICT_PREFIX): optimized[key] = trial_cfg.suggest_dict_params(val.removeprefix(OPT_DICT_PREFIX))
elif isinstance(val, dict): optimized[key] = optimize_cfg(val, trial_cfg)
elif isinstance(val, list): optimized[key] = [optimize_cfg(i, trial_cfg) if isinstance(i, dict) else i for i in val]
return optimized
[docs]
def check_pruning_thresholds(metrics, pruning_cfg, key, trial_number=None, seed=None):
level_cfg = pruning_cfg.get(key, {})
if not level_cfg.get("enabled", False): return
threshold_cfg = level_cfg.get("metrics", {})
for name, config in threshold_cfg.items():
if name not in metrics: continue
val = metrics[name]
if config.get("check_nonfinite", True) and (not isinstance(val, (int, float)) or math.isnan(val) or math.isinf(val)): raise optuna.exceptions.TrialPruned(f"{name} is non-finite")
if config.get("min") is not None and val < config["min"]: raise optuna.exceptions.TrialPruned(f"{name} < min")
if config.get("max") is not None and val > config["max"]: raise optuna.exceptions.TrialPruned(f"{name} > max")
[docs]
def sanitize_name(name):
return name.replace("/", "-")
[docs]
def prepare_parameter_data(optuna_params_cfg, param_df):
continuous, categorical = {}, []
for raw_name, cfg in optuna_params_cfg.items():
if raw_name not in param_df.columns or not isinstance(cfg, (DictConfig, dict)): continue
p_type = cfg.get("type")
if not p_type: continue
safe_name = sanitize_name(raw_name)
if safe_name != raw_name: param_df.rename(columns={raw_name: safe_name}, inplace=True)
if p_type in ["float", "logarithm"]: continuous[safe_name] = p_type
elif p_type == "categorical": categorical.append(safe_name)
for cat in categorical: param_df[cat] = param_df[cat].fillna("None")
return continuous, categorical
[docs]
def sanitize_metric_names(metric_names):
cleaned, counter = [], {}
for raw in metric_names:
name = sanitize_name(str(raw))
if name in counter:
counter[name] += 1
name = f"{name}_{counter[name]}"
else: counter[name] = 0
cleaned.append(name)
return cleaned
[docs]
def sanitize_metric_dataframe(metric_df, metric_names):
sanitized = sanitize_metric_names(metric_names)
rename_map = {o: n for o, n in zip(metric_names, sanitized) if o in metric_df.columns}
return metric_df.rename(columns=rename_map), sanitized