from loguru import logger
import os
import typing as t
import tempfile as tmp
import numpy as np
import pandas as pd
import mlflow
[docs]
def mlf_run_to_dict(
run: mlflow.entities.run.Run,
add_metric: bool = False,
add_params: bool = False,
add_tags: bool = False,
) -> dict:
"""
Converts an MLflow run object to a dictionary.
Notes
-----
- If `add_metric`, `add_params`, or `add_tags` are set to `False`,
the corresponding data will not be included in the dictionary.
Parameters
----------
run : mlflow.entities.run.Run
The MLflow run object to convert.
add_metric : bool, optional
Whether to include metrics in the dictionary. Defaults to False.
add_params : bool, optional
Whether to include parameters in the dictionary. Defaults to False.
add_tags : bool, optional
Whether to include tags in the dictionary. Defaults to False.
Returns
-------
dict
A dictionary containing the metrics, parameters, and tags of the run.
Examples
--------
"""
rec = dict()
if add_metric:
rec.update(run.data.metrics)
if add_params:
rec.update(run.data.params)
if add_tags:
rec.update(run.data.tags)
return rec
[docs]
def retrieve_mlf_exps(
tracking_uri: str,
exp_name: t.Optional[str] = None,
filter_string_list: t.Optional[t.List[str]] = None,
) -> t.List[mlflow.entities.experiment.Experiment]:
"""
Retrieves MLflow experiments based on the provided experiment name or filter strings.
Notes
-----
- Either `exp_name` or `filter_string_list` must be provided.
- If both `exp_name` and `filter_string_list` are provided, only `exp_name` will be used.
- If `filter_string_list` is provided, the filter strings are combined with "AND" logic.
Parameters
----------
tracking_uri : str
The URI of the MLflow tracking server.
exp_name : Optional[str], optional
The name of the experiment to retrieve. Defaults to None.
filter_string_list : Optional[List[str]], optional
A list of filter strings to apply to the experiment search. Defaults to None.
Returns
-------
List[mlflow.entities.experiment.Experiment]
A list of MLflow Experiment objects that match the provided criteria.
Examples
--------
"""
assert exp_name is not None or filter_string_list is not None
client = mlflow.tracking.MlflowClient(
tracking_uri=tracking_uri
)
if filter_string_list is None:
filter_string_list = list()
if exp_name is not None:
mlf_exp = client.get_experiment_by_name(exp_name)
mlf_exps = [mlf_exp]
else:
if len(filter_string_list):
filter_str = " AND ".join(filter_string_list)
else:
filter_str = ""
mlf_exps = client.search_experiments(
filter_string=filter_str
)
return mlf_exps
[docs]
def retrieve_mlf_runs(
tracking_uri: str,
exp_ids: t.List[int],
run_name: t.Optional[str] = None,
filter_string_list: t.Optional[t.List[str]] = None,
) -> t.List[mlflow.entities.run.Run]:
"""
Retrieves MLflow runs based on the provided tracking URI, experiment IDs, and optional filters.
Notes
-----
- If `run_name` is provided, it will be added to the filter string.
- If `filter_string_list` is not provided, it defaults to an empty list.
- The function can return up to 10,000 runs. If more runs are needed, the `max_results` parameter should be adjusted.
Parameters
----------
tracking_uri : str
The URI of the MLflow tracking server.
exp_ids : List[int]
A list of experiment IDs to search within.
run_name : Optional[str], optional
The name of the run to filter by. Defaults to None.
filter_string_list : Optional[List[str]], optional
A list of additional filter strings to apply. Defaults to None.
Returns
-------
List[mlflow.entities.run.Run]
A list of MLflow run objects that match the search criteria.
Examples
--------
"""
client = mlflow.tracking.MlflowClient(
tracking_uri=tracking_uri
)
if filter_string_list is None:
filter_string_list = list()
if run_name is not None:
tmp_filter_str = f"attribute.run_name = '{run_name}'"
filter_string_list.append(tmp_filter_str)
if len(filter_string_list):
filter_str = " AND ".join(filter_string_list)
else:
filter_str = ""
mlf_runs = []
page_token = None
while True:
new_runs = client.search_runs(
experiment_ids=exp_ids,
filter_string=filter_str,
page_token=page_token,
)
mlf_runs.extend(new_runs)
page_token = new_runs.token
if not page_token:
break
return mlf_runs
[docs]
def safe_set_experiment(
experiment_name: str,
tracking_uri: t.Optional[str] = None,
artifact_location: t.Optional[str] = None,
) -> str:
"""
Safely sets or creates an MLflow experiment, handling race conditions.
This function should be used during the "Initialization" phase of a study
to ensure the experiment exists before launching parallel workers.
Parameters
----------
experiment_name : str
The name of the experiment.
tracking_uri : Optional[str], optional
The MLflow tracking URI.
artifact_location : Optional[str], optional
The location to store artifacts.
Returns
-------
str
The experiment ID.
"""
if tracking_uri:
mlflow.set_tracking_uri(tracking_uri)
client = mlflow.tracking.MlflowClient()
try:
#? Attempt to get existing experiment
exp = client.get_experiment_by_name(experiment_name)
if exp is not None:
logger.debug(f"Found existing MLflow experiment: {experiment_name}")
return exp.experiment_id
#? Attempt to create if not found
logger.info(f"Creating new MLflow experiment: {experiment_name}")
return client.create_experiment(
name=experiment_name,
artifact_location=artifact_location
)
except Exception as e:
#? Handle race condition where another process created it between get and create
if "RESOURCE_ALREADY_EXISTS" in str(e) or "already exists" in str(e).lower():
logger.warning(f"Experiment {experiment_name} was created concurrently.")
exp = client.get_experiment_by_name(experiment_name)
return exp.experiment_id
raise e
[docs]
def load_mlf_exp(
tracking_uri: str,
exp_name: t.Optional[str] = None,
run_name: t.Optional[str] = None,
search_experiments_kwargs: t.Optional[dict] = None,
search_runs_kwargs: t.Optional[dict] = None,
# pareto_cond: t.Optional[dict] = None,
num_runs_returned: t.Optional[int] = None,
# ret_best_pareto: bool = True,
ret_client: bool = False,
ret_runs_as_dataframe: bool = False,
ret_run_as_dict: bool = False,
) -> t.Tuple[mlflow.tracking.MlflowClient, t.List[mlflow.entities.run.Run]]:
"""
Loads MLflow experiments and runs based on the provided criteria.
Notes
-----
- If `pareto_cond` is provided, the function will compute the Pareto front of the runs.
- If `num_runs_returned` is set to `None`, all runs will be returned.
- If `num_runs_returned` is set to a positive integer, only the specified number of runs will be returned.
- If `ret_best_pareto` is `True`, only the Pareto optimal runs will be returned.
- If `ret_client` is `True`, the MLflow client will be returned along with the runs.
Parameters
----------
tracking_uri : str
The URI to the MLflow server.
search_experiments_kwargs : dict, optional
Keyword arguments for the `search_experiments` method of the MLflow client.
search_runs_kwargs : dict, optional
Keyword arguments for the `search_runs` method of the MLflow client. Defaults to None.
pareto_cond : dict, optional
A dictionary describing metrics/attributes and whether the value must be sorted in ascending fashion.
If set to None, no Pareto front computation will be performed. Defaults to None.
num_runs_returned : int, optional
The number of MLflow runs to return. If set to 0, all experiments will be returned. Defaults to None.
ret_best_pareto : bool, optional
If True, returns only the experiments that are Pareto optimal. Defaults to True.
ret_client : bool, optional
If True, returns the MLflow client along with the runs. Defaults to False.
Returns
-------
t.Tuple[mlflow.tracking.MlflowClient, ]
A tuple containing the MLflow client and the selected runs.
Examples
--------
Example for `pareto_cond`::
pareto_cond = {
"metrics.loss1": True, # Means lower is better
"metrics.loss2": False, # Means higher is better
}
"""
assert num_runs_returned is None or (isinstance(num_runs_returned, int) and num_runs_returned > 0), \
f"Invalid value for num_runs_returned: {num_runs_returned}"
#? Default values for commented out arguments to avoid NameError in legacy code
pareto_cond = None
ret_best_pareto = True
client = mlflow.tracking.MlflowClient(
tracking_uri=tracking_uri
)
if exp_name is not None:
mlf_exp = client.get_experiment_by_name(exp_name)
exp_ids = [mlf_exp.experiment_id]
else:
if search_experiments_kwargs is None:
search_experiments_kwargs = dict()
else:
search_experiments_kwargs = search_experiments_kwargs.copy()
#? Special case for filter_string: allow list of strings or single string
filter_string = search_experiments_kwargs.get('filter_string')
if isinstance(filter_string, list):
search_experiments_kwargs['filter_string'] = " AND ".join(filter_string)
mlf_exps = client.search_experiments(
**search_experiments_kwargs
)
exp_ids = [mlf_exp.experiment_id for mlf_exp in mlf_exps]
if search_runs_kwargs is None:
search_runs_kwargs = dict()
else:
if 'filter_string' in search_runs_kwargs:
search_runs_kwargs['filter_string'] = " AND ".join(
search_runs_kwargs['filter_string']
)
mlf_runs = []
page_token = None
while True:
new_runs = client.search_runs(
experiment_ids=exp_ids,
**search_runs_kwargs,
page_token=page_token,
)
mlf_runs.extend(new_runs)
page_token = new_runs.token
if num_runs_returned is not None and len(mlf_runs) >= num_runs_returned:
break
if not page_token:
break
assert len(mlf_runs), "No experiment found!"
# if pareto_cond is not None:
# recs = []
# for ith_mlf_run in mlf_runs:
# rec = dict()
# rec['run_id'] = ith_mlf_run.info.run_id
# rec['mlf_run_obj'] = ith_mlf_run
# for kwargs_key in pareto_cond.keys():
# key_type, key_name = kwargs_key.split(".")
# ith_mlf_data = dict(ith_mlf_run.data)
# rec[kwargs_key] = ith_mlf_data[key_type][key_name.replace('`', '')]
# recs.append(rec)
# pareto_df = pd.DataFrame.from_records(recs)
# for kwargs_key, is_lower_better in pareto_cond.items():
# if not is_lower_better:
# pareto_df[kwargs_key] = -1*pareto_df[kwargs_key]
# #? First 2 columns are run_id and mlf_run_obj
# costs = pareto_df.iloc[:, 2:].to_numpy()
# #? Make sure that the values of costs only occupies the 1st quadrant
# costs -= costs.min(axis=0)
# if ret_best_pareto:
# pareto_idx = np.argmin(np.prod(costs, axis=1))
# mlf_runs = pareto_df.iloc[pareto_idx]["mlf_run_obj"]
# return client, mlf_runs
# else:
# pareto_mask = is_pareto_efficient(costs)
# mlf_runs = pareto_df.iloc[pareto_mask]["mlf_run_obj"].tolist()
# return client, mlf_runs
# #? No pareto front computation
# else:
if num_runs_returned is None:
#? Do nothing
mlf_runs = mlf_runs
elif num_runs_returned > 0:
mlf_runs = mlf_runs[:num_runs_returned]
else:
raise ValueError(f"Invalid num_runs_returned value: {num_runs_returned}")
if run_name is not None:
runs_df = pd.DataFrame([mlf_run_to_dict(mlf_run) for mlf_run in mlf_runs])
mask = runs_df['mlflow.runName'] == run_name
assert mask.any(), \
f"No run with name \"{run_name}\" in study \"{exp_name}\" is found!"
assert mask.sum() == 1, \
f"Multiple runs with name \"{run_name}\" in study \"{exp_name}\" are found!"
run_cfg = runs_df[mask].iloc[0]
run_cfg = run_cfg.filter(regex='params')
run_cfg.index = [index_name.replace("params.", "") for index_name in run_cfg.index]
run_cfg = run_cfg.to_dict()
if ret_client:
return mlf_runs, client
else:
return mlf_runs
else:
if num_runs_returned is None:
#? Do nothing
mlf_runs = mlf_runs
elif num_runs_returned > 0:
mlf_runs = mlf_runs[:num_runs_returned]
else:
raise ValueError(f"Invalid num_runs_returned value: {num_runs_returned}")
if ret_client:
return mlf_runs, client
else:
return mlf_runs
[docs]
def load_numpy_from_mlf_run(
client: mlflow.tracking.client.MlflowClient,
mlf_run: mlflow.entities.run.Run,
rel_artifact_path: str
) -> np.ndarray:
"""
Loads a numpy object from an MLflow run.
Notes
-----
- The function uses a temporary directory to download and load the artifact.
- The `MLFLOW_TRACKING_URI` environment variable is set to the client's tracking URI during the download process.
Parameters
----------
client : mlflow.tracking.client.MlflowClient
The MLflow client object.
mlf_run : mlflow.entities.run.Run
The MLflow run object containing the artifact.
rel_artifact_path : str
A path relative to the root directory of MLflow Runs containing the artifacts to download.
Returns
-------
np.ndarray
The loaded numpy object.
Examples
--------
"""
with tmp.TemporaryDirectory() as d:
###! DO NOT CHANGE ###
os.environ["MLFLOW_TRACKING_URI"] = client.tracking_uri
###! DO NOT CHANGE ###
mlflow.artifacts.download_artifacts(
run_id=mlf_run.info.run_id,
tracking_uri=client.tracking_uri,
# artifact_path=f"{mlf_run.info._artifact_uri}/{rel_artifact_path}",
artifact_path=f"{rel_artifact_path}",
dst_path=d
)
np_obj = np.load(
os.path.join(d, rel_artifact_path)
)
return np_obj