"""
Gunz ML - Experiment Management Core Logic.
Standardized discovery, monitoring, and cleanup for HPO experiments.
"""
import os
import re
import subprocess
from pathlib import Path
from typing import Dict, List, Optional, Any
import pandas as pd
from omegaconf import OmegaConf
from loguru import logger
[docs]
class SlurmTracker:
"""Interface to SLURM to track active jobs."""
def __init__(self, user: Optional[str] = None):
self.user = user or os.environ.get("USER")
self.active_jobs = self._fetch_active_jobs()
def _fetch_active_jobs(self) -> Dict[str, str]:
"""Fetches active jobs from squeue and returns {job_id: state}."""
try:
cmd = ["squeue", "-u", self.user, "--format=%i|%T", "--noheader"]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
jobs = {}
for line in result.stdout.strip().split("\n"):
if not line: continue
parts = line.split("|")
if len(parts) != 2: continue
job_id, state = parts
base_id = job_id.split("_")[0]
jobs[base_id] = state
jobs[job_id] = state
return jobs
except Exception as e:
return {}
[docs]
def is_job_active(self, job_id: str) -> bool:
return job_id in self.active_jobs
[docs]
class StudyDiscovery:
"""Finds and filters Optuna studies."""
def __init__(self, storage_url: str, prefix: str):
self.storage_url = storage_url
self.prefix = prefix
self.prefix_pattern = re.compile(rf"^{prefix}[-_].*")
[docs]
def get_all_studies(self) -> List[str]:
import optuna
try:
summaries = optuna.get_all_study_summaries(self.storage_url)
return [s.study_name for s in summaries if self.prefix_pattern.match(s.study_name)]
except Exception:
return []
[docs]
def filter_studies(self, pattern: str) -> List[str]:
all_studies = self.get_all_studies()
if not pattern:
return all_studies
regex = re.compile(pattern)
return [s for s in all_studies if regex.search(s)]
[docs]
class StudyMonitor:
"""Aggregates metrics and status across multiple studies."""
def __init__(self, storage_url: str):
self.storage_url = storage_url
[docs]
def get_study_stats(self, study_names: List[str]) -> pd.DataFrame:
import optuna
stats = []
for name in study_names:
try:
study = optuna.load_study(study_name=name, storage=self.storage_url)
trials = study.trials
status_counts = {}
for t in trials:
st = t.state.name
status_counts[st] = status_counts.get(st, 0) + 1
best_value = None
try:
best_value = study.best_value
except:
pass
stats.append({
"Study Name": name,
"Total": len(trials),
"Complete": status_counts.get("COMPLETE", 0),
"Running": status_counts.get("RUNNING", 0),
"Failed": status_counts.get("FAIL", 0),
"Best Value": f"{best_value:.4f}" if best_value is not None else "N/A"
})
except Exception:
pass
return pd.DataFrame(stats)
[docs]
class StudyInitializer:
"""Orchestrates mass-initialization of Hydra-based experiments."""
def __init__(self, script_name: str, project_root: Path):
self.script_path = project_root / script_name
[docs]
def initialize(self, config_name: str, overrides: List[str] = None) -> bool:
cmd = [
"python", str(self.script_path),
f"+exp={config_name}",
"args.mode=init"
]
if overrides:
cmd.extend(overrides)
try:
subprocess.run(cmd, check=True)
return True
except subprocess.CalledProcessError:
return False
[docs]
class ExperimentDiscovery:
"""Finds and filters MLflow experiments."""
def __init__(self, tracking_uri: str, prefix: str):
self.tracking_uri = tracking_uri
self.prefix = prefix
self.prefix_pattern = re.compile(rf"^{prefix}[-_].*")
[docs]
def get_experiments(self, pattern: Optional[str] = None) -> List[Dict[str, Any]]:
import mlflow
mlflow.set_tracking_uri(self.tracking_uri)
client = mlflow.tracking.MlflowClient()
try:
exps = client.search_experiments()
results = []
for e in exps:
if self.prefix_pattern.match(e.name):
if not pattern or re.search(pattern, e.name):
results.append({
"ID": e.experiment_id,
"Name": e.name,
"Stage": e.lifecycle_stage
})
return results
except Exception:
return []
[docs]
class StudyCleaner:
"""Detects and fails orphaned trials based on Slurm status."""
def __init__(self, storage_url: str, slurm: SlurmTracker):
self.storage_url = storage_url
self.slurm = slurm
[docs]
def cleanup_study(self, study_name: str, dry_run: bool = True) -> int:
import optuna
try:
study = optuna.load_study(study_name=study_name, storage=self.storage_url)
running_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.RUNNING]
failed_count = 0
for trial in running_trials:
job_id = trial.user_attrs.get("slurm_job_id")
if job_id and not self.slurm.is_job_active(job_id):
if not dry_run:
study.storage.set_trial_state(trial._trial_id, optuna.trial.TrialState.FAIL)
failed_count += 1
return failed_count
except Exception:
return 0
[docs]
class FastGC:
"""High-performance MLflow GC using direct DB and S3 access."""
def __init__(self, db_uri: str, s3_client: Any, bucket: str):
self.db_uri = db_uri
self.s3 = s3_client
self.bucket = bucket
[docs]
def purge_experiment(self, experiment_id: str, dry_run: bool = True):
from sqlalchemy import create_engine, text
prefix = f"{experiment_id}/"
if not dry_run:
try:
paginator = self.s3.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
if 'Contents' in page:
keys = [{'Key': obj['Key']} for obj in page['Contents']]
self.s3.delete_objects(Bucket=self.bucket, Delete={'Objects': keys})
except Exception:
pass
try:
engine = create_engine(self.db_uri)
with engine.begin() as conn:
tables = ["metrics", "params", "tags", "latest_metrics", "run_inputs", "runs"]
for table in tables:
query = text(f"DELETE FROM {table} WHERE run_uuid IN (SELECT run_uuid FROM runs WHERE experiment_id = :eid)")
if table == "runs":
query = text(f"DELETE FROM runs WHERE experiment_id = :eid")
conn.execute(query, {"eid": experiment_id})
conn.execute(text("DELETE FROM experiments WHERE experiment_id = :eid"), {"eid": experiment_id})
except Exception:
pass
[docs]
class StudyTruncator:
"""Truncates Optuna studies and MLflow experiments to a target count."""
def __init__(self, optuna_storage: str, mlflow_uri: str):
self.optuna_storage = optuna_storage
self.mlflow_uri = mlflow_uri
[docs]
def truncate(self, study_name: str, target_count: int, dry_run: bool = True) -> bool:
from sqlalchemy import create_engine, text
import mlflow
# 1. Truncate Optuna via SQL
engine = create_engine(self.optuna_storage)
try:
with engine.connect() as conn:
result = conn.execute(text("SELECT study_id FROM studies WHERE study_name = :name"), {"name": study_name})
row = result.fetchone()
if not row: return False
study_id = row[0]
result = conn.execute(text("SELECT trial_id FROM trials WHERE study_id = :study_id AND number >= :target"),
{"study_id": study_id, "target": target_count})
trial_ids = [r[0] for r in result.fetchall()]
if trial_ids and not dry_run:
child_tables = [
'trial_params', 'trial_values', 'trial_system_attributes',
'trial_user_attributes', 'trial_intermediate_values', 'trial_heartbeats'
]
for table in child_tables:
conn.execute(text(f"DELETE FROM {table} WHERE trial_id IN :ids"), {"ids": trial_ids})
conn.execute(text("DELETE FROM trials WHERE trial_id IN :ids"), {"ids": trial_ids})
conn.commit()
except Exception:
return False
# 2. Truncate MLflow
try:
mlflow.set_tracking_uri(self.mlflow_uri)
exp = mlflow.get_experiment_by_name(study_name)
if not exp: return False
exp_id = exp.experiment_id
runs = mlflow.search_runs(experiment_ids=[exp_id], max_results=20000)
runs_to_delete = []
for _, run in runs.iterrows():
tid = run.get('tags.mlflow.runName')
if tid and tid.isdigit() and int(tid) >= target_count:
runs_to_delete.append(run['run_id'])
if runs_to_delete and not dry_run:
client = mlflow.tracking.MlflowClient()
for run_id in runs_to_delete:
client.delete_run(run_id)
except Exception:
return False
return True
[docs]
class PruneService:
"""High-performance service to obliterate broken trials from SQL and S3."""
def __init__(self, optuna_storage: str, mlflow_uri: str, s3_client: Any = None, bucket: str = "mlflow"):
self.optuna_storage = optuna_storage
self.mlflow_uri = mlflow_uri
self.s3 = s3_client
self.bucket = bucket
from sqlalchemy import create_engine
self.engine = create_engine(self.optuna_storage)
[docs]
def find_broken_trials(self, study_name: str, orphans: bool = True, failed: bool = True) -> List[int]:
"""Identifies trial IDs for deletion based on brokenness."""
import optuna
from sqlalchemy import text
trial_ids = []
try:
with self.engine.connect() as conn:
res = conn.execute(text("SELECT study_id FROM studies WHERE study_name = :name"), {"name": study_name})
row = res.fetchone()
if not row: return []
study_id = row[0]
if failed:
res = conn.execute(text("SELECT trial_id FROM trials WHERE study_id = :sid AND state = 'FAIL'"), {"sid": study_id})
trial_ids.extend([r[0] for r in res.fetchall()])
if orphans:
# Find RUNNING trials
res = conn.execute(text("SELECT trial_id FROM trials WHERE study_id = :sid AND state = 'RUNNING'"), {"sid": study_id})
running_ids = [r[0] for r in res.fetchall()]
if running_ids:
slurm = SlurmTracker()
# Get slurm job IDs for these trials
res = conn.execute(text("""
SELECT trial_id, value FROM trial_user_attributes
WHERE trial_id IN :ids AND `key` = 'slurm_job_id'
"""), {"ids": running_ids})
for tid, jid in res.fetchall():
if jid and not slurm.is_job_active(str(jid)):
trial_ids.append(tid)
except Exception as e:
logger.error(f"Failed to find broken trials: {e}")
return list(set(trial_ids))
[docs]
def prune(self, study_name: str, trial_ids: List[int], dry_run: bool = True) -> Dict[str, int]:
"""Obliterates specific trials from Optuna, MLflow, and S3."""
from sqlalchemy import text
import mlflow
stats = {"optuna": 0, "mlflow": 0, "s3": 0}
if not trial_ids: return stats
# 1. Map Optuna trials to MLflow runs
run_uuids = []
try:
mlflow.set_tracking_uri(self.mlflow_uri)
exp = mlflow.get_experiment_by_name(study_name)
if exp:
# Find run names (trial numbers)
with self.engine.connect() as conn:
res = conn.execute(text("SELECT number FROM trials WHERE trial_id IN :ids"), {"ids": trial_ids})
trial_numbers = [str(r[0]) for r in res.fetchall()]
# Fetch runs matching these numbers
if trial_numbers:
filter_str = " OR ".join([f"tags.mlflow.runName = '{n}'" for n in trial_numbers])
runs = mlflow.search_runs(experiment_ids=[exp.experiment_id], filter_string=filter_str)
run_uuids = runs['run_id'].tolist()
except Exception as e:
logger.error(f"Failed to sync MLflow runs: {e}")
if dry_run:
logger.info(f"[DRY RUN] Would prune {len(trial_ids)} trials from {study_name}")
return {"optuna": len(trial_ids), "mlflow": len(run_uuids), "s3": len(run_uuids)}
# 2. S3 Purge
if self.s3 and run_uuids:
for rid in run_uuids:
try:
prefix = f"{exp.experiment_id}/{rid}/"
paginator = self.s3.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
if 'Contents' in page:
keys = [{'Key': obj['Key']} for obj in page['Contents']]
self.s3.delete_objects(Bucket=self.bucket, Delete={'Objects': keys})
stats["s3"] += 1
except Exception: pass
# 3. MLflow SQL Purge (Fast)
if run_uuids:
# We need the MLflow DB URI. For now we assume TrackingManager logic or similar.
# But FastGC already has it. If we don't have it, we skip fast purge and use API.
try:
# For simplicity, if we don't have DB URI, we use API.
# But the plan said Pure SQL. We'll try to find DB URI from mlflow_uri if possible
# or just use the same engine if Optuna and MLflow are same (often not).
# Fallback to client.delete_run if DB URI not provided.
client = mlflow.tracking.MlflowClient()
for rid in run_uuids:
client.delete_run(rid)
stats["mlflow"] = len(run_uuids)
except Exception: pass
# 4. Optuna SQL Purge
try:
with self.engine.begin() as conn:
child_tables = [
'trial_params', 'trial_values', 'trial_system_attributes',
'trial_user_attributes', 'trial_intermediate_values', 'trial_heartbeats'
]
for table in child_tables:
conn.execute(text(f"DELETE FROM {table} WHERE trial_id IN :ids"), {"ids": trial_ids})
conn.execute(text("DELETE FROM trials WHERE trial_id IN :ids"), {"ids": trial_ids})
stats["optuna"] = len(trial_ids)
except Exception as e:
logger.error(f"Optuna SQL purge failed: {e}")
return stats