diff --git a/README.md b/README.md index 43e380d..5e04f2a 100644 --- a/README.md +++ b/README.md @@ -145,10 +145,54 @@ $ curl http://127.0.0.1:5555/api/embed \ {"model":"text-embedding-004","embeddings":[[0.04824496,0.0117766075,-0.011552069,-0.018164534,-0.0026110192,0.05092675,0.08172899,0.007869772,0.054475933,0.026131334,-0.06593486,-0.002256868,0.038781915,...]]} ``` +Create chat completions: + +```sh +$ curl http://127.0.0.1:5555/api/chat \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-1.5-pro", + "messages": [ + { "role": "user", "content": "Hello, world!" } + ] + }' +``` + +Chat response example: + +```json +{ + "model": "gemini-1.5-pro", + "created_at": "2024-07-28T15:00:00Z", + "message": { "role": "assistant", "content": "Hello back to you!" }, + "done": true, + "total_duration": 123456789, + "prompt_eval_count": 5, + "eval_count": 10 +} +``` + +Advanced usage: + +- Use the `format` parameter to request JSON or enforce a JSON schema. +- Use the `options` parameter to set model generation parameters: `temperature`, `num_predict`, `top_k`, and `top_p`. + +```sh +$ curl http://127.0.0.1:5555/api/chat \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gemini-1.5-pro", + "messages": [...], + "format": "json", + "options": { "temperature": 0.5, "num_predict": 100 } + }' +``` + ### Known Ollama Limitations * Streaming is not yet supported. * Images are not supported. -* Response format is not supported. +* Response format is supported only for chat API. +* Tools are not supported for chat API. * Model parameters not supported by Gemini are ignored. ## Notes diff --git a/ollama/ollama.go b/ollama/ollama.go index 78c4c87..22000d8 100644 --- a/ollama/ollama.go +++ b/ollama/ollama.go @@ -18,16 +18,17 @@ package ollama import ( "encoding/json" + "fmt" + "github.com/google-gemini/proxy-to-gemini/internal" + "github.com/google/generative-ai-go/genai" + "github.com/gorilla/mux" "io" "net/http" "strings" "time" - - "github.com/google-gemini/proxy-to-gemini/internal" - "github.com/google/generative-ai-go/genai" - "github.com/gorilla/mux" ) +// handlers provides HTTP handlers for the Ollama proxy API. type handlers struct { client *genai.Client } @@ -36,6 +37,7 @@ func RegisterHandlers(r *mux.Router, client *genai.Client) { handlers := &handlers{client: client} r.HandleFunc("/api/generate", handlers.generateHandler) r.HandleFunc("/api/embed", handlers.embedHandler) + r.HandleFunc("/api/chat", handlers.chatHandler) } func (h *handlers) generateHandler(w http.ResponseWriter, r *http.Request) { @@ -142,6 +144,193 @@ func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) { } } +// ChatRequest represents a chat completion request for the Ollama API. +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages []ChatMessage `json:"messages,omitempty"` + Format json.RawMessage `json:"format,omitempty"` + Options Options `json:"options,omitempty"` +} + +// ChatMessage represents a single message in a chat. +type ChatMessage struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` +} + +// ChatResponse represents a chat completion response for the Ollama API. +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + Message ChatMessage `json:"message,omitempty"` + Done bool `json:"done,omitempty"` + PromptEvalCount int32 `json:"prompt_eval_count"` + EvalCount int32 `json:"eval_count"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration,omitempty"` + PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"` + EvalDuration int64 `json:"eval_duration"` +} + +func sanitizeJson(s string) string { + s = strings.ReplaceAll(s, "\n", "") + s = strings.TrimPrefix(s, "```json") + s = strings.TrimSuffix(s, "```") + s = strings.ReplaceAll(s, "'", "\\'") + return s +} + +// chatHandler handles POST /api/chat requests. +func (h *handlers) chatHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + internal.ErrorHandler(w, r, http.StatusMethodNotAllowed, "method not allowed") + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to read request body: %v", err) + return + } + defer r.Body.Close() + + var req ChatRequest + if err := json.Unmarshal(body, &req); err != nil { + internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to unmarshal chat request: %v", err) + return + } + + // Handle advanced format parameter: JSON mode or JSON schema enforcement + expectJson := false + if len(req.Format) > 0 { + expectJson = true + var formatVal interface{} + if err := json.Unmarshal(req.Format, &formatVal); err != nil { + internal.ErrorHandler(w, r, http.StatusBadRequest, "invalid format parameter: %v", err) + return + } + var instr string + switch v := formatVal.(type) { + case string: + if v == "json" { + instr = "Please respond with valid JSON." + } else { + instr = fmt.Sprintf("Please respond with format: %s.", v) + } + default: + schemaBytes, err := json.MarshalIndent(v, "", " ") + if err != nil { + schemaBytes = req.Format + } + instr = fmt.Sprintf("Please format your response according to the following JSON schema:\n%s", string(schemaBytes)) + } + // Integrate with existing system message if present + found := false + for i, m := range req.Messages { + if m.Role == "system" { + req.Messages[i].Content = m.Content + "\n\n" + instr + found = true + break + } + } + if !found { + req.Messages = append([]ChatMessage{{Role: "system", Content: instr}}, req.Messages...) + } + } + + model := h.client.GenerativeModel(req.Model) + model.GenerationConfig = genai.GenerationConfig{ + Temperature: req.Options.Temperature, + MaxOutputTokens: req.Options.NumPredict, + TopK: req.Options.TopK, + TopP: req.Options.TopP, + } + if req.Options.Stop != nil { + model.GenerationConfig.StopSequences = []string{*req.Options.Stop} + } + + chat := model.StartChat() + var lastPart genai.Part + for i, m := range req.Messages { + if m.Role == "system" { + model.SystemInstruction = &genai.Content{ + Role: m.Role, + Parts: []genai.Part{genai.Text(m.Content)}, + } + continue + } + if i == len(req.Messages)-1 { + lastPart = genai.Text(m.Content) + break + } + chat.History = append(chat.History, &genai.Content{ + Role: m.Role, + Parts: []genai.Part{genai.Text(m.Content)}, + }) + } + + // Measure time spent generating the chat response + start := time.Now() + + gresp, err := chat.SendMessage(r.Context(), lastPart) + if err != nil { + internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to send chat message: %v", err) + return + } + var builder strings.Builder + if len(gresp.Candidates) > 0 { + for _, part := range gresp.Candidates[0].Content.Parts { + if txt, ok := part.(genai.Text); ok { + builder.WriteString(string(txt)) + } + } + } + + var resp ChatResponse + if expectJson { + resp = ChatResponse{ + Model: req.Model, + CreatedAt: time.Now(), + Message: ChatMessage{ + Role: gresp.Candidates[0].Content.Role, + Content: sanitizeJson(builder.String()), + }, + Done: true, + } + } else { + resp = ChatResponse{ + Model: req.Model, + CreatedAt: time.Now(), + Message: ChatMessage{ + Role: gresp.Candidates[0].Content.Role, + Content: builder.String(), + }, + Done: true, + } + } + + if gresp.UsageMetadata != nil { + resp.PromptEvalCount = gresp.UsageMetadata.PromptTokenCount + // Compute number of tokens in the response. + if gresp.UsageMetadata.CandidatesTokenCount > 0 { + resp.EvalCount = gresp.UsageMetadata.CandidatesTokenCount + } else if gresp.UsageMetadata.TotalTokenCount >= gresp.UsageMetadata.PromptTokenCount { + // Fallback: use total tokens minus prompt tokens + resp.EvalCount = gresp.UsageMetadata.TotalTokenCount - gresp.UsageMetadata.PromptTokenCount + } + } + // Populate duration metadata (in nanoseconds) + elapsed := time.Since(start).Nanoseconds() + resp.TotalDuration = elapsed + resp.LoadDuration = 0 + resp.PromptEvalDuration = 0 + resp.EvalDuration = elapsed + if err := json.NewEncoder(w).Encode(&resp); err != nil { + internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to encode chat response: %v", err) + return + } +} + type GenerateRequest struct { Model string `json:"model,omitempty"` Prompt string `json:"prompt,omitempty"`