Skip to content

dataenginex.ml

ML training, model registry, drift detection, and model serving.

dataenginex.ml

ML training, model registry, drift detection, serving, scheduling, and metrics.

Public API::

from dataenginex.ml import (
    BaseTrainer, SklearnTrainer, TrainingResult,
    ModelRegistry, ModelArtifact, ModelStage,
    DriftDetector, DriftReport,
    DriftScheduler, DriftMonitorConfig, DriftCheckResult,
    ModelServer, PredictionRequest, PredictionResponse,
    model_prediction_total, model_prediction_latency_seconds,
    model_drift_psi, model_drift_alerts_total,
)

DriftDetector

Detect distribution drift between a reference and current dataset.

PSI thresholds (industry standard): < 0.10 — no significant drift 0.10-0.25 — moderate drift > 0.25 — significant drift

Parameters

psi_threshold: PSI value above which drift is flagged (default 0.20). n_bins: Number of histogram bins for PSI calculation (default 10).

Source code in packages/dataenginex/src/dataenginex/ml/drift.py
class DriftDetector:
    """Detect distribution drift between a reference and current dataset.

    PSI thresholds (industry standard):
        < 0.10  — no significant drift
        0.10-0.25 — moderate drift
        > 0.25  — significant drift

    Parameters
    ----------
    psi_threshold:
        PSI value above which drift is flagged (default 0.20).
    n_bins:
        Number of histogram bins for PSI calculation (default 10).
    """

    def __init__(self, psi_threshold: float = 0.20, n_bins: int = 10) -> None:
        self.psi_threshold = psi_threshold
        self.n_bins = n_bins

    def check_feature(
        self,
        feature_name: str,
        reference: list[float],
        current: list[float],
    ) -> DriftReport:
        """Check drift for a single numeric feature."""
        if not reference or not current:
            return DriftReport(
                feature_name=feature_name,
                psi=0.0,
                drift_detected=False,
                severity="none",
                details={"error": "Empty distribution(s)"},
            )

        import statistics

        ref_mean = statistics.mean(reference)
        cur_mean = statistics.mean(current)
        ref_std = statistics.stdev(reference) if len(reference) > 1 else 0.0
        cur_std = statistics.stdev(current) if len(current) > 1 else 0.0

        psi = self._compute_psi(reference, current)
        severity = self._classify_severity(psi)

        return DriftReport(
            feature_name=feature_name,
            psi=psi,
            drift_detected=psi > self.psi_threshold,
            severity=severity,
            reference_mean=round(ref_mean, 4),
            current_mean=round(cur_mean, 4),
            reference_std=round(ref_std, 4),
            current_std=round(cur_std, 4),
            details={"n_bins": self.n_bins, "threshold": self.psi_threshold},
        )

    def check_dataset(
        self,
        reference: dict[str, list[float]],
        current: dict[str, list[float]],
    ) -> list[DriftReport]:
        """Check drift across all shared features in two datasets.

        Parameters
        ----------
        reference:
            Mapping of ``feature_name → values`` for the reference period.
        current:
            Mapping of ``feature_name → values`` for the current period.
        """
        reports: list[DriftReport] = []
        shared = set(reference.keys()) & set(current.keys())
        for feat in sorted(shared):
            reports.append(self.check_feature(feat, reference[feat], current[feat]))
        return reports

    # -- internal ------------------------------------------------------------

    def _compute_psi(self, reference: list[float], current: list[float]) -> float:
        """Compute PSI using equal-width binning."""
        all_vals = reference + current
        lo = min(all_vals)
        hi = max(all_vals)
        if lo == hi:
            return 0.0

        bin_width = (hi - lo) / self.n_bins
        eps = 1e-6

        ref_counts = [0] * self.n_bins
        cur_counts = [0] * self.n_bins

        for v in reference:
            idx = min(int((v - lo) / bin_width), self.n_bins - 1)
            ref_counts[idx] += 1
        for v in current:
            idx = min(int((v - lo) / bin_width), self.n_bins - 1)
            cur_counts[idx] += 1

        ref_total = len(reference)
        cur_total = len(current)

        psi = 0.0
        for r, c in zip(ref_counts, cur_counts, strict=True):
            ref_pct = (r / ref_total) + eps
            cur_pct = (c / cur_total) + eps
            psi += (cur_pct - ref_pct) * math.log(cur_pct / ref_pct)

        return max(0.0, psi)

    @staticmethod
    def _classify_severity(psi: float) -> str:
        if psi < 0.10:
            return "none"
        if psi < 0.25:
            return "moderate"
        return "severe"

check_feature(feature_name, reference, current)

Check drift for a single numeric feature.

Source code in packages/dataenginex/src/dataenginex/ml/drift.py
def check_feature(
    self,
    feature_name: str,
    reference: list[float],
    current: list[float],
) -> DriftReport:
    """Check drift for a single numeric feature."""
    if not reference or not current:
        return DriftReport(
            feature_name=feature_name,
            psi=0.0,
            drift_detected=False,
            severity="none",
            details={"error": "Empty distribution(s)"},
        )

    import statistics

    ref_mean = statistics.mean(reference)
    cur_mean = statistics.mean(current)
    ref_std = statistics.stdev(reference) if len(reference) > 1 else 0.0
    cur_std = statistics.stdev(current) if len(current) > 1 else 0.0

    psi = self._compute_psi(reference, current)
    severity = self._classify_severity(psi)

    return DriftReport(
        feature_name=feature_name,
        psi=psi,
        drift_detected=psi > self.psi_threshold,
        severity=severity,
        reference_mean=round(ref_mean, 4),
        current_mean=round(cur_mean, 4),
        reference_std=round(ref_std, 4),
        current_std=round(cur_std, 4),
        details={"n_bins": self.n_bins, "threshold": self.psi_threshold},
    )

check_dataset(reference, current)

Check drift across all shared features in two datasets.

Parameters

reference: Mapping of feature_name → values for the reference period. current: Mapping of feature_name → values for the current period.

