Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions daft/ai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,21 @@ def load_transformers(name: str | None = None, **options: Any) -> Provider:
raise ProviderImportError(["torch", "torchvision", "transformers", "Pillow"]) from e


def load_vllm(name: str | None = None, **options: Any) -> Provider:
try:
from daft.ai.vllm.provider import VLLMProvider

return VLLMProvider(name, **options)
except ImportError as e:
raise ProviderImportError(["vllm"]) from e


PROVIDERS: dict[str, Callable[..., Provider]] = {
"lm_studio": load_lm_studio,
"openai": load_openai,
"sentence_transformers": load_sentence_transformers,
"transformers": load_transformers,
"vllm": load_vllm,
}


Expand Down
Empty file added daft/ai/vllm/__init__.py
Empty file.
Empty file.
58 changes: 58 additions & 0 deletions daft/ai/vllm/protocols/text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import torch
from transformers import AutoConfig
from vllm import LLM

from daft import DataType
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor
from daft.ai.typing import EmbeddingDimensions, Options

if TYPE_CHECKING:
from daft.ai.typing import Embedding


@dataclass
class vLLMTextEmbedderDescriptor(TextEmbedderDescriptor):
model: str
options: Options

def get_provider(self) -> str:
return "vllm"

def get_model(self) -> str:
return self.model

def get_options(self) -> Options:
return self.options

def get_dimensions(self) -> EmbeddingDimensions:
dimensions = AutoConfig.from_pretrained(self.model, trust_remote_code=True).hidden_size
return EmbeddingDimensions(size=dimensions, dtype=DataType.float32())

def instantiate(self) -> TextEmbedder:
return vLLMTextEmbedder(self.model, **self.options)


class vLLMTextEmbedder(TextEmbedder):
model: LLM
options: Options # not currently used, torch hardcoded

def __init__(self, model_name_or_path: str, **options: Any):
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
max_model_len = getattr(config, "n_ctx", None) or getattr(config, "max_position_embeddings", None)
# Let vLLM automatically determine the optimal dtype to use based on the model config file.
self.model = LLM(
model=model_name_or_path,
max_num_batched_tokens=max_model_len,
task="embed",
)
self.options = options

def embed_text(self, text: list[str]) -> list[Embedding]:
outputs = self.model.embed(text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That easy huh? 😄

Do you know if we need to do any batching on our own? I'm fine with just firing away to vLLM until we learn more, mostly asking out of curiosity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM does batching internally. Tbf there's a lot more that can be done here, but I want to move on and support endpoints like gemini (and maybe hack together something for ev)

embeddings = torch.tensor([o.outputs.embedding for o in outputs])
return embeddings.cpu().numpy()
Comment on lines +55 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Return type annotation says list[Embedding] but returns numpy array. Should return embeddings.cpu().numpy().tolist() or update annotation.

Suggested change
def embed_text(self, text: list[str]) -> list[Embedding]:
outputs = self.model.embed(text)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
return embeddings.cpu().numpy()
def embed_text(self, text: list[str]) -> list[Embedding]:
outputs = self.model.embed(text)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
return embeddings.cpu().numpy().tolist()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

The Embedding type is currently an alias to np.ndarray, but these protocols should return a list[np.ndarray].

27 changes: 27 additions & 0 deletions daft/ai/vllm/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from daft.ai.provider import Provider

if TYPE_CHECKING:
from daft.ai.protocols import TextEmbedderDescriptor
from daft.ai.typing import Options


class VLLMProvider(Provider):
_name: str
_options: Options

def __init__(self, name: str | None = None, **options: Any):
self._name = name if name else "vllm"
self._options = options
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The _options are stored but never used in get_text_embedder. Consider passing them to the descriptor like OpenAIProvider does, or remove if truly unused.


@property
def name(self) -> str:
return self._name

def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
from daft.ai.vllm.protocols.text_embedder import vLLMTextEmbedderDescriptor

return vLLMTextEmbedderDescriptor(model or "sentence-transformers/all-MiniLM-L6-v2", options)
Comment on lines +24 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include the provider name and options in the descriptor. Please see OpenAITextEmbedderDescriptor for an example. The provider name is for answering "who produced this descriptor?" and the options are useful when the provider requires late-initialization of 'something' which for OpenAI is the client — the provider takes the client options, but ultimately the must be plumbed into the descriptor->model instantiation.

8 changes: 8 additions & 0 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ Depending on your use case, you may need to install Daft with additional depende
</div>
</label>

<label class="checkbox-item">
<input type="checkbox" id="vllm" data-extra="vllm">
<span class="checkmark"></span>
<div class="checkbox-content">
<strong>vLLM</strong> <code>vllm</code>
</div>
</label>

<label class="checkbox-item">
<input type="checkbox" id="ray" data-extra="ray">
<span class="checkmark"></span>
Expand Down
29 changes: 29 additions & 0 deletions docs/modalities/text.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,35 @@ model = "text-embedding-nomic-embed-text-v1.5" # Select a text embedding model
)
```

#### Using vLLM

[vLLM](https://docs.vllm.ai/en/latest/) is a fast and easy-to-use library for LLM inference and serving.

First install the optional vLLM dependency for Daft.

```bash
pip install -U "daft[vllm]"
```

Then use the `vllm` provider with any desired open model hosted on [Hugging Face](https://huggingface.co/) such as [`Qwen/Qwen3-Embedding-0.6B`](https://huggingface.co/Qwen/Qwen3-Embedding-0.6B).


```python

