@@ -18,16 +18,18 @@ package ollama
18
18
19
19
import (
20
20
"encoding/json"
21
+ "fmt"
22
+ "github.com/google-gemini/proxy-to-gemini/internal"
23
+ "github.com/google/generative-ai-go/genai"
24
+ "github.com/gorilla/mux"
25
+ "google.golang.org/api/iterator"
21
26
"io"
22
27
"net/http"
23
28
"strings"
24
29
"time"
25
-
26
- "github.com/google-gemini/proxy-to-gemini/internal"
27
- "github.com/google/generative-ai-go/genai"
28
- "github.com/gorilla/mux"
29
30
)
30
31
32
+ // handlers provides HTTP handlers for the Ollama proxy API.
31
33
type handlers struct {
32
34
client * genai.Client
33
35
}
@@ -36,6 +38,7 @@ func RegisterHandlers(r *mux.Router, client *genai.Client) {
36
38
handlers := & handlers {client : client }
37
39
r .HandleFunc ("/api/generate" , handlers .generateHandler )
38
40
r .HandleFunc ("/api/embed" , handlers .embedHandler )
41
+ r .HandleFunc ("/api/chat" , handlers .chatHandler )
39
42
}
40
43
41
44
func (h * handlers ) generateHandler (w http.ResponseWriter , r * http.Request ) {
@@ -142,6 +145,245 @@ func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) {
142
145
}
143
146
}
144
147
148
+ // ChatRequest represents a chat completion request for the Ollama API.
149
+ type ChatRequest struct {
150
+ Model string `json:"model,omitempty"`
151
+ Messages []ChatMessage `json:"messages,omitempty"`
152
+ Stream bool `json:"stream,omitempty"`
153
+ Format json.RawMessage `json:"format,omitempty"`
154
+ Options Options `json:"options,omitempty"`
155
+ }
156
+
157
+ // ChatMessage represents a single message in a chat.
158
+ type ChatMessage struct {
159
+ Role string `json:"role,omitempty"`
160
+ Content string `json:"content,omitempty"`
161
+ Images []string `json:"images,omitempty"`
162
+ }
163
+
164
+ // ChatResponse represents a chat completion response for the Ollama API.
165
+ type ChatResponse struct {
166
+ Model string `json:"model,omitempty"`
167
+ CreatedAt time.Time `json:"created_at,omitempty"`
168
+ Message ChatMessage `json:"message,omitempty"`
169
+ Done bool `json:"done,omitempty"`
170
+ PromptEvalCount int32 `json:"prompt_eval_count"`
171
+ EvalCount int32 `json:"eval_count"`
172
+ TotalDuration int64 `json:"total_duration"`
173
+ LoadDuration int64 `json:"load_duration,omitempty"`
174
+ PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
175
+ EvalDuration int64 `json:"eval_duration"`
176
+ }
177
+
178
+ func sanitizeJson (s string ) string {
179
+ s = strings .ReplaceAll (s , "\n " , "" )
180
+ s = strings .TrimPrefix (s , "```json" )
181
+ s = strings .TrimSuffix (s , "```" )
182
+ s = strings .ReplaceAll (s , "'" , "\\ '" )
183
+ return s
184
+ }
185
+
186
+ // chatHandler handles POST /api/chat requests.
187
+ func (h * handlers ) chatHandler (w http.ResponseWriter , r * http.Request ) {
188
+ if r .Method != http .MethodPost {
189
+ internal .ErrorHandler (w , r , http .StatusMethodNotAllowed , "method not allowed" )
190
+ return
191
+ }
192
+ body , err := io .ReadAll (r .Body )
193
+ if err != nil {
194
+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to read request body: %v" , err )
195
+ return
196
+ }
197
+ defer r .Body .Close ()
198
+
199
+ var req ChatRequest
200
+ if err := json .Unmarshal (body , & req ); err != nil {
201
+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to unmarshal chat request: %v" , err )
202
+ return
203
+ }
204
+
205
+ // Handle advanced format parameter: JSON mode or JSON schema enforcement
206
+ expectJson := false
207
+ if len (req .Format ) > 0 {
208
+ expectJson = true
209
+ var formatVal interface {}
210
+ if err := json .Unmarshal (req .Format , & formatVal ); err != nil {
211
+ internal .ErrorHandler (w , r , http .StatusBadRequest , "invalid format parameter: %v" , err )
212
+ return
213
+ }
214
+ var instr string
215
+ switch v := formatVal .(type ) {
216
+ case string :
217
+ if v == "json" {
218
+ instr = "Please respond with valid JSON."
219
+ } else {
220
+ instr = fmt .Sprintf ("Please respond with format: %s." , v )
221
+ }
222
+ default :
223
+ schemaBytes , err := json .MarshalIndent (v , "" , " " )
224
+ if err != nil {
225
+ schemaBytes = req .Format
226
+ }
227
+ instr = fmt .Sprintf ("Please format your response according to the following JSON schema:\n %s" , string (schemaBytes ))
228
+ }
229
+ // Integrate with existing system message if present
230
+ found := false
231
+ for i , m := range req .Messages {
232
+ if m .Role == "system" {
233
+ req .Messages [i ].Content = m .Content + "\n \n " + instr
234
+ found = true
235
+ break
236
+ }
237
+ }
238
+ if ! found {
239
+ req .Messages = append ([]ChatMessage {{Role : "system" , Content : instr }}, req .Messages ... )
240
+ }
241
+ }
242
+
243
+ model := h .client .GenerativeModel (req .Model )
244
+ model .GenerationConfig = genai.GenerationConfig {
245
+ Temperature : req .Options .Temperature ,
246
+ MaxOutputTokens : req .Options .NumPredict ,
247
+ TopK : req .Options .TopK ,
248
+ TopP : req .Options .TopP ,
249
+ }
250
+ if req .Options .Stop != nil {
251
+ model .GenerationConfig .StopSequences = []string {* req .Options .Stop }
252
+ }
253
+
254
+ chat := model .StartChat ()
255
+ var lastPart genai.Part
256
+ for i , m := range req .Messages {
257
+ if m .Role == "system" {
258
+ model .SystemInstruction = & genai.Content {
259
+ Role : m .Role ,
260
+ Parts : []genai.Part {genai .Text (m .Content )},
261
+ }
262
+ continue
263
+ }
264
+ if i == len (req .Messages )- 1 {
265
+ lastPart = genai .Text (m .Content )
266
+ break
267
+ }
268
+ chat .History = append (chat .History , & genai.Content {
269
+ Role : m .Role ,
270
+ Parts : []genai.Part {genai .Text (m .Content )},
271
+ })
272
+ }
273
+
274
+ if req .Stream {
275
+ h .streamingChatHandler (w , r , req .Model , chat , lastPart )
276
+ return
277
+ }
278
+ // Measure time spent generating the chat response
279
+ start := time .Now ()
280
+
281
+ gresp , err := chat .SendMessage (r .Context (), lastPart )
282
+ if err != nil {
283
+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to send chat message: %v" , err )
284
+ return
285
+ }
286
+ var builder strings.Builder
287
+ if len (gresp .Candidates ) > 0 {
288
+ for _ , part := range gresp .Candidates [0 ].Content .Parts {
289
+ if txt , ok := part .(genai.Text ); ok {
290
+ builder .WriteString (string (txt ))
291
+ }
292
+ }
293
+ }
294
+
295
+ var resp ChatResponse
296
+ if expectJson {
297
+ resp = ChatResponse {
298
+ Model : req .Model ,
299
+ CreatedAt : time .Now (),
300
+ Message : ChatMessage {
301
+ Role : gresp .Candidates [0 ].Content .Role ,
302
+ Content : sanitizeJson (builder .String ()),
303
+ },
304
+ Done : true ,
305
+ }
306
+ } else {
307
+ resp = ChatResponse {
308
+ Model : req .Model ,
309
+ CreatedAt : time .Now (),
310
+ Message : ChatMessage {
311
+ Role : gresp .Candidates [0 ].Content .Role ,
312
+ Content : builder .String (),
313
+ },
314
+ Done : true ,
315
+ }
316
+ }
317
+
318
+ if gresp .UsageMetadata != nil {
319
+ resp .PromptEvalCount = gresp .UsageMetadata .PromptTokenCount
320
+ // Compute number of tokens in the response.
321
+ if gresp .UsageMetadata .CandidatesTokenCount > 0 {
322
+ resp .EvalCount = gresp .UsageMetadata .CandidatesTokenCount
323
+ } else if gresp .UsageMetadata .TotalTokenCount >= gresp .UsageMetadata .PromptTokenCount {
324
+ // Fallback: use total tokens minus prompt tokens
325
+ resp .EvalCount = gresp .UsageMetadata .TotalTokenCount - gresp .UsageMetadata .PromptTokenCount
326
+ }
327
+ }
328
+ // Populate duration metadata (in nanoseconds)
329
+ elapsed := time .Since (start ).Nanoseconds ()
330
+ resp .TotalDuration = elapsed
331
+ resp .LoadDuration = 0
332
+ resp .PromptEvalDuration = 0
333
+ resp .EvalDuration = elapsed
334
+ if err := json .NewEncoder (w ).Encode (& resp ); err != nil {
335
+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to encode chat response: %v" , err )
336
+ return
337
+ }
338
+ }
339
+
340
+ // streamingChatHandler handles streaming chat responses for /api/chat.
341
+ func (h * handlers ) streamingChatHandler (w http.ResponseWriter , r * http.Request , modelName string , chat * genai.ChatSession , lastPart genai.Part ) {
342
+ // Measure total elapsed time for streaming
343
+ start := time .Now ()
344
+ iter := chat .SendMessageStream (r .Context (), lastPart )
345
+ for {
346
+ gresp , err := iter .Next ()
347
+ if err == iterator .Done {
348
+ break
349
+ }
350
+ if err != nil {
351
+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to stream chat response: %v" , err )
352
+ return
353
+ }
354
+ var builder strings.Builder
355
+ if len (gresp .Candidates ) > 0 {
356
+ for _ , part := range gresp .Candidates [0 ].Content .Parts {
357
+ if txt , ok := part .(genai.Text ); ok {
358
+ builder .WriteString (string (txt ))
359
+ }
360
+ }
361
+ }
362
+ // Build streaming chunk with duration metadata
363
+ elapsed := time .Since (start ).Nanoseconds ()
364
+ chunk := ChatResponse {
365
+ Model : modelName ,
366
+ CreatedAt : time .Now (),
367
+ Message : ChatMessage {
368
+ Role : gresp .Candidates [0 ].Content .Role ,
369
+ Content : builder .String (),
370
+ },
371
+ Done : false ,
372
+ TotalDuration : elapsed ,
373
+ LoadDuration : 0 ,
374
+ PromptEvalDuration : 0 ,
375
+ EvalDuration : elapsed ,
376
+ }
377
+ data , err := json .Marshal (chunk )
378
+ if err != nil {
379
+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to marshal chat chunk: %v" , err )
380
+ return
381
+ }
382
+ fmt .Fprintf (w , "data: %s\n " , data )
383
+ }
384
+ fmt .Fprint (w , "data: [DONE]\n " )
385
+ }
386
+
145
387
type GenerateRequest struct {
146
388
Model string `json:"model,omitempty"`
147
389
Prompt string `json:"prompt,omitempty"`
0 commit comments