gunz_ml.integrations package
Submodules
gunz_ml.integrations.hydra module
- gunz_ml.integrations.hydra.init_hydra_and_check_config(cfg: omegaconf.DictConfig, script_name: str | None = None, allow_unresolved_keys: bool = False, create_run_dir: bool = True) omegaconf.DictConfig[source]
Initializes Hydra configuration and performs sanity checks.
- Parameters:
cfg (DictConfig) – The configuration object provided by Hydra.
script_name (str, optional) – The name of the script being run, for logging purposes.
allow_unresolved_keys (bool, optional) – If False, raises an error if any ‘???’ values exist in the config. Defaults to False.
create_run_dir (bool, optional) – If True, ensures the Hydra run/output directory exists. Defaults to True.
- Returns:
The validated Hydra configuration.
- Return type:
DictConfig
- Raises:
RuntimeError – If unresolved keys are found and allow_unresolved_keys is False.
gunz_ml.integrations.lightning module
Provides helper functions and constants for PyTorch Lightning integration.
This module includes utilities for managing common PyTorch Lightning warnings and defines standardized status enums for logging purposes.
- class gunz_ml.integrations.lightning.FinalizeStatus(value)[source]
Bases:
StrEnumEnumeration for the final status of a run or trial.
- FAILED = 'failed'
- FINISHED = 'finished'
- SUCCESS = 'success'
- gunz_ml.integrations.lightning.ignore_pl_warnings(dataloader_num_workers: bool = True, slurm_srun: bool = True, mixed_precision: bool = True)[source]
Suppresses common, often noisy, warnings from PyTorch Lightning globally.
- Parameters:
dataloader_num_workers (bool, optional) – If True, suppresses the warning about using a small number of workers in the DataLoader. Defaults to True.
slurm_srun (bool, optional) – If True, suppresses the warning about the srun command being available on the system. Defaults to True.
mixed_precision (bool, optional) – If True, suppresses the historical usage warning for 16-bit mixed precision. Defaults to True.
- gunz_ml.integrations.lightning.suppress_pl_warnings(dataloader_num_workers: bool = True, slurm_srun: bool = True, mixed_precision: bool = True)[source]
A context manager to temporarily suppress common PyTorch Lightning warnings.
- Parameters:
dataloader_num_workers (bool, optional) – If True, suppresses the warning about using a small number of workers in the DataLoader. Defaults to True.
slurm_srun (bool, optional) – If True, suppresses the warning about the srun command being available on the system. Defaults to True.
mixed_precision (bool, optional) – If True, suppresses the historical usage warning for 16-bit mixed precision. Defaults to True.
gunz_ml.integrations.mlflow module
- gunz_ml.integrations.mlflow.load_mlf_exp(tracking_uri: str, exp_name: str | None = None, run_name: str | None = None, search_experiments_kwargs: dict | None = None, search_runs_kwargs: dict | None = None, num_runs_returned: int | None = None, ret_client: bool = False, ret_runs_as_dataframe: bool = False, ret_run_as_dict: bool = False) Tuple[mlflow.tracking.MlflowClient, List[mlflow.entities.run.Run]][source]
Loads MLflow experiments and runs based on the provided criteria.
Notes
If pareto_cond is provided, the function will compute the Pareto front of the runs.
If num_runs_returned is set to None, all runs will be returned.
If num_runs_returned is set to a positive integer, only the specified number of runs will be returned.
If ret_best_pareto is True, only the Pareto optimal runs will be returned.
If ret_client is True, the MLflow client will be returned along with the runs.
- Parameters:
tracking_uri (str) – The URI to the MLflow server.
search_experiments_kwargs (dict, optional) – Keyword arguments for the search_experiments method of the MLflow client.
search_runs_kwargs (dict, optional) – Keyword arguments for the search_runs method of the MLflow client. Defaults to None.
pareto_cond (dict, optional) – A dictionary describing metrics/attributes and whether the value must be sorted in ascending fashion. If set to None, no Pareto front computation will be performed. Defaults to None.
num_runs_returned (int, optional) – The number of MLflow runs to return. If set to 0, all experiments will be returned. Defaults to None.
ret_best_pareto (bool, optional) – If True, returns only the experiments that are Pareto optimal. Defaults to True.
ret_client (bool, optional) – If True, returns the MLflow client along with the runs. Defaults to False.
- Returns:
A tuple containing the MLflow client and the selected runs.
- Return type:
t.Tuple[mlflow.tracking.MlflowClient, ]
Examples
Example for pareto_cond:
pareto_cond = { "metrics.loss1": True, # Means lower is better "metrics.loss2": False, # Means higher is better }
- gunz_ml.integrations.mlflow.load_numpy_from_mlf_run(client: mlflow.tracking.client.MlflowClient, mlf_run: mlflow.entities.run.Run, rel_artifact_path: str) ndarray[source]
Loads a numpy object from an MLflow run.
Notes
The function uses a temporary directory to download and load the artifact.
The MLFLOW_TRACKING_URI environment variable is set to the client’s tracking URI during the download process.
- Parameters:
client (mlflow.tracking.client.MlflowClient) – The MLflow client object.
mlf_run (mlflow.entities.run.Run) – The MLflow run object containing the artifact.
rel_artifact_path (str) – A path relative to the root directory of MLflow Runs containing the artifacts to download.
- Returns:
The loaded numpy object.
- Return type:
np.ndarray
Examples
- gunz_ml.integrations.mlflow.mlf_run_to_dict(run: mlflow.entities.run.Run, add_metric: bool = False, add_params: bool = False, add_tags: bool = False) dict[source]
Converts an MLflow run object to a dictionary.
Notes
If add_metric, add_params, or add_tags are set to False, the corresponding data will not be included in the dictionary.
- Parameters:
run (mlflow.entities.run.Run) – The MLflow run object to convert.
add_metric (bool, optional) – Whether to include metrics in the dictionary. Defaults to False.
add_params (bool, optional) – Whether to include parameters in the dictionary. Defaults to False.
add_tags (bool, optional) – Whether to include tags in the dictionary. Defaults to False.
- Returns:
A dictionary containing the metrics, parameters, and tags of the run.
- Return type:
Examples
- gunz_ml.integrations.mlflow.retrieve_mlf_exps(tracking_uri: str, exp_name: str | None = None, filter_string_list: List[str] | None = None) List[mlflow.entities.experiment.Experiment][source]
Retrieves MLflow experiments based on the provided experiment name or filter strings.
Notes
Either exp_name or filter_string_list must be provided.
If both exp_name and filter_string_list are provided, only exp_name will be used.
If filter_string_list is provided, the filter strings are combined with “AND” logic.
- Parameters:
- Returns:
A list of MLflow Experiment objects that match the provided criteria.
- Return type:
List[mlflow.entities.experiment.Experiment]
Examples
- gunz_ml.integrations.mlflow.retrieve_mlf_runs(tracking_uri: str, exp_ids: List[int], run_name: str | None = None, filter_string_list: List[str] | None = None) List[mlflow.entities.run.Run][source]
Retrieves MLflow runs based on the provided tracking URI, experiment IDs, and optional filters.
Notes
If run_name is provided, it will be added to the filter string.
If filter_string_list is not provided, it defaults to an empty list.
The function can return up to 10,000 runs. If more runs are needed, the max_results parameter should be adjusted.
- Parameters:
tracking_uri (str) – The URI of the MLflow tracking server.
exp_ids (List[int]) – A list of experiment IDs to search within.
run_name (Optional[str], optional) – The name of the run to filter by. Defaults to None.
filter_string_list (Optional[List[str]], optional) – A list of additional filter strings to apply. Defaults to None.
- Returns:
A list of MLflow run objects that match the search criteria.
- Return type:
List[mlflow.entities.run.Run]
Examples
- gunz_ml.integrations.mlflow.safe_set_experiment(experiment_name: str, tracking_uri: str | None = None, artifact_location: str | None = None) str[source]
Safely sets or creates an MLflow experiment, handling race conditions.
This function should be used during the “Initialization” phase of a study to ensure the experiment exists before launching parallel workers.
gunz_ml.integrations.optuna module
Provides helper functions for working with Optuna studies and configurations.
gunz_ml.integrations.timm module
gunz_ml.integrations.toml module
Provides helper functions for working with TOML configurations.
This module includes utilities for loading TOML files using the built-in tomllib and for converting them into OmegaConf DictConfig objects.
- gunz_ml.integrations.toml.load_toml(path: str | Path) dict[str, Any][source]
Loads a TOML configuration file into a standard Python dictionary.
- Parameters:
path (t.Union[str, Path]) – The path to the TOML file.
- Returns:
The configuration loaded from the TOML file.
- Return type:
- Raises:
FileNotFoundError – If the file does not exist.
tomllib.TOMLDecodeError – If the file is not a valid TOML file.
- gunz_ml.integrations.toml.load_toml_as_dictconfig(path: str | Path) omegaconf.DictConfig[source]
Loads a TOML configuration file and converts it to an OmegaConf DictConfig.
This is useful for integrating TOML-based configs with existing Hydra/OmegaConf-based codebases.
- Parameters:
path (t.Union[str, Path]) – The path to the TOML file.
- Returns:
The configuration as a DictConfig object.
- Return type:
DictConfig