Source code in packages/dataenginex/src/dataenginex/ml/drift.py
def check_dataset(
    self,
    reference: dict[str, list[float]],
    current: dict[str, list[float]],
) -> list[DriftReport]:
    """Check drift across all shared features in two datasets.

    Parameters
    ----------
    reference:
        Mapping of ``feature_name → values`` for the reference period.
    current:
        Mapping of ``feature_name → values`` for the current period.
    """
    reports: list[DriftReport] = []
    shared = set(reference.keys()) & set(current.keys())
    for feat in sorted(shared):
        reports.append(self.check_feature(feat, reference[feat], current[feat]))
    return reports

DriftReport dataclass

Outcome of a drift check for a single feature.

Attributes:

Name Type Description
feature_name str

Name of the feature that was checked.

psi float

Population Stability Index value.

drift_detected bool

Whether drift exceeds the configured threshold.

severity str

Drift severity — "none", "moderate", or "severe".

reference_mean float | None

Mean of the reference distribution.

current_mean float | None

Mean of the current distribution.

reference_std float | None

Standard deviation of reference distribution.

current_std float | None

Standard deviation of current distribution.

details dict[str, Any]

Extra context (bins, threshold, etc.).

checked_at datetime

Timestamp of the drift check.

Source code in packages/dataenginex/src/dataenginex/ml/drift.py
@dataclass
class DriftReport:
    """Outcome of a drift check for a single feature.

    Attributes:
        feature_name: Name of the feature that was checked.
        psi: Population Stability Index value.
        drift_detected: Whether drift exceeds the configured threshold.
        severity: Drift severity — ``"none"``, ``"moderate"``, or ``"severe"``.
        reference_mean: Mean of the reference distribution.
        current_mean: Mean of the current distribution.
        reference_std: Standard deviation of reference distribution.
        current_std: Standard deviation of current distribution.
        details: Extra context (bins, threshold, etc.).
        checked_at: Timestamp of the drift check.
    """

    feature_name: str
    psi: float  # Population Stability Index
    drift_detected: bool
    severity: str  # "none", "minor", "moderate", "severe"
    reference_mean: float | None = None
    current_mean: float | None = None
    reference_std: float | None = None
    current_std: float | None = None
    details: dict[str, Any] = field(default_factory=dict)
    checked_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the drift report to a plain dictionary."""
        return {
            "feature_name": self.feature_name,
            "psi": round(self.psi, 6),
            "drift_detected": self.drift_detected,
            "severity": self.severity,
            "reference_mean": self.reference_mean,
            "current_mean": self.current_mean,
            "reference_std": self.reference_std,
            "current_std": self.current_std,
            "details": self.details,
            "checked_at": self.checked_at.isoformat(),
        }

to_dict()

Serialize the drift report to a plain dictionary.

Source code in packages/dataenginex/src/dataenginex/ml/drift.py
def to_dict(self) -> dict[str, Any]:
    """Serialize the drift report to a plain dictionary."""
    return {
        "feature_name": self.feature_name,
        "psi": round(self.psi, 6),
        "drift_detected": self.drift_detected,
        "severity": self.severity,
        "reference_mean": self.reference_mean,
        "current_mean": self.current_mean,
        "reference_std": self.reference_std,
        "current_std": self.current_std,
        "details": self.details,
        "checked_at": self.checked_at.isoformat(),
    }

ModelArtifact dataclass

Registry entry for a model version.

Attributes:

Name Type Description
name str

Model name (e.g. "job_classifier").

version str

Semantic version string.

stage ModelStage

Current lifecycle stage.

artifact_path str

File path to the serialised model.

metrics dict[str, float]

Training/evaluation metrics.

parameters dict[str, Any]

Hyper-parameters used for training.

description str

Free-text description.

created_at datetime

When the artifact was registered.

promoted_at datetime | None

When the artifact was last promoted.

tags list[str]

Arbitrary labels for filtering.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
@dataclass
class ModelArtifact:
    """Registry entry for a model version.

    Attributes:
        name: Model name (e.g. ``"job_classifier"``).
        version: Semantic version string.
        stage: Current lifecycle stage.
        artifact_path: File path to the serialised model.
        metrics: Training/evaluation metrics.
        parameters: Hyper-parameters used for training.
        description: Free-text description.
        created_at: When the artifact was registered.
        promoted_at: When the artifact was last promoted.
        tags: Arbitrary labels for filtering.
    """

    name: str
    version: str
    stage: ModelStage = ModelStage.DEVELOPMENT
    artifact_path: str = ""
    metrics: dict[str, float] = field(default_factory=dict)
    parameters: dict[str, Any] = field(default_factory=dict)
    description: str = ""
    created_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))
    promoted_at: datetime | None = None
    tags: list[str] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        """Serialize the model artifact metadata to a plain dictionary."""
        d = asdict(self)
        d["stage"] = self.stage.value
        d["created_at"] = self.created_at.isoformat()
        d["promoted_at"] = self.promoted_at.isoformat() if self.promoted_at else None
        return d

to_dict()

Serialize the model artifact metadata to a plain dictionary.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def to_dict(self) -> dict[str, Any]:
    """Serialize the model artifact metadata to a plain dictionary."""
    d = asdict(self)
    d["stage"] = self.stage.value
    d["created_at"] = self.created_at.isoformat()
    d["promoted_at"] = self.promoted_at.isoformat() if self.promoted_at else None
    return d

ModelRegistry

JSON-file-backed model registry.

Parameters

persist_path: Path to a JSON file for persistence (optional).

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
class ModelRegistry:
    """JSON-file-backed model registry.

    Parameters
    ----------
    persist_path:
        Path to a JSON file for persistence (optional).
    """

    def __init__(self, persist_path: str | Path | None = None) -> None:
        # name → version → artifact
        self._models: dict[str, dict[str, ModelArtifact]] = {}
        self._persist_path = Path(persist_path) if persist_path else None
        if self._persist_path and self._persist_path.exists():
            self._load()

    # -- registration --------------------------------------------------------

    def register(self, artifact: ModelArtifact) -> ModelArtifact:
        """Register a new model version."""
        versions = self._models.setdefault(artifact.name, {})
        if artifact.version in versions:
            raise ValueError(
                f"Model {artifact.name!r} version {artifact.version} already registered"
            )
        versions[artifact.version] = artifact
        logger.info(
            "Registered model %s v%s (stage=%s)",
            artifact.name,
            artifact.version,
            artifact.stage.value,
        )
        self._save()
        return artifact

    # -- queries -------------------------------------------------------------

    def get(self, name: str, version: str) -> ModelArtifact | None:
        """Return the artifact for *name* at *version*, or ``None``."""
        return self._models.get(name, {}).get(version)

    def get_latest(self, name: str) -> ModelArtifact | None:
        """Return the most recently registered version of *name*."""
        versions = self._models.get(name)
        if not versions:
            return None
        return list(versions.values())[-1]

    def get_production(self, name: str) -> ModelArtifact | None:
        """Return the model currently in production stage."""
        for art in self._models.get(name, {}).values():
            if art.stage == ModelStage.PRODUCTION:
                return art
        return None

    def list_models(self) -> list[str]:
        """Return all registered model names."""
        return list(self._models.keys())

    def list_versions(self, name: str) -> list[str]:
        """Return all version strings registered for *name*."""
        return list(self._models.get(name, {}).keys())

    # -- promotion -----------------------------------------------------------

    def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
        """Promote a model version to a new stage.

        If promoting to ``production``, any existing production model is
        automatically archived.
        """
        artifact = self.get(name, version)
        if artifact is None:
            raise ValueError(f"Model {name!r} version {version} not found")

        if target_stage == ModelStage.PRODUCTION:
            # Archive the current production model
            current = self.get_production(name)
            if current and current.version != version:
                current.stage = ModelStage.ARCHIVED
                logger.info("Archived %s v%s", name, current.version)

        artifact.stage = target_stage
        artifact.promoted_at = datetime.now(tz=UTC)
        logger.info("Promoted %s v%s%s", name, version, target_stage.value)
        self._save()
        return artifact

    # -- persistence ---------------------------------------------------------

    def _save(self) -> None:
        if not self._persist_path:
            return
        self._persist_path.parent.mkdir(parents=True, exist_ok=True)
        data: dict[str, list[dict[str, Any]]] = {}
        for name, versions in self._models.items():
            data[name] = [v.to_dict() for v in versions.values()]
        self._persist_path.write_text(json.dumps(data, indent=2, default=str))

    def _load(self) -> None:
        if not self._persist_path or not self._persist_path.exists():
            return
        raw = json.loads(self._persist_path.read_text())
        for name, versions in raw.items():
            self._models[name] = {}
            for v in versions:
                v.pop("created_at", None)
                v.pop("promoted_at", None)
                v["stage"] = ModelStage(v.get("stage", "development"))
                self._models[name][v["version"]] = ModelArtifact(**v)
        logger.info("Loaded %d models from %s", len(self._models), self._persist_path)

register(artifact)

Register a new model version.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def register(self, artifact: ModelArtifact) -> ModelArtifact:
    """Register a new model version."""
    versions = self._models.setdefault(artifact.name, {})
    if artifact.version in versions:
        raise ValueError(
            f"Model {artifact.name!r} version {artifact.version} already registered"
        )
    versions[artifact.version] = artifact
    logger.info(
        "Registered model %s v%s (stage=%s)",
        artifact.name,
        artifact.version,
        artifact.stage.value,
    )
    self._save()
    return artifact

get(name, version)

Return the artifact for name at version, or None.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def get(self, name: str, version: str) -> ModelArtifact | None:
    """Return the artifact for *name* at *version*, or ``None``."""
    return self._models.get(name, {}).get(version)

get_latest(name)

Return the most recently registered version of name.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def get_latest(self, name: str) -> ModelArtifact | None:
    """Return the most recently registered version of *name*."""
    versions = self._models.get(name)
    if not versions:
        return None
    return list(versions.values())[-1]

get_production(name)

Return the model currently in production stage.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def get_production(self, name: str) -> ModelArtifact | None:
    """Return the model currently in production stage."""
    for art in self._models.get(name, {}).values():
        if art.stage == ModelStage.PRODUCTION:
            return art
    return None

list_models()

Return all registered model names.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def list_models(self) -> list[str]:
    """Return all registered model names."""
    return list(self._models.keys())

list_versions(name)

Return all version strings registered for name.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def list_versions(self, name: str) -> list[str]:
    """Return all version strings registered for *name*."""
    return list(self._models.get(name, {}).keys())

promote(name, version, target_stage)

Promote a model version to a new stage.

If promoting to production, any existing production model is automatically archived.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
    """Promote a model version to a new stage.

    If promoting to ``production``, any existing production model is
    automatically archived.
    """
    artifact = self.get(name, version)
    if artifact is None:
        raise ValueError(f"Model {name!r} version {version} not found")

    if target_stage == ModelStage.PRODUCTION:
        # Archive the current production model
        current = self.get_production(name)
        if current and current.version != version:
            current.stage = ModelStage.ARCHIVED
            logger.info("Archived %s v%s", name, current.version)

    artifact.stage = target_stage
    artifact.promoted_at = datetime.now(tz=UTC)
    logger.info("Promoted %s v%s%s", name, version, target_stage.value)
    self._save()
    return artifact

ModelStage

Bases: StrEnum

Model lifecycle stages.

Source code in packages/dataenginex/src/dataenginex/ml/registry.py
class ModelStage(StrEnum):
    """Model lifecycle stages."""

    DEVELOPMENT = "development"
    STAGING = "staging"
    PRODUCTION = "production"
    ARCHIVED = "archived"

DriftCheckResult dataclass

Aggregated result of a drift check across all features of a model.

Attributes:

Name Type Description
model_name str

Name of the model checked.

reports list[DriftReport]

Per-feature drift reports.

drift_detected bool

True if any feature exceeded the PSI threshold.

max_psi float

Highest PSI score across all features.

checked_at datetime

Timestamp of the check.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
@dataclass
class DriftCheckResult:
    """Aggregated result of a drift check across all features of a model.

    Attributes:
        model_name: Name of the model checked.
        reports: Per-feature drift reports.
        drift_detected: ``True`` if any feature exceeded the PSI threshold.
        max_psi: Highest PSI score across all features.
        checked_at: Timestamp of the check.
    """

    model_name: str
    reports: list[DriftReport]
    drift_detected: bool
    max_psi: float
    checked_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "drift_detected": self.drift_detected,
            "max_psi": round(self.max_psi, 6),
            "checked_at": self.checked_at.isoformat(),
            "features": [r.to_dict() for r in self.reports],
        }

to_dict()

Serialize to a plain dictionary.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
def to_dict(self) -> dict[str, Any]:
    """Serialize to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "drift_detected": self.drift_detected,
        "max_psi": round(self.max_psi, 6),
        "checked_at": self.checked_at.isoformat(),
        "features": [r.to_dict() for r in self.reports],
    }

