Skip to content

Commit 1caccc5

Browse files
committed
Implement proxy for Ollama chat API
1 parent c82c356 commit 1caccc5

File tree

2 files changed

+314
-6
lines changed

2 files changed

+314
-6
lines changed

README.md

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,76 @@ $ curl http://127.0.0.1:5555/api/embed \
145145
{"model":"text-embedding-004","embeddings":[[0.04824496,0.0117766075,-0.011552069,-0.018164534,-0.0026110192,0.05092675,0.08172899,0.007869772,0.054475933,0.026131334,-0.06593486,-0.002256868,0.038781915,...]]}
146146
```
147147

148+
Create chat completions:
149+
150+
```sh
151+
$ curl http://127.0.0.1:5555/api/chat \
152+
-H "Content-Type: application/json" \
153+
-d '{
154+
"model": "gemini-1.5-pro",
155+
"messages": [
156+
{ "role": "user", "content": "Hello, world!" }
157+
]
158+
}'
159+
```
160+
161+
Non-streaming response example:
162+
163+
```json
164+
{
165+
"model": "gemini-1.5-pro",
166+
"created_at": "2024-07-28T15:00:00Z",
167+
"message": { "role": "assistant", "content": "Hello back to you!" },
168+
"done": true,
169+
"total_duration": 123456789,
170+
"prompt_eval_count": 5,
171+
"eval_count": 10
172+
}
173+
```
174+
175+
Stream the chat responses:
176+
177+
```sh
178+
$ curl http://127.0.0.1:5555/api/chat \
179+
-H "Content-Type: application/json" \
180+
-d '{
181+
"model": "gemini-1.5-pro",
182+
"messages": [
183+
{ "role": "user", "content": "Hello, world!" }
184+
],
185+
"stream": true
186+
}'
187+
```
188+
189+
Streaming responses are sent as Server-Sent Events (SSE) prefixed with `data:`. For example:
190+
191+
```
192+
data: {"model":"gemini-1.5-pro", ...}
193+
data: [DONE]
194+
```
195+
196+
Advanced usage:
197+
198+
- Use the `format` parameter to request JSON or enforce a JSON schema.
199+
- Use the `options` parameter to set model generation parameters: `temperature`, `num_predict`, `top_k`, and `top_p`.
200+
201+
```sh
202+
$ curl http://127.0.0.1:5555/api/chat \
203+
-H "Content-Type: application/json" \
204+
-d '{
205+
"model": "gemini-1.5-pro",
206+
"messages": [...],
207+
"stream": false,
208+
"format": "json",
209+
"options": { "temperature": 0.5, "num_predict": 100 }
210+
}'
211+
```
212+
148213
### Known Ollama Limitations
149-
* Streaming is not yet supported.
214+
* Streaming is supported only for chat API.
150215
* Images are not supported.
151-
* Response format is not supported.
216+
* Response format is supported only for chat API.
217+
* Tools are not supported for chat API.
152218
* Model parameters not supported by Gemini are ignored.
153219

154220
## Notes

ollama/ollama.go

Lines changed: 246 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ package ollama
1818

1919
import (
2020
"encoding/json"
21+
"fmt"
22+
"github.com/google-gemini/proxy-to-gemini/internal"
23+
"github.com/google/generative-ai-go/genai"
24+
"github.com/gorilla/mux"
25+
"google.golang.org/api/iterator"
2126
"io"
2227
"net/http"
2328
"strings"
2429
"time"
25-
26-
"github.com/google-gemini/proxy-to-gemini/internal"
27-
"github.com/google/generative-ai-go/genai"
28-
"github.com/gorilla/mux"
2930
)
3031

32+
// handlers provides HTTP handlers for the Ollama proxy API.
3133
type handlers struct {
3234
client *genai.Client
3335
}
@@ -36,6 +38,7 @@ func RegisterHandlers(r *mux.Router, client *genai.Client) {
3638
handlers := &handlers{client: client}
3739
r.HandleFunc("/api/generate", handlers.generateHandler)
3840
r.HandleFunc("/api/embed", handlers.embedHandler)
41+
r.HandleFunc("/api/chat", handlers.chatHandler)
3942
}
4043

4144
func (h *handlers) generateHandler(w http.ResponseWriter, r *http.Request) {
@@ -142,6 +145,245 @@ func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) {
142145
}
143146
}
144147

148+
// ChatRequest represents a chat completion request for the Ollama API.
149+
type ChatRequest struct {
150+
Model string `json:"model,omitempty"`
151+
Messages []ChatMessage `json:"messages,omitempty"`
152+
Stream bool `json:"stream,omitempty"`
153+
Format json.RawMessage `json:"format,omitempty"`
154+
Options Options `json:"options,omitempty"`
155+
}
156+
157+
// ChatMessage represents a single message in a chat.
158+
type ChatMessage struct {
159+
Role string `json:"role,omitempty"`
160+
Content string `json:"content,omitempty"`
161+
Images []string `json:"images,omitempty"`
162+
}
163+
164+
// ChatResponse represents a chat completion response for the Ollama API.
165+
type ChatResponse struct {
166+
Model string `json:"model,omitempty"`
167+
CreatedAt time.Time `json:"created_at,omitempty"`
168+
Message ChatMessage `json:"message,omitempty"`
169+
Done bool `json:"done,omitempty"`
170+
PromptEvalCount int32 `json:"prompt_eval_count"`
171+
EvalCount int32 `json:"eval_count"`
172+
TotalDuration int64 `json:"total_duration"`
173+
LoadDuration int64 `json:"load_duration,omitempty"`
174+
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
175+
EvalDuration int64 `json:"eval_duration"`
176+
}
177+
178+
func sanitizeJson(s string) string {
179+
s = strings.ReplaceAll(s, "\n", "")
180+
s = strings.TrimPrefix(s, "```json")
181+
s = strings.TrimSuffix(s, "```")
182+
s = strings.ReplaceAll(s, "'", "\\'")
183+
return s
184+
}
185+
186+
// chatHandler handles POST /api/chat requests.
187+
func (h *handlers) chatHandler(w http.ResponseWriter, r *http.Request) {
188+
if r.Method != http.MethodPost {
189+
internal.ErrorHandler(w, r, http.StatusMethodNotAllowed, "method not allowed")
190+
return
191+
}
192+
body, err := io.ReadAll(r.Body)
193+
if err != nil {
194+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to read request body: %v", err)
195+
return
196+
}
197+
defer r.Body.Close()
198+
199+
var req ChatRequest
200+
if err := json.Unmarshal(body, &req); err != nil {
201+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to unmarshal chat request: %v", err)
202+
return
203+
}
204+
205+
// Handle advanced format parameter: JSON mode or JSON schema enforcement
206+
expectJson := false
207+
if len(req.Format) > 0 {
208+
expectJson = true
209+
var formatVal interface{}
210+
if err := json.Unmarshal(req.Format, &formatVal); err != nil {
211+
internal.ErrorHandler(w, r, http.StatusBadRequest, "invalid format parameter: %v", err)
212+
return
213+
}
214+
var instr string
215+
switch v := formatVal.(type) {
216+
case string:
217+
if v == "json" {
218+
instr = "Please respond with valid JSON."
219+
} else {
220+
instr = fmt.Sprintf("Please respond with format: %s.", v)
221+
}
222+
default:
223+
schemaBytes, err := json.MarshalIndent(v, "", " ")
224+
if err != nil {
225+
schemaBytes = req.Format
226+
}
227+
instr = fmt.Sprintf("Please format your response according to the following JSON schema:\n%s", string(schemaBytes))
228+
}
229+
// Integrate with existing system message if present
230+
found := false
231+
for i, m := range req.Messages {
232+
if m.Role == "system" {
233+
req.Messages[i].Content = m.Content + "\n\n" + instr
234+
found = true
235+
break
236+
}
237+
}
238+
if !found {
239+
req.Messages = append([]ChatMessage{{Role: "system", Content: instr}}, req.Messages...)
240+
}
241+
}
242+
243+
model := h.client.GenerativeModel(req.Model)
244+
model.GenerationConfig = genai.GenerationConfig{
245+
Temperature: req.Options.Temperature,
246+
MaxOutputTokens: req.Options.NumPredict,
247+
TopK: req.Options.TopK,
248+
TopP: req.Options.TopP,
249+
}
250+
if req.Options.Stop != nil {
251+
model.GenerationConfig.StopSequences = []string{*req.Options.Stop}
252+
}
253+
254+
chat := model.StartChat()
255+
var lastPart genai.Part
256+
for i, m := range req.Messages {
257+
if m.Role == "system" {
258+
model.SystemInstruction = &genai.Content{
259+
Role: m.Role,
260+
Parts: []genai.Part{genai.Text(m.Content)},
261+
}
262+
continue
263+
}
264+
if i == len(req.Messages)-1 {
265+
lastPart = genai.Text(m.Content)
266+
break
267+
}
268+
chat.History = append(chat.History, &genai.Content{
269+
Role: m.Role,
270+
Parts: []genai.Part{genai.Text(m.Content)},
271+
})
272+
}
273+
274+
if req.Stream {
275+
h.streamingChatHandler(w, r, req.Model, chat, lastPart)
276+
return
277+
}
278+
// Measure time spent generating the chat response
279+
start := time.Now()
280+
281+
gresp, err := chat.SendMessage(r.Context(), lastPart)
282+
if err != nil {
283+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to send chat message: %v", err)
284+
return
285+
}
286+
var builder strings.Builder
287+
if len(gresp.Candidates) > 0 {
288+
for _, part := range gresp.Candidates[0].Content.Parts {
289+
if txt, ok := part.(genai.Text); ok {
290+
builder.WriteString(string(txt))
291+
}
292+
}
293+
}
294+
295+
var resp ChatResponse
296+
if expectJson {
297+
resp = ChatResponse{
298+
Model: req.Model,
299+
CreatedAt: time.Now(),
300+
Message: ChatMessage{
301+
Role: gresp.Candidates[0].Content.Role,
302+
Content: sanitizeJson(builder.String()),
303+
},
304+
Done: true,
305+
}
306+
} else {
307+
resp = ChatResponse{
308+
Model: req.Model,
309+
CreatedAt: time.Now(),
310+
Message: ChatMessage{
311+
Role: gresp.Candidates[0].Content.Role,
312+
Content: builder.String(),
313+
},
314+
Done: true,
315+
}
316+
}
317+
318+
if gresp.UsageMetadata != nil {
319+
resp.PromptEvalCount = gresp.UsageMetadata.PromptTokenCount
320+
// Compute number of tokens in the response.
321+
if gresp.UsageMetadata.CandidatesTokenCount > 0 {
322+
resp.EvalCount = gresp.UsageMetadata.CandidatesTokenCount
323+
} else if gresp.UsageMetadata.TotalTokenCount >= gresp.UsageMetadata.PromptTokenCount {
324+
// Fallback: use total tokens minus prompt tokens
325+
resp.EvalCount = gresp.UsageMetadata.TotalTokenCount - gresp.UsageMetadata.PromptTokenCount
326+
}
327+
}
328+
// Populate duration metadata (in nanoseconds)
329+
elapsed := time.Since(start).Nanoseconds()
330+
resp.TotalDuration = elapsed
331+
resp.LoadDuration = 0
332+
resp.PromptEvalDuration = 0
333+
resp.EvalDuration = elapsed
334+
if err := json.NewEncoder(w).Encode(&resp); err != nil {
335+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to encode chat response: %v", err)
336+
return
337+
}
338+
}
339+
340+
// streamingChatHandler handles streaming chat responses for /api/chat.
341+
func (h *handlers) streamingChatHandler(w http.ResponseWriter, r *http.Request, modelName string, chat *genai.ChatSession, lastPart genai.Part) {
342+
// Measure total elapsed time for streaming
343+
start := time.Now()
344+
iter := chat.SendMessageStream(r.Context(), lastPart)
345+
for {
346+
gresp, err := iter.Next()
347+
if err == iterator.Done {
348+
break
349+
}
350+
if err != nil {
351+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to stream chat response: %v", err)
352+
return
353+
}
354+
var builder strings.Builder
355+
if len(gresp.Candidates) > 0 {
356+
for _, part := range gresp.Candidates[0].Content.Parts {
357+
if txt, ok := part.(genai.Text); ok {
358+
builder.WriteString(string(txt))
359+
}
360+
}
361+
}
362+
// Build streaming chunk with duration metadata
363+
elapsed := time.Since(start).Nanoseconds()
364+
chunk := ChatResponse{
365+
Model: modelName,
366+
CreatedAt: time.Now(),
367+
Message: ChatMessage{
368+
Role: gresp.Candidates[0].Content.Role,
369+
Content: builder.String(),
370+
},
371+
Done: false,
372+
TotalDuration: elapsed,
373+
LoadDuration: 0,
374+
PromptEvalDuration: 0,
375+
EvalDuration: elapsed,
376+
}
377+
data, err := json.Marshal(chunk)
378+
if err != nil {
379+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to marshal chat chunk: %v", err)
380+
return
381+
}
382+
fmt.Fprintf(w, "data: %s\n", data)
383+
}
384+
fmt.Fprint(w, "data: [DONE]\n")
385+
}
386+
145387
type GenerateRequest struct {
146388
Model string `json:"model,omitempty"`
147389
Prompt string `json:"prompt,omitempty"`

0 commit comments

Comments
 (0)