Source code for gunz_ml.management.artifacts

# -*- coding: utf-8 -*-
"""
Juno Artifact SDK for S3/MinIO.

Handles discovery, retrieval, and smart extraction of artifacts from 
Juno's artifact storage.
"""

#? Metadata
__author__ = "Gemini CLI"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import os
import pathlib
import tempfile
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import mlflow
import numpy as np
from loguru import logger

# =============================================================================
# CLASSES
# =============================================================================

[docs] class ArtifactStore: """Manager for Juno S3 artifacts (MinIO backend).""" def __init__( self, mlflow_uri: str = "http://juno.tnt.uni-hannover.de:5200", s3_endpoint: str = "http://juno.tnt.uni-hannover.de:9000", access_key: str = "minioadmin", secret_key: str = "minioadmin" ): self.mlflow_uri = mlflow_uri self.s3_endpoint = s3_endpoint #? Configure environment for MinIO os.environ["MLFLOW_S3_ENDPOINT_URL"] = s3_endpoint os.environ["AWS_ACCESS_KEY_ID"] = access_key os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key mlflow.set_tracking_uri(self.mlflow_uri) self.client = mlflow.tracking.MlflowClient()
[docs] def list_recursive(self, run_id: str, path: str = "") -> t.List[str]: """Recursively lists all artifact paths for a run.""" paths = [] try: items = self.client.list_artifacts(run_id, path=path) for item in items: if item.is_dir: paths.extend(self.list_recursive(run_id, item.path)) else: paths.append(item.path) except Exception as e: logger.error(f"Failed to list artifacts for run {run_id}: {e}") return paths
[docs] def download_file(self, run_id: str, artifact_path: str, local_dir: t.Union[str, pathlib.Path]) -> pathlib.Path: """Downloads a specific artifact file.""" local_path = self.client.download_artifacts(run_id, artifact_path, str(local_dir)) return pathlib.Path(local_path)
[docs] def fetch_points(self, run_id: str) -> t.Optional[np.ndarray]: """ Smart-fetches point coordinates from a run. Searches for 'points.npz' or 'trajectory.npz'. """ all_paths = self.list_recursive(run_id) #? Priority: points.npz > trajectory.npz target = None for p in ["points.npz", "trajectory.npz"]: if any(path.endswith(p) for path in all_paths): target = [path for path in all_paths if path.endswith(p)][0] break if not target: logger.warning(f"No point artifacts found in run {run_id}") return None with tempfile.TemporaryDirectory() as tmp: local = self.download_file(run_id, target, tmp) data = np.load(local) if target.endswith("trajectory.npz"): # Usually (epochs, N, 3), take last epoch pts = data["trajectory"] return pts[-1] if pts.ndim == 3 else pts else: return data["points"]