Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
}

if tokenCallback != nil {

if c.TemplateConfig.ReplyPrefix != "" {
tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage)
}

ss := ""

var partialRune []byte
Expand Down Expand Up @@ -165,8 +170,13 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing

response := string(reply.Message)
if c.TemplateConfig.ReplyPrefix != "" {
response = c.TemplateConfig.ReplyPrefix + response
}

return LLMResponse{
Response: string(reply.Message),
Response: response,
Usage: tokenUsage,
}, err
}
Expand Down
52 changes: 27 additions & 25 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,28 +130,28 @@ type LLMConfig struct {
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`

ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
DType string `yaml:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj"`
ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
DType string `yaml:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj"`

FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
Expand All @@ -171,9 +171,9 @@ type LLMConfig struct {

// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image"`
LimitVideoPerPrompt int `yaml:"video"`
LimitAudioPerPrompt int `yaml:"audio"`
LimitImagePerPrompt int `yaml:"image"`
LimitVideoPerPrompt int `yaml:"video"`
LimitAudioPerPrompt int `yaml:"audio"`
}

// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
Expand Down Expand Up @@ -213,6 +213,8 @@ type TemplateConfig struct {
Multimodal string `yaml:"multimodal"`

JinjaTemplate bool `yaml:"jinja_template"`

ReplyPrefix string `yaml:"reply_prefix"`
}

func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
Expand Down
Loading