# -*- coding: utf-8 -*-
"""
Juno Artifact SDK for S3/MinIO.
Handles discovery, retrieval, and smart extraction of artifacts from
Juno's artifact storage.
"""
#? Metadata
__author__ = "Gemini CLI"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import os
import pathlib
import tempfile
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import mlflow
import numpy as np
from loguru import logger
# =============================================================================
# CLASSES
# =============================================================================
[docs]
class ArtifactStore:
"""Manager for Juno S3 artifacts (MinIO backend)."""
def __init__(
self,
mlflow_uri: str = "http://juno.tnt.uni-hannover.de:5200",
s3_endpoint: str = "http://juno.tnt.uni-hannover.de:9000",
access_key: str = "minioadmin",
secret_key: str = "minioadmin"
):
self.mlflow_uri = mlflow_uri
self.s3_endpoint = s3_endpoint
#? Configure environment for MinIO
os.environ["MLFLOW_S3_ENDPOINT_URL"] = s3_endpoint
os.environ["AWS_ACCESS_KEY_ID"] = access_key
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
mlflow.set_tracking_uri(self.mlflow_uri)
self.client = mlflow.tracking.MlflowClient()
[docs]
def list_recursive(self, run_id: str, path: str = "") -> t.List[str]:
"""Recursively lists all artifact paths for a run."""
paths = []
try:
items = self.client.list_artifacts(run_id, path=path)
for item in items:
if item.is_dir:
paths.extend(self.list_recursive(run_id, item.path))
else:
paths.append(item.path)
except Exception as e:
logger.error(f"Failed to list artifacts for run {run_id}: {e}")
return paths
[docs]
def download_file(self, run_id: str, artifact_path: str, local_dir: t.Union[str, pathlib.Path]) -> pathlib.Path:
"""Downloads a specific artifact file."""
local_path = self.client.download_artifacts(run_id, artifact_path, str(local_dir))
return pathlib.Path(local_path)
[docs]
def fetch_points(self, run_id: str) -> t.Optional[np.ndarray]:
"""
Smart-fetches point coordinates from a run.
Searches for 'points.npz' or 'trajectory.npz'.
"""
all_paths = self.list_recursive(run_id)
#? Priority: points.npz > trajectory.npz
target = None
for p in ["points.npz", "trajectory.npz"]:
if any(path.endswith(p) for path in all_paths):
target = [path for path in all_paths if path.endswith(p)][0]
break
if not target:
logger.warning(f"No point artifacts found in run {run_id}")
return None
with tempfile.TemporaryDirectory() as tmp:
local = self.download_file(run_id, target, tmp)
data = np.load(local)
if target.endswith("trajectory.npz"):
# Usually (epochs, N, 3), take last epoch
pts = data["trajectory"]
return pts[-1] if pts.ndim == 3 else pts
else:
return data["points"]