Source code for gunz_ml.management.tracking

# -*- coding: utf-8 -*-
"""
Juno Tracking SDK for MLflow and Optuna.

Provides a high-level interface for searching, ranking, and retrieving 
experiment and trial data from the tracking backends.
"""

#? Metadata
__author__ = "Gemini CLI"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import re
import typing as t
from dataclasses import dataclass

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import mlflow
import optuna
import pandas as pd
from loguru import logger

# =============================================================================
# CLASSES
# =============================================================================

[docs] @dataclass class RunInfo: run_id: str experiment_id: str name: str metrics: t.Dict[str, float] params: t.Dict[str, t.Any] tags: t.Dict[str, str]
[docs] class TrackingManager: """Unified interface for MLflow and Optuna discovery.""" def __init__( self, #? --- Connection --- mlflow_uri: str = "http://juno.tnt.uni-hannover.de:5200", optuna_db: str = "mysql+pymysql://optuna:optuna@juno:3311/optuna", ): self.mlflow_uri = mlflow_uri self.optuna_db = optuna_db #? Initialize MLflow mlflow.set_tracking_uri(self.mlflow_uri) self.client = mlflow.tracking.MlflowClient()
[docs] def find_experiment(self, name_pattern: str) -> t.List[mlflow.entities.Experiment]: """Finds MLflow experiments matching a regex pattern.""" exps = self.client.search_experiments() pattern = re.compile(name_pattern, re.IGNORECASE) return [e for e in exps if pattern.search(e.name)]
[docs] def get_best_run( self, experiment_name: str, metric_name: str = "metric-spearmann_r", maximize: bool = True ) -> t.Optional[RunInfo]: """Returns the single best run for an experiment based on a metric.""" exp = self.client.get_experiment_by_name(experiment_name) if not exp: logger.warning(f"Experiment {experiment_name} not found.") return None direction = "DESC" if maximize else "ASC" runs = self.client.search_runs( experiment_ids=[exp.experiment_id], order_by=[f"metrics.`{metric_name}` {direction}"], max_results=1 ) if not runs: return None r = runs[0] return RunInfo( run_id=r.info.run_id, experiment_id=r.info.experiment_id, name=r.data.tags.get("mlflow.runName", "unnamed"), metrics=r.data.metrics, params=r.data.params, tags=r.data.tags )
[docs] def get_best_run_id( self, experiment_name: str, metric_name: str = "metric-spearmann_r", maximize: bool = True ) -> t.Optional[str]: """High-level query to get the ID of the best run.""" run = self.get_best_run(experiment_name, metric_name, maximize) return run.run_id if run else None
[docs] def search_by_metric( self, experiment_name: str, metric_name: str, threshold: float, mode: t.Literal[">", "<", ">=", "<="] = ">=", max_results: int = 100 ) -> pd.DataFrame: """Searches runs based on a metric threshold.""" filter_string = f"metrics.`{metric_name}` {mode} {threshold}" return self.search_runs(experiment_name, filter_string, max_results)
[docs] def search_runs( self, experiment_name: str, filter_string: t.Optional[str] = None, max_results: int = 100 ) -> pd.DataFrame: """Searches runs and returns a pandas DataFrame.""" return mlflow.search_runs( experiment_names=[experiment_name], filter_string=filter_string, max_results=max_results )
[docs] def list_studies(self, pattern: t.Optional[str] = None) -> t.List[optuna.study.StudySummary]: """Lists Optuna studies matching an optional pattern.""" summaries = optuna.get_all_study_summaries(storage=self.optuna_db) if pattern: p = re.compile(pattern, re.IGNORECASE) summaries = [s for s in summaries if p.search(s.study_name)] return summaries