Skip to content

dataenginex.ml

Classical ML — training, registry, drift, serving, metrics.

LLM / vectorstore / scheduling live in dataenginex.ai and dataenginex.orchestration respectively.

Public API::

from dataenginex.ml import (
    BaseTrainer, SklearnTrainer, TrainingResult,
    ModelRegistry, ModelArtifact, ModelStage,
    MLflowModelRegistry, MLflowRegistryError,
    DriftDetector, DriftReport,
    ModelServer, PredictionRequest, PredictionResponse,
    model_prediction_total, model_prediction_latency_seconds,
    model_drift_psi, model_drift_alerts_total,
)

DriftDetector

Detect distribution drift between a reference and current dataset.

PSI thresholds (industry standard): < 0.10 — no significant drift 0.10-0.25 — moderate drift > 0.25 — significant drift

Parameters

psi_threshold: PSI value above which drift is flagged (default 0.20). n_bins: Number of histogram bins for PSI calculation (default 10).

Source code in src/dataenginex/ml/drift.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class DriftDetector:
    """Detect distribution drift between a reference and current dataset.

    PSI thresholds (industry standard):
        < 0.10  — no significant drift
        0.10-0.25 — moderate drift
        > 0.25  — significant drift

    Parameters
    ----------
    psi_threshold:
        PSI value above which drift is flagged (default 0.20).
    n_bins:
        Number of histogram bins for PSI calculation (default 10).
    """

    def __init__(self, psi_threshold: float = 0.20, n_bins: int = 10) -> None:
        self.psi_threshold = psi_threshold
        self.n_bins = n_bins

    def check_feature(
        self,
        feature_name: str,
        reference: list[float],
        current: list[float],
    ) -> DriftReport:
        """Check drift for a single numeric feature."""
        if not reference or not current:
            return DriftReport(
                feature_name=feature_name,
                psi=0.0,
                drift_detected=False,
                severity="none",
                details={"error": "Empty distribution(s)"},
            )

        import statistics

        ref_mean = statistics.mean(reference)
        cur_mean = statistics.mean(current)
        ref_std = statistics.stdev(reference) if len(reference) > 1 else 0.0
        cur_std = statistics.stdev(current) if len(current) > 1 else 0.0

        psi = self._compute_psi(reference, current)
        severity = self._classify_severity(psi)

        return DriftReport(
            feature_name=feature_name,
            psi=psi,
            drift_detected=psi > self.psi_threshold,
            severity=severity,
            reference_mean=round(ref_mean, 4),
            current_mean=round(cur_mean, 4),
            reference_std=round(ref_std, 4),
            current_std=round(cur_std, 4),
            details={"n_bins": self.n_bins, "threshold": self.psi_threshold},
        )

    def check_dataset(
        self,
        reference: dict[str, list[float]],
        current: dict[str, list[float]],
    ) -> list[DriftReport]:
        """Check drift across all shared features in two datasets.

        Parameters
        ----------
        reference:
            Mapping of ``feature_name → values`` for the reference period.
        current:
            Mapping of ``feature_name → values`` for the current period.
        """
        reports: list[DriftReport] = []
        shared = set(reference.keys()) & set(current.keys())
        for feat in sorted(shared):
            reports.append(self.check_feature(feat, reference[feat], current[feat]))
        return reports

    # -- internal ------------------------------------------------------------

    def _compute_psi(self, reference: list[float], current: list[float]) -> float:
        """Compute PSI using equal-width binning."""
        all_vals = reference + current
        lo = min(all_vals)
        hi = max(all_vals)
        if lo == hi:
            return 0.0

        bin_width = (hi - lo) / self.n_bins
        eps = 1e-6

        ref_counts = [0] * self.n_bins
        cur_counts = [0] * self.n_bins

        for v in reference:
            idx = min(int((v - lo) / bin_width), self.n_bins - 1)
            ref_counts[idx] += 1
        for v in current:
            idx = min(int((v - lo) / bin_width), self.n_bins - 1)
            cur_counts[idx] += 1

        ref_total = len(reference)
        cur_total = len(current)

        psi = 0.0
        for r, c in zip(ref_counts, cur_counts, strict=True):
            ref_pct = (r / ref_total) + eps
            cur_pct = (c / cur_total) + eps
            psi += (cur_pct - ref_pct) * math.log(cur_pct / ref_pct)

        return max(0.0, psi)

    @staticmethod
    def _classify_severity(psi: float) -> str:
        if psi < 0.10:
            return "none"
        if psi < 0.25:
            return "moderate"
        return "severe"

check_feature(feature_name, reference, current)

Check drift for a single numeric feature.

Source code in src/dataenginex/ml/drift.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def check_feature(
    self,
    feature_name: str,
    reference: list[float],
    current: list[float],
) -> DriftReport:
    """Check drift for a single numeric feature."""
    if not reference or not current:
        return DriftReport(
            feature_name=feature_name,
            psi=0.0,
            drift_detected=False,
            severity="none",
            details={"error": "Empty distribution(s)"},
        )

    import statistics

    ref_mean = statistics.mean(reference)
    cur_mean = statistics.mean(current)
    ref_std = statistics.stdev(reference) if len(reference) > 1 else 0.0
    cur_std = statistics.stdev(current) if len(current) > 1 else 0.0

    psi = self._compute_psi(reference, current)
    severity = self._classify_severity(psi)

    return DriftReport(
        feature_name=feature_name,
        psi=psi,
        drift_detected=psi > self.psi_threshold,
        severity=severity,
        reference_mean=round(ref_mean, 4),
        current_mean=round(cur_mean, 4),
        reference_std=round(ref_std, 4),
        current_std=round(cur_std, 4),
        details={"n_bins": self.n_bins, "threshold": self.psi_threshold},
    )

check_dataset(reference, current)

Check drift across all shared features in two datasets.

Parameters

reference: Mapping of feature_name → values for the reference period. current: Mapping of feature_name → values for the current period.

Source code in src/dataenginex/ml/drift.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def check_dataset(
    self,
    reference: dict[str, list[float]],
    current: dict[str, list[float]],
) -> list[DriftReport]:
    """Check drift across all shared features in two datasets.

    Parameters
    ----------
    reference:
        Mapping of ``feature_name → values`` for the reference period.
    current:
        Mapping of ``feature_name → values`` for the current period.
    """
    reports: list[DriftReport] = []
    shared = set(reference.keys()) & set(current.keys())
    for feat in sorted(shared):
        reports.append(self.check_feature(feat, reference[feat], current[feat]))
    return reports

