Skip to content

dataenginex.ml

ML training, model registry, drift detection, serving, scheduling, metrics, vectorstore & LLM.

Public API::

from dataenginex.ml import (
    BaseTrainer, SklearnTrainer, TrainingResult,
    ModelRegistry, ModelArtifact, ModelStage,
    MLflowModelRegistry, MLflowRegistryError,
    DriftDetector, DriftReport,
    DriftScheduler, DriftMonitorConfig, DriftCheckResult,
    ModelServer, PredictionRequest, PredictionResponse,
    model_prediction_total, model_prediction_latency_seconds,
    model_drift_psi, model_drift_alerts_total,
    # Vector store (Issue #94)
    VectorStoreBackend, InMemoryBackend, ChromaDBBackend,
    Document, SearchResult, RAGPipeline,
    # LLM (Issue #95)
    LLMProvider, OllamaProvider, OpenAICompatibleProvider, MockProvider,
    LLMConfig, LLMResponse, ChatMessage,
    get_llm_provider,
    llm_request_latency_seconds, llm_tokens_total,
)

DriftDetector

Detect distribution drift between a reference and current dataset.

PSI thresholds (industry standard): < 0.10 — no significant drift 0.10-0.25 — moderate drift > 0.25 — significant drift

Parameters

psi_threshold: PSI value above which drift is flagged (default 0.20). n_bins: Number of histogram bins for PSI calculation (default 10).

Source code in src/dataenginex/ml/drift.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class DriftDetector:
    """Detect distribution drift between a reference and current dataset.

    PSI thresholds (industry standard):
        < 0.10  — no significant drift
        0.10-0.25 — moderate drift
        > 0.25  — significant drift

    Parameters
    ----------
    psi_threshold:
        PSI value above which drift is flagged (default 0.20).
    n_bins:
        Number of histogram bins for PSI calculation (default 10).
    """

    def __init__(self, psi_threshold: float = 0.20, n_bins: int = 10) -> None:
        self.psi_threshold = psi_threshold
        self.n_bins = n_bins

    def check_feature(
        self,
        feature_name: str,
        reference: list[float],
        current: list[float],
    ) -> DriftReport:
        """Check drift for a single numeric feature."""
        if not reference or not current:
            return DriftReport(
                feature_name=feature_name,
                psi=0.0,
                drift_detected=False,
                severity="none",
                details={"error": "Empty distribution(s)"},
            )

        import statistics

        ref_mean = statistics.mean(reference)
        cur_mean = statistics.mean(current)
        ref_std = statistics.stdev(reference) if len(reference) > 1 else 0.0
        cur_std = statistics.stdev(current) if len(current) > 1 else 0.0

        psi = self._compute_psi(reference, current)
        severity = self._classify_severity(psi)

        return DriftReport(
            feature_name=feature_name,
            psi=psi,
            drift_detected=psi > self.psi_threshold,
            severity=severity,
            reference_mean=round(ref_mean, 4),
            current_mean=round(cur_mean, 4),
            reference_std=round(ref_std, 4),
            current_std=round(cur_std, 4),
            details={"n_bins": self.n_bins, "threshold": self.psi_threshold},
        )

    def check_dataset(
        self,
        reference: dict[str, list[float]],
        current: dict[str, list[float]],
    ) -> list[DriftReport]:
        """Check drift across all shared features in two datasets.

        Parameters
        ----------
        reference:
            Mapping of ``feature_name → values`` for the reference period.
        current:
            Mapping of ``feature_name → values`` for the current period.
        """
        reports: list[DriftReport] = []
        shared = set(reference.keys()) & set(current.keys())
        for feat in sorted(shared):
            reports.append(self.check_feature(feat, reference[feat], current[feat]))
        return reports

    # -- internal ------------------------------------------------------------

    def _compute_psi(self, reference: list[float], current: list[float]) -> float:
        """Compute PSI using equal-width binning."""
        all_vals = reference + current
        lo = min(all_vals)
        hi = max(all_vals)
        if lo == hi:
            return 0.0

        bin_width = (hi - lo) / self.n_bins
        eps = 1e-6

        ref_counts = [0] * self.n_bins
        cur_counts = [0] * self.n_bins

        for v in reference:
            idx = min(int((v - lo) / bin_width), self.n_bins - 1)
            ref_counts[idx] += 1
        for v in current:
            idx = min(int((v - lo) / bin_width), self.n_bins - 1)
            cur_counts[idx] += 1

        ref_total = len(reference)
        cur_total = len(current)

        psi = 0.0
        for r, c in zip(ref_counts, cur_counts, strict=True):
            ref_pct = (r / ref_total) + eps
            cur_pct = (c / cur_total) + eps
            psi += (cur_pct - ref_pct) * math.log(cur_pct / ref_pct)

        return max(0.0, psi)

    @staticmethod
    def _classify_severity(psi: float) -> str:
        if psi < 0.10:
            return "none"
        if psi < 0.25:
            return "moderate"
        return "severe"

check_feature(feature_name, reference, current)

Check drift for a single numeric feature.

Source code in src/dataenginex/ml/drift.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def check_feature(
    self,
    feature_name: str,
    reference: list[float],
    current: list[float],
) -> DriftReport:
    """Check drift for a single numeric feature."""
    if not reference or not current:
        return DriftReport(
            feature_name=feature_name,
            psi=0.0,
            drift_detected=False,
            severity="none",
            details={"error": "Empty distribution(s)"},
        )

    import statistics

    ref_mean = statistics.mean(reference)
    cur_mean = statistics.mean(current)
    ref_std = statistics.stdev(reference) if len(reference) > 1 else 0.0
    cur_std = statistics.stdev(current) if len(current) > 1 else 0.0

    psi = self._compute_psi(reference, current)
    severity = self._classify_severity(psi)

    return DriftReport(
        feature_name=feature_name,
        psi=psi,
        drift_detected=psi > self.psi_threshold,
        severity=severity,
        reference_mean=round(ref_mean, 4),
        current_mean=round(cur_mean, 4),
        reference_std=round(ref_std, 4),
        current_std=round(cur_std, 4),
        details={"n_bins": self.n_bins, "threshold": self.psi_threshold},
    )

check_dataset(reference, current)

Check drift across all shared features in two datasets.

Parameters

reference: Mapping of feature_name → values for the reference period. current: Mapping of feature_name → values for the current period.

Source code in src/dataenginex/ml/drift.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def check_dataset(
    self,
    reference: dict[str, list[float]],
    current: dict[str, list[float]],
) -> list[DriftReport]:
    """Check drift across all shared features in two datasets.

    Parameters
    ----------
    reference:
        Mapping of ``feature_name → values`` for the reference period.
    current:
        Mapping of ``feature_name → values`` for the current period.
    """
    reports: list[DriftReport] = []
    shared = set(reference.keys()) & set(current.keys())
    for feat in sorted(shared):
        reports.append(self.check_feature(feat, reference[feat], current[feat]))
    return reports

DriftReport dataclass

Outcome of a drift check for a single feature.

Attributes:

Name Type Description
feature_name str

Name of the feature that was checked.

psi float

Population Stability Index value.

drift_detected bool

Whether drift exceeds the configured threshold.

severity str

Drift severity — "none", "moderate", or "severe".

reference_mean float | None

Mean of the reference distribution.

current_mean float | None

Mean of the current distribution.

reference_std float | None

Standard deviation of reference distribution.

current_std float | None

Standard deviation of current distribution.

details dict[str, Any]

Extra context (bins, threshold, etc.).

checked_at datetime

Timestamp of the drift check.

