# -*- coding: utf-8 -*-
"""
Juno Tracking SDK for MLflow and Optuna.
Provides a high-level interface for searching, ranking, and retrieving
experiment and trial data from the tracking backends.
"""
#? Metadata
__author__ = "Gemini CLI"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import re
import typing as t
from dataclasses import dataclass
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import mlflow
import optuna
import pandas as pd
from loguru import logger
# =============================================================================
# CLASSES
# =============================================================================
[docs]
@dataclass
class RunInfo:
run_id: str
experiment_id: str
name: str
metrics: t.Dict[str, float]
params: t.Dict[str, t.Any]
tags: t.Dict[str, str]
[docs]
class TrackingManager:
"""Unified interface for MLflow and Optuna discovery."""
def __init__(
self,
#? --- Connection ---
mlflow_uri: str = "http://juno.tnt.uni-hannover.de:5200",
optuna_db: str = "mysql+pymysql://optuna:optuna@juno:3311/optuna",
):
self.mlflow_uri = mlflow_uri
self.optuna_db = optuna_db
#? Initialize MLflow
mlflow.set_tracking_uri(self.mlflow_uri)
self.client = mlflow.tracking.MlflowClient()
[docs]
def find_experiment(self, name_pattern: str) -> t.List[mlflow.entities.Experiment]:
"""Finds MLflow experiments matching a regex pattern."""
exps = self.client.search_experiments()
pattern = re.compile(name_pattern, re.IGNORECASE)
return [e for e in exps if pattern.search(e.name)]
[docs]
def get_best_run(
self,
experiment_name: str,
metric_name: str = "metric-spearmann_r",
maximize: bool = True
) -> t.Optional[RunInfo]:
"""Returns the single best run for an experiment based on a metric."""
exp = self.client.get_experiment_by_name(experiment_name)
if not exp:
logger.warning(f"Experiment {experiment_name} not found.")
return None
direction = "DESC" if maximize else "ASC"
runs = self.client.search_runs(
experiment_ids=[exp.experiment_id],
order_by=[f"metrics.`{metric_name}` {direction}"],
max_results=1
)
if not runs:
return None
r = runs[0]
return RunInfo(
run_id=r.info.run_id,
experiment_id=r.info.experiment_id,
name=r.data.tags.get("mlflow.runName", "unnamed"),
metrics=r.data.metrics,
params=r.data.params,
tags=r.data.tags
)
[docs]
def get_best_run_id(
self,
experiment_name: str,
metric_name: str = "metric-spearmann_r",
maximize: bool = True
) -> t.Optional[str]:
"""High-level query to get the ID of the best run."""
run = self.get_best_run(experiment_name, metric_name, maximize)
return run.run_id if run else None
[docs]
def search_by_metric(
self,
experiment_name: str,
metric_name: str,
threshold: float,
mode: t.Literal[">", "<", ">=", "<="] = ">=",
max_results: int = 100
) -> pd.DataFrame:
"""Searches runs based on a metric threshold."""
filter_string = f"metrics.`{metric_name}` {mode} {threshold}"
return self.search_runs(experiment_name, filter_string, max_results)
[docs]
def search_runs(
self,
experiment_name: str,
filter_string: t.Optional[str] = None,
max_results: int = 100
) -> pd.DataFrame:
"""Searches runs and returns a pandas DataFrame."""
return mlflow.search_runs(
experiment_names=[experiment_name],
filter_string=filter_string,
max_results=max_results
)
[docs]
def list_studies(self, pattern: t.Optional[str] = None) -> t.List[optuna.study.StudySummary]:
"""Lists Optuna studies matching an optional pattern."""
summaries = optuna.get_all_study_summaries(storage=self.optuna_db)
if pattern:
p = re.compile(pattern, re.IGNORECASE)
summaries = [s for s in summaries if p.search(s.study_name)]
return summaries