DriftMonitorConfig dataclass

Configuration for monitoring a single model's data drift.

Attributes:

Name Type Description
model_name str

Name of the model being monitored.

reference_data dict[str, list[float]]

Mapping of feature_name → reference distribution values.

psi_threshold float

PSI value above which drift is flagged (default 0.20).

check_interval_seconds float

Seconds between consecutive checks (default 300).

n_bins int

Number of histogram bins for PSI calculation (default 10).

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
@dataclass
class DriftMonitorConfig:
    """Configuration for monitoring a single model's data drift.

    Attributes:
        model_name: Name of the model being monitored.
        reference_data: Mapping of feature_name → reference distribution values.
        psi_threshold: PSI value above which drift is flagged (default 0.20).
        check_interval_seconds: Seconds between consecutive checks (default 300).
        n_bins: Number of histogram bins for PSI calculation (default 10).
    """

    model_name: str
    reference_data: dict[str, list[float]]
    psi_threshold: float = 0.20
    check_interval_seconds: float = 300.0
    n_bins: int = 10

DriftScheduler

Background scheduler for periodic model drift checks.

Runs a daemon thread that iterates registered monitors and invokes DriftDetector when each monitor's interval has elapsed. Results are published to Prometheus gauges and counters.

Parameters

tick_seconds: How often the scheduler loop wakes up to check deadlines (default 5.0). Lower values give more precise timing at the cost of CPU.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
class DriftScheduler:
    """Background scheduler for periodic model drift checks.

    Runs a daemon thread that iterates registered monitors and
    invokes ``DriftDetector`` when each monitor's interval has elapsed.
    Results are published to Prometheus gauges and counters.

    Parameters
    ----------
    tick_seconds:
        How often the scheduler loop wakes up to check deadlines
        (default ``5.0``).  Lower values give more precise timing
        at the cost of CPU.
    """

    def __init__(self, tick_seconds: float = 5.0) -> None:
        self._monitors: dict[str, _MonitorEntry] = {}
        self._lock = threading.Lock()
        self._stop_event = threading.Event()
        self._thread: threading.Thread | None = None
        self._tick = tick_seconds
        self._results: dict[str, DriftCheckResult] = {}

    # -- public API ----------------------------------------------------------

    def register(
        self,
        config: DriftMonitorConfig,
        data_fn: DataProvider,
    ) -> None:
        """Register a model for periodic drift monitoring.

        Parameters
        ----------
        config:
            Monitor configuration (thresholds, interval, reference data).
        data_fn:
            Callable returning current feature data as
            ``dict[str, list[float]]``.

        Raises
        ------
        ValueError:
            If ``config.reference_data`` is empty.
        """
        if not config.reference_data:
            msg = f"reference_data must not be empty for model {config.model_name!r}"
            raise ValueError(msg)

        with self._lock:
            self._monitors[config.model_name] = (config, data_fn, 0.0)
        logger.info(
            "Drift monitor registered: model=%s interval=%ss features=%d",
            config.model_name,
            config.check_interval_seconds,
            len(config.reference_data),
        )

    def unregister(self, model_name: str) -> None:
        """Remove a model from drift monitoring.

        Parameters
        ----------
        model_name:
            Name of the model to unregister.

        Raises
        ------
        KeyError:
            If the model is not registered.
        """
        with self._lock:
            if model_name not in self._monitors:
                msg = f"Model {model_name!r} is not registered for drift monitoring"
                raise KeyError(msg)
            del self._monitors[model_name]
            self._results.pop(model_name, None)
        logger.info("Drift monitor unregistered: model=%s", model_name)

    def start(self) -> None:
        """Start the background monitoring thread.

        Raises
        ------
        RuntimeError:
            If the scheduler is already running.
        """
        if self._thread is not None and self._thread.is_alive():
            msg = "DriftScheduler is already running"
            raise RuntimeError(msg)

        self._stop_event.clear()
        self._thread = threading.Thread(
            target=self._run_loop,
            name="drift-scheduler",
            daemon=True,
        )
        self._thread.start()
        logger.info("DriftScheduler started (tick=%ss)", self._tick)

    def stop(self, timeout: float = 10.0) -> None:
        """Stop the background monitoring thread.

        Parameters
        ----------
        timeout:
            Seconds to wait for the thread to join (default ``10.0``).
        """
        self._stop_event.set()
        if self._thread is not None:
            self._thread.join(timeout=timeout)
            self._thread = None
        logger.info("DriftScheduler stopped")

    @property
    def is_running(self) -> bool:
        """Whether the scheduler thread is alive."""
        return self._thread is not None and self._thread.is_alive()

    @property
    def registered_models(self) -> list[str]:
        """Names of all registered models."""
        with self._lock:
            return list(self._monitors.keys())

    def get_last_result(self, model_name: str) -> DriftCheckResult | None:
        """Return the most recent drift check result for a model."""
        return self._results.get(model_name)

    def run_check(self, model_name: str) -> DriftCheckResult:
        """Manually trigger a drift check for one model.

        Parameters
        ----------
        model_name:
            Name of a registered model to check.

        Raises
        ------
        KeyError:
            If the model is not registered.

        Returns
        -------
        DriftCheckResult:
            Aggregated result with per-feature reports.
        """
        with self._lock:
            entry = self._monitors.get(model_name)
            if entry is None:
                msg = f"Model {model_name!r} is not registered for drift monitoring"
                raise KeyError(msg)
            config, data_fn, _ = entry

        return self._execute_check(config, data_fn)

    # -- internal ------------------------------------------------------------

    def _run_loop(self) -> None:
        """Background loop — check each monitor when its interval elapses."""
        logger.debug("Drift scheduler loop entered")
        while not self._stop_event.is_set():
            now = time.monotonic()
            with self._lock:
                snapshot = list(self._monitors.items())

            for name, (config, data_fn, last_check) in snapshot:
                if now - last_check >= config.check_interval_seconds:
                    try:
                        self._execute_check(config, data_fn)
                    except Exception:
                        logger.exception("Drift check failed for model=%s", name)
                    # Update last_check regardless of success/failure
                    with self._lock:
                        if name in self._monitors:
                            old = self._monitors[name]
                            self._monitors[name] = (old[0], old[1], time.monotonic())

            self._stop_event.wait(timeout=self._tick)

    def _execute_check(
        self,
        config: DriftMonitorConfig,
        data_fn: DataProvider,
    ) -> DriftCheckResult:
        """Run a single drift check and publish metrics."""
        detector = DriftDetector(
            psi_threshold=config.psi_threshold,
            n_bins=config.n_bins,
        )

        current_data = data_fn()
        reports = detector.check_dataset(config.reference_data, current_data)

        drift_detected = any(r.drift_detected for r in reports)
        max_psi = max((r.psi for r in reports), default=0.0)

        result = DriftCheckResult(
            model_name=config.model_name,
            reports=reports,
            drift_detected=drift_detected,
            max_psi=max_psi,
        )

        # Publish to Prometheus
        for report in reports:
            model_drift_psi.labels(
                model=config.model_name,
                feature=report.feature_name,
            ).set(report.psi)

            if report.drift_detected:
                model_drift_alerts_total.labels(
                    model=config.model_name,
                    severity=report.severity,
                ).inc()

        self._results[config.model_name] = result

        if drift_detected:
            logger.warning(
                "Drift detected: model=%s max_psi=%.4f features_drifted=%d/%d",
                config.model_name,
                max_psi,
                sum(1 for r in reports if r.drift_detected),
                len(reports),
            )
        else:
            logger.info(
                "Drift check OK: model=%s max_psi=%.4f features=%d",
                config.model_name,
                max_psi,
                len(reports),
            )

        return result