Source code in src/dataenginex/ml/drift.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@dataclass
class DriftReport:
    """Outcome of a drift check for a single feature.

    Attributes:
        feature_name: Name of the feature that was checked.
        psi: Population Stability Index value.
        drift_detected: Whether drift exceeds the configured threshold.
        severity: Drift severity — ``"none"``, ``"moderate"``, or ``"severe"``.
        reference_mean: Mean of the reference distribution.
        current_mean: Mean of the current distribution.
        reference_std: Standard deviation of reference distribution.
        current_std: Standard deviation of current distribution.
        details: Extra context (bins, threshold, etc.).
        checked_at: Timestamp of the drift check.
    """

    feature_name: str
    psi: float  # Population Stability Index
    drift_detected: bool
    severity: str  # "none", "moderate", "severe"
    reference_mean: float | None = None
    current_mean: float | None = None
    reference_std: float | None = None
    current_std: float | None = None
    details: dict[str, Any] = field(default_factory=dict)
    checked_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the drift report to a plain dictionary."""
        return {
            "feature_name": self.feature_name,
            "psi": round(self.psi, 6),
            "drift_detected": self.drift_detected,
            "severity": self.severity,
            "reference_mean": self.reference_mean,
            "current_mean": self.current_mean,
            "reference_std": self.reference_std,
            "current_std": self.current_std,
            "details": self.details,
            "checked_at": self.checked_at.isoformat(),
        }

to_dict()

Serialize the drift report to a plain dictionary.

Source code in src/dataenginex/ml/drift.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def to_dict(self) -> dict[str, Any]:
    """Serialize the drift report to a plain dictionary."""
    return {
        "feature_name": self.feature_name,
        "psi": round(self.psi, 6),
        "drift_detected": self.drift_detected,
        "severity": self.severity,
        "reference_mean": self.reference_mean,
        "current_mean": self.current_mean,
        "reference_std": self.reference_std,
        "current_std": self.current_std,
        "details": self.details,
        "checked_at": self.checked_at.isoformat(),
    }

ChatMessage dataclass

Single chat message.

Source code in src/dataenginex/ml/llm.py
67
68
69
70
71
72
@dataclass
class ChatMessage:
    """Single chat message."""

    role: str  # "system" | "user" | "assistant"
    content: str

LLMConfig dataclass

Configuration for an LLM provider.

Source code in src/dataenginex/ml/llm.py
75
76
77
78
79
80
81
82
83
84
@dataclass
class LLMConfig:
    """Configuration for an LLM provider."""

    model: str = "llama3.1:8b"
    temperature: float = 0.7
    max_tokens: int = 2048
    top_p: float = 0.9
    system_prompt: str = "You are a helpful data engineering assistant."
    timeout_seconds: int = 120

LLMProvider

Bases: ABC

Abstract LLM provider interface.

Source code in src/dataenginex/ml/llm.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class LLMProvider(abc.ABC):
    """Abstract LLM provider interface."""

    def __init__(self, config: LLMConfig | None = None) -> None:
        self.config = config or LLMConfig()

    @abc.abstractmethod
    def generate(self, prompt: str) -> LLMResponse:
        """Generate text from a single prompt string."""

    @abc.abstractmethod
    def chat(self, messages: list[ChatMessage]) -> LLMResponse:
        """Generate a response from a chat conversation."""

    @abc.abstractmethod
    def is_available(self) -> bool:
        """Check whether the provider is reachable."""

    def generate_with_context(
        self,
        question: str,
        context: str,
        system_prompt: str | None = None,
    ) -> LLMResponse:
        """RAG-style generation: inject *context* before the *question*.

        Args:
            question: User question.
            context: Retrieved context documents.
            system_prompt: Optional override for the system prompt.

        Returns:
            LLM response with augmented generation.
        """
        sys_msg = system_prompt or self.config.system_prompt
        augmented_prompt = (
            f"Use the following context to answer the question.\n\n"
            f"Context:\n{context}\n\n"
            f"Question: {question}\n\n"
            f"Answer:"
        )
        messages = [
            ChatMessage(role="system", content=sys_msg),
            ChatMessage(role="user", content=augmented_prompt),
        ]
        return self.chat(messages)

generate(prompt) abstractmethod

Generate text from a single prompt string.

Source code in src/dataenginex/ml/llm.py
111
112
113
@abc.abstractmethod
def generate(self, prompt: str) -> LLMResponse:
    """Generate text from a single prompt string."""

chat(messages) abstractmethod

Generate a response from a chat conversation.

Source code in src/dataenginex/ml/llm.py
115
116
117
@abc.abstractmethod
def chat(self, messages: list[ChatMessage]) -> LLMResponse:
    """Generate a response from a chat conversation."""

is_available() abstractmethod

Check whether the provider is reachable.

Source code in src/dataenginex/ml/llm.py
119
120
121
@abc.abstractmethod
def is_available(self) -> bool:
    """Check whether the provider is reachable."""

generate_with_context(question, context, system_prompt=None)

RAG-style generation: inject context before the question.

Parameters:

Name Type Description Default
question str

User question.

required
context str

Retrieved context documents.

required
system_prompt str | None

Optional override for the system prompt.

None

Returns:

Type Description
LLMResponse

LLM response with augmented generation.

Source code in src/dataenginex/ml/llm.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def generate_with_context(
    self,
    question: str,
    context: str,
    system_prompt: str | None = None,
) -> LLMResponse:
    """RAG-style generation: inject *context* before the *question*.

    Args:
        question: User question.
        context: Retrieved context documents.
        system_prompt: Optional override for the system prompt.

    Returns:
        LLM response with augmented generation.
    """
    sys_msg = system_prompt or self.config.system_prompt
    augmented_prompt = (
        f"Use the following context to answer the question.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {question}\n\n"
        f"Answer:"
    )
    messages = [
        ChatMessage(role="system", content=sys_msg),
        ChatMessage(role="user", content=augmented_prompt),
    ]
    return self.chat(messages)

LLMResponse dataclass

Response from an LLM generation call.

Source code in src/dataenginex/ml/llm.py
87
88
89
90
91
92
93
94
95
96
97
@dataclass
class LLMResponse:
    """Response from an LLM generation call."""

    text: str
    model: str = ""
    finish_reason: str = "stop"
    prompt_tokens: int = 0
    completion_tokens: int = 0
    total_tokens: int = 0
    metadata: dict[str, Any] = field(default_factory=dict)

MockProvider

Bases: LLMProvider

Deterministic mock LLM provider for testing.

Returns canned responses that include the prompt in the output for assertion convenience.

Source code in src/dataenginex/ml/llm.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class MockProvider(LLMProvider):
    """Deterministic mock LLM provider for testing.

    Returns canned responses that include the prompt in the output
    for assertion convenience.
    """

    def __init__(
        self,
        config: LLMConfig | None = None,
        default_response: str = "This is a mock LLM response.",
    ) -> None:
        super().__init__(config or LLMConfig(model="mock-model"))
        self.default_response = default_response
        self.call_history: list[dict[str, Any]] = []

    def generate(self, prompt: str) -> LLMResponse:
        self.call_history.append({"type": "generate", "prompt": prompt})
        return LLMResponse(
            text=f"{self.default_response} (prompt_length={len(prompt)})",
            model=self.config.model,
            prompt_tokens=len(prompt.split()),
            completion_tokens=10,
            total_tokens=len(prompt.split()) + 10,
        )

    def chat(self, messages: list[ChatMessage]) -> LLMResponse:
        self.call_history.append({"type": "chat", "messages": len(messages)})
        return LLMResponse(
            text=f"{self.default_response} (messages={len(messages)})",
            model=self.config.model,
            prompt_tokens=sum(len(m.content.split()) for m in messages),
            completion_tokens=10,
            total_tokens=sum(len(m.content.split()) for m in messages) + 10,
        )

    def is_available(self) -> bool:
        return True

OllamaProvider

Bases: LLMProvider

Ollama local LLM provider.

Talks to a local Ollama server via its REST API.

Parameters:

Name Type Description Default
model str

Ollama model name (e.g. llama3.1:8b).

'llama3.1:8b'
base_url str

Ollama server URL.

'http://localhost:11434'
config LLMConfig | None

LLM configuration overrides.

None
Source code in src/dataenginex/ml/llm.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class OllamaProvider(LLMProvider):
    """Ollama local LLM provider.

    Talks to a local Ollama server via its REST API.

    Args:
        model: Ollama model name (e.g. ``llama3.1:8b``).
        base_url: Ollama server URL.
        config: LLM configuration overrides.
    """

    def __init__(
        self,
        model: str = "llama3.1:8b",
        base_url: str = "http://localhost:11434",
        config: LLMConfig | None = None,
    ) -> None:
        cfg = config or LLMConfig(model=model)
        super().__init__(cfg)
        self.base_url = base_url.rstrip("/")
        self._api_generate = f"{self.base_url}/api/generate"
        self._api_chat = f"{self.base_url}/api/chat"
        self._api_tags = f"{self.base_url}/api/tags"
        logger.info("ollama provider initialised", model=cfg.model, url=self.base_url)

    def generate(self, prompt: str) -> LLMResponse:
        """Generate text via Ollama ``/api/generate``."""
        try:
            import httpx
        except ImportError as exc:
            msg = "httpx is required for OllamaProvider — install with: uv add httpx"
            raise ImportError(msg) from exc

        payload = {
            "model": self.config.model,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": self.config.temperature,
                "num_predict": self.config.max_tokens,
                "top_p": self.config.top_p,
            },
        }

        start = time.monotonic()
        try:
            resp = httpx.post(
                self._api_generate,
                json=payload,
                timeout=self.config.timeout_seconds,
            )
            resp.raise_for_status()
            data = resp.json()

            result = LLMResponse(
                text=data.get("response", ""),
                model=data.get("model", self.config.model),
                finish_reason="stop" if data.get("done") else "length",
                prompt_tokens=data.get("prompt_eval_count", 0),
                completion_tokens=data.get("eval_count", 0),
                total_tokens=(data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
                metadata={
                    "total_duration_ns": data.get("total_duration", 0),
                    "load_duration_ns": data.get("load_duration", 0),
                },
            )

            elapsed = time.monotonic() - start
            labels = {"provider": "ollama", "model": self.config.model}
            llm_request_latency_seconds.labels(method="generate", **labels).observe(elapsed)
            llm_tokens_total.labels(direction="input", **labels).inc(result.prompt_tokens)
            llm_tokens_total.labels(direction="output", **labels).inc(result.completion_tokens)

            return result
        except httpx.ConnectError as exc:
            logger.error("ollama server not reachable", url=self.base_url)
            msg = f"Ollama server not reachable at {self.base_url}"
            raise ConnectionError(msg) from exc
        except httpx.HTTPStatusError as exc:
            logger.error("ollama http error", status=exc.response.status_code)
            msg = f"Ollama returned HTTP {exc.response.status_code}"
            raise ConnectionError(msg) from exc

    def chat(self, messages: list[ChatMessage]) -> LLMResponse:
        """Generate via Ollama ``/api/chat``."""
        try:
            import httpx
        except ImportError as exc:
            msg = "httpx is required for OllamaProvider — install with: uv add httpx"
            raise ImportError(msg) from exc

        payload = {
            "model": self.config.model,
            "messages": [{"role": m.role, "content": m.content} for m in messages],
            "stream": False,
            "options": {
                "temperature": self.config.temperature,
                "num_predict": self.config.max_tokens,
                "top_p": self.config.top_p,
            },
        }

        start = time.monotonic()
        try:
            resp = httpx.post(
                self._api_chat,
                json=payload,
                timeout=self.config.timeout_seconds,
            )
            resp.raise_for_status()
            data = resp.json()

            msg = data.get("message", {})
            result = LLMResponse(
                text=msg.get("content", ""),
                model=data.get("model", self.config.model),
                finish_reason="stop" if data.get("done") else "length",
                prompt_tokens=data.get("prompt_eval_count", 0),
                completion_tokens=data.get("eval_count", 0),
                total_tokens=(data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
            )

            elapsed = time.monotonic() - start
            labels = {"provider": "ollama", "model": self.config.model}
            llm_request_latency_seconds.labels(method="chat", **labels).observe(elapsed)
            llm_tokens_total.labels(direction="input", **labels).inc(result.prompt_tokens)
            llm_tokens_total.labels(direction="output", **labels).inc(result.completion_tokens)

            return result
        except httpx.ConnectError as exc:
            logger.error("ollama server not reachable", url=self.base_url)
            msg = f"Ollama server not reachable at {self.base_url}"
            raise ConnectionError(msg) from exc
        except httpx.HTTPStatusError as exc:
            logger.error("ollama http error", status=exc.response.status_code)
            msg = f"Ollama returned HTTP {exc.response.status_code}"
            raise ConnectionError(msg) from exc

    def is_available(self) -> bool:
        """Check if Ollama server is running and the model is loaded."""
        try:
            import httpx

            resp = httpx.get(self._api_tags, timeout=5)
            if resp.status_code != 200:
                return False
            models = resp.json().get("models", [])
            available = [m.get("name", "") for m in models]
            return any(self.config.model in name for name in available)
        except (ImportError, httpx.ConnectError, httpx.TimeoutException):
            return False

    def list_models(self) -> list[str]:
        """List models available on the Ollama server."""
        try:
            import httpx

            resp = httpx.get(self._api_tags, timeout=5)
            resp.raise_for_status()
            models = resp.json().get("models", [])
            return [m.get("name", "") for m in models]
        except (ImportError, httpx.ConnectError, httpx.TimeoutException):
            logger.warning("could not list ollama models")
            return []
        except httpx.HTTPStatusError as exc:
            logger.warning("ollama http error listing models", status=exc.response.status_code)
            return []

generate(prompt)

Generate text via Ollama /api/generate.

Source code in src/dataenginex/ml/llm.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def generate(self, prompt: str) -> LLMResponse:
    """Generate text via Ollama ``/api/generate``."""
    try:
        import httpx
    except ImportError as exc:
        msg = "httpx is required for OllamaProvider — install with: uv add httpx"
        raise ImportError(msg) from exc

    payload = {
        "model": self.config.model,
        "prompt": prompt,
        "stream": False,
        "options": {
            "temperature": self.config.temperature,
            "num_predict": self.config.max_tokens,
            "top_p": self.config.top_p,
        },
    }

    start = time.monotonic()
    try:
        resp = httpx.post(
            self._api_generate,
            json=payload,
            timeout=self.config.timeout_seconds,
        )
        resp.raise_for_status()
        data = resp.json()

        result = LLMResponse(
            text=data.get("response", ""),
            model=data.get("model", self.config.model),
            finish_reason="stop" if data.get("done") else "length",
            prompt_tokens=data.get("prompt_eval_count", 0),
            completion_tokens=data.get("eval_count", 0),
            total_tokens=(data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
            metadata={
                "total_duration_ns": data.get("total_duration", 0),
                "load_duration_ns": data.get("load_duration", 0),
            },
        )

        elapsed = time.monotonic() - start
        labels = {"provider": "ollama", "model": self.config.model}
        llm_request_latency_seconds.labels(method="generate", **labels).observe(elapsed)
        llm_tokens_total.labels(direction="input", **labels).inc(result.prompt_tokens)
        llm_tokens_total.labels(direction="output", **labels).inc(result.completion_tokens)

        return result
    except httpx.ConnectError as exc:
        logger.error("ollama server not reachable", url=self.base_url)
        msg = f"Ollama server not reachable at {self.base_url}"
        raise ConnectionError(msg) from exc
    except httpx.HTTPStatusError as exc:
        logger.error("ollama http error", status=exc.response.status_code)
        msg = f"Ollama returned HTTP {exc.response.status_code}"
        raise ConnectionError(msg) from exc

chat(messages)

Generate via Ollama /api/chat.

Source code in src/dataenginex/ml/llm.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def chat(self, messages: list[ChatMessage]) -> LLMResponse:
    """Generate via Ollama ``/api/chat``."""
    try:
        import httpx
    except ImportError as exc:
        msg = "httpx is required for OllamaProvider — install with: uv add httpx"
        raise ImportError(msg) from exc

    payload = {
        "model": self.config.model,
        "messages": [{"role": m.role, "content": m.content} for m in messages],
        "stream": False,
        "options": {
            "temperature": self.config.temperature,
            "num_predict": self.config.max_tokens,
            "top_p": self.config.top_p,
        },
    }

    start = time.monotonic()
    try:
        resp = httpx.post(
            self._api_chat,
            json=payload,
            timeout=self.config.timeout_seconds,
        )
        resp.raise_for_status()
        data = resp.json()

        msg = data.get("message", {})
        result = LLMResponse(
            text=msg.get("content", ""),
            model=data.get("model", self.config.model),
            finish_reason="stop" if data.get("done") else "length",
            prompt_tokens=data.get("prompt_eval_count", 0),
            completion_tokens=data.get("eval_count", 0),
            total_tokens=(data.get("prompt_eval_count", 0) + data.get("eval_count", 0)),
        )

        elapsed = time.monotonic() - start
        labels = {"provider": "ollama", "model": self.config.model}
        llm_request_latency_seconds.labels(method="chat", **labels).observe(elapsed)
        llm_tokens_total.labels(direction="input", **labels).inc(result.prompt_tokens)
        llm_tokens_total.labels(direction="output", **labels).inc(result.completion_tokens)

        return result
    except httpx.ConnectError as exc:
        logger.error("ollama server not reachable", url=self.base_url)
        msg = f"Ollama server not reachable at {self.base_url}"
        raise ConnectionError(msg) from exc
    except httpx.HTTPStatusError as exc:
        logger.error("ollama http error", status=exc.response.status_code)
        msg = f"Ollama returned HTTP {exc.response.status_code}"
        raise ConnectionError(msg) from exc

is_available()

Check if Ollama server is running and the model is loaded.

Source code in src/dataenginex/ml/llm.py
296
297
298
299
300
301
302
303
304
305
306
307
308
def is_available(self) -> bool:
    """Check if Ollama server is running and the model is loaded."""
    try:
        import httpx

        resp = httpx.get(self._api_tags, timeout=5)
        if resp.status_code != 200:
            return False
        models = resp.json().get("models", [])
        available = [m.get("name", "") for m in models]
        return any(self.config.model in name for name in available)
    except (ImportError, httpx.ConnectError, httpx.TimeoutException):
        return False

list_models()

List models available on the Ollama server.

Source code in src/dataenginex/ml/llm.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def list_models(self) -> list[str]:
    """List models available on the Ollama server."""
    try:
        import httpx

        resp = httpx.get(self._api_tags, timeout=5)
        resp.raise_for_status()
        models = resp.json().get("models", [])
        return [m.get("name", "") for m in models]
    except (ImportError, httpx.ConnectError, httpx.TimeoutException):
        logger.warning("could not list ollama models")
        return []
    except httpx.HTTPStatusError as exc:
        logger.warning("ollama http error listing models", status=exc.response.status_code)
        return []

OpenAICompatibleProvider

Bases: LLMProvider

OpenAI-compatible API provider (supports OpenAI, Groq, Together, etc.).

Uses the /v1/chat/completions endpoint with httpx.

Parameters:

Name Type Description Default
api_key str

API key for authentication. Never logged.

required
base_url str

API base URL (default: OpenAI).

'https://api.openai.com'
model str

Model name.

'gpt-4o-mini'
config LLMConfig | None

LLM configuration overrides.

None
Source code in src/dataenginex/ml/llm.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
class OpenAICompatibleProvider(LLMProvider):
    """OpenAI-compatible API provider (supports OpenAI, Groq, Together, etc.).

    Uses the ``/v1/chat/completions`` endpoint with httpx.

    Args:
        api_key: API key for authentication. Never logged.
        base_url: API base URL (default: OpenAI).
        model: Model name.
        config: LLM configuration overrides.
    """

    def __init__(
        self,
        api_key: str,
        base_url: str = "https://api.openai.com",
        model: str = "gpt-4o-mini",
        config: LLMConfig | None = None,
    ) -> None:
        cfg = config or LLMConfig(model=model)
        super().__init__(cfg)
        self._api_key = api_key
        self.base_url = base_url.rstrip("/")
        self._chat_url = f"{self.base_url}/v1/chat/completions"
        # Never log the API key
        logger.info("openai-compatible provider initialised", model=cfg.model, url=self.base_url)

    def _headers(self) -> dict[str, str]:
        return {
            "Authorization": f"Bearer {self._api_key}",
            "Content-Type": "application/json",
        }

    def generate(self, prompt: str) -> LLMResponse:
        """Generate text via a single-turn chat completion."""
        messages = [
            ChatMessage(role="system", content=self.config.system_prompt),
            ChatMessage(role="user", content=prompt),
        ]
        return self.chat(messages)

    def chat(self, messages: list[ChatMessage]) -> LLMResponse:
        """Generate a response via ``/v1/chat/completions``."""
        try:
            import httpx
        except ImportError as exc:
            msg = "httpx is required for OpenAICompatibleProvider — install with: uv add httpx"
            raise ImportError(msg) from exc

        payload = {
            "model": self.config.model,
            "messages": [{"role": m.role, "content": m.content} for m in messages],
            "temperature": self.config.temperature,
            "max_tokens": self.config.max_tokens,
            "top_p": self.config.top_p,
        }

        start = time.monotonic()
        try:
            resp = httpx.post(
                self._chat_url,
                json=payload,
                headers=self._headers(),
                timeout=self.config.timeout_seconds,
            )
            resp.raise_for_status()
            data = resp.json()

            choice = data.get("choices", [{}])[0]
            usage = data.get("usage", {})

            result = LLMResponse(
                text=choice.get("message", {}).get("content", ""),
                model=data.get("model", self.config.model),
                finish_reason=choice.get("finish_reason", "stop"),
                prompt_tokens=usage.get("prompt_tokens", 0),
                completion_tokens=usage.get("completion_tokens", 0),
                total_tokens=usage.get("total_tokens", 0),
            )

            elapsed = time.monotonic() - start
            labels = {"provider": "openai_compatible", "model": self.config.model}
            llm_request_latency_seconds.labels(method="chat", **labels).observe(elapsed)
            llm_tokens_total.labels(direction="input", **labels).inc(result.prompt_tokens)
            llm_tokens_total.labels(direction="output", **labels).inc(result.completion_tokens)

            return result
        except httpx.ConnectError as exc:
            logger.error("openai-compatible server not reachable", url=self.base_url)
            msg = f"OpenAI-compatible server not reachable at {self.base_url}"
            raise ConnectionError(msg) from exc
        except httpx.HTTPStatusError as exc:
            logger.error("openai-compatible http error", status=exc.response.status_code)
            msg = f"OpenAI-compatible API returned HTTP {exc.response.status_code}"
            raise ConnectionError(msg) from exc

    def is_available(self) -> bool:
        """Check if the API is reachable (HEAD request to base URL)."""
        try:
            import httpx

            resp = httpx.get(
                f"{self.base_url}/v1/models",
                headers=self._headers(),
                timeout=5,
            )
            return resp.status_code == 200
        except (ImportError, httpx.ConnectError, httpx.TimeoutException):
            return False

generate(prompt)

Generate text via a single-turn chat completion.

Source code in src/dataenginex/ml/llm.py
410
411
412
413
414
415
416
def generate(self, prompt: str) -> LLMResponse:
    """Generate text via a single-turn chat completion."""
    messages = [
        ChatMessage(role="system", content=self.config.system_prompt),
        ChatMessage(role="user", content=prompt),
    ]
    return self.chat(messages)

chat(messages)

Generate a response via /v1/chat/completions.

Source code in src/dataenginex/ml/llm.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def chat(self, messages: list[ChatMessage]) -> LLMResponse:
    """Generate a response via ``/v1/chat/completions``."""
    try:
        import httpx
    except ImportError as exc:
        msg = "httpx is required for OpenAICompatibleProvider — install with: uv add httpx"
        raise ImportError(msg) from exc

    payload = {
        "model": self.config.model,
        "messages": [{"role": m.role, "content": m.content} for m in messages],
        "temperature": self.config.temperature,
        "max_tokens": self.config.max_tokens,
        "top_p": self.config.top_p,
    }

    start = time.monotonic()
    try:
        resp = httpx.post(
            self._chat_url,
            json=payload,
            headers=self._headers(),
            timeout=self.config.timeout_seconds,
        )
        resp.raise_for_status()
        data = resp.json()

        choice = data.get("choices", [{}])[0]
        usage = data.get("usage", {})

        result = LLMResponse(
            text=choice.get("message", {}).get("content", ""),
            model=data.get("model", self.config.model),
            finish_reason=choice.get("finish_reason", "stop"),
            prompt_tokens=usage.get("prompt_tokens", 0),
            completion_tokens=usage.get("completion_tokens", 0),
            total_tokens=usage.get("total_tokens", 0),
        )

        elapsed = time.monotonic() - start
        labels = {"provider": "openai_compatible", "model": self.config.model}
        llm_request_latency_seconds.labels(method="chat", **labels).observe(elapsed)
        llm_tokens_total.labels(direction="input", **labels).inc(result.prompt_tokens)
        llm_tokens_total.labels(direction="output", **labels).inc(result.completion_tokens)

        return result
    except httpx.ConnectError as exc:
        logger.error("openai-compatible server not reachable", url=self.base_url)
        msg = f"OpenAI-compatible server not reachable at {self.base_url}"
        raise ConnectionError(msg) from exc
    except httpx.HTTPStatusError as exc:
        logger.error("openai-compatible http error", status=exc.response.status_code)
        msg = f"OpenAI-compatible API returned HTTP {exc.response.status_code}"
        raise ConnectionError(msg) from exc

is_available()

Check if the API is reachable (HEAD request to base URL).

Source code in src/dataenginex/ml/llm.py
473
474
475
476
477
478
479
480
481
482
483
484
485
def is_available(self) -> bool:
    """Check if the API is reachable (HEAD request to base URL)."""
    try:
        import httpx

        resp = httpx.get(
            f"{self.base_url}/v1/models",
            headers=self._headers(),
            timeout=5,
        )
        return resp.status_code == 200
    except (ImportError, httpx.ConnectError, httpx.TimeoutException):
        return False

MLflowModelRegistry

MLflow-backed model registry compatible with ModelRegistry.

Parameters

tracking_uri: MLflow tracking server URI. Defaults to MLFLOW_TRACKING_URI env var or http://localhost:5000.

Source code in src/dataenginex/ml/mlflow_registry.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class MLflowModelRegistry:
    """MLflow-backed model registry compatible with ``ModelRegistry``.

    Parameters
    ----------
    tracking_uri:
        MLflow tracking server URI.  Defaults to ``MLFLOW_TRACKING_URI``
        env var or ``http://localhost:5000``.
    """

    def __init__(self, tracking_uri: str = _DEFAULT_TRACKING_URI) -> None:
        self._tracking_uri = tracking_uri
        self._client = _get_client(tracking_uri)
        logger.info("mlflow registry connected", uri=tracking_uri)

    # -- registration --------------------------------------------------------

    def register(self, artifact: ModelArtifact) -> ModelArtifact:
        """Register a model version in MLflow and log its run metadata."""
        import mlflow  # noqa: PLC0415

        mlflow.set_tracking_uri(self._tracking_uri)

        try:
            with mlflow.start_run(run_name=f"{artifact.name}_v{artifact.version}") as run:
                mlflow.log_params(artifact.parameters)
                mlflow.log_metrics(artifact.metrics)
                mlflow.set_tags(
                    {
                        "dex.version": artifact.version,
                        "dex.description": artifact.description,
                        **{f"dex.tag.{t}": "true" for t in artifact.tags},
                    }
                )

                # Register the model URI (use artifact_path if it's a local path)
                model_uri = (
                    f"runs:/{run.info.run_id}/model"
                    if not artifact.artifact_path
                    else artifact.artifact_path
                )
                mv = mlflow.register_model(model_uri=model_uri, name=artifact.name)

            logger.info(
                "Registered %s v%s in MLflow (version=%s)",
                artifact.name,
                artifact.version,
                mv.version,
            )
        except Exception as exc:
            raise MLflowRegistryError(
                f"Failed to register {artifact.name!r} in MLflow: {exc}"
            ) from exc

        return artifact

    # -- queries -------------------------------------------------------------

    def get(self, name: str, version: str) -> ModelArtifact | None:
        """Fetch a model version from MLflow."""
        try:
            mv = self._client.get_model_version(name=name, version=version)
        except Exception:  # noqa: BLE001
            return None

        return self._mv_to_artifact(mv)

    def get_latest(self, name: str) -> ModelArtifact | None:
        """Return the highest version number for *name* regardless of stage."""
        try:
            versions = self._client.search_model_versions(f"name='{name}'")
        except Exception:  # noqa: BLE001
            return None

        if not versions:
            return None

        latest = max(versions, key=lambda v: int(v.version))
        return self._mv_to_artifact(latest)

    def get_production(self, name: str) -> ModelArtifact | None:
        """Return the model currently aliased as ``production``."""
        try:
            mv = self._client.get_model_version_by_alias(name, "production")
        except Exception:  # noqa: BLE001
            return None

        return self._mv_to_artifact(mv)

    def list_models(self) -> list[str]:
        """Return all registered model names."""
        try:
            registered = self._client.search_registered_models()
            return [m.name for m in registered]
        except Exception as exc:  # noqa: BLE001
            raise MLflowRegistryError(f"Failed to list models: {exc}") from exc

    def list_versions(self, name: str) -> list[str]:
        """Return all version strings for *name*."""
        try:
            versions = self._client.search_model_versions(f"name='{name}'")
            return [v.version for v in versions]
        except Exception as exc:  # noqa: BLE001
            raise MLflowRegistryError(f"Failed to list versions for {name!r}: {exc}") from exc

    # -- promotion -----------------------------------------------------------

    def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
        """Transition a model version to the target stage via MLflow aliases."""
        try:
            if target_stage == ModelStage.DEVELOPMENT:
                # Remove any DEX-managed aliases from this version
                mv = self._client.get_model_version(name, version)
                for alias in list(getattr(mv, "aliases", [])):
                    if alias in _REVERSE_ALIAS_MAP:
                        self._client.delete_registered_model_alias(name, alias)
            else:
                alias = _ALIAS_MAP[target_stage]
                self._client.set_registered_model_alias(name, alias, version)
        except Exception as exc:
            raise MLflowRegistryError(
                f"Failed to promote {name!r} v{version} to {target_stage}: {exc}"
            ) from exc

        mv = self._client.get_model_version(name, version)
        logger.info("model promoted", name=name, version=version, stage=str(target_stage))
        return self._mv_to_artifact(mv)

    # -- helpers -------------------------------------------------------------

    def _mv_to_artifact(self, mv: Any) -> ModelArtifact:
        """Convert an MLflow ModelVersion object to a ``ModelArtifact``."""
        aliases: list[str] = list(getattr(mv, "aliases", []))

        # Resolve stage from aliases: production > staging > archived > development
        stage = ModelStage.DEVELOPMENT
        for alias in aliases:
            candidate = _REVERSE_ALIAS_MAP.get(alias)
            if candidate == ModelStage.PRODUCTION:
                stage = ModelStage.PRODUCTION
                break
            if candidate == ModelStage.STAGING:
                stage = ModelStage.STAGING
            elif candidate == ModelStage.ARCHIVED and stage == ModelStage.DEVELOPMENT:
                stage = ModelStage.ARCHIVED

        creation_ts = getattr(mv, "creation_timestamp", None)
        created_at = (
            datetime.fromtimestamp(creation_ts / 1000, tz=UTC)
            if creation_ts
            else datetime.now(tz=UTC)
        )

        last_updated_ts = getattr(mv, "last_updated_timestamp", None)
        promoted_at = (
            datetime.fromtimestamp(last_updated_ts / 1000, tz=UTC) if last_updated_ts else None
        )

        return ModelArtifact(
            name=mv.name,
            version=mv.version,
            stage=stage,
            artifact_path=getattr(mv, "source", ""),
            description=getattr(mv, "description", "") or "",
            tags=[t.key for t in getattr(mv, "tags", [])],
            created_at=created_at,
            promoted_at=promoted_at,
        )

register(artifact)

Register a model version in MLflow and log its run metadata.

Source code in src/dataenginex/ml/mlflow_registry.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def register(self, artifact: ModelArtifact) -> ModelArtifact:
    """Register a model version in MLflow and log its run metadata."""
    import mlflow  # noqa: PLC0415

    mlflow.set_tracking_uri(self._tracking_uri)

    try:
        with mlflow.start_run(run_name=f"{artifact.name}_v{artifact.version}") as run:
            mlflow.log_params(artifact.parameters)
            mlflow.log_metrics(artifact.metrics)
            mlflow.set_tags(
                {
                    "dex.version": artifact.version,
                    "dex.description": artifact.description,
                    **{f"dex.tag.{t}": "true" for t in artifact.tags},
                }
            )

            # Register the model URI (use artifact_path if it's a local path)
            model_uri = (
                f"runs:/{run.info.run_id}/model"
                if not artifact.artifact_path
                else artifact.artifact_path
            )
            mv = mlflow.register_model(model_uri=model_uri, name=artifact.name)

        logger.info(
            "Registered %s v%s in MLflow (version=%s)",
            artifact.name,
            artifact.version,
            mv.version,
        )
    except Exception as exc:
        raise MLflowRegistryError(
            f"Failed to register {artifact.name!r} in MLflow: {exc}"
        ) from exc

    return artifact

get(name, version)

Fetch a model version from MLflow.

Source code in src/dataenginex/ml/mlflow_registry.py
117
118
119
120
121
122
123
124
def get(self, name: str, version: str) -> ModelArtifact | None:
    """Fetch a model version from MLflow."""
    try:
        mv = self._client.get_model_version(name=name, version=version)
    except Exception:  # noqa: BLE001
        return None

    return self._mv_to_artifact(mv)

get_latest(name)

Return the highest version number for name regardless of stage.

Source code in src/dataenginex/ml/mlflow_registry.py
126
127
128
129
130
131
132
133
134
135
136
137
def get_latest(self, name: str) -> ModelArtifact | None:
    """Return the highest version number for *name* regardless of stage."""
    try:
        versions = self._client.search_model_versions(f"name='{name}'")
    except Exception:  # noqa: BLE001
        return None

    if not versions:
        return None

    latest = max(versions, key=lambda v: int(v.version))
    return self._mv_to_artifact(latest)

get_production(name)

Return the model currently aliased as production.

Source code in src/dataenginex/ml/mlflow_registry.py
139
140
141
142
143
144
145
146
def get_production(self, name: str) -> ModelArtifact | None:
    """Return the model currently aliased as ``production``."""
    try:
        mv = self._client.get_model_version_by_alias(name, "production")
    except Exception:  # noqa: BLE001
        return None

    return self._mv_to_artifact(mv)

list_models()

Return all registered model names.

Source code in src/dataenginex/ml/mlflow_registry.py
148
149
150
151
152
153
154
def list_models(self) -> list[str]:
    """Return all registered model names."""
    try:
        registered = self._client.search_registered_models()
        return [m.name for m in registered]
    except Exception as exc:  # noqa: BLE001
        raise MLflowRegistryError(f"Failed to list models: {exc}") from exc

list_versions(name)

Return all version strings for name.

Source code in src/dataenginex/ml/mlflow_registry.py
156
157
158
159
160
161
162
def list_versions(self, name: str) -> list[str]:
    """Return all version strings for *name*."""
    try:
        versions = self._client.search_model_versions(f"name='{name}'")
        return [v.version for v in versions]
    except Exception as exc:  # noqa: BLE001
        raise MLflowRegistryError(f"Failed to list versions for {name!r}: {exc}") from exc

promote(name, version, target_stage)

Transition a model version to the target stage via MLflow aliases.

Source code in src/dataenginex/ml/mlflow_registry.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
    """Transition a model version to the target stage via MLflow aliases."""
    try:
        if target_stage == ModelStage.DEVELOPMENT:
            # Remove any DEX-managed aliases from this version
            mv = self._client.get_model_version(name, version)
            for alias in list(getattr(mv, "aliases", [])):
                if alias in _REVERSE_ALIAS_MAP:
                    self._client.delete_registered_model_alias(name, alias)
        else:
            alias = _ALIAS_MAP[target_stage]
            self._client.set_registered_model_alias(name, alias, version)
    except Exception as exc:
        raise MLflowRegistryError(
            f"Failed to promote {name!r} v{version} to {target_stage}: {exc}"
        ) from exc

    mv = self._client.get_model_version(name, version)
    logger.info("model promoted", name=name, version=version, stage=str(target_stage))
    return self._mv_to_artifact(mv)

MLflowRegistryError

Bases: RuntimeError

Raised when the MLflow server is unreachable or returns an error.

Source code in src/dataenginex/ml/mlflow_registry.py
45
46
class MLflowRegistryError(RuntimeError):
    """Raised when the MLflow server is unreachable or returns an error."""

ModelArtifact dataclass

Registry entry for a model version.

Attributes:

Name Type Description
name str

Model name (e.g. "churn_classifier").

version str

Semantic version string.

stage ModelStage

Current lifecycle stage.

artifact_path str

File path to the serialised model.

metrics dict[str, float]

Training/evaluation metrics.

parameters dict[str, Any]

Hyper-parameters used for training.

description str

Free-text description.

created_at datetime

When the artifact was registered.

promoted_at datetime | None

When the artifact was last promoted.

tags list[str]

Arbitrary labels for filtering.

Source code in src/dataenginex/ml/registry.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@dataclass
class ModelArtifact:
    """Registry entry for a model version.

    Attributes:
        name: Model name (e.g. ``"churn_classifier"``).
        version: Semantic version string.
        stage: Current lifecycle stage.
        artifact_path: File path to the serialised model.
        metrics: Training/evaluation metrics.
        parameters: Hyper-parameters used for training.
        description: Free-text description.
        created_at: When the artifact was registered.
        promoted_at: When the artifact was last promoted.
        tags: Arbitrary labels for filtering.
    """

    name: str
    version: str
    stage: ModelStage = ModelStage.DEVELOPMENT
    artifact_path: str = ""
    metrics: dict[str, float] = field(default_factory=dict)
    parameters: dict[str, Any] = field(default_factory=dict)
    description: str = ""
    created_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))
    promoted_at: datetime | None = None
    tags: list[str] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        """Serialize the model artifact metadata to a plain dictionary."""
        d = asdict(self)
        d["stage"] = self.stage.value
        d["created_at"] = self.created_at.isoformat()
        d["promoted_at"] = self.promoted_at.isoformat() if self.promoted_at else None
        return d

to_dict()

Serialize the model artifact metadata to a plain dictionary.

Source code in src/dataenginex/ml/registry.py
65
66
67
68
69
70
71
def to_dict(self) -> dict[str, Any]:
    """Serialize the model artifact metadata to a plain dictionary."""
    d = asdict(self)
    d["stage"] = self.stage.value
    d["created_at"] = self.created_at.isoformat()
    d["promoted_at"] = self.promoted_at.isoformat() if self.promoted_at else None
    return d

ModelRegistry

JSON-file-backed model registry.

Parameters

persist_path: Path to a JSON file for persistence (optional).

Source code in src/dataenginex/ml/registry.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class ModelRegistry:
    """JSON-file-backed model registry.

    Parameters
    ----------
    persist_path:
        Path to a JSON file for persistence (optional).
    """

    def __init__(self, persist_path: str | Path | None = None) -> None:
        # name → version → artifact
        self._models: dict[str, dict[str, ModelArtifact]] = {}
        self._persist_path = Path(persist_path) if persist_path else None
        self._lock = threading.Lock()
        if self._persist_path and self._persist_path.exists():
            self._load()

    # -- registration --------------------------------------------------------

    def register(self, artifact: ModelArtifact) -> ModelArtifact:
        """Register a new model version."""
        with self._lock:
            versions = self._models.setdefault(artifact.name, {})
            if artifact.version in versions:
                raise ValueError(
                    f"Model {artifact.name!r} version {artifact.version} already registered"
                )
            versions[artifact.version] = artifact
            logger.info(
                "Registered model %s v%s (stage=%s)",
                artifact.name,
                artifact.version,
                artifact.stage.value,
            )
            self._save()
        return artifact

    # -- queries -------------------------------------------------------------

    def get(self, name: str, version: str) -> ModelArtifact | None:
        """Return the artifact for *name* at *version*, or ``None``."""
        return self._models.get(name, {}).get(version)

    def get_latest(self, name: str) -> ModelArtifact | None:
        """Return the most recently registered version of *name*."""
        versions = self._models.get(name)
        if not versions:
            return None
        return list(versions.values())[-1]

    def get_production(self, name: str) -> ModelArtifact | None:
        """Return the model currently in production stage."""
        for art in self._models.get(name, {}).values():
            if art.stage == ModelStage.PRODUCTION:
                return art
        return None

    def list_models(self) -> list[str]:
        """Return all registered model names."""
        return list(self._models.keys())

    def list_versions(self, name: str) -> list[str]:
        """Return all version strings registered for *name*."""
        return list(self._models.get(name, {}).keys())

    # -- promotion -----------------------------------------------------------

    def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
        """Promote a model version to a new stage.

        If promoting to ``production``, any existing production model is
        automatically archived.
        """
        with self._lock:
            artifact = self.get(name, version)
            if artifact is None:
                raise ValueError(f"Model {name!r} version {version} not found")

            if target_stage == ModelStage.PRODUCTION:
                # Archive the current production model
                current = self.get_production(name)
                if current and current.version != version:
                    current.stage = ModelStage.ARCHIVED
                    logger.info("model archived", name=name, version=current.version)

            artifact.stage = target_stage
            artifact.promoted_at = datetime.now(tz=UTC)
            logger.info("model promoted", name=name, version=version, stage=target_stage.value)
            self._save()
        return artifact

    # -- persistence ---------------------------------------------------------

    def _save(self) -> None:
        if not self._persist_path:
            return
        self._persist_path.parent.mkdir(parents=True, exist_ok=True)
        data: dict[str, list[dict[str, Any]]] = {}
        for name, versions in self._models.items():
            data[name] = [v.to_dict() for v in versions.values()]
        self._persist_path.write_text(json.dumps(data, indent=2, default=str))

    def _load(self) -> None:
        if not self._persist_path or not self._persist_path.exists():
            return
        raw = json.loads(self._persist_path.read_text())
        for name, versions in raw.items():
            self._models[name] = {}
            for v in versions:
                v.pop("created_at", None)
                v.pop("promoted_at", None)
                v["stage"] = ModelStage(v.get("stage", "development"))
                self._models[name][v["version"]] = ModelArtifact(**v)
        logger.info("models loaded", count=len(self._models), path=str(self._persist_path))

register(artifact)

Register a new model version.

Source code in src/dataenginex/ml/registry.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def register(self, artifact: ModelArtifact) -> ModelArtifact:
    """Register a new model version."""
    with self._lock:
        versions = self._models.setdefault(artifact.name, {})
        if artifact.version in versions:
            raise ValueError(
                f"Model {artifact.name!r} version {artifact.version} already registered"
            )
        versions[artifact.version] = artifact
        logger.info(
            "Registered model %s v%s (stage=%s)",
            artifact.name,
            artifact.version,
            artifact.stage.value,
        )
        self._save()
    return artifact

get(name, version)

Return the artifact for name at version, or None.

Source code in src/dataenginex/ml/registry.py
113
114
115
def get(self, name: str, version: str) -> ModelArtifact | None:
    """Return the artifact for *name* at *version*, or ``None``."""
    return self._models.get(name, {}).get(version)

get_latest(name)

Return the most recently registered version of name.

Source code in src/dataenginex/ml/registry.py
117
118
119
120
121
122
def get_latest(self, name: str) -> ModelArtifact | None:
    """Return the most recently registered version of *name*."""
    versions = self._models.get(name)
    if not versions:
        return None
    return list(versions.values())[-1]

get_production(name)

Return the model currently in production stage.

Source code in src/dataenginex/ml/registry.py
124
125
126
127
128
129
def get_production(self, name: str) -> ModelArtifact | None:
    """Return the model currently in production stage."""
    for art in self._models.get(name, {}).values():
        if art.stage == ModelStage.PRODUCTION:
            return art
    return None

list_models()

Return all registered model names.

Source code in src/dataenginex/ml/registry.py
131
132
133
def list_models(self) -> list[str]:
    """Return all registered model names."""
    return list(self._models.keys())

list_versions(name)

Return all version strings registered for name.

Source code in src/dataenginex/ml/registry.py
135
136
137
def list_versions(self, name: str) -> list[str]:
    """Return all version strings registered for *name*."""
    return list(self._models.get(name, {}).keys())

promote(name, version, target_stage)

Promote a model version to a new stage.

If promoting to production, any existing production model is automatically archived.

Source code in src/dataenginex/ml/registry.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def promote(self, name: str, version: str, target_stage: ModelStage) -> ModelArtifact:
    """Promote a model version to a new stage.

    If promoting to ``production``, any existing production model is
    automatically archived.
    """
    with self._lock:
        artifact = self.get(name, version)
        if artifact is None:
            raise ValueError(f"Model {name!r} version {version} not found")

        if target_stage == ModelStage.PRODUCTION:
            # Archive the current production model
            current = self.get_production(name)
            if current and current.version != version:
                current.stage = ModelStage.ARCHIVED
                logger.info("model archived", name=name, version=current.version)

        artifact.stage = target_stage
        artifact.promoted_at = datetime.now(tz=UTC)
        logger.info("model promoted", name=name, version=version, stage=target_stage.value)
        self._save()
    return artifact

ModelStage

Bases: StrEnum

Model lifecycle stages.

Source code in src/dataenginex/ml/registry.py
28
29
30
31
32
33
34
class ModelStage(StrEnum):
    """Model lifecycle stages."""

    DEVELOPMENT = "development"
    STAGING = "staging"
    PRODUCTION = "production"
    ARCHIVED = "archived"

DriftCheckResult dataclass

Aggregated result of a drift check across all features of a model.

Attributes:

Name Type Description
model_name str

Name of the model checked.

reports list[DriftReport]

Per-feature drift reports.

drift_detected bool

True if any feature exceeded the PSI threshold.

max_psi float

Highest PSI score across all features.

checked_at datetime

Timestamp of the check.

Source code in src/dataenginex/ml/scheduler.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@dataclass
class DriftCheckResult:
    """Aggregated result of a drift check across all features of a model.

    Attributes:
        model_name: Name of the model checked.
        reports: Per-feature drift reports.
        drift_detected: ``True`` if any feature exceeded the PSI threshold.
        max_psi: Highest PSI score across all features.
        checked_at: Timestamp of the check.
    """

    model_name: str
    reports: list[DriftReport]
    drift_detected: bool
    max_psi: float
    checked_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "drift_detected": self.drift_detected,
            "max_psi": round(self.max_psi, 6),
            "checked_at": self.checked_at.isoformat(),
            "features": [r.to_dict() for r in self.reports],
        }

to_dict()

Serialize to a plain dictionary.

Source code in src/dataenginex/ml/scheduler.py
 96
 97
 98
 99
100
101
102
103
104
def to_dict(self) -> dict[str, Any]:
    """Serialize to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "drift_detected": self.drift_detected,
        "max_psi": round(self.max_psi, 6),
        "checked_at": self.checked_at.isoformat(),
        "features": [r.to_dict() for r in self.reports],
    }

