gunz_ml.integrations package

Submodules

gunz_ml.integrations.hydra module

gunz_ml.integrations.hydra.init_hydra_and_check_config(cfg: omegaconf.DictConfig, script_name: str | None = None, allow_unresolved_keys: bool = False, create_run_dir: bool = True) omegaconf.DictConfig[source]

Initializes Hydra configuration and performs sanity checks.

Parameters:
  • cfg (DictConfig) – The configuration object provided by Hydra.

  • script_name (str, optional) – The name of the script being run, for logging purposes.

  • allow_unresolved_keys (bool, optional) – If False, raises an error if any ‘???’ values exist in the config. Defaults to False.

  • create_run_dir (bool, optional) – If True, ensures the Hydra run/output directory exists. Defaults to True.

Returns:

The validated Hydra configuration.

Return type:

DictConfig

Raises:

RuntimeError – If unresolved keys are found and allow_unresolved_keys is False.

gunz_ml.integrations.hydra.resolve_cfg(cfg: omegaconf.DictConfig | None, default_to_empty_dict: bool = False) dict | None[source]

Resolves an OmegaConf DictConfig object into a standard Python dictionary.

Parameters:
  • cfg (DictConfig, optional) – The configuration object to resolve.

  • default_to_empty_dict (bool, optional) – If True, returns an empty dict if cfg is None. If False, returns None. Defaults to False.

Returns:

The resolved configuration as a dictionary, or None.

Return type:

dict, optional

gunz_ml.integrations.lightning module

Provides helper functions and constants for PyTorch Lightning integration.

This module includes utilities for managing common PyTorch Lightning warnings and defines standardized status enums for logging purposes.

class gunz_ml.integrations.lightning.FinalizeStatus(value)[source]

Bases: StrEnum

Enumeration for the final status of a run or trial.

FAILED = 'failed'
FINISHED = 'finished'
SUCCESS = 'success'
gunz_ml.integrations.lightning.ignore_pl_warnings(dataloader_num_workers: bool = True, slurm_srun: bool = True, mixed_precision: bool = True)[source]

Suppresses common, often noisy, warnings from PyTorch Lightning globally.

Parameters:
  • dataloader_num_workers (bool, optional) – If True, suppresses the warning about using a small number of workers in the DataLoader. Defaults to True.

  • slurm_srun (bool, optional) – If True, suppresses the warning about the srun command being available on the system. Defaults to True.

  • mixed_precision (bool, optional) – If True, suppresses the historical usage warning for 16-bit mixed precision. Defaults to True.

gunz_ml.integrations.lightning.suppress_pl_warnings(dataloader_num_workers: bool = True, slurm_srun: bool = True, mixed_precision: bool = True)[source]

A context manager to temporarily suppress common PyTorch Lightning warnings.

Parameters:
  • dataloader_num_workers (bool, optional) – If True, suppresses the warning about using a small number of workers in the DataLoader. Defaults to True.

  • slurm_srun (bool, optional) – If True, suppresses the warning about the srun command being available on the system. Defaults to True.

  • mixed_precision (bool, optional) – If True, suppresses the historical usage warning for 16-bit mixed precision. Defaults to True.

gunz_ml.integrations.mlflow module

gunz_ml.integrations.mlflow.load_mlf_exp(tracking_uri: str, exp_name: str | None = None, run_name: str | None = None, search_experiments_kwargs: dict | None = None, search_runs_kwargs: dict | None = None, num_runs_returned: int | None = None, ret_client: bool = False, ret_runs_as_dataframe: bool = False, ret_run_as_dict: bool = False) Tuple[mlflow.tracking.MlflowClient, List[mlflow.entities.run.Run]][source]

Loads MLflow experiments and runs based on the provided criteria.

Notes

  • If pareto_cond is provided, the function will compute the Pareto front of the runs.

  • If num_runs_returned is set to None, all runs will be returned.

  • If num_runs_returned is set to a positive integer, only the specified number of runs will be returned.

  • If ret_best_pareto is True, only the Pareto optimal runs will be returned.

  • If ret_client is True, the MLflow client will be returned along with the runs.

Parameters:
  • tracking_uri (str) – The URI to the MLflow server.

  • search_experiments_kwargs (dict, optional) – Keyword arguments for the search_experiments method of the MLflow client.

  • search_runs_kwargs (dict, optional) – Keyword arguments for the search_runs method of the MLflow client. Defaults to None.

  • pareto_cond (dict, optional) – A dictionary describing metrics/attributes and whether the value must be sorted in ascending fashion. If set to None, no Pareto front computation will be performed. Defaults to None.

  • num_runs_returned (int, optional) – The number of MLflow runs to return. If set to 0, all experiments will be returned. Defaults to None.

  • ret_best_pareto (bool, optional) – If True, returns only the experiments that are Pareto optimal. Defaults to True.

  • ret_client (bool, optional) – If True, returns the MLflow client along with the runs. Defaults to False.

Returns:

A tuple containing the MLflow client and the selected runs.

Return type:

t.Tuple[mlflow.tracking.MlflowClient, ]

Examples

Example for pareto_cond:

pareto_cond = {
    "metrics.loss1": True,  # Means lower is better
    "metrics.loss2": False,  # Means higher is better
}
gunz_ml.integrations.mlflow.load_numpy_from_mlf_run(client: mlflow.tracking.client.MlflowClient, mlf_run: mlflow.entities.run.Run, rel_artifact_path: str) ndarray[source]

Loads a numpy object from an MLflow run.

Notes

  • The function uses a temporary directory to download and load the artifact.

  • The MLFLOW_TRACKING_URI environment variable is set to the client’s tracking URI during the download process.

