import mlflow
from omegaconf.errors import ConfigKeyError
from omegaconf import DictConfig, OmegaConf
from ..consts import OptunaType
from ._base_cfg import BaseConfig
[docs]
class MLFlowConfig(BaseConfig):
"""
Configuration wrapper for MLFlow runs.
This class handles parameter suggestion based on an MLFlow run's recorded parameters.
Parameters
----------
exp : mlflow.entities.Run
The MLFlow run object containing experiment data.
params_cfg : dict
Configuration specifying parameter types and options.
"""
def __init__(
self,
exp:mlflow.entities.Run,
params_cfg:dict
):
self.exp = exp
self.params_cfg = params_cfg
self.hparams = {}
[docs]
def suggest_param(self,
key: str
):
"""
Suggest a value for a parameter based on the MLFlow run data.
Parameters
----------
key : str
The key identifying the parameter to suggest.
Returns
-------
Any
The suggested parameter value. The type depends on the configuration
(e.g., int, float, str).
Raises
------
KeyError
If the key is not found in the parameter configuration.
RuntimeError
If an unknown parameter type is encountered.
"""
try:
vals = self.params_cfg[key]
except (ConfigKeyError, KeyError):
raise KeyError(f"Key not found:{key}")
#? If vals is optuna optimize style, retrieve value from mlflow.Run
#? dict is allowed for resolved hydra config
if isinstance(vals, (dict, DictConfig)):
val_type = OptunaType(vals['type'])#? Enumerate value
val = self.exp.data.params[key]
if val_type == OptunaType.INT:
val = int(val)
elif val_type in [OptunaType.FLOAT, OptunaType.LOG]:
val = float(val)
#? Do nothing as the value is categorical with all possible type
elif val_type == OptunaType.CATEGORICAL:
pass
else:
raise RuntimeError(
"Should never be reached because the data type is already checked."
)
else:
val = vals
self.hparams[key] = val
return vals