is_running property

Whether the scheduler thread is alive.

registered_models property

Names of all registered models.

register(config, data_fn)

Register a model for periodic drift monitoring.

Parameters

config: Monitor configuration (thresholds, interval, reference data). data_fn: Callable returning current feature data as dict[str, list[float]].

Raises

ValueError: If config.reference_data is empty.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
def register(
    self,
    config: DriftMonitorConfig,
    data_fn: DataProvider,
) -> None:
    """Register a model for periodic drift monitoring.

    Parameters
    ----------
    config:
        Monitor configuration (thresholds, interval, reference data).
    data_fn:
        Callable returning current feature data as
        ``dict[str, list[float]]``.

    Raises
    ------
    ValueError:
        If ``config.reference_data`` is empty.
    """
    if not config.reference_data:
        msg = f"reference_data must not be empty for model {config.model_name!r}"
        raise ValueError(msg)

    with self._lock:
        self._monitors[config.model_name] = (config, data_fn, 0.0)
    logger.info(
        "Drift monitor registered: model=%s interval=%ss features=%d",
        config.model_name,
        config.check_interval_seconds,
        len(config.reference_data),
    )

unregister(model_name)

Remove a model from drift monitoring.

Parameters

model_name: Name of the model to unregister.

Raises

KeyError: If the model is not registered.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
def unregister(self, model_name: str) -> None:
    """Remove a model from drift monitoring.

    Parameters
    ----------
    model_name:
        Name of the model to unregister.

    Raises
    ------
    KeyError:
        If the model is not registered.
    """
    with self._lock:
        if model_name not in self._monitors:
            msg = f"Model {model_name!r} is not registered for drift monitoring"
            raise KeyError(msg)
        del self._monitors[model_name]
        self._results.pop(model_name, None)
    logger.info("Drift monitor unregistered: model=%s", model_name)

