Source code for gunz_ml.integrations.mlflow

from loguru import logger
import os
import typing as t
import tempfile as tmp
import numpy as np
import pandas as pd
import mlflow

[docs] def mlf_run_to_dict( run: mlflow.entities.run.Run, add_metric: bool = False, add_params: bool = False, add_tags: bool = False, ) -> dict: """ Converts an MLflow run object to a dictionary. Notes ----- - If `add_metric`, `add_params`, or `add_tags` are set to `False`, the corresponding data will not be included in the dictionary. Parameters ---------- run : mlflow.entities.run.Run The MLflow run object to convert. add_metric : bool, optional Whether to include metrics in the dictionary. Defaults to False. add_params : bool, optional Whether to include parameters in the dictionary. Defaults to False. add_tags : bool, optional Whether to include tags in the dictionary. Defaults to False. Returns ------- dict A dictionary containing the metrics, parameters, and tags of the run. Examples -------- """ rec = dict() if add_metric: rec.update(run.data.metrics) if add_params: rec.update(run.data.params) if add_tags: rec.update(run.data.tags) return rec
[docs] def retrieve_mlf_exps( tracking_uri: str, exp_name: t.Optional[str] = None, filter_string_list: t.Optional[t.List[str]] = None, ) -> t.List[mlflow.entities.experiment.Experiment]: """ Retrieves MLflow experiments based on the provided experiment name or filter strings. Notes ----- - Either `exp_name` or `filter_string_list` must be provided. - If both `exp_name` and `filter_string_list` are provided, only `exp_name` will be used. - If `filter_string_list` is provided, the filter strings are combined with "AND" logic. Parameters ---------- tracking_uri : str The URI of the MLflow tracking server. exp_name : Optional[str], optional The name of the experiment to retrieve. Defaults to None. filter_string_list : Optional[List[str]], optional A list of filter strings to apply to the experiment search. Defaults to None. Returns ------- List[mlflow.entities.experiment.Experiment] A list of MLflow Experiment objects that match the provided criteria. Examples -------- """ assert exp_name is not None or filter_string_list is not None client = mlflow.tracking.MlflowClient( tracking_uri=tracking_uri ) if filter_string_list is None: filter_string_list = list() if exp_name is not None: mlf_exp = client.get_experiment_by_name(exp_name) mlf_exps = [mlf_exp] else: if len(filter_string_list): filter_str = " AND ".join(filter_string_list) else: filter_str = "" mlf_exps = client.search_experiments( filter_string=filter_str ) return mlf_exps
[docs] def retrieve_mlf_runs( tracking_uri: str, exp_ids: t.List[int], run_name: t.Optional[str] = None, filter_string_list: t.Optional[t.List[str]] = None, ) -> t.List[mlflow.entities.run.Run]: """ Retrieves MLflow runs based on the provided tracking URI, experiment IDs, and optional filters. Notes ----- - If `run_name` is provided, it will be added to the filter string. - If `filter_string_list` is not provided, it defaults to an empty list. - The function can return up to 10,000 runs. If more runs are needed, the `max_results` parameter should be adjusted. Parameters ---------- tracking_uri : str The URI of the MLflow tracking server. exp_ids : List[int] A list of experiment IDs to search within. run_name : Optional[str], optional The name of the run to filter by. Defaults to None. filter_string_list : Optional[List[str]], optional A list of additional filter strings to apply. Defaults to None. Returns ------- List[mlflow.entities.run.Run] A list of MLflow run objects that match the search criteria. Examples -------- """ client = mlflow.tracking.MlflowClient( tracking_uri=tracking_uri ) if filter_string_list is None: filter_string_list = list() if run_name is not None: tmp_filter_str = f"attribute.run_name = '{run_name}'" filter_string_list.append(tmp_filter_str) if len(filter_string_list): filter_str = " AND ".join(filter_string_list) else: filter_str = "" mlf_runs = [] page_token = None while True: new_runs = client.search_runs( experiment_ids=exp_ids, filter_string=filter_str, page_token=page_token, ) mlf_runs.extend(new_runs) page_token = new_runs.token if not page_token: break return mlf_runs
[docs] def safe_set_experiment( experiment_name: str, tracking_uri: t.Optional[str] = None, artifact_location: t.Optional[str] = None, ) -> str: """ Safely sets or creates an MLflow experiment, handling race conditions. This function should be used during the "Initialization" phase of a study to ensure the experiment exists before launching parallel workers. Parameters ---------- experiment_name : str The name of the experiment. tracking_uri : Optional[str], optional The MLflow tracking URI. artifact_location : Optional[str], optional The location to store artifacts. Returns ------- str The experiment ID. """ if tracking_uri: mlflow.set_tracking_uri(tracking_uri) client = mlflow.tracking.MlflowClient() try: #? Attempt to get existing experiment exp = client.get_experiment_by_name(experiment_name) if exp is not None: logger.debug(f"Found existing MLflow experiment: {experiment_name}") return exp.experiment_id #? Attempt to create if not found logger.info(f"Creating new MLflow experiment: {experiment_name}") return client.create_experiment( name=experiment_name, artifact_location=artifact_location ) except Exception as e: #? Handle race condition where another process created it between get and create if "RESOURCE_ALREADY_EXISTS" in str(e) or "already exists" in str(e).lower(): logger.warning(f"Experiment {experiment_name} was created concurrently.") exp = client.get_experiment_by_name(experiment_name) return exp.experiment_id raise e
[docs] def load_mlf_exp( tracking_uri: str, exp_name: t.Optional[str] = None, run_name: t.Optional[str] = None, search_experiments_kwargs: t.Optional[dict] = None, search_runs_kwargs: t.Optional[dict] = None, # pareto_cond: t.Optional[dict] = None, num_runs_returned: t.Optional[int] = None, # ret_best_pareto: bool = True, ret_client: bool = False, ret_runs_as_dataframe: bool = False, ret_run_as_dict: bool = False, ) -> t.Tuple[mlflow.tracking.MlflowClient, t.List[mlflow.entities.run.Run]]: """ Loads MLflow experiments and runs based on the provided criteria. Notes ----- - If `pareto_cond` is provided, the function will compute the Pareto front of the runs. - If `num_runs_returned` is set to `None`, all runs will be returned. - If `num_runs_returned` is set to a positive integer, only the specified number of runs will be returned. - If `ret_best_pareto` is `True`, only the Pareto optimal runs will be returned. - If `ret_client` is `True`, the MLflow client will be returned along with the runs. Parameters ---------- tracking_uri : str The URI to the MLflow server. search_experiments_kwargs : dict, optional Keyword arguments for the `search_experiments` method of the MLflow client. search_runs_kwargs : dict, optional Keyword arguments for the `search_runs` method of the MLflow client. Defaults to None. pareto_cond : dict, optional A dictionary describing metrics/attributes and whether the value must be sorted in ascending fashion. If set to None, no Pareto front computation will be performed. Defaults to None. num_runs_returned : int, optional The number of MLflow runs to return. If set to 0, all experiments will be returned. Defaults to None. ret_best_pareto : bool, optional If True, returns only the experiments that are Pareto optimal. Defaults to True. ret_client : bool, optional If True, returns the MLflow client along with the runs. Defaults to False. Returns ------- t.Tuple[mlflow.tracking.MlflowClient, ] A tuple containing the MLflow client and the selected runs. Examples -------- Example for `pareto_cond`:: pareto_cond = { "metrics.loss1": True, # Means lower is better "metrics.loss2": False, # Means higher is better } """ assert num_runs_returned is None or (isinstance(num_runs_returned, int) and num_runs_returned > 0), \ f"Invalid value for num_runs_returned: {num_runs_returned}" #? Default values for commented out arguments to avoid NameError in legacy code pareto_cond = None ret_best_pareto = True client = mlflow.tracking.MlflowClient( tracking_uri=tracking_uri ) if exp_name is not None: mlf_exp = client.get_experiment_by_name(exp_name) exp_ids = [mlf_exp.experiment_id] else: if search_experiments_kwargs is None: search_experiments_kwargs = dict() else: search_experiments_kwargs = search_experiments_kwargs.copy() #? Special case for filter_string: allow list of strings or single string filter_string = search_experiments_kwargs.get('filter_string') if isinstance(filter_string, list): search_experiments_kwargs['filter_string'] = " AND ".join(filter_string) mlf_exps = client.search_experiments( **search_experiments_kwargs ) exp_ids = [mlf_exp.experiment_id for mlf_exp in mlf_exps] if search_runs_kwargs is None: search_runs_kwargs = dict() else: if 'filter_string' in search_runs_kwargs: search_runs_kwargs['filter_string'] = " AND ".join( search_runs_kwargs['filter_string'] ) mlf_runs = [] page_token = None while True: new_runs = client.search_runs( experiment_ids=exp_ids, **search_runs_kwargs, page_token=page_token, ) mlf_runs.extend(new_runs) page_token = new_runs.token if num_runs_returned is not None and len(mlf_runs) >= num_runs_returned: break if not page_token: break assert len(mlf_runs), "No experiment found!" # if pareto_cond is not None: # recs = [] # for ith_mlf_run in mlf_runs: # rec = dict() # rec['run_id'] = ith_mlf_run.info.run_id # rec['mlf_run_obj'] = ith_mlf_run # for kwargs_key in pareto_cond.keys(): # key_type, key_name = kwargs_key.split(".") # ith_mlf_data = dict(ith_mlf_run.data) # rec[kwargs_key] = ith_mlf_data[key_type][key_name.replace('`', '')] # recs.append(rec) # pareto_df = pd.DataFrame.from_records(recs) # for kwargs_key, is_lower_better in pareto_cond.items(): # if not is_lower_better: # pareto_df[kwargs_key] = -1*pareto_df[kwargs_key] # #? First 2 columns are run_id and mlf_run_obj # costs = pareto_df.iloc[:, 2:].to_numpy() # #? Make sure that the values of costs only occupies the 1st quadrant # costs -= costs.min(axis=0) # if ret_best_pareto: # pareto_idx = np.argmin(np.prod(costs, axis=1)) # mlf_runs = pareto_df.iloc[pareto_idx]["mlf_run_obj"] # return client, mlf_runs # else: # pareto_mask = is_pareto_efficient(costs) # mlf_runs = pareto_df.iloc[pareto_mask]["mlf_run_obj"].tolist() # return client, mlf_runs # #? No pareto front computation # else: if num_runs_returned is None: #? Do nothing mlf_runs = mlf_runs elif num_runs_returned > 0: mlf_runs = mlf_runs[:num_runs_returned] else: raise ValueError(f"Invalid num_runs_returned value: {num_runs_returned}") if run_name is not None: runs_df = pd.DataFrame([mlf_run_to_dict(mlf_run) for mlf_run in mlf_runs]) mask = runs_df['mlflow.runName'] == run_name assert mask.any(), \ f"No run with name \"{run_name}\" in study \"{exp_name}\" is found!" assert mask.sum() == 1, \ f"Multiple runs with name \"{run_name}\" in study \"{exp_name}\" are found!" run_cfg = runs_df[mask].iloc[0] run_cfg = run_cfg.filter(regex='params') run_cfg.index = [index_name.replace("params.", "") for index_name in run_cfg.index] run_cfg = run_cfg.to_dict() if ret_client: return mlf_runs, client else: return mlf_runs else: if num_runs_returned is None: #? Do nothing mlf_runs = mlf_runs elif num_runs_returned > 0: mlf_runs = mlf_runs[:num_runs_returned] else: raise ValueError(f"Invalid num_runs_returned value: {num_runs_returned}") if ret_client: return mlf_runs, client else: return mlf_runs
[docs] def load_numpy_from_mlf_run( client: mlflow.tracking.client.MlflowClient, mlf_run: mlflow.entities.run.Run, rel_artifact_path: str ) -> np.ndarray: """ Loads a numpy object from an MLflow run. Notes ----- - The function uses a temporary directory to download and load the artifact. - The `MLFLOW_TRACKING_URI` environment variable is set to the client's tracking URI during the download process. Parameters ---------- client : mlflow.tracking.client.MlflowClient The MLflow client object. mlf_run : mlflow.entities.run.Run The MLflow run object containing the artifact. rel_artifact_path : str A path relative to the root directory of MLflow Runs containing the artifacts to download. Returns ------- np.ndarray The loaded numpy object. Examples -------- """ with tmp.TemporaryDirectory() as d: ###! DO NOT CHANGE ### os.environ["MLFLOW_TRACKING_URI"] = client.tracking_uri ###! DO NOT CHANGE ### mlflow.artifacts.download_artifacts( run_id=mlf_run.info.run_id, tracking_uri=client.tracking_uri, # artifact_path=f"{mlf_run.info._artifact_uri}/{rel_artifact_path}", artifact_path=f"{rel_artifact_path}", dst_path=d ) np_obj = np.load( os.path.join(d, rel_artifact_path) ) return np_obj