-
Notifications
You must be signed in to change notification settings - Fork 296
feat(text_embed): Add vLLM as a provider #5136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Summary
This PR adds vLLM as a new text embedding provider to Daft's AI framework, following the established provider pattern used by existing providers like OpenAI and SentenceTransformers. The implementation includes four main components:
-
Provider Class (
daft/ai/vllm/provider.py
): ImplementsVLLMProvider
with the standard constructor pattern and aget_text_embedder
method that returns avLLMTextEmbedderDescriptor
. -
Protocol Implementation (
daft/ai/vllm/protocols/text_embedder.py
): Contains the core embedding logic withvLLMTextEmbedderDescriptor
andvLLMTextEmbedder
classes. The descriptor handles model configuration and dimension detection usingAutoConfig.from_pretrained()
, while the embedder wraps a vLLM LLM instance for actual text embedding operations. -
Provider Registry Update (
daft/ai/provider.py
): Adds aload_vllm()
function and registers it in thePROVIDERS
dictionary, enabling users to instantiate the provider viaload_provider("vllm")
. -
Test Coverage (
tests/ai/test_vllm.py
): Provides comprehensive unit tests using mocking to validate the provider interface and embedding functionality without requiring actual vLLM dependencies.
The implementation integrates seamlessly with Daft's existing AI infrastructure, allowing users to leverage vLLM's high-performance inference capabilities for text embedding tasks. vLLM is known for its optimized inference engine for large language models, making this addition valuable for production environments requiring fast embedding generation. The code follows established patterns for lazy dependency loading, consistent interfaces, and proper error handling with ProviderImportError
when vLLM is not installed.
Confidence score: 4/5
- This PR is safe to merge with only minor issues that don't affect core functionality
- Score reflects solid implementation following established patterns, but with a few inconsistencies in option handling and some missing validation
- Pay close attention to
daft/ai/vllm/provider.py
for the unused_options
parameter anddaft/ai/vllm/protocols/text_embedder.py
for potential error handling improvements
5 files reviewed, 2 comments
|
||
def __init__(self, name: str | None = None, **options: Any): | ||
self._name = name if name else "vllm" | ||
self._options = options |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The _options
are stored but never used in get_text_embedder
. Consider passing them to the descriptor like OpenAIProvider does, or remove if truly unused.
def embed_text(self, text: list[str]) -> list[Embedding]: | ||
outputs = self.model.embed(text) | ||
embeddings = torch.tensor([o.outputs.embedding for o in outputs]) | ||
return embeddings.cpu().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Return type annotation says list[Embedding]
but returns numpy array. Should return embeddings.cpu().numpy().tolist()
or update annotation.
def embed_text(self, text: list[str]) -> list[Embedding]: | |
outputs = self.model.embed(text) | |
embeddings = torch.tensor([o.outputs.embedding for o in outputs]) | |
return embeddings.cpu().numpy() | |
def embed_text(self, text: list[str]) -> list[Embedding]: | |
outputs = self.model.embed(text) | |
embeddings = torch.tensor([o.outputs.embedding for o in outputs]) | |
return embeddings.cpu().numpy().tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
The Embedding
type is currently an alias to np.ndarray, but these protocols should return a list[np.ndarray].
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #5136 +/- ##
==========================================
- Coverage 76.54% 71.92% -4.63%
==========================================
Files 953 955 +2
Lines 130653 130581 -72
==========================================
- Hits 100006 93918 -6088
- Misses 30647 36663 +6016
🚀 New features to boost your workflow:
|
def embed_text(self, text: list[str]) -> list[Embedding]: | ||
outputs = self.model.embed(text) | ||
embeddings = torch.tensor([o.outputs.embedding for o in outputs]) | ||
return embeddings.cpu().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
The Embedding
type is currently an alias to np.ndarray, but these protocols should return a list[np.ndarray].
def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor: | ||
from daft.ai.vllm.protocols.text_embedder import vLLMTextEmbedderDescriptor | ||
|
||
return vLLMTextEmbedderDescriptor(model or "sentence-transformers/all-MiniLM-L6-v2", options) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should include the provider name and options in the descriptor. Please see OpenAITextEmbedderDescriptor for an example. The provider name is for answering "who produced this descriptor?" and the options are useful when the provider requires late-initialization of 'something' which for OpenAI is the client — the provider takes the client options, but ultimately the must be plumbed into the descriptor->model instantiation.
self.options = options | ||
|
||
def embed_text(self, text: list[str]) -> list[Embedding]: | ||
outputs = self.model.embed(text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That easy huh? 😄
Do you know if we need to do any batching on our own? I'm fine with just firing away to vLLM until we learn more, mostly asking out of curiosity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vLLM does batching internally. Tbf there's a lot more that can be done here, but I want to move on and support endpoints like gemini (and maybe hack together something for ev)
Changes Made
Adds vLLM as a provider for text embedding.