Source code for gunz_ml.utils.connectivity

import socket
import logging
from typing import Optional, Union
from urllib.parse import urlparse

# Optional imports to avoid hard dependencies if not needed
try:
    import sqlalchemy
    from sqlalchemy import create_engine, text
    HAS_SQLALCHEMY = True
except ImportError:
    HAS_SQLALCHEMY = False

try:
    import requests
    HAS_REQUESTS = True
except ImportError:
    HAS_REQUESTS = False

from ..schemas.db import DatabaseConfig

logger = logging.getLogger(__name__)

[docs] def build_db_uri(cfg: Union[DatabaseConfig, dict]) -> str: """ Constructs a database URI from a configuration object or dictionary. Parameters ---------- cfg : Union[DatabaseConfig, dict] The configuration object or dictionary containing database connection details. Expected keys/attributes: driver, user, password, host, port, db_name. Returns ------- str A formatted SQLAlchemy or HTTP URI string. Raises ------ ValueError If the database driver is not specified in the config. """ # Allow dict access for flexibility get = lambda k, default=None: getattr(cfg, k, default) if hasattr(cfg, k) else cfg.get(k, default) driver = get('driver') user = get('user') password = get('password') host = get('host') port = get('port') db_name = get('db_name') if not driver: raise ValueError("Database driver must be specified") # HTTP/HTTPS (e.g. MLflow) if driver in ['http', 'https']: uri = f"{driver}://{host}" if port: uri += f":{port}" return uri # SQLite if driver == 'sqlite': path = db_name if db_name else ':memory:' return f"sqlite:///{path}" # Standard SQL (MySQL, Postgres, etc.) uri = f"{driver}://" if user: uri += f"{user}" if password: uri += f":{password}" uri += "@" uri += f"{host}" if port: uri += f":{port}" if db_name: uri += f"/{db_name}" return uri
[docs] def check_db_connection(cfg: Union[DatabaseConfig, dict], timeout: int = 5) -> bool: """ Verifies connectivity to the database tracker (MLflow or Optuna). Parameters ---------- cfg : Union[DatabaseConfig, dict] The configuration object or dictionary containing connection details. timeout : int, optional The connection timeout in seconds, by default 5. Returns ------- bool True if the connection was successfully established, False otherwise. """ uri = build_db_uri(cfg) driver = getattr(cfg, 'driver', cfg.get('driver')) logger.info(f"Checking connection to {driver} at {getattr(cfg, 'host', cfg.get('host'))}...") # 1. HTTP Check (MLflow) if driver in ['http', 'https']: if not HAS_REQUESTS: logger.warning("`requests` not installed. Skipping HTTP connection check.") return True try: response = requests.get(uri, timeout=timeout) response.raise_for_status() logger.info("HTTP connection successful.") return True except requests.RequestException as e: logger.error(f"HTTP connection failed: {e}") return False # 2. SQL Check (Optuna) if HAS_SQLALCHEMY: try: if 'mysql' in driver: connect_args = {'connect_timeout': timeout} elif 'postgresql' in driver: connect_args = {'connect_timeout': timeout} else: connect_args = {} engine = create_engine(uri, connect_args=connect_args) with engine.connect() as conn: conn.execute(text("SELECT 1")) logger.info("SQL connection successful.") return True except Exception as e: logger.error(f"SQL connection failed: {e}") return False else: # Fallback: Simple TCP Ping return _check_tcp_port(cfg, timeout)
def _check_tcp_port(cfg: Union[DatabaseConfig, dict], timeout: int) -> bool: """ Performs a raw TCP ping to the specified host and port. Parameters ---------- cfg : Union[DatabaseConfig, dict] The configuration containing host and port. timeout : int The timeout in seconds. Returns ------- bool True if the port is reachable, False otherwise. """ host = getattr(cfg, 'host', cfg.get('host')) port = getattr(cfg, 'port', cfg.get('port')) if not port: return True try: with socket.create_connection((host, int(port)), timeout=timeout): logger.info(f"TCP connection to {host}:{port} successful.") return True except (socket.timeout, ConnectionRefusedError, OSError) as e: logger.error(f"TCP connection to {host}:{port} failed: {e}") return False