start()

Start the background monitoring thread.

Raises

RuntimeError: If the scheduler is already running.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
def start(self) -> None:
    """Start the background monitoring thread.

    Raises
    ------
    RuntimeError:
        If the scheduler is already running.
    """
    if self._thread is not None and self._thread.is_alive():
        msg = "DriftScheduler is already running"
        raise RuntimeError(msg)

    self._stop_event.clear()
    self._thread = threading.Thread(
        target=self._run_loop,
        name="drift-scheduler",
        daemon=True,
    )
    self._thread.start()
    logger.info("DriftScheduler started (tick=%ss)", self._tick)

stop(timeout=10.0)

Stop the background monitoring thread.

Parameters

timeout: Seconds to wait for the thread to join (default 10.0).

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
def stop(self, timeout: float = 10.0) -> None:
    """Stop the background monitoring thread.

    Parameters
    ----------
    timeout:
        Seconds to wait for the thread to join (default ``10.0``).
    """
    self._stop_event.set()
    if self._thread is not None:
        self._thread.join(timeout=timeout)
        self._thread = None
    logger.info("DriftScheduler stopped")

get_last_result(model_name)

Return the most recent drift check result for a model.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
def get_last_result(self, model_name: str) -> DriftCheckResult | None:
    """Return the most recent drift check result for a model."""
    return self._results.get(model_name)

run_check(model_name)

Manually trigger a drift check for one model.

Parameters

model_name: Name of a registered model to check.

Raises

KeyError: If the model is not registered.

Returns