DriftMonitorConfig dataclass

Configuration for monitoring a single model's data drift.

Attributes:

Name Type Description
model_name str

Name of the model being monitored.

reference_data dict[str, list[float]]

Mapping of feature_name → reference distribution values.

psi_threshold float

PSI value above which drift is flagged (default 0.20).

check_interval_seconds float

Seconds between consecutive checks (default 300).

n_bins int

Number of histogram bins for PSI calculation (default 10).

Source code in src/dataenginex/ml/scheduler.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@dataclass
class DriftMonitorConfig:
    """Configuration for monitoring a single model's data drift.

    Attributes:
        model_name: Name of the model being monitored.
        reference_data: Mapping of feature_name → reference distribution values.
        psi_threshold: PSI value above which drift is flagged (default 0.20).
        check_interval_seconds: Seconds between consecutive checks (default 300).
        n_bins: Number of histogram bins for PSI calculation (default 10).
    """

    model_name: str
    reference_data: dict[str, list[float]]
    psi_threshold: float = 0.20
    check_interval_seconds: float = 300.0
    n_bins: int = 10

DriftScheduler

Background scheduler for periodic model drift checks.

Runs a daemon thread that iterates registered monitors and invokes DriftDetector when each monitor's interval has elapsed. Results are published to Prometheus gauges and counters.

