Source code for gunz_ml.management.core

"""
Gunz ML - Experiment Management Core Logic.
Standardized discovery, monitoring, and cleanup for HPO experiments.
"""
import os
import re
import subprocess
from pathlib import Path
from typing import Dict, List, Optional, Any

import pandas as pd
from omegaconf import OmegaConf
from loguru import logger

[docs] class ProjectMetadata: """Extracts project-level metadata from Hydra configurations.""" def __init__(self, start_path: Path): self.project_root = self._find_project_root(start_path) self.config_dir = self.project_root / "configs" self.base_exp_path = self.config_dir / "base_exp.yaml" self.prefix = self._extract_prefix() self.db_configs = self._resolve_db_configs() def _find_project_root(self, start_path: Path) -> Path: """Walks up to find the project root (containing configs/).""" current = start_path.resolve() for parent in [current] + list(current.parents): if (parent / "configs").exists(): return parent return current def _extract_prefix(self) -> str: """Reads study_prefix from base_exp.yaml.""" if not self.base_exp_path.exists(): return "GUNZ" try: conf = OmegaConf.load(self.base_exp_path) prefix = conf.get("study_prefix", "GUNZ") return str(prefix) except Exception as e: return "GUNZ" def _resolve_db_configs(self) -> Dict[str, Any]: """Locates and parses database configurations.""" # Priority: juno.yaml for remote HPO optuna_db_path = self.config_dir / "db" / "optuna" / "juno.yaml" db_configs = {} if optuna_db_path.exists(): conf = OmegaConf.load(optuna_db_path) db_configs['optuna'] = OmegaConf.to_container(conf.get("db", {}).get("optuna", {})) return db_configs
[docs] def get_optuna_storage_url(self) -> Optional[str]: """Constructs the Optuna storage URL.""" opt = self.db_configs.get('optuna') if not opt: return None return f"{opt['driver']}://{opt['user']}:{opt['pass']}@{opt['host']}:{opt['port']}/{opt['db_name']}"
[docs] class SlurmTracker: """Interface to SLURM to track active jobs.""" def __init__(self, user: Optional[str] = None): self.user = user or os.environ.get("USER") self.active_jobs = self._fetch_active_jobs() def _fetch_active_jobs(self) -> Dict[str, str]: """Fetches active jobs from squeue and returns {job_id: state}.""" try: cmd = ["squeue", "-u", self.user, "--format=%i|%T", "--noheader"] result = subprocess.run(cmd, capture_output=True, text=True, check=True) jobs = {} for line in result.stdout.strip().split("\n"): if not line: continue parts = line.split("|") if len(parts) != 2: continue job_id, state = parts base_id = job_id.split("_")[0] jobs[base_id] = state jobs[job_id] = state return jobs except Exception as e: return {}
[docs] def is_job_active(self, job_id: str) -> bool: return job_id in self.active_jobs
[docs] class StudyDiscovery: """Finds and filters Optuna studies.""" def __init__(self, storage_url: str, prefix: str): self.storage_url = storage_url self.prefix = prefix self.prefix_pattern = re.compile(rf"^{prefix}[-_].*")
[docs] def get_all_studies(self) -> List[str]: import optuna try: summaries = optuna.get_all_study_summaries(self.storage_url) return [s.study_name for s in summaries if self.prefix_pattern.match(s.study_name)] except Exception: return []
[docs] def filter_studies(self, pattern: str) -> List[str]: all_studies = self.get_all_studies() if not pattern: return all_studies regex = re.compile(pattern) return [s for s in all_studies if regex.search(s)]
[docs] class StudyMonitor: """Aggregates metrics and status across multiple studies.""" def __init__(self, storage_url: str): self.storage_url = storage_url
[docs] def get_study_stats(self, study_names: List[str]) -> pd.DataFrame: import optuna stats = [] for name in study_names: try: study = optuna.load_study(study_name=name, storage=self.storage_url) trials = study.trials status_counts = {} for t in trials: st = t.state.name status_counts[st] = status_counts.get(st, 0) + 1 best_value = None try: best_value = study.best_value except: pass stats.append({ "Study Name": name, "Total": len(trials), "Complete": status_counts.get("COMPLETE", 0), "Running": status_counts.get("RUNNING", 0), "Failed": status_counts.get("FAIL", 0), "Best Value": f"{best_value:.4f}" if best_value is not None else "N/A" }) except Exception: pass return pd.DataFrame(stats)
[docs] class StudyInitializer: """Orchestrates mass-initialization of Hydra-based experiments.""" def __init__(self, script_name: str, project_root: Path): self.script_path = project_root / script_name
[docs] def initialize(self, config_name: str, overrides: List[str] = None) -> bool: cmd = [ "python", str(self.script_path), f"+exp={config_name}", "args.mode=init" ] if overrides: cmd.extend(overrides) try: subprocess.run(cmd, check=True) return True except subprocess.CalledProcessError: return False
[docs] class ExperimentDiscovery: """Finds and filters MLflow experiments.""" def __init__(self, tracking_uri: str, prefix: str): self.tracking_uri = tracking_uri self.prefix = prefix self.prefix_pattern = re.compile(rf"^{prefix}[-_].*")
[docs] def get_experiments(self, pattern: Optional[str] = None) -> List[Dict[str, Any]]: import mlflow mlflow.set_tracking_uri(self.tracking_uri) client = mlflow.tracking.MlflowClient() try: exps = client.search_experiments() results = [] for e in exps: if self.prefix_pattern.match(e.name): if not pattern or re.search(pattern, e.name): results.append({ "ID": e.experiment_id, "Name": e.name, "Stage": e.lifecycle_stage }) return results except Exception: return []
[docs] class StudyCleaner: """Detects and fails orphaned trials based on Slurm status.""" def __init__(self, storage_url: str, slurm: SlurmTracker): self.storage_url = storage_url self.slurm = slurm
[docs] def cleanup_study(self, study_name: str, dry_run: bool = True) -> int: import optuna try: study = optuna.load_study(study_name=study_name, storage=self.storage_url) running_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.RUNNING] failed_count = 0 for trial in running_trials: job_id = trial.user_attrs.get("slurm_job_id") if job_id and not self.slurm.is_job_active(job_id): if not dry_run: study.storage.set_trial_state(trial._trial_id, optuna.trial.TrialState.FAIL) failed_count += 1 return failed_count except Exception: return 0
[docs] class FastGC: """High-performance MLflow GC using direct DB and S3 access.""" def __init__(self, db_uri: str, s3_client: Any, bucket: str): self.db_uri = db_uri self.s3 = s3_client self.bucket = bucket
[docs] def purge_experiment(self, experiment_id: str, dry_run: bool = True): from sqlalchemy import create_engine, text prefix = f"{experiment_id}/" if not dry_run: try: paginator = self.s3.get_paginator('list_objects_v2') for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): if 'Contents' in page: keys = [{'Key': obj['Key']} for obj in page['Contents']] self.s3.delete_objects(Bucket=self.bucket, Delete={'Objects': keys}) except Exception: pass try: engine = create_engine(self.db_uri) with engine.begin() as conn: tables = ["metrics", "params", "tags", "latest_metrics", "run_inputs", "runs"] for table in tables: query = text(f"DELETE FROM {table} WHERE run_uuid IN (SELECT run_uuid FROM runs WHERE experiment_id = :eid)") if table == "runs": query = text(f"DELETE FROM runs WHERE experiment_id = :eid") conn.execute(query, {"eid": experiment_id}) conn.execute(text("DELETE FROM experiments WHERE experiment_id = :eid"), {"eid": experiment_id}) except Exception: pass
[docs] class StudyTruncator: """Truncates Optuna studies and MLflow experiments to a target count.""" def __init__(self, optuna_storage: str, mlflow_uri: str): self.optuna_storage = optuna_storage self.mlflow_uri = mlflow_uri
[docs] def truncate(self, study_name: str, target_count: int, dry_run: bool = True) -> bool: from sqlalchemy import create_engine, text import mlflow # 1. Truncate Optuna via SQL engine = create_engine(self.optuna_storage) try: with engine.connect() as conn: result = conn.execute(text("SELECT study_id FROM studies WHERE study_name = :name"), {"name": study_name}) row = result.fetchone() if not row: return False study_id = row[0] result = conn.execute(text("SELECT trial_id FROM trials WHERE study_id = :study_id AND number >= :target"), {"study_id": study_id, "target": target_count}) trial_ids = [r[0] for r in result.fetchall()] if trial_ids and not dry_run: child_tables = [ 'trial_params', 'trial_values', 'trial_system_attributes', 'trial_user_attributes', 'trial_intermediate_values', 'trial_heartbeats' ] for table in child_tables: conn.execute(text(f"DELETE FROM {table} WHERE trial_id IN :ids"), {"ids": trial_ids}) conn.execute(text("DELETE FROM trials WHERE trial_id IN :ids"), {"ids": trial_ids}) conn.commit() except Exception: return False # 2. Truncate MLflow try: mlflow.set_tracking_uri(self.mlflow_uri) exp = mlflow.get_experiment_by_name(study_name) if not exp: return False exp_id = exp.experiment_id runs = mlflow.search_runs(experiment_ids=[exp_id], max_results=20000) runs_to_delete = [] for _, run in runs.iterrows(): tid = run.get('tags.mlflow.runName') if tid and tid.isdigit() and int(tid) >= target_count: runs_to_delete.append(run['run_id']) if runs_to_delete and not dry_run: client = mlflow.tracking.MlflowClient() for run_id in runs_to_delete: client.delete_run(run_id) except Exception: return False return True
[docs] class PruneService: """High-performance service to obliterate broken trials from SQL and S3.""" def __init__(self, optuna_storage: str, mlflow_uri: str, s3_client: Any = None, bucket: str = "mlflow"): self.optuna_storage = optuna_storage self.mlflow_uri = mlflow_uri self.s3 = s3_client self.bucket = bucket from sqlalchemy import create_engine self.engine = create_engine(self.optuna_storage)
[docs] def find_broken_trials(self, study_name: str, orphans: bool = True, failed: bool = True) -> List[int]: """Identifies trial IDs for deletion based on brokenness.""" import optuna from sqlalchemy import text trial_ids = [] try: with self.engine.connect() as conn: res = conn.execute(text("SELECT study_id FROM studies WHERE study_name = :name"), {"name": study_name}) row = res.fetchone() if not row: return [] study_id = row[0] if failed: res = conn.execute(text("SELECT trial_id FROM trials WHERE study_id = :sid AND state = 'FAIL'"), {"sid": study_id}) trial_ids.extend([r[0] for r in res.fetchall()]) if orphans: # Find RUNNING trials res = conn.execute(text("SELECT trial_id FROM trials WHERE study_id = :sid AND state = 'RUNNING'"), {"sid": study_id}) running_ids = [r[0] for r in res.fetchall()] if running_ids: slurm = SlurmTracker() # Get slurm job IDs for these trials res = conn.execute(text(""" SELECT trial_id, value FROM trial_user_attributes WHERE trial_id IN :ids AND `key` = 'slurm_job_id' """), {"ids": running_ids}) for tid, jid in res.fetchall(): if jid and not slurm.is_job_active(str(jid)): trial_ids.append(tid) except Exception as e: logger.error(f"Failed to find broken trials: {e}") return list(set(trial_ids))
[docs] def prune(self, study_name: str, trial_ids: List[int], dry_run: bool = True) -> Dict[str, int]: """Obliterates specific trials from Optuna, MLflow, and S3.""" from sqlalchemy import text import mlflow stats = {"optuna": 0, "mlflow": 0, "s3": 0} if not trial_ids: return stats # 1. Map Optuna trials to MLflow runs run_uuids = [] try: mlflow.set_tracking_uri(self.mlflow_uri) exp = mlflow.get_experiment_by_name(study_name) if exp: # Find run names (trial numbers) with self.engine.connect() as conn: res = conn.execute(text("SELECT number FROM trials WHERE trial_id IN :ids"), {"ids": trial_ids}) trial_numbers = [str(r[0]) for r in res.fetchall()] # Fetch runs matching these numbers if trial_numbers: filter_str = " OR ".join([f"tags.mlflow.runName = '{n}'" for n in trial_numbers]) runs = mlflow.search_runs(experiment_ids=[exp.experiment_id], filter_string=filter_str) run_uuids = runs['run_id'].tolist() except Exception as e: logger.error(f"Failed to sync MLflow runs: {e}") if dry_run: logger.info(f"[DRY RUN] Would prune {len(trial_ids)} trials from {study_name}") return {"optuna": len(trial_ids), "mlflow": len(run_uuids), "s3": len(run_uuids)} # 2. S3 Purge if self.s3 and run_uuids: for rid in run_uuids: try: prefix = f"{exp.experiment_id}/{rid}/" paginator = self.s3.get_paginator('list_objects_v2') for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): if 'Contents' in page: keys = [{'Key': obj['Key']} for obj in page['Contents']] self.s3.delete_objects(Bucket=self.bucket, Delete={'Objects': keys}) stats["s3"] += 1 except Exception: pass # 3. MLflow SQL Purge (Fast) if run_uuids: # We need the MLflow DB URI. For now we assume TrackingManager logic or similar. # But FastGC already has it. If we don't have it, we skip fast purge and use API. try: # For simplicity, if we don't have DB URI, we use API. # But the plan said Pure SQL. We'll try to find DB URI from mlflow_uri if possible # or just use the same engine if Optuna and MLflow are same (often not). # Fallback to client.delete_run if DB URI not provided. client = mlflow.tracking.MlflowClient() for rid in run_uuids: client.delete_run(rid) stats["mlflow"] = len(run_uuids) except Exception: pass # 4. Optuna SQL Purge try: with self.engine.begin() as conn: child_tables = [ 'trial_params', 'trial_values', 'trial_system_attributes', 'trial_user_attributes', 'trial_intermediate_values', 'trial_heartbeats' ] for table in child_tables: conn.execute(text(f"DELETE FROM {table} WHERE trial_id IN :ids"), {"ids": trial_ids}) conn.execute(text("DELETE FROM trials WHERE trial_id IN :ids"), {"ids": trial_ids}) stats["optuna"] = len(trial_ids) except Exception as e: logger.error(f"Optuna SQL purge failed: {e}") return stats