DriftReport dataclass

Outcome of a drift check for a single feature.

Attributes:

Name Type Description
feature_name str

Name of the feature that was checked.

psi float

Population Stability Index value.

drift_detected bool

Whether drift exceeds the configured threshold.

severity str

Drift severity — "none", "moderate", or "severe".

reference_mean float | None

Mean of the reference distribution.

current_mean float | None

Mean of the current distribution.

reference_std float | None

Standard deviation of reference distribution.

current_std float | None

Standard deviation of current distribution.

details dict[str, Any]

Extra context (bins, threshold, etc.).

checked_at datetime

Timestamp of the drift check.

Source code in src/dataenginex/ml/drift.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@dataclass
class DriftReport:
    """Outcome of a drift check for a single feature.

    Attributes:
        feature_name: Name of the feature that was checked.
        psi: Population Stability Index value.
        drift_detected: Whether drift exceeds the configured threshold.
        severity: Drift severity — ``"none"``, ``"moderate"``, or ``"severe"``.
        reference_mean: Mean of the reference distribution.
        current_mean: Mean of the current distribution.
        reference_std: Standard deviation of reference distribution.
        current_std: Standard deviation of current distribution.
        details: Extra context (bins, threshold, etc.).
        checked_at: Timestamp of the drift check.
    """

    feature_name: str
    psi: float  # Population Stability Index
    drift_detected: bool
    severity: str  # "none", "moderate", "severe"
    reference_mean: float | None = None
    current_mean: float | None = None
    reference_std: float | None = None
    current_std: float | None = None
    details: dict[str, Any] = field(default_factory=dict)
    checked_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the drift report to a plain dictionary."""
        return {
            "feature_name": self.feature_name,
            "psi": round(self.psi, 6),
            "drift_detected": self.drift_detected,
            "severity": self.severity,
            "reference_mean": self.reference_mean,
            "current_mean": self.current_mean,
            "reference_std": self.reference_std,
            "current_std": self.current_std,
            "details": self.details,
            "checked_at": self.checked_at.isoformat(),
        }

to_dict()

Serialize the drift report to a plain dictionary.

Source code in src/dataenginex/ml/drift.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def to_dict(self) -> dict[str, Any]:
    """Serialize the drift report to a plain dictionary."""
    return {
        "feature_name": self.feature_name,
        "psi": round(self.psi, 6),
        "drift_detected": self.drift_detected,
        "severity": self.severity,
        "reference_mean": self.reference_mean,
        "current_mean": self.current_mean,
        "reference_std": self.reference_std,
        "current_std": self.current_std,
        "details": self.details,
        "checked_at": self.checked_at.isoformat(),
    }

MLflowModelRegistry

MLflow-backed model registry compatible with ModelRegistry.

Parameters

tracking_uri: MLflow tracking server URI. Defaults to MLFLOW_TRACKING_URI env var or http://localhost:5000.

Source code in src/dataenginex/ml/mlflow_registry.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class MLflowModelRegistry:
    """MLflow-backed model registry compatible with ``ModelRegistry``.

    Parameters
    ----------
    tracking_uri:
        MLflow tracking server URI.  Defaults to ``MLFLOW_TRACKING_URI``
        env var or ``http://localhost:5000``.
    """

    def __init__(self, tracking_uri: str = _DEFAULT_TRACKING_URI) -> None:
        self._tracking_uri = tracking_uri
        self._client = _get_client(tracking_uri)
        logger.info("mlflow registry connected", uri=tracking_uri)

    # -- registration --------------------------------------------------------

    def register(self, artifact: ModelArtifact) -> ModelArtifact:
        """Register a model version in MLflow and log its run metadata."""
        import mlflow  # noqa: PLC0415

        mlflow.set_tracking_uri(self._tracking_uri)

        try:
            with mlflow.start_run(run_name=f"{artifact.name}_v{artifact.version}") as run:
                mlflow.log_params(artifact.parameters)
                mlflow.log_metrics(artifact.metrics)
                mlflow.set_tags(
                    {
                        "dex.version": artifact.version,
                        "dex.description": artifact.description,
                        **{f"dex.tag.{t}": "true" for t in artifact.tags},
                    }
                )

                # Register the model URI (use artifact_path if it's a local path)
                model_uri = (
                    f"runs:/{run.info.run_id}/model"
                    if not artifact.artifact_path
                    else artifact.artifact_path
                )
                mv = mlflow.register_model(model_uri=model_uri, name=artifact.name)

            logger.info(
                "Registered %s v%s in MLflow (version=%s)",
                artifact.name,
                artifact.version,
                mv.version,
            )
        except Exception as exc:
            raise MLflowRegistryError(
                f"Failed to register {artifact.name!r} in MLflow: {exc}"
            ) from exc

        return artifact

    # -- queries -------------------------------------------------------------

    def get(self, name: str, version: str) -> ModelArtifact | None:
        """Fetch a model version from MLflow."""
        try:
            mv = self._client.get_model_version(name=name, version=version)
        except Exception:  # noqa: BLE001
            return None

        return self._mv_to_artifact(mv)

    def get_latest(self, name: str) -> ModelArtifact | None:
        """Return the highest version number for *name* regardless of stage."""
        try:
            versions = self._client.search_model_versions(f"name='{name}'")
        except Exception:  # noqa: BLE001
            return None

        if not versions:
            return None

        latest = max(versions, key=lambda v: int(v.version))
        return self._mv_to_artifact(latest)

    def get_production(self, name: str) -> ModelArtifact | None:
        """Return the model currently aliased as ``production``."""
        try:
            mv = self._client.get_model_version_by_alias(name, "production")
        except Exception:  # noqa: BLE001
            return None

        return self._mv_to_artifact(mv)

    def list_models(self) -> list[str]:
        """Return all registered model names."""
        try:
            registered = self._client.search_registered_models()
            return [m.name for m in registered]
        except Exception as exc:  # noqa: BLE001
            raise MLflowRegistryError(f"Failed to list models: {exc}") from exc

    def list_versions(self, name: str) -> list[str]:
        """Return all version strings for *name*."""
        try:
            versions = self._client.search_model_versions(f"name='{name}'")
            return [v.version for v in versions]
        except Exception as exc:  # noqa: BLE001
            raise MLflowRegistryError(f"Failed to list versions for {name!r}: {exc}") from exc

    # -- promotion -----------------------------------------------------------

    def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
        """Transition a model version to the target stage via MLflow aliases."""
        try:
            if target_stage == ModelStage.DEVELOPMENT:
                # Remove any DEX-managed aliases from this version
                mv = self._client.get_model_version(name, version)
                for alias in list(getattr(mv, "aliases", [])):
                    if alias in _REVERSE_ALIAS_MAP:
                        self._client.delete_registered_model_alias(name, alias)
            else:
                alias = _ALIAS_MAP[target_stage]
                self._client.set_registered_model_alias(name, alias, version)
        except Exception as exc:
            raise MLflowRegistryError(
                f"Failed to promote {name!r} v{version} to {target_stage}: {exc}"
            ) from exc

        mv = self._client.get_model_version(name, version)
        logger.info("model promoted", name=name, version=version, stage=str(target_stage))
        return self._mv_to_artifact(mv)

    # -- helpers -------------------------------------------------------------

    def _mv_to_artifact(self, mv: Any) -> ModelArtifact:
        """Convert an MLflow ModelVersion object to a ``ModelArtifact``."""
        aliases: list[str] = list(getattr(mv, "aliases", []))

        # Resolve stage from aliases: production > staging > archived > development
        stage = ModelStage.DEVELOPMENT
        for alias in aliases:
            candidate = _REVERSE_ALIAS_MAP.get(alias)
            if candidate == ModelStage.PRODUCTION:
                stage = ModelStage.PRODUCTION
                break
            if candidate == ModelStage.STAGING:
                stage = ModelStage.STAGING
            elif candidate == ModelStage.ARCHIVED and stage == ModelStage.DEVELOPMENT:
                stage = ModelStage.ARCHIVED

        creation_ts = getattr(mv, "creation_timestamp", None)
        created_at = (
            datetime.fromtimestamp(creation_ts / 1000, tz=UTC)
            if creation_ts
            else datetime.now(tz=UTC)
        )

        last_updated_ts = getattr(mv, "last_updated_timestamp", None)
        promoted_at = (
            datetime.fromtimestamp(last_updated_ts / 1000, tz=UTC) if last_updated_ts else None
        )

        return ModelArtifact(
            name=mv.name,
            version=mv.version,
            stage=stage,
            artifact_path=getattr(mv, "source", ""),
            description=getattr(mv, "description", "") or "",
            tags=[t.key for t in getattr(mv, "tags", [])],
            created_at=created_at,
            promoted_at=promoted_at,
        )

register(artifact)

Register a model version in MLflow and log its run metadata.

Source code in src/dataenginex/ml/mlflow_registry.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def register(self, artifact: ModelArtifact) -> ModelArtifact:
    """Register a model version in MLflow and log its run metadata."""
    import mlflow  # noqa: PLC0415

    mlflow.set_tracking_uri(self._tracking_uri)

    try:
        with mlflow.start_run(run_name=f"{artifact.name}_v{artifact.version}") as run:
            mlflow.log_params(artifact.parameters)
            mlflow.log_metrics(artifact.metrics)
            mlflow.set_tags(
                {
                    "dex.version": artifact.version,
                    "dex.description": artifact.description,
                    **{f"dex.tag.{t}": "true" for t in artifact.tags},
                }
            )

            # Register the model URI (use artifact_path if it's a local path)
            model_uri = (
                f"runs:/{run.info.run_id}/model"
                if not artifact.artifact_path
                else artifact.artifact_path
            )
            mv = mlflow.register_model(model_uri=model_uri, name=artifact.name)

        logger.info(
            "Registered %s v%s in MLflow (version=%s)",
            artifact.name,
            artifact.version,
            mv.version,
        )
    except Exception as exc:
        raise MLflowRegistryError(
            f"Failed to register {artifact.name!r} in MLflow: {exc}"
        ) from exc

    return artifact

get(name, version)

Fetch a model version from MLflow.

Source code in src/dataenginex/ml/mlflow_registry.py
117
118
119
120
121
122
123
124
def get(self, name: str, version: str) -> ModelArtifact | None:
    """Fetch a model version from MLflow."""
    try:
        mv = self._client.get_model_version(name=name, version=version)
    except Exception:  # noqa: BLE001
        return None

    return self._mv_to_artifact(mv)

get_latest(name)

Return the highest version number for name regardless of stage.

Source code in src/dataenginex/ml/mlflow_registry.py
126
127
128
129
130
131
132
133
134
135
136
137
def get_latest(self, name: str) -> ModelArtifact | None:
    """Return the highest version number for *name* regardless of stage."""
    try:
        versions = self._client.search_model_versions(f"name='{name}'")
    except Exception:  # noqa: BLE001
        return None

    if not versions:
        return None

    latest = max(versions, key=lambda v: int(v.version))
    return self._mv_to_artifact(latest)

get_production(name)

Return the model currently aliased as production.

Source code in src/dataenginex/ml/mlflow_registry.py
139
140
141
142
143
144
145
146
def get_production(self, name: str) -> ModelArtifact | None:
    """Return the model currently aliased as ``production``."""
    try:
        mv = self._client.get_model_version_by_alias(name, "production")
    except Exception:  # noqa: BLE001
        return None

    return self._mv_to_artifact(mv)

list_models()

Return all registered model names.

Source code in src/dataenginex/ml/mlflow_registry.py
148
149
150
151
152
153
154
def list_models(self) -> list[str]:
    """Return all registered model names."""
    try:
        registered = self._client.search_registered_models()
        return [m.name for m in registered]
    except Exception as exc:  # noqa: BLE001
        raise MLflowRegistryError(f"Failed to list models: {exc}") from exc

list_versions(name)

Return all version strings for name.

Source code in src/dataenginex/ml/mlflow_registry.py
156
157
158
159
160
161
162
def list_versions(self, name: str) -> list[str]:
    """Return all version strings for *name*."""
    try:
        versions = self._client.search_model_versions(f"name='{name}'")
        return [v.version for v in versions]
    except Exception as exc:  # noqa: BLE001
        raise MLflowRegistryError(f"Failed to list versions for {name!r}: {exc}") from exc

promote(name, version, target_stage)

Transition a model version to the target stage via MLflow aliases.

Source code in src/dataenginex/ml/mlflow_registry.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
    """Transition a model version to the target stage via MLflow aliases."""
    try:
        if target_stage == ModelStage.DEVELOPMENT:
            # Remove any DEX-managed aliases from this version
            mv = self._client.get_model_version(name, version)
            for alias in list(getattr(mv, "aliases", [])):
                if alias in _REVERSE_ALIAS_MAP:
                    self._client.delete_registered_model_alias(name, alias)
        else:
            alias = _ALIAS_MAP[target_stage]
            self._client.set_registered_model_alias(name, alias, version)
    except Exception as exc:
        raise MLflowRegistryError(
            f"Failed to promote {name!r} v{version} to {target_stage}: {exc}"
        ) from exc

    mv = self._client.get_model_version(name, version)
    logger.info("model promoted", name=name, version=version, stage=str(target_stage))
    return self._mv_to_artifact(mv)

MLflowRegistryError

Bases: RuntimeError

Raised when the MLflow server is unreachable or returns an error.

Source code in src/dataenginex/ml/mlflow_registry.py
45
46
class MLflowRegistryError(RuntimeError):
    """Raised when the MLflow server is unreachable or returns an error."""

ModelArtifact dataclass

Registry entry for a model version.

Attributes:

Name Type Description
name str

Model name (e.g. "churn_classifier").

version str

Semantic version string.

stage ModelStage

Current lifecycle stage.

artifact_path str

File path to the serialised model.

metrics dict[str, float]

Training/evaluation metrics.

parameters dict[str, Any]

Hyper-parameters used for training.

description str

Free-text description.

created_at datetime

When the artifact was registered.

promoted_at datetime | None

When the artifact was last promoted.

tags list[str]

Arbitrary labels for filtering.

Source code in src/dataenginex/ml/registry.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@dataclass
class ModelArtifact:
    """Registry entry for a model version.

    Attributes:
        name: Model name (e.g. ``"churn_classifier"``).
        version: Semantic version string.
        stage: Current lifecycle stage.
        artifact_path: File path to the serialised model.
        metrics: Training/evaluation metrics.
        parameters: Hyper-parameters used for training.
        description: Free-text description.
        created_at: When the artifact was registered.
        promoted_at: When the artifact was last promoted.
        tags: Arbitrary labels for filtering.
    """

    name: str
    version: str
    stage: ModelStage = ModelStage.DEVELOPMENT
    artifact_path: str = ""
    metrics: dict[str, float] = field(default_factory=dict)
    parameters: dict[str, Any] = field(default_factory=dict)
    description: str = ""
    created_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))
    promoted_at: datetime | None = None
    tags: list[str] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        """Serialize the model artifact metadata to a plain dictionary."""
        d = asdict(self)
        d["stage"] = self.stage.value
        d["created_at"] = self.created_at.isoformat()
        d["promoted_at"] = self.promoted_at.isoformat() if self.promoted_at else None
        return d

to_dict()

Serialize the model artifact metadata to a plain dictionary.

Source code in src/dataenginex/ml/registry.py
65
66
67
68
69
70
71
def to_dict(self) -> dict[str, Any]:
    """Serialize the model artifact metadata to a plain dictionary."""
    d = asdict(self)
    d["stage"] = self.stage.value
    d["created_at"] = self.created_at.isoformat()
    d["promoted_at"] = self.promoted_at.isoformat() if self.promoted_at else None
    return d

ModelRegistry

JSON-file-backed model registry.

Parameters

persist_path: Path to a JSON file for persistence (optional).

Source code in src/dataenginex/ml/registry.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class ModelRegistry:
    """JSON-file-backed model registry.

    Parameters
    ----------
    persist_path:
        Path to a JSON file for persistence (optional).
    """

    def __init__(self, persist_path: str | Path | None = None) -> None:
        # name → version → artifact
        self._models: dict[str, dict[str, ModelArtifact]] = {}
        self._persist_path = Path(persist_path) if persist_path else None
        self._lock = threading.Lock()
        if self._persist_path and self._persist_path.exists():
            self._load()

    # -- registration --------------------------------------------------------

    def register(self, artifact: ModelArtifact) -> ModelArtifact:
        """Register a new model version."""
        with self._lock:
            versions = self._models.setdefault(artifact.name, {})
            if artifact.version in versions:
                raise ValueError(
                    f"Model {artifact.name!r} version {artifact.version} already registered"
                )
            versions[artifact.version] = artifact
            logger.info(
                "Registered model %s v%s (stage=%s)",
                artifact.name,
                artifact.version,
                artifact.stage.value,
            )
            self._save()
        return artifact

    # -- queries -------------------------------------------------------------

    def get(self, name: str, version: str) -> ModelArtifact | None:
        """Return the artifact for *name* at *version*, or ``None``."""
        return self._models.get(name, {}).get(version)

    def get_latest(self, name: str) -> ModelArtifact | None:
        """Return the most recently registered version of *name*."""
        versions = self._models.get(name)
        if not versions:
            return None
        return list(versions.values())[-1]

    def get_production(self, name: str) -> ModelArtifact | None:
        """Return the model currently in production stage."""
        for art in self._models.get(name, {}).values():
            if art.stage == ModelStage.PRODUCTION:
                return art
        return None

    def list_models(self) -> list[str]:
        """Return all registered model names."""
        return list(self._models.keys())

    def list_versions(self, name: str) -> list[str]:
        """Return all version strings registered for *name*."""
        return list(self._models.get(name, {}).keys())

    # -- promotion -----------------------------------------------------------

    def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
        """Promote a model version to a new stage.

        If promoting to ``production``, any existing production model is
        automatically archived.
        """
        with self._lock:
            artifact = self.get(name, version)
            if artifact is None:
                raise ValueError(f"Model {name!r} version {version} not found")

            if target_stage == ModelStage.PRODUCTION:
                # Archive the current production model
                current = self.get_production(name)
                if current and current.version != version:
                    current.stage = ModelStage.ARCHIVED
                    logger.info("model archived", name=name, version=current.version)

            artifact.stage = target_stage
            artifact.promoted_at = datetime.now(tz=UTC)
            logger.info("model promoted", name=name, version=version, stage=target_stage.value)
            self._save()
        return artifact

    # -- persistence ---------------------------------------------------------

    def _save(self) -> None:
        if not self._persist_path:
            return
        self._persist_path.parent.mkdir(parents=True, exist_ok=True)
        data: dict[str, list[dict[str, Any]]] = {}
        for name, versions in self._models.items():
            data[name] = [v.to_dict() for v in versions.values()]
        self._persist_path.write_text(json.dumps(data, indent=2, default=str))

    def _load(self) -> None:
        if not self._persist_path or not self._persist_path.exists():
            return
        raw = json.loads(self._persist_path.read_text())
        for name, versions in raw.items():
            self._models[name] = {}
            for v in versions:
                v.pop("created_at", None)
                v.pop("promoted_at", None)
                v["stage"] = ModelStage(v.get("stage", "development"))
                self._models[name][v["version"]] = ModelArtifact(**v)
        logger.info("models loaded", count=len(self._models), path=str(self._persist_path))

register(artifact)

Register a new model version.

Source code in src/dataenginex/ml/registry.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def register(self, artifact: ModelArtifact) -> ModelArtifact:
    """Register a new model version."""
    with self._lock:
        versions = self._models.setdefault(artifact.name, {})
        if artifact.version in versions:
            raise ValueError(
                f"Model {artifact.name!r} version {artifact.version} already registered"
            )
        versions[artifact.version] = artifact
        logger.info(
            "Registered model %s v%s (stage=%s)",
            artifact.name,
            artifact.version,
            artifact.stage.value,
        )
        self._save()
    return artifact

get(name, version)

Return the artifact for name at version, or None.

Source code in src/dataenginex/ml/registry.py
113
114
115
def get(self, name: str, version: str) -> ModelArtifact | None:
    """Return the artifact for *name* at *version*, or ``None``."""
    return self._models.get(name, {}).get(version)

get_latest(name)

Return the most recently registered version of name.

Source code in src/dataenginex/ml/registry.py
117
118
119
120
121
122
def get_latest(self, name: str) -> ModelArtifact | None:
    """Return the most recently registered version of *name*."""
    versions = self._models.get(name)
    if not versions:
        return None
    return list(versions.values())[-1]

get_production(name)

Return the model currently in production stage.

Source code in src/dataenginex/ml/registry.py
124
125
126
127
128
129
def get_production(self, name: str) -> ModelArtifact | None:
    """Return the model currently in production stage."""
    for art in self._models.get(name, {}).values():
        if art.stage == ModelStage.PRODUCTION:
            return art
    return None

list_models()

Return all registered model names.

Source code in src/dataenginex/ml/registry.py
131
132
133
def list_models(self) -> list[str]:
    """Return all registered model names."""
    return list(self._models.keys())

list_versions(name)

Return all version strings registered for name.

Source code in src/dataenginex/ml/registry.py
135
136
137
def list_versions(self, name: str) -> list[str]:
    """Return all version strings registered for *name*."""
    return list(self._models.get(name, {}).keys())

promote(name, version, target_stage)

Promote a model version to a new stage.

If promoting to production, any existing production model is automatically archived.

Source code in src/dataenginex/ml/registry.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
    """Promote a model version to a new stage.

    If promoting to ``production``, any existing production model is
    automatically archived.
    """
    with self._lock:
        artifact = self.get(name, version)
        if artifact is None:
            raise ValueError(f"Model {name!r} version {version} not found")

        if target_stage == ModelStage.PRODUCTION:
            # Archive the current production model
            current = self.get_production(name)
            if current and current.version != version:
                current.stage = ModelStage.ARCHIVED
                logger.info("model archived", name=name, version=current.version)

        artifact.stage = target_stage
        artifact.promoted_at = datetime.now(tz=UTC)
        logger.info("model promoted", name=name, version=version, stage=target_stage.value)
        self._save()
    return artifact

ModelStage

Bases: StrEnum

Model lifecycle stages.

Source code in src/dataenginex/ml/registry.py
28
29
30
31
32
33
34
class ModelStage(StrEnum):
    """Model lifecycle stages."""

    DEVELOPMENT = "development"
    STAGING = "staging"
    PRODUCTION = "production"
    ARCHIVED = "archived"

ModelServer

Registry-aware model server.

Loads a model from the ModelRegistry and serves predictions via the predict method.

Parameters

registry: A ModelRegistry instance (from dataenginex.ml.registry).

Source code in src/dataenginex/ml/serving.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class ModelServer:
    """Registry-aware model server.

    Loads a model from the ``ModelRegistry`` and serves predictions via
    the ``predict`` method.

    Parameters
    ----------
    registry:
        A ``ModelRegistry`` instance (from ``dataenginex.ml.registry``).
    """

    def __init__(self, registry: Any = None) -> None:
        self._registry = registry
        self._loaded: dict[str, Any] = {}  # "name:version" → model object

    def load_model(self, name: str, version: str, model: Any) -> None:
        """Register a model object for serving.

        Parameters
        ----------
        name:
            Model name matching registry entries.
        version:
            Model version.
        model:
            Any object with a ``predict(X)`` method.
        """
        key = f"{name}:{version}"
        self._loaded[key] = model
        logger.info("loaded model for serving", key=key)

    def predict(self, request: PredictionRequest) -> PredictionResponse:
        """Run inference for *request*."""
        start = time.perf_counter()

        version = request.version or self._resolve_production_version(request.model_name)
        key = f"{request.model_name}:{version}"

        model = self._loaded.get(key)
        if model is None:
            raise RuntimeError(
                f"Model {request.model_name} v{version} not loaded — call load_model() first"
            )

        # Convert features to the format the model expects
        feature_values = self._features_to_array(request.features)
        raw_predictions = model.predict(feature_values)

        latency = (time.perf_counter() - start) * 1000

        predictions: list[Any]
        if hasattr(raw_predictions, "tolist"):
            predictions = raw_predictions.tolist()
        else:
            predictions = list(raw_predictions)

        return PredictionResponse(
            model_name=request.model_name,
            version=version,
            predictions=predictions,
            latency_ms=latency,
            request_id=request.request_id,
        )

    def list_loaded(self) -> list[str]:
        """Return keys of all loaded models."""
        return list(self._loaded.keys())

    # -- helpers -------------------------------------------------------------

    def _resolve_production_version(self, name: str) -> str:
        if self._registry is not None:
            prod = self._registry.get_production(name)
            if prod:
                return prod.version  # type: ignore[no-any-return]
        # Fallback: use latest loaded
        loaded_keys: list[str] = list(self._loaded.keys())
        for key in reversed(loaded_keys):
            if key.startswith(f"{name}:"):
                version: str = key.split(":")[1]
                return version
        raise RuntimeError(f"No version found for model {name!r}")

    @staticmethod
    def _features_to_array(features: list[dict[str, Any]]) -> list[list[Any]]:
        """Convert list-of-dicts to a 2D list suitable for sklearn ``predict``."""
        if not features:
            return []
        keys = list(features[0].keys())
        return [[row.get(k) for k in keys] for row in features]

load_model(name, version, model)

Register a model object for serving.

Parameters

name: Model name matching registry entries. version: Model version. model: Any object with a predict(X) method.

Source code in src/dataenginex/ml/serving.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def load_model(self, name: str, version: str, model: Any) -> None:
    """Register a model object for serving.

    Parameters
    ----------
    name:
        Model name matching registry entries.
    version:
        Model version.
    model:
        Any object with a ``predict(X)`` method.
    """
    key = f"{name}:{version}"
    self._loaded[key] = model
    logger.info("loaded model for serving", key=key)

predict(request)

Run inference for request.

Source code in src/dataenginex/ml/serving.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def predict(self, request: PredictionRequest) -> PredictionResponse:
    """Run inference for *request*."""
    start = time.perf_counter()

    version = request.version or self._resolve_production_version(request.model_name)
    key = f"{request.model_name}:{version}"

    model = self._loaded.get(key)
    if model is None:
        raise RuntimeError(
            f"Model {request.model_name} v{version} not loaded — call load_model() first"
        )

    # Convert features to the format the model expects
    feature_values = self._features_to_array(request.features)
    raw_predictions = model.predict(feature_values)

    latency = (time.perf_counter() - start) * 1000

    predictions: list[Any]
    if hasattr(raw_predictions, "tolist"):
        predictions = raw_predictions.tolist()
    else:
        predictions = list(raw_predictions)

    return PredictionResponse(
        model_name=request.model_name,
        version=version,
        predictions=predictions,
        latency_ms=latency,
        request_id=request.request_id,
    )

list_loaded()

Return keys of all loaded models.

Source code in src/dataenginex/ml/serving.py
139
140
141
def list_loaded(self) -> list[str]:
    """Return keys of all loaded models."""
    return list(self._loaded.keys())

PredictionRequest dataclass

Input to the serving layer.

Attributes:

Name Type Description
model_name str

Name of the model to invoke.

version str | None

Model version (None resolves to the production version).

features list[dict[str, Any]]

List of feature dicts — each dict is one sample.

request_id str

Caller-provided request ID for tracing.

Source code in src/dataenginex/ml/serving.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@dataclass
class PredictionRequest:
    """Input to the serving layer.

    Attributes:
        model_name: Name of the model to invoke.
        version: Model version (``None`` resolves to the production version).
        features: List of feature dicts — each dict is one sample.
        request_id: Caller-provided request ID for tracing.
    """

    model_name: str
    version: str | None = None  # None → use production model
    features: list[dict[str, Any]] = field(default_factory=list)
    request_id: str = ""

PredictionResponse dataclass

Output from the serving layer.

Attributes:

Name Type Description
model_name str

Name of the model that produced predictions.

version str

Version of the model used.

predictions list[Any]

List of prediction values.

latency_ms float

Inference latency in milliseconds.

request_id str

Echoed request ID for tracing.

served_at datetime

Timestamp of the prediction.

Source code in src/dataenginex/ml/serving.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@dataclass
class PredictionResponse:
    """Output from the serving layer.

    Attributes:
        model_name: Name of the model that produced predictions.
        version: Version of the model used.
        predictions: List of prediction values.
        latency_ms: Inference latency in milliseconds.
        request_id: Echoed request ID for tracing.
        served_at: Timestamp of the prediction.
    """

    model_name: str
    version: str
    predictions: list[Any] = field(default_factory=list)
    latency_ms: float = 0.0
    request_id: str = ""
    served_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the prediction response to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "version": self.version,
            "predictions": self.predictions,
            "latency_ms": round(self.latency_ms, 2),
            "request_id": self.request_id,
            "served_at": self.served_at.isoformat(),
        }

to_dict()

Serialize the prediction response to a plain dictionary.

Source code in src/dataenginex/ml/serving.py
62
63
64
65
66
67
68
69
70
71
def to_dict(self) -> dict[str, Any]:
    """Serialize the prediction response to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "version": self.version,
        "predictions": self.predictions,
        "latency_ms": round(self.latency_ms, 2),
        "request_id": self.request_id,
        "served_at": self.served_at.isoformat(),
    }