Parameters:
  • client (mlflow.tracking.client.MlflowClient) – The MLflow client object.

  • mlf_run (mlflow.entities.run.Run) – The MLflow run object containing the artifact.

  • rel_artifact_path (str) – A path relative to the root directory of MLflow Runs containing the artifacts to download.

Returns:

The loaded numpy object.

Return type:

np.ndarray

Examples

gunz_ml.integrations.mlflow.mlf_run_to_dict(run: mlflow.entities.run.Run, add_metric: bool = False, add_params: bool = False, add_tags: bool = False) dict[source]

Converts an MLflow run object to a dictionary.

Notes

  • If add_metric, add_params, or add_tags are set to False, the corresponding data will not be included in the dictionary.

Parameters:
  • run (mlflow.entities.run.Run) – The MLflow run object to convert.

  • add_metric (bool, optional) – Whether to include metrics in the dictionary. Defaults to False.

  • add_params (bool, optional) – Whether to include parameters in the dictionary. Defaults to False.

  • add_tags (bool, optional) – Whether to include tags in the dictionary. Defaults to False.

Returns:

A dictionary containing the metrics, parameters, and tags of the run.

Return type:

dict

Examples

gunz_ml.integrations.mlflow.retrieve_mlf_exps(tracking_uri: str, exp_name: str | None = None, filter_string_list: List[str] | None = None) List[mlflow.entities.experiment.Experiment][source]

Retrieves MLflow experiments based on the provided experiment name or filter strings.

Notes

  • Either exp_name or filter_string_list must be provided.

  • If both exp_name and filter_string_list are provided, only exp_name will be used.

  • If filter_string_list is provided, the filter strings are combined with “AND” logic.

Parameters:
  • tracking_uri (str) – The URI of the MLflow tracking server.

  • exp_name (Optional[str], optional) – The name of the experiment to retrieve. Defaults to None.

  • filter_string_list (Optional[List[str]], optional) – A list of filter strings to apply to the experiment search. Defaults to None.

Returns:

A list of MLflow Experiment objects that match the provided criteria.

Return type:

List[mlflow.entities.experiment.Experiment]

Examples

gunz_ml.integrations.mlflow.retrieve_mlf_runs(tracking_uri: str, exp_ids: List[int], run_name: str | None = None, filter_string_list: List[str] | None = None) List[mlflow.entities.run.Run][source]

Retrieves MLflow runs based on the provided tracking URI, experiment IDs, and optional filters.

Notes

  • If run_name is provided, it will be added to the filter string.

  • If filter_string_list is not provided, it defaults to an empty list.

  • The function can return up to 10,000 runs. If more runs are needed, the max_results parameter should be adjusted.

Parameters:
  • tracking_uri (str) – The URI of the MLflow tracking server.

  • exp_ids (List[int]) – A list of experiment IDs to search within.

  • run_name (Optional[str], optional) – The name of the run to filter by. Defaults to None.

  • filter_string_list (Optional[List[str]], optional) – A list of additional filter strings to apply. Defaults to None.

Returns:

A list of MLflow run objects that match the search criteria.

Return type:

List[mlflow.entities.run.Run]

Examples

gunz_ml.integrations.mlflow.safe_set_experiment(experiment_name: str, tracking_uri: str | None = None, artifact_location: str | None = None) str[source]

Safely sets or creates an MLflow experiment, handling race conditions.

This function should be used during the “Initialization” phase of a study to ensure the experiment exists before launching parallel workers.

Parameters:
  • experiment_name (str) – The name of the experiment.

  • tracking_uri (Optional[str], optional) – The MLflow tracking URI.

  • artifact_location (Optional[str], optional) – The location to store artifacts.

Returns:

The experiment ID.

Return type:

str

gunz_ml.integrations.optuna module

Provides helper functions for working with Optuna studies and configurations.

gunz_ml.integrations.optuna.check_pruning_thresholds(metrics, pruning_cfg, key, trial_number=None, seed=None)[source]
gunz_ml.integrations.optuna.create_study(study_kwargs, objs_cfg, dry_run=False)[source]
gunz_ml.integrations.optuna.load_study(study_kwargs)[source]
gunz_ml.integrations.optuna.optimize_cfg(cfg, trial_cfg)[source]
gunz_ml.integrations.optuna.prepare_parameter_data(optuna_params_cfg, param_df)[source]
gunz_ml.integrations.optuna.sanitize_metric_dataframe(metric_df, metric_names)[source]
gunz_ml.integrations.optuna.sanitize_metric_names(metric_names)[source]
gunz_ml.integrations.optuna.sanitize_name(name)[source]
gunz_ml.integrations.optuna.select_objective_values(evaluation_results, objectives_spec)[source]
gunz_ml.integrations.optuna.trials_to_dataframes(trials, metric_names=None, target_state=optuna.trial.TrialState.COMPLETE)[source]

gunz_ml.integrations.timm module

gunz_ml.integrations.toml module

Provides helper functions for working with TOML configurations.

This module includes utilities for loading TOML files using the built-in tomllib and for converting them into OmegaConf DictConfig objects.

gunz_ml.integrations.toml.load_toml(path: str | Path) dict[str, Any][source]

Loads a TOML configuration file into a standard Python dictionary.

Parameters:

path (t.Union[str, Path]) – The path to the TOML file.

Returns:

The configuration loaded from the TOML file.

Return type:

dict

Raises:
gunz_ml.integrations.toml.load_toml_as_dictconfig(path: str | Path) omegaconf.DictConfig[source]

Loads a TOML configuration file and converts it to an OmegaConf DictConfig.

This is useful for integrating TOML-based configs with existing Hydra/OmegaConf-based codebases.

Parameters:

path (t.Union[str, Path]) – The path to the TOML file.

Returns:

The configuration as a DictConfig object.

Return type:

DictConfig