import os
import typing as t
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from hydra.core.hydra_config import HydraConfig
[docs]
def init_hydra_and_check_config(
cfg: DictConfig,
script_name: t.Optional[str] = None,
allow_unresolved_keys: bool = False,
create_run_dir: bool = True,
) -> DictConfig:
"""
Initializes Hydra configuration and performs sanity checks.
Parameters
----------
cfg : DictConfig
The configuration object provided by Hydra.
script_name : str, optional
The name of the script being run, for logging purposes.
allow_unresolved_keys : bool, optional
If False, raises an error if any '???' values exist in the config.
Defaults to False.
create_run_dir : bool, optional
If True, ensures the Hydra run/output directory exists.
Defaults to True.
Returns
-------
DictConfig
The validated Hydra configuration.
Raises
------
RuntimeError
If unresolved keys are found and `allow_unresolved_keys` is False.
"""
#? Get the hydra config part if it exists
try:
hydra_cfg = HydraConfig.get()
except Exception:
hydra_cfg = None
#? Create output directory if requested
if create_run_dir and hydra_cfg:
run_dir = hydra_cfg.runtime.output_dir
if run_dir:
os.makedirs(run_dir, exist_ok=True)
#? Check for any unresolved ('???') keys in the configuration.
if not allow_unresolved_keys:
missing_keys = OmegaConf.missing_keys(cfg)
if missing_keys:
raise RuntimeError(f"The following keys are missing from the config: {missing_keys}")
return cfg
[docs]
def resolve_cfg(
cfg: t.Optional[DictConfig],
default_to_empty_dict: bool = False,
) -> t.Optional[dict]:
"""
Resolves an OmegaConf DictConfig object into a standard Python dictionary.
Parameters
----------
cfg : DictConfig, optional
The configuration object to resolve.
default_to_empty_dict : bool, optional
If True, returns an empty dict if `cfg` is None. If False,
returns None. Defaults to False.
Returns
-------
dict, optional
The resolved configuration as a dictionary, or None.
"""
if cfg is None:
return {} if default_to_empty_dict else None
return OmegaConf.to_container(cfg, resolve=True)