-
Notifications
You must be signed in to change notification settings - Fork 312
feat(text_embed): Add vLLM as a provider #5136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,58 @@ | ||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||
|
||||||||||||||||||
from dataclasses import dataclass | ||||||||||||||||||
from typing import TYPE_CHECKING, Any | ||||||||||||||||||
|
||||||||||||||||||
import torch | ||||||||||||||||||
from transformers import AutoConfig | ||||||||||||||||||
from vllm import LLM | ||||||||||||||||||
|
||||||||||||||||||
from daft import DataType | ||||||||||||||||||
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor | ||||||||||||||||||
from daft.ai.typing import EmbeddingDimensions, Options | ||||||||||||||||||
|
||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||
from daft.ai.typing import Embedding | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@dataclass | ||||||||||||||||||
class vLLMTextEmbedderDescriptor(TextEmbedderDescriptor): | ||||||||||||||||||
model: str | ||||||||||||||||||
options: Options | ||||||||||||||||||
|
||||||||||||||||||
def get_provider(self) -> str: | ||||||||||||||||||
return "vllm" | ||||||||||||||||||
|
||||||||||||||||||
def get_model(self) -> str: | ||||||||||||||||||
return self.model | ||||||||||||||||||
|
||||||||||||||||||
def get_options(self) -> Options: | ||||||||||||||||||
return self.options | ||||||||||||||||||
|
||||||||||||||||||
def get_dimensions(self) -> EmbeddingDimensions: | ||||||||||||||||||
dimensions = AutoConfig.from_pretrained(self.model, trust_remote_code=True).hidden_size | ||||||||||||||||||
return EmbeddingDimensions(size=dimensions, dtype=DataType.float32()) | ||||||||||||||||||
|
||||||||||||||||||
def instantiate(self) -> TextEmbedder: | ||||||||||||||||||
return vLLMTextEmbedder(self.model, **self.options) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class vLLMTextEmbedder(TextEmbedder): | ||||||||||||||||||
model: LLM | ||||||||||||||||||
options: Options # not currently used, torch hardcoded | ||||||||||||||||||
|
||||||||||||||||||
def __init__(self, model_name_or_path: str, **options: Any): | ||||||||||||||||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) | ||||||||||||||||||
max_model_len = getattr(config, "n_ctx", None) or getattr(config, "max_position_embeddings", None) | ||||||||||||||||||
# Let vLLM automatically determine the optimal dtype to use based on the model config file. | ||||||||||||||||||
self.model = LLM( | ||||||||||||||||||
model=model_name_or_path, | ||||||||||||||||||
max_num_batched_tokens=max_model_len, | ||||||||||||||||||
task="embed", | ||||||||||||||||||
) | ||||||||||||||||||
self.options = options | ||||||||||||||||||
|
||||||||||||||||||
def embed_text(self, text: list[str]) -> list[Embedding]: | ||||||||||||||||||
outputs = self.model.embed(text) | ||||||||||||||||||
embeddings = torch.tensor([o.outputs.embedding for o in outputs]) | ||||||||||||||||||
return embeddings.cpu().numpy() | ||||||||||||||||||
Comment on lines
+55
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Return type annotation says
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 The |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from daft.ai.provider import Provider | ||
|
||
if TYPE_CHECKING: | ||
from daft.ai.protocols import TextEmbedderDescriptor | ||
from daft.ai.typing import Options | ||
|
||
|
||
class VLLMProvider(Provider): | ||
_name: str | ||
_options: Options | ||
|
||
def __init__(self, name: str | None = None, **options: Any): | ||
self._name = name if name else "vllm" | ||
self._options = options | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: The |
||
|
||
@property | ||
def name(self) -> str: | ||
return self._name | ||
|
||
def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor: | ||
from daft.ai.vllm.protocols.text_embedder import vLLMTextEmbedderDescriptor | ||
|
||
return vLLMTextEmbedderDescriptor(model or "sentence-transformers/all-MiniLM-L6-v2", options) | ||
Comment on lines
+24
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should include the provider name and options in the descriptor. Please see OpenAITextEmbedderDescriptor for an example. The provider name is for answering "who produced this descriptor?" and the options are useful when the provider requires late-initialization of 'something' which for OpenAI is the client — the provider takes the client options, but ultimately the must be plumbed into the descriptor->model instantiation. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
pytest.importorskip("vllm") | ||
|
||
from unittest.mock import Mock, patch | ||
|
||
import numpy as np | ||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput | ||
|
||
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor | ||
from daft.ai.vllm.provider import VLLMProvider | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model, embedding_dim", | ||
[ | ||
("sentence-transformers/all-MiniLM-L6-v2", 384), | ||
("Qwen/Qwen3-Embedding-0.6B", 1024), | ||
], | ||
) | ||
def test_vllm_text_embedder(model, embedding_dim): | ||
text_data = [ | ||
"Either the pipeline had many steps, or Desmond processed the data very slowly, for he had plenty of time as he streamed along to look about him and to wonder what was going to happen next.", | ||
"First, he tried to peek ahead and see what partition he was coming to, but it was too opaque to see anything; then he looked at the sides of the pipeline and noticed that they were filled with changing schemas and modalities;", | ||
"here and there he saw potential OOMs and SEGFAULTs pinned upon walls of logs.", | ||
"He pulled a file from one of the shelves as he passed; it was labelled 'DELTA TABLE', but to his great disappointment it was just an empty Parquet: he did not like to drop the file for fear of corrupting somebody's downstream job, so managed to tuck it back into a catalog as he streamed past it.", | ||
"'Well!' thought Desmond to himself, 'after such a fall as this, I shall think nothing of GPU inference! How brave they'll all think me at home! Why, I wouldn't say anything about it, even if I fell off the top of production!' (Which was very likely true.)", | ||
"Down, down, down. Would the pipeline never come to an end? 'I wonder how many terabytes I've shuffled by this time?' he said aloud. 'I must be getting somewhere near the centre of the datalake.'", | ||
"Let me see: that would be four hundred million rows, I think— (for, you see, Desmond had learnt several things of this sort in his adventures with Daft, and though this was not a very good opportunity for showing off his scaling knowledge, as there was no one to listen to him, still it was good practice to say it over)", | ||
"—yes, that's about the right size—but then I wonder what cluster I've got to? (Desmond had no idea what a cluster was, but thought it was a nice grand word to say.)", | ||
"Presently he began again. 'I wonder if I shall fall right through the datalake! How funny it'll seem to come out among the people that walk with their queries upside down! The Eventualites, I think—'", | ||
"(he was rather glad there was no one listening, this time, as it didn't sound at all the right word) '—but I shall have to ask them what the name of the platform is, you know. Please, Ma'am, is this Ev or Eventual?'", | ||
"(and he tried to bow politely as he spoke—fancy bowing while streaming through compute nodes! Do you think you could manage it?) 'And what an ignorant fellow they'll think me for asking! No, it'll never do to ask: perhaps I shall see it written up somewhere—on a dashboard, or maybe in the logs.'", | ||
"Down, down, down. There was nothing else to do, so Desmond soon began talking again. 'I wonder who'll miss me while I'm debugging to-night, I should think! I hope they'll remember to check the metrics.'", | ||
"'Ah, Daft my dear! I wish you were down here with me! There are no bugs in the air, I'm afraid, but you might catch a straggling morsel, and that's very like a microbatch, you know.'", | ||
"'But do engines eat morsels, I wonder?' And here Desmond began to get rather sleepy, and went on saying to himself, in a dreamy sort of way, 'Do engines eat morsels? Do morsels eat engines?'", | ||
"and sometimes, 'Do Eventuals eat Dafts?' for, you see, as he couldn't answer either question, it didn't much matter which way he put it.", | ||
"He felt that he was dozing off, and had just begun to dream that he was walking hand in hand with Daft itself, and saying to it very earnestly, 'Now, tell me the truth: did you ever process a batched bat?'", | ||
"when suddenly—thump! thump!—down he landed upon a heap of logs and job reports, and the pipeline was over.", | ||
] | ||
|
||
def mock_embedding_response(input_data): | ||
if isinstance(input_data, list): | ||
num_texts = len(input_data) | ||
else: | ||
num_texts = 1 | ||
|
||
embedding_values = [0.1] * embedding_dim | ||
outputs = EmbeddingOutput(embedding=embedding_values) | ||
return [ | ||
EmbeddingRequestOutput(request_id=Mock(), outputs=outputs, prompt_token_ids=Mock(), finished=Mock()) | ||
] * num_texts | ||
|
||
with patch("daft.ai.vllm.protocols.text_embedder.LLM") as MockLLM: | ||
instance = MockLLM.return_value | ||
instance.embed.side_effect = lambda input_data, *args, **kwargs: mock_embedding_response(input_data) | ||
|
||
descriptor = VLLMProvider().get_text_embedder(model=model) | ||
assert isinstance(descriptor, TextEmbedderDescriptor) | ||
assert descriptor.get_provider() == "vllm" | ||
assert descriptor.get_model() == model | ||
assert descriptor.get_dimensions().size == embedding_dim | ||
|
||
embedder = descriptor.instantiate() | ||
assert isinstance(embedder, TextEmbedder) | ||
embeddings = embedder.embed_text(text_data) | ||
assert isinstance(embeddings, np.ndarray) | ||
assert embeddings.shape == (len(text_data), embedding_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That easy huh? 😄
Do you know if we need to do any batching on our own? I'm fine with just firing away to vLLM until we learn more, mostly asking out of curiosity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vLLM does batching internally. Tbf there's a lot more that can be done here, but I want to move on and support endpoints like gemini (and maybe hack together something for ev)