Source code for gunz_ml.loggers

import typing as t
import lightning
import mlflow
from ..loggers.modmlflow import ModMLFlowLogger

[docs] def init_logger( logger_cfg: t.Dict[str, t.Dict[str, t.Any]], name: str = None, tags: t.List[str] = None, ) -> t.Optional[t.Any]: """ Initialize a logger based on the provided configuration. Parameters ---------- logger_cfg : dict Configuration for the logger. It should contain the logger name as key and its corresponding arguments as value. name : str, optional Name of the experiment. Defaults to None. tags : list of str, optional Tags for the logger. Defaults to None. Returns ------- logger : Any or None An instance of the logger class or None if the logger configuration is None. Notes ----- This function supports multiple types of loggers, including DummyLogger, CSVLogger, MLFlowLogger, ModMLFlowLogger, TensorBoardLogger, and WandbLogger. """ if logger_cfg is None: return None else: for logger_name, logger_kwargs in logger_cfg.items(): if logger_name == "DummyLogger": from lightning.pytorch.loggers.logger import DummyLogger logger = DummyLogger() if logger_name == 'CSVLogger': from lightning.pytorch.loggers import CSVLogger logger = CSVLogger( name=name, **logger_kwargs ) #? Make as if-cases because the arguments are confusing #? "run_name" for mlflow and "exp_name" for tensorflow elif logger_name == "MLFlowLogger": #? Putting here to lossen the requirements of installing MLFlow from lightning.pytorch.loggers import MLFlowLogger #? Based on https://github.com/mlflow/mlflow/issues/5852 tracking_uri = logger_kwargs['tracking_uri'] mlflow.set_tracking_uri(tracking_uri) # Check if 'run_name' is in kwargs, if not use name as run_name fallback # but 'experiment_name' is usually what 'name' refers to in this context logger = MLFlowLogger( experiment_name=name, tags=tags, **logger_kwargs ) elif logger_name == "ModMLFlowLogger": #? Based on https://github.com/mlflow/mlflow/issues/5852 tracking_uri = logger_kwargs['tracking_uri'] mlflow.set_tracking_uri(tracking_uri) logger = ModMLFlowLogger( experiment_name=name, tags=tags, **logger_kwargs ) elif logger_name == 'TensorBoardLogger': from lightning.pytorch.loggers import TensorBoardLogger #? Handle name if already in kwargs to avoid duplicate argument error if 'name' in logger_kwargs and name is not None: # Priority given to the 'name' argument passed to init_logger if it's not None logger_kwargs['name'] = name logger = TensorBoardLogger( **logger_kwargs ) elif logger_name == 'WandbLogger': #? Putting here to lossen the requirements of installing MLFlow from lightning.pytorch.loggers import WandbLogger raise NotImplementedError("Not yet completely implemented!") else: raise ValueError(f"Unsupported logger:{logger_name}") return logger