DriftCheckResult: Aggregated result with per-feature reports.

Source code in packages/dataenginex/src/dataenginex/ml/scheduler.py
def run_check(self, model_name: str) -> DriftCheckResult:
    """Manually trigger a drift check for one model.

    Parameters
    ----------
    model_name:
        Name of a registered model to check.

    Raises
    ------
    KeyError:
        If the model is not registered.

    Returns
    -------
    DriftCheckResult:
        Aggregated result with per-feature reports.
    """
    with self._lock:
        entry = self._monitors.get(model_name)
        if entry is None:
            msg = f"Model {model_name!r} is not registered for drift monitoring"
            raise KeyError(msg)
        config, data_fn, _ = entry

    return self._execute_check(config, data_fn)

ModelServer

Registry-aware model server.

Loads a model from the ModelRegistry and serves predictions via the predict method.

Parameters

registry: A ModelRegistry instance (from dataenginex.ml.registry).

Source code in packages/dataenginex/src/dataenginex/ml/serving.py
class ModelServer:
    """Registry-aware model server.

    Loads a model from the ``ModelRegistry`` and serves predictions via
    the ``predict`` method.

    Parameters
    ----------
    registry:
        A ``ModelRegistry`` instance (from ``dataenginex.ml.registry``).
    """

    def __init__(self, registry: Any = None) -> None:
        self._registry = registry
        self._loaded: dict[str, Any] = {}  # "name:version" → model object

    def load_model(self, name: str, version: str, model: Any) -> None:
        """Register a model object for serving.

        Parameters
        ----------
        name:
            Model name matching registry entries.
        version:
            Model version.
        model:
            Any object with a ``predict(X)`` method.
        """
        key = f"{name}:{version}"
        self._loaded[key] = model
        logger.info("Loaded model %s for serving", key)

    def predict(self, request: PredictionRequest) -> PredictionResponse:
        """Run inference for *request*."""
        start = time.perf_counter()

        version = request.version or self._resolve_production_version(request.model_name)
        key = f"{request.model_name}:{version}"

        model = self._loaded.get(key)
        if model is None:
            raise RuntimeError(
                f"Model {request.model_name} v{version} not loaded — call load_model() first"
            )

        # Convert features to the format the model expects
        feature_values = self._features_to_array(request.features)
        raw_predictions = model.predict(feature_values)

        latency = (time.perf_counter() - start) * 1000

        predictions: list[Any]
        if hasattr(raw_predictions, "tolist"):
            predictions = raw_predictions.tolist()
        else:
            predictions = list(raw_predictions)

        return PredictionResponse(
            model_name=request.model_name,
            version=version,
            predictions=predictions,
            latency_ms=latency,
            request_id=request.request_id,
        )

    def list_loaded(self) -> list[str]:
        """Return keys of all loaded models."""
        return list(self._loaded.keys())

    # -- helpers -------------------------------------------------------------

    def _resolve_production_version(self, name: str) -> str:
        if self._registry is not None:
            prod = self._registry.get_production(name)
            if prod:
                return prod.version  # type: ignore[no-any-return]
        # Fallback: use latest loaded
        loaded_keys: list[str] = list(self._loaded.keys())
        for key in reversed(loaded_keys):
            if key.startswith(f"{name}:"):
                version: str = key.split(":")[1]
                return version
        raise RuntimeError(f"No version found for model {name!r}")

    @staticmethod
    def _features_to_array(features: list[dict[str, Any]]) -> list[list[Any]]:
        """Convert list-of-dicts to a 2D list suitable for sklearn ``predict``."""
        if not features:
            return []
        keys = list(features[0].keys())
        return [[row.get(k) for k in keys] for row in features]

load_model(name, version, model)

Register a model object for serving.

Parameters

name: Model name matching registry entries. version: Model version. model: Any object with a predict(X) method.

Source code in packages/dataenginex/src/dataenginex/ml/serving.py
def load_model(self, name: str, version: str, model: Any) -> None:
    """Register a model object for serving.

    Parameters
    ----------
    name:
        Model name matching registry entries.
    version:
        Model version.
    model:
        Any object with a ``predict(X)`` method.
    """
    key = f"{name}:{version}"
    self._loaded[key] = model
    logger.info("Loaded model %s for serving", key)

predict(request)

Run inference for request.

Source code in packages/dataenginex/src/dataenginex/ml/serving.py
def predict(self, request: PredictionRequest) -> PredictionResponse:
    """Run inference for *request*."""
    start = time.perf_counter()

    version = request.version or self._resolve_production_version(request.model_name)
    key = f"{request.model_name}:{version}"

    model = self._loaded.get(key)
    if model is None:
        raise RuntimeError(
            f"Model {request.model_name} v{version} not loaded — call load_model() first"
        )

    # Convert features to the format the model expects
    feature_values = self._features_to_array(request.features)
    raw_predictions = model.predict(feature_values)

    latency = (time.perf_counter() - start) * 1000

    predictions: list[Any]
    if hasattr(raw_predictions, "tolist"):
        predictions = raw_predictions.tolist()
    else:
        predictions = list(raw_predictions)

    return PredictionResponse(
        model_name=request.model_name,
        version=version,
        predictions=predictions,
        latency_ms=latency,
        request_id=request.request_id,
    )

list_loaded()

Return keys of all loaded models.

Source code in packages/dataenginex/src/dataenginex/ml/serving.py
def list_loaded(self) -> list[str]:
    """Return keys of all loaded models."""
    return list(self._loaded.keys())

PredictionRequest dataclass

Input to the serving layer.

Attributes:

Name Type Description
model_name str

Name of the model to invoke.

version str | None

Model version (None resolves to the production version).

features list[dict[str, Any]]

List of feature dicts — each dict is one sample.

request_id str

Caller-provided request ID for tracing.

Source code in packages/dataenginex/src/dataenginex/ml/serving.py
@dataclass
class PredictionRequest:
    """Input to the serving layer.

    Attributes:
        model_name: Name of the model to invoke.
        version: Model version (``None`` resolves to the production version).
        features: List of feature dicts — each dict is one sample.
        request_id: Caller-provided request ID for tracing.
    """

    model_name: str
    version: str | None = None  # None → use production model
    features: list[dict[str, Any]] = field(default_factory=list)
    request_id: str = ""

