import socket
import logging
from typing import Optional, Union
from urllib.parse import urlparse
# Optional imports to avoid hard dependencies if not needed
try:
import sqlalchemy
from sqlalchemy import create_engine, text
HAS_SQLALCHEMY = True
except ImportError:
HAS_SQLALCHEMY = False
try:
import requests
HAS_REQUESTS = True
except ImportError:
HAS_REQUESTS = False
from ..schemas.db import DatabaseConfig
logger = logging.getLogger(__name__)
[docs]
def build_db_uri(cfg: Union[DatabaseConfig, dict]) -> str:
"""
Constructs a database URI from a configuration object or dictionary.
Parameters
----------
cfg : Union[DatabaseConfig, dict]
The configuration object or dictionary containing database connection details.
Expected keys/attributes: driver, user, password, host, port, db_name.
Returns
-------
str
A formatted SQLAlchemy or HTTP URI string.
Raises
------
ValueError
If the database driver is not specified in the config.
"""
# Allow dict access for flexibility
get = lambda k, default=None: getattr(cfg, k, default) if hasattr(cfg, k) else cfg.get(k, default)
driver = get('driver')
user = get('user')
password = get('password')
host = get('host')
port = get('port')
db_name = get('db_name')
if not driver:
raise ValueError("Database driver must be specified")
# HTTP/HTTPS (e.g. MLflow)
if driver in ['http', 'https']:
uri = f"{driver}://{host}"
if port:
uri += f":{port}"
return uri
# SQLite
if driver == 'sqlite':
path = db_name if db_name else ':memory:'
return f"sqlite:///{path}"
# Standard SQL (MySQL, Postgres, etc.)
uri = f"{driver}://"
if user:
uri += f"{user}"
if password:
uri += f":{password}"
uri += "@"
uri += f"{host}"
if port:
uri += f":{port}"
if db_name:
uri += f"/{db_name}"
return uri
[docs]
def check_db_connection(cfg: Union[DatabaseConfig, dict], timeout: int = 5) -> bool:
"""
Verifies connectivity to the database tracker (MLflow or Optuna).
Parameters
----------
cfg : Union[DatabaseConfig, dict]
The configuration object or dictionary containing connection details.
timeout : int, optional
The connection timeout in seconds, by default 5.
Returns
-------
bool
True if the connection was successfully established, False otherwise.
"""
uri = build_db_uri(cfg)
driver = getattr(cfg, 'driver', cfg.get('driver'))
logger.info(f"Checking connection to {driver} at {getattr(cfg, 'host', cfg.get('host'))}...")
# 1. HTTP Check (MLflow)
if driver in ['http', 'https']:
if not HAS_REQUESTS:
logger.warning("`requests` not installed. Skipping HTTP connection check.")
return True
try:
response = requests.get(uri, timeout=timeout)
response.raise_for_status()
logger.info("HTTP connection successful.")
return True
except requests.RequestException as e:
logger.error(f"HTTP connection failed: {e}")
return False
# 2. SQL Check (Optuna)
if HAS_SQLALCHEMY:
try:
if 'mysql' in driver:
connect_args = {'connect_timeout': timeout}
elif 'postgresql' in driver:
connect_args = {'connect_timeout': timeout}
else:
connect_args = {}
engine = create_engine(uri, connect_args=connect_args)
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
logger.info("SQL connection successful.")
return True
except Exception as e:
logger.error(f"SQL connection failed: {e}")
return False
else:
# Fallback: Simple TCP Ping
return _check_tcp_port(cfg, timeout)
def _check_tcp_port(cfg: Union[DatabaseConfig, dict], timeout: int) -> bool:
"""
Performs a raw TCP ping to the specified host and port.
Parameters
----------
cfg : Union[DatabaseConfig, dict]
The configuration containing host and port.
timeout : int
The timeout in seconds.
Returns
-------
bool
True if the port is reachable, False otherwise.
"""
host = getattr(cfg, 'host', cfg.get('host'))
port = getattr(cfg, 'port', cfg.get('port'))
if not port:
return True
try:
with socket.create_connection((host, int(port)), timeout=timeout):
logger.info(f"TCP connection to {host}:{port} successful.")
return True
except (socket.timeout, ConnectionRefusedError, OSError) as e:
logger.error(f"TCP connection to {host}:{port} failed: {e}")
return False