Parameters

tick_seconds: How often the scheduler loop wakes up to check deadlines (default 5.0). Lower values give more precise timing at the cost of CPU.

Source code in src/dataenginex/ml/scheduler.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
class DriftScheduler:
    """Background scheduler for periodic model drift checks.

    Runs a daemon thread that iterates registered monitors and
    invokes ``DriftDetector`` when each monitor's interval has elapsed.
    Results are published to Prometheus gauges and counters.

    Parameters
    ----------
    tick_seconds:
        How often the scheduler loop wakes up to check deadlines
        (default ``5.0``).  Lower values give more precise timing
        at the cost of CPU.
    """

    def __init__(self, tick_seconds: float = 5.0) -> None:
        self._monitors: dict[str, _MonitorEntry] = {}
        self._lock = threading.Lock()
        self._stop_event = threading.Event()
        self._thread: threading.Thread | None = None
        self._tick = tick_seconds
        self._results: dict[str, DriftCheckResult] = {}

    # -- public API ----------------------------------------------------------

    def register(
        self,
        config: DriftMonitorConfig,
        data_fn: DataProvider,
    ) -> None:
        """Register a model for periodic drift monitoring.

        Parameters
        ----------
        config:
            Monitor configuration (thresholds, interval, reference data).
        data_fn:
            Callable returning current feature data as
            ``dict[str, list[float]]``.

        Raises
        ------
        ValueError:
            If ``config.reference_data`` is empty.
        """
        if not config.reference_data:
            msg = f"reference_data must not be empty for model {config.model_name!r}"
            raise ValueError(msg)

        with self._lock:
            self._monitors[config.model_name] = (config, data_fn, 0.0)
        logger.info(
            "drift monitor registered",
            model=config.model_name,
            interval=config.check_interval_seconds,
            features=len(config.reference_data),
        )

    def unregister(self, model_name: str) -> None:
        """Remove a model from drift monitoring.

        Parameters
        ----------
        model_name:
            Name of the model to unregister.

        Raises
        ------
        KeyError:
            If the model is not registered.
        """
        with self._lock:
            if model_name not in self._monitors:
                msg = f"Model {model_name!r} is not registered for drift monitoring"
                raise KeyError(msg)
            del self._monitors[model_name]
            self._results.pop(model_name, None)
        logger.info("drift monitor unregistered", model=model_name)

    def start(self) -> None:
        """Start the background monitoring thread.

        Raises
        ------
        RuntimeError:
            If the scheduler is already running.
        """
        if self._thread is not None and self._thread.is_alive():
            msg = "DriftScheduler is already running"
            raise RuntimeError(msg)

        self._stop_event.clear()
        self._thread = threading.Thread(
            target=self._run_loop,
            name="drift-scheduler",
            daemon=True,
        )
        self._thread.start()
        logger.info("drift scheduler started", tick=self._tick)

    def stop(self, timeout: float = 10.0) -> None:
        """Stop the background monitoring thread.

        Parameters
        ----------
        timeout:
            Seconds to wait for the thread to join (default ``10.0``).
        """
        self._stop_event.set()
        if self._thread is not None:
            self._thread.join(timeout=timeout)
            self._thread = None
        logger.info("drift scheduler stopped")

    @property
    def is_running(self) -> bool:
        """Whether the scheduler thread is alive."""
        return self._thread is not None and self._thread.is_alive()

    @property
    def registered_models(self) -> list[str]:
        """Names of all registered models."""
        with self._lock:
            return list(self._monitors.keys())

    def get_last_result(self, model_name: str) -> DriftCheckResult | None:
        """Return the most recent drift check result for a model."""
        return self._results.get(model_name)

    def run_check(self, model_name: str) -> DriftCheckResult:
        """Manually trigger a drift check for one model.

        Parameters
        ----------
        model_name:
            Name of a registered model to check.

        Raises
        ------
        KeyError:
            If the model is not registered.

        Returns
        -------
        DriftCheckResult:
            Aggregated result with per-feature reports.
        """
        with self._lock:
            entry = self._monitors.get(model_name)
            if entry is None:
                msg = f"Model {model_name!r} is not registered for drift monitoring"
                raise KeyError(msg)
            config, data_fn, _ = entry

        return self._execute_check(config, data_fn)

    # -- internal ------------------------------------------------------------

    def _run_loop(self) -> None:
        """Background loop — check each monitor when its interval elapses."""
        logger.debug("drift scheduler loop entered")
        while not self._stop_event.is_set():
            now = time.monotonic()
            with self._lock:
                snapshot = list(self._monitors.items())

            for name, (config, data_fn, last_check) in snapshot:
                if now - last_check >= config.check_interval_seconds:
                    try:
                        self._execute_check(config, data_fn)
                    except Exception:
                        logger.exception("drift check failed", model=name)
                    # Update last_check regardless of success/failure
                    with self._lock:
                        if name in self._monitors:
                            old = self._monitors[name]
                            self._monitors[name] = (old[0], old[1], time.monotonic())

            self._stop_event.wait(timeout=self._tick)

    def _execute_check(
        self,
        config: DriftMonitorConfig,
        data_fn: DataProvider,
    ) -> DriftCheckResult:
        """Run a single drift check and publish metrics."""
        detector = DriftDetector(
            psi_threshold=config.psi_threshold,
            n_bins=config.n_bins,
        )

        current_data = data_fn()
        reports = detector.check_dataset(config.reference_data, current_data)

        drift_detected = any(r.drift_detected for r in reports)
        max_psi = max((r.psi for r in reports), default=0.0)

        result = DriftCheckResult(
            model_name=config.model_name,
            reports=reports,
            drift_detected=drift_detected,
            max_psi=max_psi,
        )

        # Publish to Prometheus
        for report in reports:
            model_drift_psi.labels(
                model=config.model_name,
                feature=report.feature_name,
            ).set(report.psi)

            if report.drift_detected:
                model_drift_alerts_total.labels(
                    model=config.model_name,
                    severity=report.severity,
                ).inc()

        self._results[config.model_name] = result

        if drift_detected:
            logger.warning(
                "drift detected",
                model=config.model_name,
                max_psi=round(max_psi, 4),
                features_drifted=sum(1 for r in reports if r.drift_detected),
                features_total=len(reports),
            )
        else:
            logger.info(
                "drift check ok",
                model=config.model_name,
                max_psi=round(max_psi, 4),
                features=len(reports),
            )

        return result

