Source code for gunz_ml.schemas.mlflow_cfg

import mlflow
from omegaconf.errors import ConfigKeyError
from omegaconf import DictConfig, OmegaConf
from ..consts import OptunaType
from ._base_cfg import BaseConfig

[docs] class MLFlowConfig(BaseConfig): """ Configuration wrapper for MLFlow runs. This class handles parameter suggestion based on an MLFlow run's recorded parameters. Parameters ---------- exp : mlflow.entities.Run The MLFlow run object containing experiment data. params_cfg : dict Configuration specifying parameter types and options. """ def __init__( self, exp:mlflow.entities.Run, params_cfg:dict ): self.exp = exp self.params_cfg = params_cfg self.hparams = {}
[docs] def suggest_param(self, key: str ): """ Suggest a value for a parameter based on the MLFlow run data. Parameters ---------- key : str The key identifying the parameter to suggest. Returns ------- Any The suggested parameter value. The type depends on the configuration (e.g., int, float, str). Raises ------ KeyError If the key is not found in the parameter configuration. RuntimeError If an unknown parameter type is encountered. """ try: vals = self.params_cfg[key] except (ConfigKeyError, KeyError): raise KeyError(f"Key not found:{key}") #? If vals is optuna optimize style, retrieve value from mlflow.Run #? dict is allowed for resolved hydra config if isinstance(vals, (dict, DictConfig)): val_type = OptunaType(vals['type'])#? Enumerate value val = self.exp.data.params[key] if val_type == OptunaType.INT: val = int(val) elif val_type in [OptunaType.FLOAT, OptunaType.LOG]: val = float(val) #? Do nothing as the value is categorical with all possible type elif val_type == OptunaType.CATEGORICAL: pass else: raise RuntimeError( "Should never be reached because the data type is already checked." ) else: val = vals self.hparams[key] = val return vals