Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
// If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
predInput = evaluator.TemplateMessages(input.Messages, config, funcs, shouldUseFn)
predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn)

log.Debug().Msgf("Prompt (after templating): %s", predInput)
if config.Grammar != "" {
Expand Down
12 changes: 8 additions & 4 deletions core/http/endpoints/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
predInput := config.PromptStrings[0]

templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
Input: predInput,
SystemPrompt: config.SystemPrompt,
Input: predInput,
SystemPrompt: config.SystemPrompt,
ReasoningEffort: input.ReasoningEffort,
Metadata: input.Metadata,
})
if err == nil {
predInput = templatedInput
Expand Down Expand Up @@ -160,8 +162,10 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e

for k, i := range config.PromptStrings {
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
SystemPrompt: config.SystemPrompt,
Input: i,
ReasoningEffort: input.ReasoningEffort,
Metadata: input.Metadata,
})
if err == nil {
i = templatedInput
Expand Down
8 changes: 5 additions & 3 deletions core/http/endpoints/openai/edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat

for _, i := range config.InputStrings {
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
ReasoningEffort: input.ReasoningEffort,
Metadata: input.Metadata,
})
if err == nil {
i = templatedInput
Expand Down
4 changes: 4 additions & 0 deletions core/schema/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ type OpenAIRequest struct {
Backend string `json:"backend" yaml:"backend"`

ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`

ReasoningEffort string `json:"reasoning_effort" yaml:"reasoning_effort"`

Metadata map[string]string `json:"metadata" yaml:"metadata"`
}

type ModelsDataResponse struct {
Expand Down
6 changes: 5 additions & 1 deletion core/templates/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type PromptTemplateData struct {
Instruction string
Functions []functions.Function
MessageIndex int
ReasoningEffort string
Metadata map[string]string
}

type ChatMessageTemplateData struct {
Expand Down Expand Up @@ -133,7 +135,7 @@ func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, te
return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation)
}

func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {
func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.BackendConfig, funcs []functions.Function, shouldUseFn bool) string {

if config.TemplateConfig.JinjaTemplate {
var messageData []ChatMessageTemplateData
Expand Down Expand Up @@ -283,6 +285,8 @@ func (e *Evaluator) TemplateMessages(messages []schema.Message, config *config.B
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
ReasoningEffort: input.ReasoningEffort,
Metadata: input.Metadata,
})
if err == nil {
predInput = templatedInput
Expand Down
6 changes: 3 additions & 3 deletions core/templates/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ var _ = Describe("Templates", func() {
for key := range chatMLTestMatch {
foo := chatMLTestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
Expand All @@ -232,7 +232,7 @@ var _ = Describe("Templates", func() {
for key := range llama3TestMatch {
foo := llama3TestMatch[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
Expand All @@ -245,7 +245,7 @@ var _ = Describe("Templates", func() {
for key := range jinjaTest {
foo := jinjaTest[key]
It("renders correctly `"+key+"`", func() {
templated := evaluator.TemplateMessages(foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.BackendConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool))
Expect(templated).To(Equal(foo["expected"]), templated)
})
}
Expand Down
Loading