is_running property

Whether the scheduler thread is alive.

registered_models property

Names of all registered models.

register(config, data_fn)

Register a model for periodic drift monitoring.

Parameters

config: Monitor configuration (thresholds, interval, reference data). data_fn: Callable returning current feature data as dict[str, list[float]].

Raises

ValueError: If config.reference_data is empty.

Source code in src/dataenginex/ml/scheduler.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def register(
    self,
    config: DriftMonitorConfig,
    data_fn: DataProvider,
) -> None:
    """Register a model for periodic drift monitoring.

    Parameters
    ----------
    config:
        Monitor configuration (thresholds, interval, reference data).
    data_fn:
        Callable returning current feature data as
        ``dict[str, list[float]]``.

    Raises
    ------
    ValueError:
        If ``config.reference_data`` is empty.
    """
    if not config.reference_data:
        msg = f"reference_data must not be empty for model {config.model_name!r}"
        raise ValueError(msg)

    with self._lock:
        self._monitors[config.model_name] = (config, data_fn, 0.0)
    logger.info(
        "drift monitor registered",
        model=config.model_name,
        interval=config.check_interval_seconds,
        features=len(config.reference_data),
    )

unregister(model_name)

Remove a model from drift monitoring.

Parameters

model_name: Name of the model to unregister.

Raises

KeyError: If the model is not registered.

Source code in src/dataenginex/ml/scheduler.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def unregister(self, model_name: str) -> None:
    """Remove a model from drift monitoring.

    Parameters
    ----------
    model_name:
        Name of the model to unregister.

    Raises
    ------
    KeyError:
        If the model is not registered.
    """
    with self._lock:
        if model_name not in self._monitors:
            msg = f"Model {model_name!r} is not registered for drift monitoring"
            raise KeyError(msg)
        del self._monitors[model_name]
        self._results.pop(model_name, None)
    logger.info("drift monitor unregistered", model=model_name)

start()

Start the background monitoring thread.

Raises

RuntimeError: If the scheduler is already running.

Source code in src/dataenginex/ml/scheduler.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def start(self) -> None:
    """Start the background monitoring thread.

    Raises
    ------
    RuntimeError:
        If the scheduler is already running.
    """
    if self._thread is not None and self._thread.is_alive():
        msg = "DriftScheduler is already running"
        raise RuntimeError(msg)

    self._stop_event.clear()
    self._thread = threading.Thread(
        target=self._run_loop,
        name="drift-scheduler",
        daemon=True,
    )
    self._thread.start()
    logger.info("drift scheduler started", tick=self._tick)

stop(timeout=10.0)

Stop the background monitoring thread.

Parameters

timeout: Seconds to wait for the thread to join (default 10.0).

Source code in src/dataenginex/ml/scheduler.py
211
212
213
214
215
216
217
218
219
220
221
222
223
def stop(self, timeout: float = 10.0) -> None:
    """Stop the background monitoring thread.

    Parameters
    ----------
    timeout:
        Seconds to wait for the thread to join (default ``10.0``).
    """
    self._stop_event.set()
    if self._thread is not None:
        self._thread.join(timeout=timeout)
        self._thread = None
    logger.info("drift scheduler stopped")

get_last_result(model_name)

Return the most recent drift check result for a model.

Source code in src/dataenginex/ml/scheduler.py
236
237
238
def get_last_result(self, model_name: str) -> DriftCheckResult | None:
    """Return the most recent drift check result for a model."""
    return self._results.get(model_name)

run_check(model_name)

Manually trigger a drift check for one model.

Parameters

model_name: Name of a registered model to check.

Raises

KeyError: If the model is not registered.

Returns

DriftCheckResult: Aggregated result with per-feature reports.

Source code in src/dataenginex/ml/scheduler.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def run_check(self, model_name: str) -> DriftCheckResult:
    """Manually trigger a drift check for one model.

    Parameters
    ----------
    model_name:
        Name of a registered model to check.

    Raises
    ------
    KeyError:
        If the model is not registered.

    Returns
    -------
    DriftCheckResult:
        Aggregated result with per-feature reports.
    """
    with self._lock:
        entry = self._monitors.get(model_name)
        if entry is None:
            msg = f"Model {model_name!r} is not registered for drift monitoring"
            raise KeyError(msg)
        config, data_fn, _ = entry

    return self._execute_check(config, data_fn)

ModelServer

Registry-aware model server.

Loads a model from the ModelRegistry and serves predictions via the predict method.

Parameters

registry: A ModelRegistry instance (from dataenginex.ml.registry).

Source code in src/dataenginex/ml/serving.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class ModelServer:
    """Registry-aware model server.

    Loads a model from the ``ModelRegistry`` and serves predictions via
    the ``predict`` method.

    Parameters
    ----------
    registry:
        A ``ModelRegistry`` instance (from ``dataenginex.ml.registry``).
    """

    def __init__(self, registry: Any = None) -> None:
        self._registry = registry
        self._loaded: dict[str, Any] = {}  # "name:version" → model object

    def load_model(self, name: str, version: str, model: Any) -> None:
        """Register a model object for serving.

        Parameters
        ----------
        name:
            Model name matching registry entries.
        version:
            Model version.
        model:
            Any object with a ``predict(X)`` method.
        """
        key = f"{name}:{version}"
        self._loaded[key] = model
        logger.info("loaded model for serving", key=key)

    def predict(self, request: PredictionRequest) -> PredictionResponse:
        """Run inference for *request*."""
        start = time.perf_counter()

        version = request.version or self._resolve_production_version(request.model_name)
        key = f"{request.model_name}:{version}"

        model = self._loaded.get(key)
        if model is None:
            raise RuntimeError(
                f"Model {request.model_name} v{version} not loaded — call load_model() first"
            )

        # Convert features to the format the model expects
        feature_values = self._features_to_array(request.features)
        raw_predictions = model.predict(feature_values)

        latency = (time.perf_counter() - start) * 1000

        predictions: list[Any]
        if hasattr(raw_predictions, "tolist"):
            predictions = raw_predictions.tolist()
        else:
            predictions = list(raw_predictions)

        return PredictionResponse(
            model_name=request.model_name,
            version=version,
            predictions=predictions,
            latency_ms=latency,
            request_id=request.request_id,
        )

    def list_loaded(self) -> list[str]:
        """Return keys of all loaded models."""
        return list(self._loaded.keys())

    # -- helpers -------------------------------------------------------------

    def _resolve_production_version(self, name: str) -> str:
        if self._registry is not None:
            prod = self._registry.get_production(name)
            if prod:
                return prod.version  # type: ignore[no-any-return]
        # Fallback: use latest loaded
        loaded_keys: list[str] = list(self._loaded.keys())
        for key in reversed(loaded_keys):
            if key.startswith(f"{name}:"):
                version: str = key.split(":")[1]
                return version
        raise RuntimeError(f"No version found for model {name!r}")

    @staticmethod
    def _features_to_array(features: list[dict[str, Any]]) -> list[list[Any]]:
        """Convert list-of-dicts to a 2D list suitable for sklearn ``predict``."""
        if not features:
            return []
        keys = list(features[0].keys())
        return [[row.get(k) for k in keys] for row in features]