import daft
from daft.functions.ai import embed_text

provider = "vllm"
model = "Qwen/Qwen3-Embedding-0.6B"

(
daft.read_huggingface("Open-Orca/OpenOrca")
.with_column("embedding", embed_text(daft.col("response"), provider=provider, model=model))
.show()
)

```

### How to work with embeddings

It's common to use embeddings for various tasks like similarity search or retrieval with a vector database.
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ homepage = "https://www.daft.ai"
repository = "https://github.com/Eventual-Inc/Daft"

[project.optional-dependencies]
all = ["daft[aws, azure, clickhouse, deltalake, gcp, hudi, huggingface, iceberg, lance, numpy, openai, pandas, ray, sentence-transformers, spark, sql, transformers, turbopuffer, unity]"]
all = ["daft[aws, azure, clickhouse, deltalake, gcp, hudi, huggingface, iceberg, lance, numpy, openai, pandas, ray, sentence-transformers, spark, sql, transformers, turbopuffer, unity, vllm]"]
aws = ["boto3"]
azure = []
clickhouse = ["clickhouse_connect"]
Expand Down Expand Up @@ -60,6 +60,7 @@ sql = ["connectorx", "sqlalchemy", "sqlglot"]
turbopuffer = ["turbopuffer"]
unity = ["httpx <= 0.27.2", "unitycatalog"]
viz = []
vllm = ["vllm"]

[dependency-groups]
dev = [
Expand Down
70 changes: 70 additions & 0 deletions tests/ai/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import pytest

pytest.importorskip("vllm")

from unittest.mock import Mock, patch

import numpy as np
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput

from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor
from daft.ai.vllm.provider import VLLMProvider


@pytest.mark.parametrize(
"model, embedding_dim",
[
("sentence-transformers/all-MiniLM-L6-v2", 384),
("Qwen/Qwen3-Embedding-0.6B", 1024),
],
)
def test_vllm_text_embedder(model, embedding_dim):
text_data = [
"Either the pipeline had many steps, or Desmond processed the data very slowly, for he had plenty of time as he streamed along to look about him and to wonder what was going to happen next.",
"First, he tried to peek ahead and see what partition he was coming to, but it was too opaque to see anything; then he looked at the sides of the pipeline and noticed that they were filled with changing schemas and modalities;",
"here and there he saw potential OOMs and SEGFAULTs pinned upon walls of logs.",
"He pulled a file from one of the shelves as he passed; it was labelled 'DELTA TABLE', but to his great disappointment it was just an empty Parquet: he did not like to drop the file for fear of corrupting somebody's downstream job, so managed to tuck it back into a catalog as he streamed past it.",
"'Well!' thought Desmond to himself, 'after such a fall as this, I shall think nothing of GPU inference! How brave they'll all think me at home! Why, I wouldn't say anything about it, even if I fell off the top of production!' (Which was very likely true.)",
"Down, down, down. Would the pipeline never come to an end? 'I wonder how many terabytes I've shuffled by this time?' he said aloud. 'I must be getting somewhere near the centre of the datalake.'",
"Let me see: that would be four hundred million rows, I think— (for, you see, Desmond had learnt several things of this sort in his adventures with Daft, and though this was not a very good opportunity for showing off his scaling knowledge, as there was no one to listen to him, still it was good practice to say it over)",
"—yes, that's about the right size—but then I wonder what cluster I've got to? (Desmond had no idea what a cluster was, but thought it was a nice grand word to say.)",
"Presently he began again. 'I wonder if I shall fall right through the datalake! How funny it'll seem to come out among the people that walk with their queries upside down! The Eventualites, I think—'",
"(he was rather glad there was no one listening, this time, as it didn't sound at all the right word) '—but I shall have to ask them what the name of the platform is, you know. Please, Ma'am, is this Ev or Eventual?'",
"(and he tried to bow politely as he spoke—fancy bowing while streaming through compute nodes! Do you think you could manage it?) 'And what an ignorant fellow they'll think me for asking! No, it'll never do to ask: perhaps I shall see it written up somewhere—on a dashboard, or maybe in the logs.'",
"Down, down, down. There was nothing else to do, so Desmond soon began talking again. 'I wonder who'll miss me while I'm debugging to-night, I should think! I hope they'll remember to check the metrics.'",
"'Ah, Daft my dear! I wish you were down here with me! There are no bugs in the air, I'm afraid, but you might catch a straggling morsel, and that's very like a microbatch, you know.'",
"'But do engines eat morsels, I wonder?' And here Desmond began to get rather sleepy, and went on saying to himself, in a dreamy sort of way, 'Do engines eat morsels? Do morsels eat engines?'",
"and sometimes, 'Do Eventuals eat Dafts?' for, you see, as he couldn't answer either question, it didn't much matter which way he put it.",
"He felt that he was dozing off, and had just begun to dream that he was walking hand in hand with Daft itself, and saying to it very earnestly, 'Now, tell me the truth: did you ever process a batched bat?'",
"when suddenly—thump! thump!—down he landed upon a heap of logs and job reports, and the pipeline was over.",
]

def mock_embedding_response(input_data):
if isinstance(input_data, list):
num_texts = len(input_data)
else:
num_texts = 1

embedding_values = [0.1] * embedding_dim
outputs = EmbeddingOutput(embedding=embedding_values)
return [
EmbeddingRequestOutput(request_id=Mock(), outputs=outputs, prompt_token_ids=Mock(), finished=Mock())
] * num_texts

with patch("daft.ai.vllm.protocols.text_embedder.LLM") as MockLLM:
instance = MockLLM.return_value
instance.embed.side_effect = lambda input_data, *args, **kwargs: mock_embedding_response(input_data)

descriptor = VLLMProvider().get_text_embedder(model=model)
assert isinstance(descriptor, TextEmbedderDescriptor)
assert descriptor.get_provider() == "vllm"
assert descriptor.get_model() == model
assert descriptor.get_dimensions().size == embedding_dim

embedder = descriptor.instantiate()
assert isinstance(embedder, TextEmbedder)
embeddings = embedder.embed_text(text_data)
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (len(text_data), embedding_dim)
Loading