BaseTrainer

Bases: ABC

Abstract base class for model trainers.

Source code in src/dataenginex/ml/training.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class BaseTrainer(ABC):
    """Abstract base class for model trainers."""

    def __init__(self, model_name: str, version: str = "1.0.0") -> None:
        self.model_name = model_name
        self.version = version

    @abstractmethod
    def train(
        self,
        X_train: Any,
        y_train: Any,
        **params: Any,  # noqa: N803
    ) -> TrainingResult:
        """Train the model and return metrics."""
        ...

    @abstractmethod
    def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
        """Evaluate the model on test data."""
        ...

    @abstractmethod
    def predict(self, X: Any) -> Any:  # noqa: N803
        """Generate predictions."""
        ...

    @abstractmethod
    def save(self, path: str) -> str:
        """Persist the model to *path* and return the artifact path."""
        ...

    @abstractmethod
    def load(
        self,
        path: str,
        *,
        extra_modules: frozenset[str] | None = None,
    ) -> None:
        """Load a previously saved model from *path*."""
        ...

train(X_train, y_train, **params) abstractmethod

Train the model and return metrics.

Source code in src/dataenginex/ml/training.py
131
132
133
134
135
136
137
138
139
@abstractmethod
def train(
    self,
    X_train: Any,
    y_train: Any,
    **params: Any,  # noqa: N803
) -> TrainingResult:
    """Train the model and return metrics."""
    ...