load_model(name, version, model)

Register a model object for serving.

Parameters

name: Model name matching registry entries. version: Model version. model: Any object with a predict(X) method.

Source code in src/dataenginex/ml/serving.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def load_model(self, name: str, version: str, model: Any) -> None:
    """Register a model object for serving.

    Parameters
    ----------
    name:
        Model name matching registry entries.
    version:
        Model version.
    model:
        Any object with a ``predict(X)`` method.
    """
    key = f"{name}:{version}"
    self._loaded[key] = model
    logger.info("loaded model for serving", key=key)

predict(request)

Run inference for request.

Source code in src/dataenginex/ml/serving.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def predict(self, request: PredictionRequest) -> PredictionResponse:
    """Run inference for *request*."""
    start = time.perf_counter()

    version = request.version or self._resolve_production_version(request.model_name)
    key = f"{request.model_name}:{version}"

    model = self._loaded.get(key)
    if model is None:
        raise RuntimeError(
            f"Model {request.model_name} v{version} not loaded — call load_model() first"
        )

    # Convert features to the format the model expects
    feature_values = self._features_to_array(request.features)
    raw_predictions = model.predict(feature_values)

    latency = (time.perf_counter() - start) * 1000

    predictions: list[Any]
    if hasattr(raw_predictions, "tolist"):
        predictions = raw_predictions.tolist()
    else:
        predictions = list(raw_predictions)

    return PredictionResponse(
        model_name=request.model_name,
        version=version,
        predictions=predictions,
        latency_ms=latency,
        request_id=request.request_id,
    )

list_loaded()

Return keys of all loaded models.

Source code in src/dataenginex/ml/serving.py
139
140
141
def list_loaded(self) -> list[str]:
    """Return keys of all loaded models."""
    return list(self._loaded.keys())

PredictionRequest dataclass

Input to the serving layer.

Attributes:

Name Type Description
model_name str

Name of the model to invoke.

version str | None

Model version (None resolves to the production version).

features list[dict[str, Any]]

List of feature dicts — each dict is one sample.

request_id str

Caller-provided request ID for tracing.

Source code in src/dataenginex/ml/serving.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@dataclass
class PredictionRequest:
    """Input to the serving layer.

    Attributes:
        model_name: Name of the model to invoke.
        version: Model version (``None`` resolves to the production version).
        features: List of feature dicts — each dict is one sample.
        request_id: Caller-provided request ID for tracing.
    """

    model_name: str
    version: str | None = None  # None → use production model
    features: list[dict[str, Any]] = field(default_factory=list)
    request_id: str = ""

PredictionResponse dataclass

Output from the serving layer.

Attributes:

Name Type Description
model_name str

Name of the model that produced predictions.

version str

Version of the model used.

predictions list[Any]

List of prediction values.

latency_ms float

Inference latency in milliseconds.

request_id str

Echoed request ID for tracing.

served_at datetime

Timestamp of the prediction.

Source code in src/dataenginex/ml/serving.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@dataclass
class PredictionResponse:
    """Output from the serving layer.

    Attributes:
        model_name: Name of the model that produced predictions.
        version: Version of the model used.
        predictions: List of prediction values.
        latency_ms: Inference latency in milliseconds.
        request_id: Echoed request ID for tracing.
        served_at: Timestamp of the prediction.
    """

    model_name: str
    version: str
    predictions: list[Any] = field(default_factory=list)
    latency_ms: float = 0.0
    request_id: str = ""
    served_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the prediction response to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "version": self.version,
            "predictions": self.predictions,
            "latency_ms": round(self.latency_ms, 2),
            "request_id": self.request_id,
            "served_at": self.served_at.isoformat(),
        }

to_dict()

Serialize the prediction response to a plain dictionary.

