import typing as t
import lightning
import mlflow
from ..loggers.modmlflow import ModMLFlowLogger
[docs]
def init_logger(
logger_cfg: t.Dict[str, t.Dict[str, t.Any]],
name: str = None,
tags: t.List[str] = None,
) -> t.Optional[t.Any]:
"""
Initialize a logger based on the provided configuration.
Parameters
----------
logger_cfg : dict
Configuration for the logger. It should contain the logger name as key and its corresponding arguments as value.
name : str, optional
Name of the experiment.
Defaults to None.
tags : list of str, optional
Tags for the logger.
Defaults to None.
Returns
-------
logger : Any or None
An instance of the logger class or None if the logger configuration is None.
Notes
-----
This function supports multiple types of loggers, including DummyLogger, CSVLogger, MLFlowLogger, ModMLFlowLogger, TensorBoardLogger, and WandbLogger.
"""
if logger_cfg is None:
return None
else:
for logger_name, logger_kwargs in logger_cfg.items():
if logger_name == "DummyLogger":
from lightning.pytorch.loggers.logger import DummyLogger
logger = DummyLogger()
if logger_name == 'CSVLogger':
from lightning.pytorch.loggers import CSVLogger
logger = CSVLogger(
name=name,
**logger_kwargs
)
#? Make as if-cases because the arguments are confusing
#? "run_name" for mlflow and "exp_name" for tensorflow
elif logger_name == "MLFlowLogger":
#? Putting here to lossen the requirements of installing MLFlow
from lightning.pytorch.loggers import MLFlowLogger
#? Based on https://github.com/mlflow/mlflow/issues/5852
tracking_uri = logger_kwargs['tracking_uri']
mlflow.set_tracking_uri(tracking_uri)
# Check if 'run_name' is in kwargs, if not use name as run_name fallback
# but 'experiment_name' is usually what 'name' refers to in this context
logger = MLFlowLogger(
experiment_name=name,
tags=tags,
**logger_kwargs
)
elif logger_name == "ModMLFlowLogger":
#? Based on https://github.com/mlflow/mlflow/issues/5852
tracking_uri = logger_kwargs['tracking_uri']
mlflow.set_tracking_uri(tracking_uri)
logger = ModMLFlowLogger(
experiment_name=name,
tags=tags,
**logger_kwargs
)
elif logger_name == 'TensorBoardLogger':
from lightning.pytorch.loggers import TensorBoardLogger
#? Handle name if already in kwargs to avoid duplicate argument error
if 'name' in logger_kwargs and name is not None:
# Priority given to the 'name' argument passed to init_logger if it's not None
logger_kwargs['name'] = name
logger = TensorBoardLogger(
**logger_kwargs
)
elif logger_name == 'WandbLogger':
#? Putting here to lossen the requirements of installing MLFlow
from lightning.pytorch.loggers import WandbLogger
raise NotImplementedError("Not yet completely implemented!")
else:
raise ValueError(f"Unsupported logger:{logger_name}")
return logger