PredictionResponse dataclass

Output from the serving layer.

Attributes:

Name Type Description
model_name str

Name of the model that produced predictions.

version str

Version of the model used.

predictions list[Any]

List of prediction values.

latency_ms float

Inference latency in milliseconds.

request_id str

Echoed request ID for tracing.

served_at datetime

Timestamp of the prediction.

Source code in packages/dataenginex/src/dataenginex/ml/serving.py
@dataclass
class PredictionResponse:
    """Output from the serving layer.

    Attributes:
        model_name: Name of the model that produced predictions.
        version: Version of the model used.
        predictions: List of prediction values.
        latency_ms: Inference latency in milliseconds.
        request_id: Echoed request ID for tracing.
        served_at: Timestamp of the prediction.
    """

    model_name: str
    version: str
    predictions: list[Any] = field(default_factory=list)
    latency_ms: float = 0.0
    request_id: str = ""
    served_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the prediction response to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "version": self.version,
            "predictions": self.predictions,
            "latency_ms": round(self.latency_ms, 2),
            "request_id": self.request_id,
            "served_at": self.served_at.isoformat(),
        }

to_dict()

Serialize the prediction response to a plain dictionary.

Source code in packages/dataenginex/src/dataenginex/ml/serving.py
def to_dict(self) -> dict[str, Any]:
    """Serialize the prediction response to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "version": self.version,
        "predictions": self.predictions,
        "latency_ms": round(self.latency_ms, 2),
        "request_id": self.request_id,
        "served_at": self.served_at.isoformat(),
    }

BaseTrainer

Bases: ABC

Abstract base class for model trainers.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
class BaseTrainer(ABC):
    """Abstract base class for model trainers."""

    def __init__(self, model_name: str, version: str = "1.0.0") -> None:
        self.model_name = model_name
        self.version = version

    @abstractmethod
    def train(
        self,
        X_train: Any,
        y_train: Any,
        **params: Any,  # noqa: N803
    ) -> TrainingResult:
        """Train the model and return metrics."""
        ...

    @abstractmethod
    def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
        """Evaluate the model on test data."""
        ...

    @abstractmethod
    def predict(self, X: Any) -> Any:  # noqa: N803
        """Generate predictions."""
        ...

    @abstractmethod
    def save(self, path: str) -> str:
        """Persist the model to *path* and return the artifact path."""
        ...

    @abstractmethod
    def load(self, path: str) -> None:
        """Load a previously saved model from *path*."""
        ...

train(X_train, y_train, **params) abstractmethod

Train the model and return metrics.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
@abstractmethod
def train(
    self,
    X_train: Any,
    y_train: Any,
    **params: Any,  # noqa: N803
) -> TrainingResult:
    """Train the model and return metrics."""
    ...

evaluate(X_test, y_test) abstractmethod

Evaluate the model on test data.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
@abstractmethod
def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
    """Evaluate the model on test data."""
    ...

predict(X) abstractmethod

Generate predictions.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
@abstractmethod
def predict(self, X: Any) -> Any:  # noqa: N803
    """Generate predictions."""
    ...

save(path) abstractmethod

Persist the model to path and return the artifact path.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
@abstractmethod
def save(self, path: str) -> str:
    """Persist the model to *path* and return the artifact path."""
    ...

load(path) abstractmethod

Load a previously saved model from path.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
@abstractmethod
def load(self, path: str) -> None:
    """Load a previously saved model from *path*."""
    ...

SklearnTrainer

Bases: BaseTrainer

scikit-learn model trainer.

Works with any sklearn estimator (or pipeline) that implements fit, predict, and score.

Parameters

model_name: Name used in model registry. version: Semantic version string. estimator: An sklearn estimator instance (e.g. RandomForestClassifier()).

Source code in packages/dataenginex/src/dataenginex/ml/training.py
class SklearnTrainer(BaseTrainer):
    """scikit-learn model trainer.

    Works with any sklearn estimator (or pipeline) that implements
    ``fit``, ``predict``, and ``score``.

    Parameters
    ----------
    model_name:
        Name used in model registry.
    version:
        Semantic version string.
    estimator:
        An sklearn estimator instance (e.g. ``RandomForestClassifier()``).
    """

    def __init__(
        self,
        model_name: str,
        version: str = "1.0.0",
        estimator: Any = None,
    ) -> None:
        super().__init__(model_name, version)
        self.estimator = estimator
        self._is_fitted = False

    def train(
        self,
        X_train: Any,
        y_train: Any,
        **params: Any,  # noqa: N803
    ) -> TrainingResult:
        """Fit the estimator on *X_train*/*y_train* and return metrics."""
        if self.estimator is None:
            raise RuntimeError("No estimator provided to SklearnTrainer")

        # Apply params
        if params:
            self.estimator.set_params(**params)

        start = time.perf_counter()
        self.estimator.fit(X_train, y_train)
        duration = time.perf_counter() - start
        self._is_fitted = True

        # Compute training score
        train_score = float(self.estimator.score(X_train, y_train))
        metrics = {"train_score": round(train_score, 4)}

        logger.info(
            "Trained %s v%s in %.2fs (train_score=%.4f)",
            self.model_name,
            self.version,
            duration,
            train_score,
        )

        return TrainingResult(
            model_name=self.model_name,
            version=self.version,
            metrics=metrics,
            parameters=self.estimator.get_params(),
            duration_seconds=duration,
        )

    def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
        """Score the fitted model on *X_test*/*y_test* and return metrics."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")

        test_score = float(self.estimator.score(X_test, y_test))
        predictions = self.estimator.predict(X_test)

        metrics: dict[str, float] = {"test_score": round(test_score, 4)}

        # Attempt classification metrics
        try:
            from sklearn.metrics import (  # type: ignore[import-not-found]
                f1_score,
                precision_score,
                recall_score,
            )

            metrics["precision"] = round(
                float(
                    precision_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
            metrics["recall"] = round(
                float(
                    recall_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
            metrics["f1"] = round(
                float(
                    f1_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
        except Exception:
            pass

        return metrics

    def predict(self, X: Any) -> Any:  # noqa: N803
        """Generate predictions for *X* using the fitted estimator."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")
        return self.estimator.predict(X)

    def save(self, path: str) -> str:
        """Pickle the fitted model and its metadata to *path*."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")

        out = Path(path)
        out.parent.mkdir(parents=True, exist_ok=True)
        out.write_bytes(pickle.dumps(self.estimator))

        # Save metadata alongside
        meta = out.with_suffix(".json")
        meta.write_text(
            json.dumps(
                {
                    "model_name": self.model_name,
                    "version": self.version,
                    "saved_at": datetime.now(tz=UTC).isoformat(),
                }
            )
        )

        logger.info("Saved model %s to %s", self.model_name, out)
        return str(out)

    def load(self, path: str) -> None:
        """Load a pickled model from *path* and mark as fitted."""
        data = Path(path).read_bytes()
        self.estimator = pickle.loads(data)  # noqa: S301
        self._is_fitted = True
        logger.info("Loaded model %s from %s", self.model_name, path)

train(X_train, y_train, **params)

Fit the estimator on X_train/y_train and return metrics.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
def train(
    self,
    X_train: Any,
    y_train: Any,
    **params: Any,  # noqa: N803
) -> TrainingResult:
    """Fit the estimator on *X_train*/*y_train* and return metrics."""
    if self.estimator is None:
        raise RuntimeError("No estimator provided to SklearnTrainer")

    # Apply params
    if params:
        self.estimator.set_params(**params)

    start = time.perf_counter()
    self.estimator.fit(X_train, y_train)
    duration = time.perf_counter() - start
    self._is_fitted = True

    # Compute training score
    train_score = float(self.estimator.score(X_train, y_train))
    metrics = {"train_score": round(train_score, 4)}

    logger.info(
        "Trained %s v%s in %.2fs (train_score=%.4f)",
        self.model_name,
        self.version,
        duration,
        train_score,
    )

    return TrainingResult(
        model_name=self.model_name,
        version=self.version,
        metrics=metrics,
        parameters=self.estimator.get_params(),
        duration_seconds=duration,
    )

evaluate(X_test, y_test)

Score the fitted model on X_test/y_test and return metrics.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
    """Score the fitted model on *X_test*/*y_test* and return metrics."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")

    test_score = float(self.estimator.score(X_test, y_test))
    predictions = self.estimator.predict(X_test)

    metrics: dict[str, float] = {"test_score": round(test_score, 4)}

    # Attempt classification metrics
    try:
        from sklearn.metrics import (  # type: ignore[import-not-found]
            f1_score,
            precision_score,
            recall_score,
        )

        metrics["precision"] = round(
            float(
                precision_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
        metrics["recall"] = round(
            float(
                recall_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
        metrics["f1"] = round(
            float(
                f1_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
    except Exception:
        pass

    return metrics

predict(X)

Generate predictions for X using the fitted estimator.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
def predict(self, X: Any) -> Any:  # noqa: N803
    """Generate predictions for *X* using the fitted estimator."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")
    return self.estimator.predict(X)

save(path)

Pickle the fitted model and its metadata to path.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
def save(self, path: str) -> str:
    """Pickle the fitted model and its metadata to *path*."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")

    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_bytes(pickle.dumps(self.estimator))

    # Save metadata alongside
    meta = out.with_suffix(".json")
    meta.write_text(
        json.dumps(
            {
                "model_name": self.model_name,
                "version": self.version,
                "saved_at": datetime.now(tz=UTC).isoformat(),
            }
        )
    )

    logger.info("Saved model %s to %s", self.model_name, out)
    return str(out)

load(path)

Load a pickled model from path and mark as fitted.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
def load(self, path: str) -> None:
    """Load a pickled model from *path* and mark as fitted."""
    data = Path(path).read_bytes()
    self.estimator = pickle.loads(data)  # noqa: S301
    self._is_fitted = True
    logger.info("Loaded model %s from %s", self.model_name, path)

TrainingResult dataclass

Outcome of a model training run.

Attributes:

Name Type Description
model_name str

Name of the trained model.

version str

Semantic version of this training run.

metrics dict[str, float]

Training metrics (e.g. {"train_score": 0.95}).

parameters dict[str, Any]

Hyper-parameters used for training.

duration_seconds float

Wall-clock training time.

artifact_path str | None

Path where the model artifact is saved.

trained_at datetime

Timestamp of training completion.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
@dataclass
class TrainingResult:
    """Outcome of a model training run.

    Attributes:
        model_name: Name of the trained model.
        version: Semantic version of this training run.
        metrics: Training metrics (e.g. ``{"train_score": 0.95}``).
        parameters: Hyper-parameters used for training.
        duration_seconds: Wall-clock training time.
        artifact_path: Path where the model artifact is saved.
        trained_at: Timestamp of training completion.
    """

    model_name: str
    version: str
    metrics: dict[str, float] = field(default_factory=dict)
    parameters: dict[str, Any] = field(default_factory=dict)
    duration_seconds: float = 0.0
    artifact_path: str | None = None
    trained_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the training result to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "version": self.version,
            "metrics": self.metrics,
            "parameters": self.parameters,
            "duration_seconds": round(self.duration_seconds, 2),
            "artifact_path": self.artifact_path,
            "trained_at": self.trained_at.isoformat(),
        }

to_dict()

Serialize the training result to a plain dictionary.

Source code in packages/dataenginex/src/dataenginex/ml/training.py
def to_dict(self) -> dict[str, Any]:
    """Serialize the training result to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "version": self.version,
        "metrics": self.metrics,
        "parameters": self.parameters,
        "duration_seconds": round(self.duration_seconds, 2),
        "artifact_path": self.artifact_path,
        "trained_at": self.trained_at.isoformat(),
    }