evaluate(X_test, y_test) abstractmethod

Evaluate the model on test data.

Source code in src/dataenginex/ml/training.py
141
142
143
144
@abstractmethod
def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
    """Evaluate the model on test data."""
    ...

predict(X) abstractmethod

Generate predictions.

Source code in src/dataenginex/ml/training.py
146
147
148
149
@abstractmethod
def predict(self, X: Any) -> Any:  # noqa: N803
    """Generate predictions."""
    ...

save(path) abstractmethod

Persist the model to path and return the artifact path.

Source code in src/dataenginex/ml/training.py
151
152
153
154
@abstractmethod
def save(self, path: str) -> str:
    """Persist the model to *path* and return the artifact path."""
    ...

load(path, *, extra_modules=None) abstractmethod

Load a previously saved model from path.

Source code in src/dataenginex/ml/training.py
156
157
158
159
160
161
162
163
164
@abstractmethod
def load(
    self,
    path: str,
    *,
    extra_modules: frozenset[str] | None = None,
) -> None:
    """Load a previously saved model from *path*."""
    ...

SklearnTrainer

Bases: BaseTrainer

scikit-learn model trainer.

Works with any sklearn estimator (or pipeline) that implements fit, predict, and score.

Parameters

model_name: Name used in model registry. version: Semantic version string. estimator: An sklearn estimator instance (e.g. RandomForestClassifier()).