Source code in src/dataenginex/ml/serving.py
62
63
64
65
66
67
68
69
70
71
def to_dict(self) -> dict[str, Any]:
    """Serialize the prediction response to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "version": self.version,
        "predictions": self.predictions,
        "latency_ms": round(self.latency_ms, 2),
        "request_id": self.request_id,
        "served_at": self.served_at.isoformat(),
    }

BaseTrainer

Bases: ABC

Abstract base class for model trainers.

Source code in src/dataenginex/ml/training.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class BaseTrainer(ABC):
    """Abstract base class for model trainers."""

    def __init__(self, model_name: str, version: str = "1.0.0") -> None:
        self.model_name = model_name
        self.version = version

    @abstractmethod
    def train(
        self,
        X_train: Any,
        y_train: Any,
        **params: Any,  # noqa: N803
    ) -> TrainingResult:
        """Train the model and return metrics."""
        ...

    @abstractmethod
    def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
        """Evaluate the model on test data."""
        ...

    @abstractmethod
    def predict(self, X: Any) -> Any:  # noqa: N803
        """Generate predictions."""
        ...

    @abstractmethod
    def save(self, path: str) -> str:
        """Persist the model to *path* and return the artifact path."""
        ...

    @abstractmethod
    def load(
        self,
        path: str,
        *,
        extra_modules: frozenset[str] | None = None,
    ) -> None:
        """Load a previously saved model from *path*."""
        ...

train(X_train, y_train, **params) abstractmethod

Train the model and return metrics.

Source code in src/dataenginex/ml/training.py
131
132
133
134
135
136
137
138
139
@abstractmethod
def train(
    self,
    X_train: Any,
    y_train: Any,
    **params: Any,  # noqa: N803
) -> TrainingResult:
    """Train the model and return metrics."""
    ...

evaluate(X_test, y_test) abstractmethod

Evaluate the model on test data.

Source code in src/dataenginex/ml/training.py
141
142
143
144
@abstractmethod
def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
    """Evaluate the model on test data."""
    ...

predict(X) abstractmethod

Generate predictions.

Source code in src/dataenginex/ml/training.py
146
147
148
149
@abstractmethod
def predict(self, X: Any) -> Any:  # noqa: N803
    """Generate predictions."""
    ...

save(path) abstractmethod

Persist the model to path and return the artifact path.

Source code in src/dataenginex/ml/training.py
151
152
153
154
@abstractmethod
def save(self, path: str) -> str:
    """Persist the model to *path* and return the artifact path."""
    ...

load(path, *, extra_modules=None) abstractmethod

Load a previously saved model from path.

Source code in src/dataenginex/ml/training.py
156
157
158
159
160
161
162
163
164
@abstractmethod
def load(
    self,
    path: str,
    *,
    extra_modules: frozenset[str] | None = None,
) -> None:
    """Load a previously saved model from *path*."""
    ...

SklearnTrainer

Bases: BaseTrainer

scikit-learn model trainer.

Works with any sklearn estimator (or pipeline) that implements fit, predict, and score.

Parameters

model_name: Name used in model registry. version: Semantic version string. estimator: An sklearn estimator instance (e.g. RandomForestClassifier()).

Source code in src/dataenginex/ml/training.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
class SklearnTrainer(BaseTrainer):
    """scikit-learn model trainer.

    Works with any sklearn estimator (or pipeline) that implements
    ``fit``, ``predict``, and ``score``.

    Parameters
    ----------
    model_name:
        Name used in model registry.
    version:
        Semantic version string.
    estimator:
        An sklearn estimator instance (e.g. ``RandomForestClassifier()``).
    """

    def __init__(
        self,
        model_name: str,
        version: str = "1.0.0",
        estimator: Any = None,
    ) -> None:
        super().__init__(model_name, version)
        self.estimator = estimator
        self._is_fitted = False

    def train(
        self,
        X_train: Any,
        y_train: Any,
        **params: Any,  # noqa: N803
    ) -> TrainingResult:
        """Fit the estimator on *X_train*/*y_train* and return metrics."""
        if self.estimator is None:
            raise RuntimeError("No estimator provided to SklearnTrainer")

        # Apply params
        if params:
            self.estimator.set_params(**params)

        start = time.perf_counter()
        self.estimator.fit(X_train, y_train)
        duration = time.perf_counter() - start
        self._is_fitted = True

        # Compute training score
        train_score = float(self.estimator.score(X_train, y_train))
        metrics = {"train_score": round(train_score, 4)}

        logger.info(
            "Trained %s v%s in %.2fs (train_score=%.4f)",
            self.model_name,
            self.version,
            duration,
            train_score,
        )

        return TrainingResult(
            model_name=self.model_name,
            version=self.version,
            metrics=metrics,
            parameters=self.estimator.get_params(),
            duration_seconds=duration,
        )

    def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
        """Score the fitted model on *X_test*/*y_test* and return metrics."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")

        test_score = float(self.estimator.score(X_test, y_test))
        predictions = self.estimator.predict(X_test)

        metrics: dict[str, float] = {"test_score": round(test_score, 4)}

        # Attempt classification metrics
        try:
            from sklearn.metrics import (  # type: ignore[import-untyped,import-not-found]
                f1_score,
                precision_score,
                recall_score,
            )

            metrics["precision"] = round(
                float(
                    precision_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
            metrics["recall"] = round(
                float(
                    recall_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
            metrics["f1"] = round(
                float(
                    f1_score(
                        y_test,
                        predictions,
                        average="weighted",
                        zero_division=0,
                    ),
                ),
                4,
            )
        except ImportError:
            logger.debug(
                "sklearn.metrics not available — skipping precision/recall/f1",
            )

        return metrics

    def predict(self, X: Any) -> Any:  # noqa: N803
        """Generate predictions for *X* using the fitted estimator."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")
        return self.estimator.predict(X)

    def save(self, path: str) -> str:
        """Pickle the fitted model, write an HMAC signature, and metadata."""
        if not self._is_fitted:
            raise RuntimeError("Model not yet trained")

        out = Path(path)
        out.parent.mkdir(parents=True, exist_ok=True)

        model_bytes = pickle.dumps(self.estimator)
        out.write_bytes(model_bytes)

        # HMAC sidecar — verifies integrity on load
        sig_path = out.with_suffix(".sig")
        sig_path.write_text(_hmac_sign(model_bytes))

        # Save metadata alongside
        meta = out.with_suffix(".json")
        meta.write_text(
            json.dumps(
                {
                    "model_name": self.model_name,
                    "version": self.version,
                    "saved_at": datetime.now(tz=UTC).isoformat(),
                }
            )
        )

        logger.info("model saved", name=self.model_name, path=str(out))
        return str(out)

    def load(
        self,
        path: str,
        *,
        extra_modules: frozenset[str] | None = None,
    ) -> None:
        """Load a pickled model with HMAC verification and safe unpickling.

        Args:
            path: Filesystem path to the ``.pkl`` artifact.
            extra_modules: Additional top-level module names to allow
                during unpickling (e.g. ``frozenset({"tests"})`` for
                test-only estimators).
        """
        artifact = Path(path)
        data = artifact.read_bytes()

        # Verify HMAC signature if sidecar exists
        sig_path = artifact.with_suffix(".sig")
        if sig_path.exists():
            expected = sig_path.read_text().strip()
            if not _hmac_verify(data, expected):
                msg = (
                    f"HMAC verification failed for {path}. "
                    "The model file may have been tampered with."
                )
                raise ValueError(msg)
        else:
            logger.warning(
                "No .sig sidecar for %s — skipping HMAC check",
                path,
            )

        # Safe unpickle — restricted to sklearn/numpy namespaces
        self.estimator = _SafeUnpickler(
            io.BytesIO(data),
            extra_modules=extra_modules,
        ).load()
        self._is_fitted = True
        logger.info("model loaded", name=self.model_name, path=str(path))

train(X_train, y_train, **params)

Fit the estimator on X_train/y_train and return metrics.

Source code in src/dataenginex/ml/training.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def train(
    self,
    X_train: Any,
    y_train: Any,
    **params: Any,  # noqa: N803
) -> TrainingResult:
    """Fit the estimator on *X_train*/*y_train* and return metrics."""
    if self.estimator is None:
        raise RuntimeError("No estimator provided to SklearnTrainer")

    # Apply params
    if params:
        self.estimator.set_params(**params)

    start = time.perf_counter()
    self.estimator.fit(X_train, y_train)
    duration = time.perf_counter() - start
    self._is_fitted = True

    # Compute training score
    train_score = float(self.estimator.score(X_train, y_train))
    metrics = {"train_score": round(train_score, 4)}

    logger.info(
        "Trained %s v%s in %.2fs (train_score=%.4f)",
        self.model_name,
        self.version,
        duration,
        train_score,
    )

    return TrainingResult(
        model_name=self.model_name,
        version=self.version,
        metrics=metrics,
        parameters=self.estimator.get_params(),
        duration_seconds=duration,
    )

evaluate(X_test, y_test)

Score the fitted model on X_test/y_test and return metrics.

Source code in src/dataenginex/ml/training.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def evaluate(self, X_test: Any, y_test: Any) -> dict[str, float]:  # noqa: N803
    """Score the fitted model on *X_test*/*y_test* and return metrics."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")

    test_score = float(self.estimator.score(X_test, y_test))
    predictions = self.estimator.predict(X_test)

    metrics: dict[str, float] = {"test_score": round(test_score, 4)}

    # Attempt classification metrics
    try:
        from sklearn.metrics import (  # type: ignore[import-untyped,import-not-found]
            f1_score,
            precision_score,
            recall_score,
        )

        metrics["precision"] = round(
            float(
                precision_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
        metrics["recall"] = round(
            float(
                recall_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
        metrics["f1"] = round(
            float(
                f1_score(
                    y_test,
                    predictions,
                    average="weighted",
                    zero_division=0,
                ),
            ),
            4,
        )
    except ImportError:
        logger.debug(
            "sklearn.metrics not available — skipping precision/recall/f1",
        )

    return metrics

predict(X)

Generate predictions for X using the fitted estimator.

Source code in src/dataenginex/ml/training.py
290
291
292
293
294
def predict(self, X: Any) -> Any:  # noqa: N803
    """Generate predictions for *X* using the fitted estimator."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")
    return self.estimator.predict(X)

save(path)

Pickle the fitted model, write an HMAC signature, and metadata.

Source code in src/dataenginex/ml/training.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def save(self, path: str) -> str:
    """Pickle the fitted model, write an HMAC signature, and metadata."""
    if not self._is_fitted:
        raise RuntimeError("Model not yet trained")

    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)

    model_bytes = pickle.dumps(self.estimator)
    out.write_bytes(model_bytes)

    # HMAC sidecar — verifies integrity on load
    sig_path = out.with_suffix(".sig")
    sig_path.write_text(_hmac_sign(model_bytes))

    # Save metadata alongside
    meta = out.with_suffix(".json")
    meta.write_text(
        json.dumps(
            {
                "model_name": self.model_name,
                "version": self.version,
                "saved_at": datetime.now(tz=UTC).isoformat(),
            }
        )
    )

    logger.info("model saved", name=self.model_name, path=str(out))
    return str(out)

load(path, *, extra_modules=None)

Load a pickled model with HMAC verification and safe unpickling.

Parameters:

Name Type Description Default
path str

Filesystem path to the .pkl artifact.

required
extra_modules frozenset[str] | None

Additional top-level module names to allow during unpickling (e.g. frozenset({"tests"}) for test-only estimators).

None
Source code in src/dataenginex/ml/training.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def load(
    self,
    path: str,
    *,
    extra_modules: frozenset[str] | None = None,
) -> None:
    """Load a pickled model with HMAC verification and safe unpickling.

    Args:
        path: Filesystem path to the ``.pkl`` artifact.
        extra_modules: Additional top-level module names to allow
            during unpickling (e.g. ``frozenset({"tests"})`` for
            test-only estimators).
    """
    artifact = Path(path)
    data = artifact.read_bytes()

    # Verify HMAC signature if sidecar exists
    sig_path = artifact.with_suffix(".sig")
    if sig_path.exists():
        expected = sig_path.read_text().strip()
        if not _hmac_verify(data, expected):
            msg = (
                f"HMAC verification failed for {path}. "
                "The model file may have been tampered with."
            )
            raise ValueError(msg)
    else:
        logger.warning(
            "No .sig sidecar for %s — skipping HMAC check",
            path,
        )

    # Safe unpickle — restricted to sklearn/numpy namespaces
    self.estimator = _SafeUnpickler(
        io.BytesIO(data),
        extra_modules=extra_modules,
    ).load()
    self._is_fitted = True
    logger.info("model loaded", name=self.model_name, path=str(path))

TrainingResult dataclass

Outcome of a model training run.

Attributes:

Name Type Description
model_name str

Name of the trained model.

version str

Semantic version of this training run.

metrics dict[str, float]

Training metrics (e.g. {"train_score": 0.95}).

parameters dict[str, Any]

Hyper-parameters used for training.

duration_seconds float

Wall-clock training time.

artifact_path str | None

Path where the model artifact is saved.

trained_at datetime

Timestamp of training completion.

Source code in src/dataenginex/ml/training.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@dataclass
class TrainingResult:
    """Outcome of a model training run.

    Attributes:
        model_name: Name of the trained model.
        version: Semantic version of this training run.
        metrics: Training metrics (e.g. ``{"train_score": 0.95}``).
        parameters: Hyper-parameters used for training.
        duration_seconds: Wall-clock training time.
        artifact_path: Path where the model artifact is saved.
        trained_at: Timestamp of training completion.
    """

    model_name: str
    version: str
    metrics: dict[str, float] = field(default_factory=dict)
    parameters: dict[str, Any] = field(default_factory=dict)
    duration_seconds: float = 0.0
    artifact_path: str | None = None
    trained_at: datetime = field(default_factory=lambda: datetime.now(tz=UTC))

    def to_dict(self) -> dict[str, Any]:
        """Serialize the training result to a plain dictionary."""
        return {
            "model_name": self.model_name,
            "version": self.version,
            "metrics": self.metrics,
            "parameters": self.parameters,
            "duration_seconds": round(self.duration_seconds, 2),
            "artifact_path": self.artifact_path,
            "trained_at": self.trained_at.isoformat(),
        }

to_dict()

Serialize the training result to a plain dictionary.

Source code in src/dataenginex/ml/training.py
111
112
113
114
115
116
117
118
119
120
121
def to_dict(self) -> dict[str, Any]:
    """Serialize the training result to a plain dictionary."""
    return {
        "model_name": self.model_name,
        "version": self.version,
        "metrics": self.metrics,
        "parameters": self.parameters,
        "duration_seconds": round(self.duration_seconds, 2),
        "artifact_path": self.artifact_path,
        "trained_at": self.trained_at.isoformat(),
    }

ChromaDBBackend

Bases: VectorStoreBackend

ChromaDB-backed vector store (optional dependency).

Falls back to :class:InMemoryBackend if chromadb is not installed.

Parameters:

Name Type Description Default
collection_name str

ChromaDB collection name.

'dex_documents'
persist_directory str | None

Path for local persistence (None = in-memory).

None
dimension int

Embedding dimension hint.

384
Source code in src/dataenginex/ml/vectorstore.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
class ChromaDBBackend(VectorStoreBackend):
    """ChromaDB-backed vector store (optional dependency).

    Falls back to :class:`InMemoryBackend` if ``chromadb`` is not
    installed.

    Args:
        collection_name: ChromaDB collection name.
        persist_directory: Path for local persistence (``None`` = in-memory).
        dimension: Embedding dimension hint.
    """

    def __init__(
        self,
        collection_name: str = "dex_documents",
        persist_directory: str | None = None,
        dimension: int = 384,
    ) -> None:
        self.collection_name = collection_name
        self.dimension = dimension
        self._client: Any = None
        self._collection: Any = None
        self._fallback: InMemoryBackend | None = None

        try:
            import chromadb  # type: ignore[import-not-found]

            if persist_directory:
                self._client = chromadb.PersistentClient(path=persist_directory)
            else:
                self._client = chromadb.Client()
            self._collection = self._client.get_or_create_collection(collection_name)
            logger.info(
                "ChromaDB backend ready collection={} persist={}",
                collection_name,
                persist_directory,
            )
        except ImportError:
            logger.warning("chromadb not installed — falling back to InMemoryBackend")
            self._fallback = InMemoryBackend(dimension=dimension)

    def upsert(self, documents: list[Document]) -> int:
        if self._fallback:
            return self._fallback.upsert(documents)

        ids = [d.id for d in documents]
        embeddings = [d.embedding for d in documents if d.embedding]
        texts = [d.text for d in documents]
        metadatas = [d.metadata for d in documents]

        if embeddings and len(embeddings) == len(ids):
            self._collection.upsert(
                ids=ids,
                embeddings=embeddings,
                documents=texts,
                metadatas=metadatas,
            )
        else:
            self._collection.upsert(
                ids=ids,
                documents=texts,
                metadatas=metadatas,
            )
        logger.info("chromadb upserted", count=len(ids))
        return len(ids)

    def query(
        self,
        embedding: list[float],
        top_k: int = 10,
        filter_metadata: dict[str, Any] | None = None,
    ) -> list[SearchResult]:
        if self._fallback:
            return self._fallback.query(embedding, top_k, filter_metadata)

        kwargs: dict[str, Any] = {
            "query_embeddings": [embedding],
            "n_results": min(top_k, self._collection.count() or 1),
        }
        if filter_metadata:
            kwargs["where"] = filter_metadata

        results = self._collection.query(**kwargs)
        hits: list[SearchResult] = []
        if results and results.get("ids"):
            for i, doc_id in enumerate(results["ids"][0]):
                dist = results.get("distances", [[]])[0][i] if results.get("distances") else 0.0
                text = results.get("documents", [[]])[0][i] if results.get("documents") else ""
                meta = results.get("metadatas", [[]])[0][i] if results.get("metadatas") else {}
                hits.append(
                    SearchResult(
                        document=Document(id=doc_id, text=text, metadata=meta),
                        score=1.0 - dist,  # ChromaDB returns distance, convert to similarity
                    )
                )
        return hits

    def delete(self, ids: list[str]) -> int:
        if self._fallback:
            return self._fallback.delete(ids)
        self._collection.delete(ids=ids)
        return len(ids)

    def count(self) -> int:
        if self._fallback:
            return self._fallback.count()
        return int(self._collection.count())

    def clear(self) -> None:
        if self._fallback:
            self._fallback.clear()
            return
        # Re-create collection
        self._client.delete_collection(self.collection_name)
        self._collection = self._client.get_or_create_collection(self.collection_name)

    def get(self, doc_id: str) -> Document | None:
        if self._fallback:
            return self._fallback.get(doc_id)
        result = self._collection.get(ids=[doc_id])
        if result and result.get("ids") and result["ids"]:
            text = result.get("documents", [""])[0] if result.get("documents") else ""
            meta = result.get("metadatas", [{}])[0] if result.get("metadatas") else {}
            return Document(id=doc_id, text=text, metadata=meta)
        return None

Document dataclass

A text document with optional metadata and embedding.

Source code in src/dataenginex/ml/vectorstore.py
54
55
56
57
58
59
60
61
@dataclass
class Document:
    """A text document with optional metadata and embedding."""

    id: str = field(default_factory=lambda: uuid.uuid4().hex[:16])
    text: str = ""
    metadata: dict[str, Any] = field(default_factory=dict)
    embedding: list[float] = field(default_factory=list)

InMemoryBackend

Bases: VectorStoreBackend

Brute-force in-memory vector store (testing & prototyping).

Stores all documents in a dict. Queries iterate over all stored vectors and compute cosine similarity.

Parameters:

Name Type Description Default
dimension int

Expected embedding dimension (for validation).

384
Source code in src/dataenginex/ml/vectorstore.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class InMemoryBackend(VectorStoreBackend):
    """Brute-force in-memory vector store (testing & prototyping).

    Stores all documents in a dict.  Queries iterate over all stored
    vectors and compute cosine similarity.

    Args:
        dimension: Expected embedding dimension (for validation).
    """

    def __init__(self, dimension: int = 384) -> None:
        self.dimension = dimension
        self._docs: dict[str, Document] = {}

    def upsert(self, documents: list[Document]) -> int:
        """Insert or update documents."""
        count = 0
        for doc in documents:
            if doc.embedding and len(doc.embedding) != self.dimension:
                logger.warning(
                    "embedding dimension mismatch",
                    doc_id=doc.id,
                    expected=self.dimension,
                    got=len(doc.embedding),
                )
                continue
            self._docs[doc.id] = doc
            count += 1
        logger.info("in-memory upserted", count=count, total=len(self._docs))
        return count

    def query(
        self,
        embedding: list[float],
        top_k: int = 10,
        filter_metadata: dict[str, Any] | None = None,
    ) -> list[SearchResult]:
        """Return top-k nearest documents by cosine similarity."""
        scored: list[SearchResult] = []
        for doc in self._docs.values():
            if not doc.embedding:
                continue
            if filter_metadata and not self._matches_filter(doc.metadata, filter_metadata):
                continue
            sim = self._cosine(embedding, doc.embedding)
            scored.append(SearchResult(document=doc, score=sim))

        scored.sort(key=lambda r: r.score, reverse=True)
        return scored[:top_k]

    def delete(self, ids: list[str]) -> int:
        removed = 0
        for doc_id in ids:
            if doc_id in self._docs:
                del self._docs[doc_id]
                removed += 1
        return removed

    def count(self) -> int:
        return len(self._docs)

    def clear(self) -> None:
        self._docs.clear()

    def get(self, doc_id: str) -> Document | None:
        return self._docs.get(doc_id)

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _cosine(a: list[float], b: list[float]) -> float:
        dot = sum(x * y for x, y in zip(a, b, strict=False))
        ma = math.sqrt(sum(x * x for x in a)) or 1.0
        mb = math.sqrt(sum(y * y for y in b)) or 1.0
        return dot / (ma * mb)

    @staticmethod
    def _matches_filter(
        metadata: dict[str, Any],
        filter_metadata: dict[str, Any],
    ) -> bool:
        return all(metadata.get(k) == v for k, v in filter_metadata.items())

upsert(documents)

Insert or update documents.

Source code in src/dataenginex/ml/vectorstore.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def upsert(self, documents: list[Document]) -> int:
    """Insert or update documents."""
    count = 0
    for doc in documents:
        if doc.embedding and len(doc.embedding) != self.dimension:
            logger.warning(
                "embedding dimension mismatch",
                doc_id=doc.id,
                expected=self.dimension,
                got=len(doc.embedding),
            )
            continue
        self._docs[doc.id] = doc
        count += 1
    logger.info("in-memory upserted", count=count, total=len(self._docs))
    return count

query(embedding, top_k=10, filter_metadata=None)

Return top-k nearest documents by cosine similarity.

Source code in src/dataenginex/ml/vectorstore.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def query(
    self,
    embedding: list[float],
    top_k: int = 10,
    filter_metadata: dict[str, Any] | None = None,
) -> list[SearchResult]:
    """Return top-k nearest documents by cosine similarity."""
    scored: list[SearchResult] = []
    for doc in self._docs.values():
        if not doc.embedding:
            continue
        if filter_metadata and not self._matches_filter(doc.metadata, filter_metadata):
            continue
        sim = self._cosine(embedding, doc.embedding)
        scored.append(SearchResult(document=doc, score=sim))

    scored.sort(key=lambda r: r.score, reverse=True)
    return scored[:top_k]

RAGPipeline

Retrieve-Augment-Generate pipeline orchestrator.

Combines a vector-store backend with an embedding provider to support document ingestion and semantic retrieval. When an LLM adapter is attached, the generate method augments the prompt with retrieved context.

Parameters:

Name Type Description Default
store VectorStoreBackend | None

Vector-store backend to use.

None
embed_fn Any | None

Callable that maps text → embedding vector. If None, uses a simple hash-based fallback.

None
dimension int

Embedding dimension.

384
Source code in src/dataenginex/ml/vectorstore.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
class RAGPipeline:
    """Retrieve-Augment-Generate pipeline orchestrator.

    Combines a vector-store backend with an embedding provider to
    support document ingestion and semantic retrieval.  When an LLM
    adapter is attached, the ``generate`` method augments the prompt
    with retrieved context.

    Args:
        store: Vector-store backend to use.
        embed_fn: Callable that maps text → embedding vector.
            If ``None``, uses a simple hash-based fallback.
        dimension: Embedding dimension.
    """

    def __init__(
        self,
        store: VectorStoreBackend | None = None,
        embed_fn: Any | None = None,
        dimension: int = 384,
    ) -> None:
        self.dimension = dimension
        self.store = store or InMemoryBackend(dimension=dimension)
        self._embed_fn = embed_fn or self._hash_embed

    def ingest(
        self,
        texts: list[str],
        metadata: list[dict[str, Any]] | None = None,
        ids: list[str] | None = None,
    ) -> int:
        """Embed and store a batch of texts.

        Args:
            texts: Raw text documents.
            metadata: Optional per-document metadata.
            ids: Optional document IDs (auto-generated if omitted).

        Returns:
            Number of documents stored.
        """
        meta = metadata or [{} for _ in texts]
        doc_ids = ids or [uuid.uuid4().hex[:16] for _ in texts]

        docs: list[Document] = []
        for doc_id, text, m in zip(doc_ids, texts, meta, strict=True):
            embedding = self._embed_fn(text)
            docs.append(Document(id=doc_id, text=text, metadata=m, embedding=embedding))

        count = self.store.upsert(docs)
        logger.info("rag ingest complete", texts=len(texts), stored=count)
        return count

    def query(
        self,
        question: str,
        top_k: int = 5,
        filter_metadata: dict[str, Any] | None = None,
    ) -> list[SearchResult]:
        """Retrieve top-k relevant documents for *question*."""
        q_embed = self._embed_fn(question)
        results = self.store.query(q_embed, top_k=top_k, filter_metadata=filter_metadata)
        logger.info("rag query complete", top_k=top_k, results=len(results))
        return results

    def build_context(
        self,
        question: str,
        top_k: int = 5,
        max_context_chars: int = 4000,
    ) -> str:
        """Build an LLM context string from retrieved documents.

        Args:
            question: User question.
            top_k: Number of documents to retrieve.
            max_context_chars: Maximum context length in characters.

        Returns:
            Formatted context string for LLM prompting.
        """
        results = self.query(question, top_k=top_k)
        parts: list[str] = []
        total = 0
        for r in results:
            chunk = f"[{r.document.id}] {r.document.text}"
            if total + len(chunk) > max_context_chars:
                break
            parts.append(chunk)
            total += len(chunk)
        return "\n\n".join(parts)

    def answer(
        self,
        question: str,
        llm: LLMProvider,
        top_k: int = 5,
        max_context_chars: int = 4000,
        system_prompt: str | None = None,
    ) -> LLMResponse:
        """Full RAG loop: retrieve → augment → generate.

        Combines :meth:`build_context` with
        :meth:`~dataenginex.ml.llm.LLMProvider.generate_with_context`
        into a single call.

        Args:
            question: User question.
            llm: Any :class:`~dataenginex.ml.llm.LLMProvider` instance.
            top_k: Documents to retrieve.
            max_context_chars: Context length cap in characters.
            system_prompt: Optional system-prompt override for the LLM.

        Returns:
            :class:`~dataenginex.ml.llm.LLMResponse` from the provider.
        """
        context = self.build_context(question, top_k=top_k, max_context_chars=max_context_chars)
        logger.info("rag answer complete", question_len=len(question), context_len=len(context))
        return llm.generate_with_context(question, context, system_prompt=system_prompt)

    # ------------------------------------------------------------------
    # Fallback embedding
    # ------------------------------------------------------------------

    def _hash_embed(self, text: str) -> list[float]:
        """Deterministic hash-based embedding for testing."""
        import hashlib

        h = hashlib.sha256(text.encode()).hexdigest()
        vec = [int(h[i : i + 2], 16) / 255.0 for i in range(0, min(len(h), self.dimension * 2), 2)]
        vec = (vec + [0.0] * self.dimension)[: self.dimension]
        norm = math.sqrt(sum(x * x for x in vec)) or 1.0
        return [x / norm for x in vec]

ingest(texts, metadata=None, ids=None)

Embed and store a batch of texts.

Parameters:

Name Type Description Default
texts list[str]

Raw text documents.

required
metadata list[dict[str, Any]] | None

Optional per-document metadata.

None
ids list[str] | None

Optional document IDs (auto-generated if omitted).

None

Returns:

Type Description
int

Number of documents stored.

Source code in src/dataenginex/ml/vectorstore.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def ingest(
    self,
    texts: list[str],
    metadata: list[dict[str, Any]] | None = None,
    ids: list[str] | None = None,
) -> int:
    """Embed and store a batch of texts.

    Args:
        texts: Raw text documents.
        metadata: Optional per-document metadata.
        ids: Optional document IDs (auto-generated if omitted).

    Returns:
        Number of documents stored.
    """
    meta = metadata or [{} for _ in texts]
    doc_ids = ids or [uuid.uuid4().hex[:16] for _ in texts]

    docs: list[Document] = []
    for doc_id, text, m in zip(doc_ids, texts, meta, strict=True):
        embedding = self._embed_fn(text)
        docs.append(Document(id=doc_id, text=text, metadata=m, embedding=embedding))

    count = self.store.upsert(docs)
    logger.info("rag ingest complete", texts=len(texts), stored=count)
    return count

query(question, top_k=5, filter_metadata=None)

Retrieve top-k relevant documents for question.

Source code in src/dataenginex/ml/vectorstore.py
433
434
435
436
437
438
439
440
441
442
443
def query(
    self,
    question: str,
    top_k: int = 5,
    filter_metadata: dict[str, Any] | None = None,
) -> list[SearchResult]:
    """Retrieve top-k relevant documents for *question*."""
    q_embed = self._embed_fn(question)
    results = self.store.query(q_embed, top_k=top_k, filter_metadata=filter_metadata)
    logger.info("rag query complete", top_k=top_k, results=len(results))
    return results

build_context(question, top_k=5, max_context_chars=4000)

Build an LLM context string from retrieved documents.

Parameters:

Name Type Description Default
question str

User question.

required
top_k int

Number of documents to retrieve.

5
max_context_chars int

Maximum context length in characters.

4000

Returns:

Type Description
str

Formatted context string for LLM prompting.

Source code in src/dataenginex/ml/vectorstore.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def build_context(
    self,
    question: str,
    top_k: int = 5,
    max_context_chars: int = 4000,
) -> str:
    """Build an LLM context string from retrieved documents.

    Args:
        question: User question.
        top_k: Number of documents to retrieve.
        max_context_chars: Maximum context length in characters.

    Returns:
        Formatted context string for LLM prompting.
    """
    results = self.query(question, top_k=top_k)
    parts: list[str] = []
    total = 0
    for r in results:
        chunk = f"[{r.document.id}] {r.document.text}"
        if total + len(chunk) > max_context_chars:
            break
        parts.append(chunk)
        total += len(chunk)
    return "\n\n".join(parts)

answer(question, llm, top_k=5, max_context_chars=4000, system_prompt=None)

Full RAG loop: retrieve → augment → generate.

Combines :meth:build_context with :meth:~dataenginex.ml.llm.LLMProvider.generate_with_context into a single call.

Parameters:

Name Type Description Default
question str

User question.

required
llm LLMProvider

Any :class:~dataenginex.ml.llm.LLMProvider instance.

required
top_k int

Documents to retrieve.

5
max_context_chars int

Context length cap in characters.

4000
system_prompt str | None

Optional system-prompt override for the LLM.

None

Returns:

Type Description
LLMResponse

class:~dataenginex.ml.llm.LLMResponse from the provider.

Source code in src/dataenginex/ml/vectorstore.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def answer(
    self,
    question: str,
    llm: LLMProvider,
    top_k: int = 5,
    max_context_chars: int = 4000,
    system_prompt: str | None = None,
) -> LLMResponse:
    """Full RAG loop: retrieve → augment → generate.

    Combines :meth:`build_context` with
    :meth:`~dataenginex.ml.llm.LLMProvider.generate_with_context`
    into a single call.

    Args:
        question: User question.
        llm: Any :class:`~dataenginex.ml.llm.LLMProvider` instance.
        top_k: Documents to retrieve.
        max_context_chars: Context length cap in characters.
        system_prompt: Optional system-prompt override for the LLM.

    Returns:
        :class:`~dataenginex.ml.llm.LLMResponse` from the provider.
    """
    context = self.build_context(question, top_k=top_k, max_context_chars=max_context_chars)
    logger.info("rag answer complete", question_len=len(question), context_len=len(context))
    return llm.generate_with_context(question, context, system_prompt=system_prompt)

SearchResult dataclass

Single search hit from a vector store query.

Source code in src/dataenginex/ml/vectorstore.py
64
65
66
67
68
69
@dataclass
class SearchResult:
    """Single search hit from a vector store query."""

    document: Document
    score: float

VectorStoreBackend

Bases: ABC

Abstract vector-store backend.

All backends store fixed-dimension vectors keyed by string ID and support nearest-neighbour queries by cosine similarity.

Source code in src/dataenginex/ml/vectorstore.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class VectorStoreBackend(abc.ABC):
    """Abstract vector-store backend.

    All backends store fixed-dimension vectors keyed by string ID and
    support nearest-neighbour queries by cosine similarity.
    """

    @abc.abstractmethod
    def upsert(self, documents: list[Document]) -> int:
        """Insert or update documents. Returns count upserted."""

    @abc.abstractmethod
    def query(
        self,
        embedding: list[float],
        top_k: int = 10,
        filter_metadata: dict[str, Any] | None = None,
    ) -> list[SearchResult]:
        """Return top-k nearest documents by cosine similarity."""

    @abc.abstractmethod
    def delete(self, ids: list[str]) -> int:
        """Delete documents by id. Returns count deleted."""

    @abc.abstractmethod
    def count(self) -> int:
        """Number of documents in the store."""

    @abc.abstractmethod
    def clear(self) -> None:
        """Delete all documents."""

    @abc.abstractmethod
    def get(self, doc_id: str) -> Document | None:
        """Retrieve a single document by ID."""

upsert(documents) abstractmethod

Insert or update documents. Returns count upserted.

Source code in src/dataenginex/ml/vectorstore.py
84
85
86
@abc.abstractmethod
def upsert(self, documents: list[Document]) -> int:
    """Insert or update documents. Returns count upserted."""

query(embedding, top_k=10, filter_metadata=None) abstractmethod

Return top-k nearest documents by cosine similarity.

Source code in src/dataenginex/ml/vectorstore.py
88
89
90
91
92
93
94
95
@abc.abstractmethod
def query(
    self,
    embedding: list[float],
    top_k: int = 10,
    filter_metadata: dict[str, Any] | None = None,
) -> list[SearchResult]:
    """Return top-k nearest documents by cosine similarity."""

delete(ids) abstractmethod

Delete documents by id. Returns count deleted.

Source code in src/dataenginex/ml/vectorstore.py
97
98
99
@abc.abstractmethod
def delete(self, ids: list[str]) -> int:
    """Delete documents by id. Returns count deleted."""

count() abstractmethod

Number of documents in the store.

Source code in src/dataenginex/ml/vectorstore.py
101
102
103
@abc.abstractmethod
def count(self) -> int:
    """Number of documents in the store."""

clear() abstractmethod

Delete all documents.

Source code in src/dataenginex/ml/vectorstore.py
105
106
107
@abc.abstractmethod
def clear(self) -> None:
    """Delete all documents."""

get(doc_id) abstractmethod

Retrieve a single document by ID.

Source code in src/dataenginex/ml/vectorstore.py
109
110
111
@abc.abstractmethod
def get(self, doc_id: str) -> Document | None:
    """Retrieve a single document by ID."""

get_llm_provider(provider, **kwargs)

Create an LLM provider by name.

Parameters:

Name Type Description Default
provider str

One of "ollama", "openai", "mock".

required
**kwargs Any

Passed directly to the provider constructor.

{}

Returns:

Type Description
LLMProvider

LLMProvider instance.

Raises:

Type Description
ValueError

If the provider name is unknown.

Example::

llm = get_llm_provider("ollama", model="llama3.1:8b")
llm = get_llm_provider("openai", api_key="sk-...", model="gpt-4o")
llm = get_llm_provider("mock")
Source code in src/dataenginex/ml/llm.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def get_llm_provider(provider: str, **kwargs: Any) -> LLMProvider:
    """Create an LLM provider by name.

    Args:
        provider: One of ``"ollama"``, ``"openai"``, ``"mock"``.
        **kwargs: Passed directly to the provider constructor.

    Returns:
        LLMProvider instance.

    Raises:
        ValueError: If the provider name is unknown.

    Example::

        llm = get_llm_provider("ollama", model="llama3.1:8b")
        llm = get_llm_provider("openai", api_key="sk-...", model="gpt-4o")
        llm = get_llm_provider("mock")
    """
    providers: dict[str, type[LLMProvider]] = {
        "ollama": OllamaProvider,
        "openai": OpenAICompatibleProvider,
        "mock": MockProvider,
    }
    cls = providers.get(provider.lower())
    if cls is None:
        valid = ", ".join(sorted(providers.keys()))
        msg = f"Unknown LLM provider '{provider}'. Valid: {valid}"
        raise ValueError(msg)
    return cls(**kwargs)