Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,54 @@ $ curl http://127.0.0.1:5555/api/embed \
{"model":"text-embedding-004","embeddings":[[0.04824496,0.0117766075,-0.011552069,-0.018164534,-0.0026110192,0.05092675,0.08172899,0.007869772,0.054475933,0.026131334,-0.06593486,-0.002256868,0.038781915,...]]}
```

Create chat completions:

```sh
$ curl http://127.0.0.1:5555/api/chat \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-1.5-pro",
"messages": [
{ "role": "user", "content": "Hello, world!" }
]
}'
```

Chat response example:

```json
{
"model": "gemini-1.5-pro",
"created_at": "2024-07-28T15:00:00Z",
"message": { "role": "assistant", "content": "Hello back to you!" },
"done": true,
"total_duration": 123456789,
"prompt_eval_count": 5,
"eval_count": 10
}
```

Advanced usage:

- Use the `format` parameter to request JSON or enforce a JSON schema.
- Use the `options` parameter to set model generation parameters: `temperature`, `num_predict`, `top_k`, and `top_p`.

```sh
$ curl http://127.0.0.1:5555/api/chat \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-1.5-pro",
"messages": [...],
"format": "json",
"options": { "temperature": 0.5, "num_predict": 100 }
}'
```

### Known Ollama Limitations
* Streaming is not yet supported.
* Images are not supported.
* Response format is not supported.
* Response format is supported only for chat API.
* Tools are not supported for chat API.
* Model parameters not supported by Gemini are ignored.

## Notes
Expand Down
197 changes: 193 additions & 4 deletions ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ package ollama

import (
"encoding/json"
"fmt"
"github.com/google-gemini/proxy-to-gemini/internal"
"github.com/google/generative-ai-go/genai"
"github.com/gorilla/mux"
"io"
"net/http"
"strings"
"time"

"github.com/google-gemini/proxy-to-gemini/internal"
"github.com/google/generative-ai-go/genai"
"github.com/gorilla/mux"
)

// handlers provides HTTP handlers for the Ollama proxy API.
type handlers struct {
client *genai.Client
}
Expand All @@ -36,6 +37,7 @@ func RegisterHandlers(r *mux.Router, client *genai.Client) {
handlers := &handlers{client: client}
r.HandleFunc("/api/generate", handlers.generateHandler)
r.HandleFunc("/api/embed", handlers.embedHandler)
r.HandleFunc("/api/chat", handlers.chatHandler)
}

func (h *handlers) generateHandler(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -142,6 +144,193 @@ func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) {
}
}

// ChatRequest represents a chat completion request for the Ollama API.
type ChatRequest struct {
Model string `json:"model,omitempty"`
Messages []ChatMessage `json:"messages,omitempty"`
Format json.RawMessage `json:"format,omitempty"`
Options Options `json:"options,omitempty"`
}

// ChatMessage represents a single message in a chat.
type ChatMessage struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
}

// ChatResponse represents a chat completion response for the Ollama API.
type ChatResponse struct {
Model string `json:"model,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`
Message ChatMessage `json:"message,omitempty"`
Done bool `json:"done,omitempty"`
PromptEvalCount int32 `json:"prompt_eval_count"`
EvalCount int32 `json:"eval_count"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration,omitempty"`
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
EvalDuration int64 `json:"eval_duration"`
}

func sanitizeJson(s string) string {
s = strings.ReplaceAll(s, "\n", "")
s = strings.TrimPrefix(s, "```json")
s = strings.TrimSuffix(s, "```")
s = strings.ReplaceAll(s, "'", "\\'")
return s
}

// chatHandler handles POST /api/chat requests.
func (h *handlers) chatHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
internal.ErrorHandler(w, r, http.StatusMethodNotAllowed, "method not allowed")
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to read request body: %v", err)
return
}
defer r.Body.Close()

var req ChatRequest
if err := json.Unmarshal(body, &req); err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to unmarshal chat request: %v", err)
return
}

