import os
from argparse import Namespace
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Union, Literal
from lightning_fabric.utilities.logger import _convert_params, _flatten_dict
from pytorch_lightning.loggers.mlflow import MLFlowLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from mlflow.entities import Param
from ..integrations.mlflow import safe_set_experiment
[docs]
class ModMLFlowLogger(MLFlowLogger):
"""
A modified MLFlow logger for PyTorch Lightning aligned with PL 2.6+.
"""
def __init__(
self,
experiment_name: str = "lightning_logs",
run_id: Optional[str] = None,
run_name: Optional[str] = None,
tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"),
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = "./mlruns",
log_model: Literal[True, False, "all"] = False,
prefix: str = "",
artifact_location: Optional[str] = None,
):
#? Safely set/create experiment to avoid race conditions
safe_set_experiment(
experiment_name=experiment_name,
tracking_uri=tracking_uri,
artifact_location=artifact_location
)
super().__init__(
experiment_name=experiment_name,
run_name=run_name,
tracking_uri=tracking_uri,
tags=tags,
save_dir=save_dir,
log_model=log_model,
prefix=prefix,
artifact_location=artifact_location,
run_id=run_id,
)
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
"""Record hyperparameters to the experiment."""
params = _convert_params(params)
params = _flatten_dict(params)
# Convert to MLflow Param entities for log_batch
params_list = [Param(key=str(k), value=str(v)) for k, v in params.items()]
# MLflow limits batch logging to 100 items per batch
batch_size = 100
for i in range(0, len(params_list), batch_size):
self.experiment.log_batch(
run_id=self.run_id, params=params_list[i : i + batch_size]
)
@rank_zero_only
def log_hyperparams_metrics(
self,
params: Union[Dict[str, Any], Namespace],
metrics: Mapping[str, float],
step: Optional[int] = None,
):
self.log_hyperparams(params)
self.log_metrics(metrics, step)
@rank_zero_only
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None) -> None:
"""Log a single artifact."""
self.experiment.log_artifact(self.run_id, local_path, artifact_path)
@rank_zero_only
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
"""Log a directory of artifacts."""
self.experiment.log_artifacts(self.run_id, local_dir, artifact_path)