"""
Provides configuration handling specifically for Optuna trials.
This module defines the `OptunaConfig` class, which wraps an Optuna Trial
object to suggest hyperparameters based on a provided configuration dictionary.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.2.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import copy
import typing as t
from enum import StrEnum
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import optuna
from optuna.study import StudyDirection
from omegaconf import DictConfig, OmegaConf
from pydantic import validate_call
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ._base_cfg import BaseConfig
from ..consts import OptunaType
[docs]
class ObjectiveDirection(StrEnum):
"""
Enumeration for optimization directions in Optuna.
Attributes
----------
MAXIMIZE : str
Direction for maximizing the objective.
MINIMIZE : str
Direction for minimizing the objective.
"""
MAXIMIZE = "maximize"
MINIMIZE = "minimize"
@classmethod
def _missing_(cls, value):
"""
Handles conversion of legacy or non-standard string values.
Parameters
----------
value : any
The value to convert to an ObjectiveDirection.
Returns
-------
ObjectiveDirection
The corresponding enum member.
Raises
------
ValueError
If the value cannot be converted.
"""
# Handle conversion from StudyDirection enum
if isinstance(value, StudyDirection):
if value == StudyDirection.MAXIMIZE:
return cls.MAXIMIZE
elif value == StudyDirection.MINIMIZE:
return cls.MINIMIZE
else:
raise ValueError(f"Cannot convert {value} to {cls.__name__}")
if not isinstance(value, str):
return super()._missing_(value)
value_lower = value.lower()
#? Handle general case-insensitive matching.
for member in cls:
if member.value == value_lower:
return member
#? If no match is found, raise an error.
valid_options = ", ".join(m.value for m in cls)
raise ValueError(f"'{value}' is not a valid {cls.__name__}. Please use one of: {valid_options}")
@validate_call(config=dict(arbitrary_types_allowed=True))
def optuna_suggest(
trial: optuna.Trial,
type_tag: OptunaType,
name: str,
low: t.Optional[t.Union[int, float]] = None,
high: t.Optional[t.Union[int, float]] = None,
step: t.Optional[t.Union[int, float]] = None,
vals: t.Optional[t.Sequence[t.Any]] = None,
) -> t.Any:
"""
Dispatches to the appropriate Optuna `suggest_*` method.
Parameters
----------
trial : optuna.Trial
The Optuna trial object to suggest a value from.
type_tag : OptunaType
An enum value that determines which `suggest_*` method to call.
name : str
The name of the hyperparameter to suggest.
low : int or float, optional
The lower bound for numerical suggestions. Defaults to None.
high : int or float, optional
The upper bound for numerical suggestions. Defaults to None.
step : int or float, optional
The step size for numerical suggestions. Defaults to None.
vals : sequence, optional
A sequence of choices for categorical suggestions. Defaults to None.
Returns
-------
any
The value suggested by the Optuna trial.
"""
if type_tag == OptunaType.INT:
#? Default step to 1 for integers if not specified.
if step is None:
step = 1
return trial.suggest_int(name, low, high, step=step)
elif type_tag == OptunaType.FLOAT:
return trial.suggest_float(name, low, high, step=step)
elif type_tag == OptunaType.LOGARITHM:
return trial.suggest_float(name, low, high, step=step, log=True)
elif type_tag == OptunaType.CATEGORICAL:
if vals is None:
raise ValueError("`vals` must be provided for categorical suggestions.")
return trial.suggest_categorical(name, vals)
else:
raise ValueError(f"Unsupported OptunaType: {type_tag}")
[docs]
class OptunaConfig(BaseConfig):
"""
A wrapper for an Optuna Trial object to streamline hyperparameter suggestion.
This class implements the `BaseConfig` interface to provide a consistent
way of accessing hyperparameters during an Optuna optimization study.
Attributes
----------
trial : optuna.Trial
The active Optuna trial object.
params_cfg : dict
The configuration dictionary defining the search space.
hparams : dict
A dictionary that logs the actual suggested hyperparameters for the trial.
"""
def __init__(
self,
trial: optuna.Trial,
params_cfg: dict,
):
"""
Initializes the OptunaConfig instance.
Parameters
----------
trial : optuna.Trial
The Optuna trial object for the current run.
params_cfg : dict
A dictionary defining the hyperparameter search space.
"""
self.trial = trial
self.params_cfg = params_cfg
self.hparams: t.Dict[str, t.Any] = {}
@validate_call(config=dict(arbitrary_types_allowed=True))
def suggest_param(
self,
name: str,
def_val: t.Any = None,
ena_def_val: bool = False,
) -> t.Any:
"""
Suggests a value for a hyperparameter based on the configuration.
This method reads the configuration for the given parameter `name`.
If the parameter is not found and `ena_def_val` is True, it returns
the provided `def_val`. Otherwise, it raises a KeyError.
Parameters
----------
name : str
The name of the parameter in the configuration.
def_val : any, optional
The default value to return if the parameter is not found.
Defaults to None.
ena_def_val : bool, optional
If True, enables the default value mechanism. Defaults to False.
Returns
-------
any
The suggested, fixed, or default value for the parameter.
"""
try:
param_config = self.params_cfg[name]
except KeyError:
if ena_def_val:
#? If default value is enabled, log and return it.
self.hparams[name] = def_val
return def_val
else:
#? Otherwise, raise the error.
raise KeyError(f'There is no parameter "{name}" in the Optuna parameters config.')
#? If the config for this parameter is a dictionary, it's an Optuna suggestion.
if isinstance(param_config, (dict, DictConfig)):
#? Make a copy to avoid modifying the original config.
kwargs = dict(param_config)
try:
val_type = OptunaType(kwargs.pop('type'))
except KeyError:
raise ValueError(f"The configuration for parameter '{name}' is missing a 'type' key.")
#? --- Alias Support: 'choices' -> 'vals' ---
if "choices" in kwargs and "vals" not in kwargs:
kwargs["vals"] = kwargs.pop("choices")
#? --- Safety: Filter kwargs to only allowed arguments ---
#? This prevents Pydantic validation errors if the config contains extra metadata.
allowed_args = {"low", "high", "step", "vals"}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_args}
val = optuna_suggest(self.trial, val_type, name=name, **filtered_kwargs)
else:
#? Otherwise, it's a constant value.
val = param_config
#? Log the suggested hyperparameter in a dictionary for easy access.
self.hparams[name] = val
return val
@validate_call(config=dict(arbitrary_types_allowed=True))
def optimize_preproc_pipeline(
opt_cfg: OptunaConfig,
pipeline_configs: t.Dict,
) -> t.Dict:
"""
Optimizes hyperparameters within a preprocessing pipeline configuration.
This function searches for a special `_OPTIMIZE_` tag in the pipeline
configuration and replaces it with a value suggested by Optuna.
Parameters
----------
opt_cfg : OptunaConfig
The OptunaConfig instance for suggesting parameters.
pipeline_configs : dict
The preprocessing pipeline configuration dictionary.
Returns
-------
dict
A new dictionary with the `_OPTIMIZE_` tags replaced by suggested values.
"""
#? Create a deep copy to avoid modifying the original configuration object.
processed_configs = copy.deepcopy(
OmegaConf.to_container(pipeline_configs, resolve=True)
)
for preproc_name, preproc_details in processed_configs.items():
for arg_group, arg_values in preproc_details.items():
if arg_values is not None:
for param_name, param_value in arg_values.items():
if param_value == '_OPTIMIZE_':
#? Construct the full key name expected in the Optuna config.
optuna_key = f"preproc-{preproc_name}-{param_name}"
suggested_val = opt_cfg.suggest_param(optuna_key)
arg_values[param_name] = suggested_val
return processed_configs