"""
Configuration factory for initializing the correct trial configuration wrapper.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.1.2"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import optuna
import mlflow
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ._base_cfg import BaseConfig
from .run_cfg import RunConfig
from .mlflow_cfg import MLFlowConfig
from .optuna_cfg import OptunaConfig
# =============================================================================
# FACTORY FUNCTION
# =============================================================================
[docs]
def init_trial_cfg(
trial: t.Optional[optuna.trial.Trial],
exp_params: t.Dict[str, t.Any],
) -> BaseConfig:
"""
Initializes the appropriate trial configuration wrapper.
This function acts as a factory, inspecting the `trial` object to
determine whether to configure for a single run or an Optuna trial.
Parameters
----------
trial : optuna.trial.Trial, optional
- If an `optuna.trial.Trial` object is provided, the function
configures for an optimization trial.
- If `None` is provided, the function configures for a single,
non-optimized run.
exp_params : dict
A dictionary of parameters.
- For an optimization, this should be the hyperparameter search space.
- For a single run, this should be the fixed parameter values.
Returns
-------
BaseConfig
An initialized instance of either `RunConfig` or `OptunaConfig`.
"""
if trial is None or isinstance(trial, dict):
#? This is a single, non-optimized run. `exp_params` contains the fixed values.
return RunConfig(
params_cfg=exp_params,
)
elif isinstance(trial, mlflow.entities.run.Run):
trial_cfg = MLFlowConfig(trial, exp_params)
elif isinstance(trial, optuna.trial._trial.Trial):
trial_cfg = OptunaConfig(
trial,
exp_params
)
else:
raise ValueError(f"Unsupported type for trial object: {type(trial)}")
return trial_cfg