Source code for gunz_ml.schemas

"""
Configuration factory for initializing the correct trial configuration wrapper.
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.1.2"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import optuna
import mlflow

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ._base_cfg import BaseConfig
from .run_cfg import RunConfig
from .mlflow_cfg import MLFlowConfig
from .optuna_cfg import OptunaConfig

# =============================================================================
# FACTORY FUNCTION
# =============================================================================
[docs] def init_trial_cfg( trial: t.Optional[optuna.trial.Trial], exp_params: t.Dict[str, t.Any], ) -> BaseConfig: """ Initializes the appropriate trial configuration wrapper. This function acts as a factory, inspecting the `trial` object to determine whether to configure for a single run or an Optuna trial. Parameters ---------- trial : optuna.trial.Trial, optional - If an `optuna.trial.Trial` object is provided, the function configures for an optimization trial. - If `None` is provided, the function configures for a single, non-optimized run. exp_params : dict A dictionary of parameters. - For an optimization, this should be the hyperparameter search space. - For a single run, this should be the fixed parameter values. Returns ------- BaseConfig An initialized instance of either `RunConfig` or `OptunaConfig`. """ if trial is None or isinstance(trial, dict): #? This is a single, non-optimized run. `exp_params` contains the fixed values. return RunConfig( params_cfg=exp_params, ) elif isinstance(trial, mlflow.entities.run.Run): trial_cfg = MLFlowConfig(trial, exp_params) elif isinstance(trial, optuna.trial._trial.Trial): trial_cfg = OptunaConfig( trial, exp_params ) else: raise ValueError(f"Unsupported type for trial object: {type(trial)}") return trial_cfg