// Handle advanced format parameter: JSON mode or JSON schema enforcement
expectJson := false
if len(req.Format) > 0 {
expectJson = true
var formatVal interface{}
if err := json.Unmarshal(req.Format, &formatVal); err != nil {
internal.ErrorHandler(w, r, http.StatusBadRequest, "invalid format parameter: %v", err)
return
}
var instr string
switch v := formatVal.(type) {
case string:
if v == "json" {
instr = "Please respond with valid JSON."
} else {
instr = fmt.Sprintf("Please respond with format: %s.", v)
}
default:
schemaBytes, err := json.MarshalIndent(v, "", " ")
if err != nil {
schemaBytes = req.Format
}
instr = fmt.Sprintf("Please format your response according to the following JSON schema:\n%s", string(schemaBytes))
}
// Integrate with existing system message if present
found := false
for i, m := range req.Messages {
if m.Role == "system" {
req.Messages[i].Content = m.Content + "\n\n" + instr
found = true
break
}
}
if !found {
req.Messages = append([]ChatMessage{{Role: "system", Content: instr}}, req.Messages...)
}
}

model := h.client.GenerativeModel(req.Model)
model.GenerationConfig = genai.GenerationConfig{
Temperature: req.Options.Temperature,
MaxOutputTokens: req.Options.NumPredict,
TopK: req.Options.TopK,
TopP: req.Options.TopP,
}
if req.Options.Stop != nil {
model.GenerationConfig.StopSequences = []string{*req.Options.Stop}
}

chat := model.StartChat()
var lastPart genai.Part
for i, m := range req.Messages {
if m.Role == "system" {
model.SystemInstruction = &genai.Content{
Role: m.Role,
Parts: []genai.Part{genai.Text(m.Content)},
}
continue
}
if i == len(req.Messages)-1 {
lastPart = genai.Text(m.Content)
break
}
chat.History = append(chat.History, &genai.Content{
Role: m.Role,
Parts: []genai.Part{genai.Text(m.Content)},
})
}

// Measure time spent generating the chat response
start := time.Now()

gresp, err := chat.SendMessage(r.Context(), lastPart)
if err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to send chat message: %v", err)
return
}
var builder strings.Builder
if len(gresp.Candidates) > 0 {
for _, part := range gresp.Candidates[0].Content.Parts {
if txt, ok := part.(genai.Text); ok {
builder.WriteString(string(txt))
}
}
}

var resp ChatResponse
if expectJson {
resp = ChatResponse{
Model: req.Model,
CreatedAt: time.Now(),
Message: ChatMessage{
Role: gresp.Candidates[0].Content.Role,
Content: sanitizeJson(builder.String()),
},
Done: true,
}
} else {
resp = ChatResponse{
Model: req.Model,
CreatedAt: time.Now(),
Message: ChatMessage{
Role: gresp.Candidates[0].Content.Role,
Content: builder.String(),
},
Done: true,
}
}

if gresp.UsageMetadata != nil {
resp.PromptEvalCount = gresp.UsageMetadata.PromptTokenCount
// Compute number of tokens in the response.
if gresp.UsageMetadata.CandidatesTokenCount > 0 {
resp.EvalCount = gresp.UsageMetadata.CandidatesTokenCount
} else if gresp.UsageMetadata.TotalTokenCount >= gresp.UsageMetadata.PromptTokenCount {
// Fallback: use total tokens minus prompt tokens
resp.EvalCount = gresp.UsageMetadata.TotalTokenCount - gresp.UsageMetadata.PromptTokenCount
}
}
// Populate duration metadata (in nanoseconds)
elapsed := time.Since(start).Nanoseconds()
resp.TotalDuration = elapsed
resp.LoadDuration = 0
resp.PromptEvalDuration = 0
resp.EvalDuration = elapsed
if err := json.NewEncoder(w).Encode(&resp); err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to encode chat response: %v", err)
return
}
}

type GenerateRequest struct {
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
Expand Down