Source code for gunz_ml.schemas.optuna_cfg

"""
Provides configuration handling specifically for Optuna trials.

This module defines the `OptunaConfig` class, which wraps an Optuna Trial
object to suggest hyperparameters based on a provided configuration dictionary.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.2.0"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import copy
import typing as t
from enum import StrEnum

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import optuna
from optuna.study import StudyDirection
from omegaconf import DictConfig, OmegaConf
from pydantic import validate_call

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ._base_cfg import BaseConfig
from ..consts import OptunaType

[docs] class ObjectiveDirection(StrEnum): """ Enumeration for optimization directions in Optuna. Attributes ---------- MAXIMIZE : str Direction for maximizing the objective. MINIMIZE : str Direction for minimizing the objective. """ MAXIMIZE = "maximize" MINIMIZE = "minimize" @classmethod def _missing_(cls, value): """ Handles conversion of legacy or non-standard string values. Parameters ---------- value : any The value to convert to an ObjectiveDirection. Returns ------- ObjectiveDirection The corresponding enum member. Raises ------ ValueError If the value cannot be converted. """ # Handle conversion from StudyDirection enum if isinstance(value, StudyDirection): if value == StudyDirection.MAXIMIZE: return cls.MAXIMIZE elif value == StudyDirection.MINIMIZE: return cls.MINIMIZE else: raise ValueError(f"Cannot convert {value} to {cls.__name__}") if not isinstance(value, str): return super()._missing_(value) value_lower = value.lower() #? Handle general case-insensitive matching. for member in cls: if member.value == value_lower: return member #? If no match is found, raise an error. valid_options = ", ".join(m.value for m in cls) raise ValueError(f"'{value}' is not a valid {cls.__name__}. Please use one of: {valid_options}")
@validate_call(config=dict(arbitrary_types_allowed=True)) def optuna_suggest( trial: optuna.Trial, type_tag: OptunaType, name: str, low: t.Optional[t.Union[int, float]] = None, high: t.Optional[t.Union[int, float]] = None, step: t.Optional[t.Union[int, float]] = None, vals: t.Optional[t.Sequence[t.Any]] = None, ) -> t.Any: """ Dispatches to the appropriate Optuna `suggest_*` method. Parameters ---------- trial : optuna.Trial The Optuna trial object to suggest a value from. type_tag : OptunaType An enum value that determines which `suggest_*` method to call. name : str The name of the hyperparameter to suggest. low : int or float, optional The lower bound for numerical suggestions. Defaults to None. high : int or float, optional The upper bound for numerical suggestions. Defaults to None. step : int or float, optional The step size for numerical suggestions. Defaults to None. vals : sequence, optional A sequence of choices for categorical suggestions. Defaults to None. Returns ------- any The value suggested by the Optuna trial. """ if type_tag == OptunaType.INT: #? Default step to 1 for integers if not specified. if step is None: step = 1 return trial.suggest_int(name, low, high, step=step) elif type_tag == OptunaType.FLOAT: return trial.suggest_float(name, low, high, step=step) elif type_tag == OptunaType.LOGARITHM: return trial.suggest_float(name, low, high, step=step, log=True) elif type_tag == OptunaType.CATEGORICAL: if vals is None: raise ValueError("`vals` must be provided for categorical suggestions.") return trial.suggest_categorical(name, vals) else: raise ValueError(f"Unsupported OptunaType: {type_tag}")
[docs] class OptunaConfig(BaseConfig): """ A wrapper for an Optuna Trial object to streamline hyperparameter suggestion. This class implements the `BaseConfig` interface to provide a consistent way of accessing hyperparameters during an Optuna optimization study. Attributes ---------- trial : optuna.Trial The active Optuna trial object. params_cfg : dict The configuration dictionary defining the search space. hparams : dict A dictionary that logs the actual suggested hyperparameters for the trial. """ def __init__( self, trial: optuna.Trial, params_cfg: dict, ): """ Initializes the OptunaConfig instance. Parameters ---------- trial : optuna.Trial The Optuna trial object for the current run. params_cfg : dict A dictionary defining the hyperparameter search space. """ self.trial = trial self.params_cfg = params_cfg self.hparams: t.Dict[str, t.Any] = {} @validate_call(config=dict(arbitrary_types_allowed=True)) def suggest_param( self, name: str, def_val: t.Any = None, ena_def_val: bool = False, ) -> t.Any: """ Suggests a value for a hyperparameter based on the configuration. This method reads the configuration for the given parameter `name`. If the parameter is not found and `ena_def_val` is True, it returns the provided `def_val`. Otherwise, it raises a KeyError. Parameters ---------- name : str The name of the parameter in the configuration. def_val : any, optional The default value to return if the parameter is not found. Defaults to None. ena_def_val : bool, optional If True, enables the default value mechanism. Defaults to False. Returns ------- any The suggested, fixed, or default value for the parameter. """ try: param_config = self.params_cfg[name] except KeyError: if ena_def_val: #? If default value is enabled, log and return it. self.hparams[name] = def_val return def_val else: #? Otherwise, raise the error. raise KeyError(f'There is no parameter "{name}" in the Optuna parameters config.') #? If the config for this parameter is a dictionary, it's an Optuna suggestion. if isinstance(param_config, (dict, DictConfig)): #? Make a copy to avoid modifying the original config. kwargs = dict(param_config) try: val_type = OptunaType(kwargs.pop('type')) except KeyError: raise ValueError(f"The configuration for parameter '{name}' is missing a 'type' key.") #? --- Alias Support: 'choices' -> 'vals' --- if "choices" in kwargs and "vals" not in kwargs: kwargs["vals"] = kwargs.pop("choices") #? --- Safety: Filter kwargs to only allowed arguments --- #? This prevents Pydantic validation errors if the config contains extra metadata. allowed_args = {"low", "high", "step", "vals"} filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_args} val = optuna_suggest(self.trial, val_type, name=name, **filtered_kwargs) else: #? Otherwise, it's a constant value. val = param_config #? Log the suggested hyperparameter in a dictionary for easy access. self.hparams[name] = val return val
@validate_call(config=dict(arbitrary_types_allowed=True)) def optimize_preproc_pipeline( opt_cfg: OptunaConfig, pipeline_configs: t.Dict, ) -> t.Dict: """ Optimizes hyperparameters within a preprocessing pipeline configuration. This function searches for a special `_OPTIMIZE_` tag in the pipeline configuration and replaces it with a value suggested by Optuna. Parameters ---------- opt_cfg : OptunaConfig The OptunaConfig instance for suggesting parameters. pipeline_configs : dict The preprocessing pipeline configuration dictionary. Returns ------- dict A new dictionary with the `_OPTIMIZE_` tags replaced by suggested values. """ #? Create a deep copy to avoid modifying the original configuration object. processed_configs = copy.deepcopy( OmegaConf.to_container(pipeline_configs, resolve=True) ) for preproc_name, preproc_details in processed_configs.items(): for arg_group, arg_values in preproc_details.items(): if arg_values is not None: for param_name, param_value in arg_values.items(): if param_value == '_OPTIMIZE_': #? Construct the full key name expected in the Optuna config. optuna_key = f"preproc-{preproc_name}-{param_name}" suggested_val = opt_cfg.suggest_param(optuna_key) arg_values[param_name] = suggested_val return processed_configs