Source code for gunz_ml.loggers.modmlflow

import os
from argparse import Namespace
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Union, Literal

from lightning_fabric.utilities.logger import _convert_params, _flatten_dict
from pytorch_lightning.loggers.mlflow import MLFlowLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from mlflow.entities import Param

from ..integrations.mlflow import safe_set_experiment


[docs] class ModMLFlowLogger(MLFlowLogger): """ A modified MLFlow logger for PyTorch Lightning aligned with PL 2.6+. """ def __init__( self, experiment_name: str = "lightning_logs", run_id: Optional[str] = None, run_name: Optional[str] = None, tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", artifact_location: Optional[str] = None, ): #? Safely set/create experiment to avoid race conditions safe_set_experiment( experiment_name=experiment_name, tracking_uri=tracking_uri, artifact_location=artifact_location ) super().__init__( experiment_name=experiment_name, run_name=run_name, tracking_uri=tracking_uri, tags=tags, save_dir=save_dir, log_model=log_model, prefix=prefix, artifact_location=artifact_location, run_id=run_id, ) @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: """Record hyperparameters to the experiment.""" params = _convert_params(params) params = _flatten_dict(params) # Convert to MLflow Param entities for log_batch params_list = [Param(key=str(k), value=str(v)) for k, v in params.items()] # MLflow limits batch logging to 100 items per batch batch_size = 100 for i in range(0, len(params_list), batch_size): self.experiment.log_batch( run_id=self.run_id, params=params_list[i : i + batch_size] ) @rank_zero_only def log_hyperparams_metrics( self, params: Union[Dict[str, Any], Namespace], metrics: Mapping[str, float], step: Optional[int] = None, ): self.log_hyperparams(params) self.log_metrics(metrics, step) @rank_zero_only def log_artifact(self, local_path: str, artifact_path: Optional[str] = None) -> None: """Log a single artifact.""" self.experiment.log_artifact(self.run_id, local_path, artifact_path) @rank_zero_only def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None: """Log a directory of artifacts.""" self.experiment.log_artifacts(self.run_id, local_dir, artifact_path)