Source code in src/dataenginex/ml/training.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
class SklearnTrainer(BaseTrainer):
    """scikit-learn model trainer.

    Works with any sklearn estimator (or pipeline) that implements
    ``fit``, ``predict``, and ``score``.

    Parameters
    ----------
    model_name:
        Name used in model registry.
    version:
        Semantic version string.
    estimator:
        An sklearn estimator instance (e.g. ``RandomForestClassifier()``).
    """

    def __init__(
        self,
        model_name: str,
        version: str = "1.0.0",
        estimator: Any = None,
    ) -> None:
        super().__init__(model_name, version)
        self.estimator = estimator
        self._is_fitted = False

    def train(
        self,
        X_train: Any,
        y_train: Any,
        **params: Any,  # noqa: N803
    ) -> TrainingResult:
        """Fit the estimator on *X_train*/*y_train* and return metrics."""
        if self.estimator is None:
            raise RuntimeError("No estimator provided to SklearnTrainer")

        # Apply params
        if params:
            self.estimator.set_params(**params)

        start = time.perf_counter()
        self.estimator.fit(X_train, y_train)
        duration = time.perf_counter() - start
        self._is_fitted = True

        # Compute training score
        train_score = float(self.estimator.score(X_train, y_train))
        metrics = {"train_score": round(train_score, 4)}

        logger.info(
            "Trained %s v%s in %.2fs (train_score=%.4f)",
            self.model_name,
            self.version,
            duration,
            train_score,
        )

        return TrainingResult(
            model_name=self.model_name,
            version=self.version,
            metrics=metrics,
            parameters=self.estimator.get_params(),
            duration_seconds=duration,
        )

    def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
        """Score the fitted model on *X_test*/*y_test* and return metrics."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")

        test_score = float(self.estimator.score(X_test, y_test))
        predictions = self.estimator.predict(X_test)

        metrics: dict[str, float] = {"test_score": round(test_score, 4)}

        # Attempt classification metrics
        try:
            from sklearn.metrics import (  # type: ignore[import-untyped,import-not-found]
                f1_score,
                precision_score,
                recall_score,
            )

            metrics["precision"] = round(
                float(
                    precision_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
            metrics["recall"] = round(
                float(
                    recall_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
            metrics["f1"] = round(
                float(
                    f1_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
        except ImportError:
            logger.debug(
                "sklearn.metrics not available — skipping precision/recall/f1",
            )

        return metrics

    def predict(self, X: Any) -> Any:  # noqa: N803
        """Generate predictions for *X* using the fitted estimator."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")
        return self.estimator.predict(X)

    def save(self, path: str) -> str:
        """Pickle the fitted model, write an HMAC signature, and metadata."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")

        out = Path(path)
        out.parent.mkdir(parents=True, exist_ok=True)

        model_bytes = pickle.dumps(self.estimator)
        out.write_bytes(model_bytes)

        # HMAC sidecar — verifies integrity on load
        sig_path = out.with_suffix(".sig")
        sig_path.write_text(_hmac_sign(model_bytes))

        # Save metadata alongside
        meta = out.with_suffix(".json")
        meta.write_text(
            json.dumps(
                {
                    "model_name": self.model_name,
                    "version": self.version,
                    "saved_at": datetime.now(tz=UTC).isoformat(),
                }
            )
        )

        logger.info("model saved", name=self.model_name, path=str(out))
        return str(out)

    def load(
        self,
        path: str,
        *,
        extra_modules: frozenset[str] | None = None,
    ) -> None:
        """Load a pickled model with HMAC verification and safe unpickling.

        Args:
            path: Filesystem path to the ``.pkl`` artifact.
            extra_modules: Additional top-level module names to allow
                during unpickling (e.g. ``frozenset({"tests"})`` for
                test-only estimators).
        """
        artifact = Path(path)
        data = artifact.read_bytes()

        # Verify HMAC signature if sidecar exists
        sig_path = artifact.with_suffix(".sig")
        if sig_path.exists():
            expected = sig_path.read_text().strip()
            if not _hmac_verify(data, expected):
                msg = (
                    f"HMAC verification failed for {path}. "
                    "The model file may have been tampered with."
                )
                raise ValueError(msg)
        else:
            logger.warning(
                "No .sig sidecar for %s — skipping HMAC check",
                path,
            )

        # Safe unpickle — restricted to sklearn/numpy namespaces
        self.estimator = _SafeUnpickler(
            io.BytesIO(data),
            extra_modules=extra_modules,
        ).load()
        self._is_fitted = True
        logger.info("model loaded", name=self.model_name, path=str(path))

train(X_train, y_train, **params)

Fit the estimator on X_train/y_train and return metrics.

Source code in src/dataenginex/ml/training.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def train(
    self,
    X_train: Any,
    y_train: Any,
    **params: Any,  # noqa: N803
) -> TrainingResult:
    """Fit the estimator on *X_train*/*y_train* and return metrics."""
    if self.estimator is None:
        raise RuntimeError("No estimator provided to SklearnTrainer")

    # Apply params
    if params:
        self.estimator.set_params(**params)

    start = time.perf_counter()
    self.estimator.fit(X_train, y_train)
    duration = time.perf_counter() - start
    self._is_fitted = True

    # Compute training score
    train_score = float(self.estimator.score(X_train, y_train))
    metrics = {"train_score": round(train_score, 4)}

    logger.info(
        "Trained %s v%s in %.2fs (train_score=%.4f)",
        self.model_name,
        self.version,
        duration,
        train_score,
    )

    return TrainingResult(
        model_name=self.model_name,
        version=self.version,
        metrics=metrics,
        parameters=self.estimator.get_params(),
        duration_seconds=duration,
    )

evaluate(X_test, y_test)

Score the fitted model on X_test/y_test and return metrics.

Source code in src/dataenginex/ml/training.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
    """Score the fitted model on *X_test*/*y_test* and return metrics."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")

    test_score = float(self.estimator.score(X_test, y_test))
    predictions = self.estimator.predict(X_test)

    metrics: dict[str, float] = {"test_score": round(test_score, 4)}

    # Attempt classification metrics
    try:
        from sklearn.metrics import (  # type: ignore[import-untyped,import-not-found]
            f1_score,
            precision_score,
            recall_score,
        )

        metrics["precision"] = round(
            float(
                precision_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
        metrics["recall"] = round(
            float(
                recall_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
        metrics["f1"] = round(
            float(
                f1_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
    except ImportError:
        logger.debug(
            "sklearn.metrics not available — skipping precision/recall/f1",
        )

    return metrics

predict(X)

Generate predictions for X using the fitted estimator.

Source code in src/dataenginex/ml/training.py
290
291
292
293
294
def predict(self, X: Any) -> Any:  # noqa: N803
    """Generate predictions for *X* using the fitted estimator."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")
    return self.estimator.predict(X)

save(path)

Pickle the fitted model, write an HMAC signature, and metadata.

Source code in src/dataenginex/ml/training.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def save(self, path: str) -> str:
    """Pickle the fitted model, write an HMAC signature, and metadata."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")

    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)

    model_bytes = pickle.dumps(self.estimator)
    out.write_bytes(model_bytes)

    # HMAC sidecar — verifies integrity on load
    sig_path = out.with_suffix(".sig")
    sig_path.write_text(_hmac_sign(model_bytes))

    # Save metadata alongside
    meta = out.with_suffix(".json")
    meta.write_text(
        json.dumps(
            {
                "model_name": self.model_name,
                "version": self.version,
                "saved_at": datetime.now(tz=UTC).isoformat(),
            }
        )
    )

    logger.info("model saved", name=self.model_name, path=str(out))
    return str(out)

load(path, *, extra_modules=None)

Load a pickled model with HMAC verification and safe unpickling.

Parameters:

Name Type Description Default
path str

Filesystem path to the .pkl artifact.

required
extra_modules frozenset[str] | None

Additional top-level module names to allow during unpickling (e.g. frozenset({"tests"}) for test-only estimators).

None
Source code in src/dataenginex/ml/training.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def load(
    self,
    path: str,
    *,
    extra_modules: frozenset[str] | None = None,
) -> None:
    """Load a pickled model with HMAC verification and safe unpickling.

    Args:
        path: Filesystem path to the ``.pkl`` artifact.
        extra_modules: Additional top-level module names to allow
            during unpickling (e.g. ``frozenset({"tests"})`` for
            test-only estimators).
    """
    artifact = Path(path)
    data = artifact.read_bytes()

    # Verify HMAC signature if sidecar exists
    sig_path = artifact.with_suffix(".sig")
    if sig_path.exists():
        expected = sig_path.read_text().strip()
        if not _hmac_verify(data, expected):
            msg = (
                f"HMAC verification failed for {path}. "
                "The model file may have been tampered with."
            )
            raise ValueError(msg)
    else:
        logger.warning(
            "No .sig sidecar for %s — skipping HMAC check",
            path,
        )

    # Safe unpickle — restricted to sklearn/numpy namespaces
    self.estimator = _SafeUnpickler(
        io.BytesIO(data),
        extra_modules=extra_modules,
    ).load()
    self._is_fitted = True
    logger.info("model loaded", name=self.model_name, path=str(path))

TrainingResult dataclass

Outcome of a model training run.

Attributes:

Name Type Description
model_name str

Name of the trained model.

version str

Semantic version of this training run.

metrics dict[str, float]

Training metrics (e.g. {"train_score": 0.95}).

parameters dict[str, Any]

Hyper-parameters used for training.

duration_seconds float

Wall-clock training time.

artifact_path str | None

Path where the model artifact is saved.

trained_at datetime

Timestamp of training completion.

Source code in src/dataenginex/ml/training.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@dataclass
class TrainingResult:
    """Outcome of a model training run.

    Attributes:
        model_name: Name of the trained model.
        version: Semantic version of this training run.
        metrics: Training metrics (e.g. ``{"train_score": 0.95}``).
        parameters: Hyper-parameters used for training.
        duration_seconds: Wall-clock training time.
        artifact_path: Path where the model artifact is saved.
        trained_at: Timestamp of training completion.
    """

    model_name: str
    version: str
    metrics: dict[str, float] = field(default_factory=dict)
    parameters: dict[str, Any] = field(default_factory=dict)
    duration_seconds: float = 0.0
    artifact_path: str | None = None
    trained_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the training result to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "version": self.version,
            "metrics": self.metrics,
            "parameters": self.parameters,
            "duration_seconds": round(self.duration_seconds, 2),
            "artifact_path": self.artifact_path,
            "trained_at": self.trained_at.isoformat(),
        }

to_dict()

Serialize the training result to a plain dictionary.

Source code in src/dataenginex/ml/training.py
111
112
113
114
115
116
117
118
119
120
121
def to_dict(self) -> dict[str, Any]:
    """Serialize the training result to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "version": self.version,
        "metrics": self.metrics,
        "parameters": self.parameters,
        "duration_seconds": round(self.duration_seconds, 2),
        "artifact_path": self.artifact_path,
        "trained_at": self.trained_at.isoformat(),
    }