Skip to content

Commit 1fc3a37

Browse files
authored
feat: inline templates and accept URLs in models (#1452)
* feat: Allow inline templates * feat: Allow to specify url in model config files Signed-off-by: Ettore Di Giacinto <[email protected]> * feat: support 'huggingface://' format * style: reuse-code from gallery --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 64a8471 commit 1fc3a37

File tree

9 files changed

+286
-133
lines changed

9 files changed

+286
-133
lines changed

api/api.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
4747
}
4848
}
4949

50+
if err := cl.Preload(options.Loader.ModelPath); err != nil {
51+
log.Error().Msgf("error downloading models: %s", err.Error())
52+
}
53+
5054
if options.Debug {
5155
for _, v := range cl.ListConfigs() {
5256
cfg, _ := cl.GetConfig(v)

api/api_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ var _ = Describe("API test", func() {
294294
Expect(content["backend"]).To(Equal("bert-embeddings"))
295295
})
296296

297-
It("runs openllama", Label("llama"), func() {
297+
It("runs openllama(llama-ggml backend)", Label("llama"), func() {
298298
if runtime.GOOS != "linux" {
299299
Skip("test supported only on linux")
300300
}
@@ -362,9 +362,10 @@ var _ = Describe("API test", func() {
362362
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
363363
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
364364
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
365+
365366
})
366367

367-
It("runs openllama gguf", Label("llama-gguf"), func() {
368+
It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() {
368369
if runtime.GOOS != "linux" {
369370
Skip("test supported only on linux")
370371
}

api/config/config.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"strings"
99
"sync"
1010

11+
"github.com/go-skynet/LocalAI/pkg/utils"
12+
"github.com/rs/zerolog/log"
1113
"gopkg.in/yaml.v3"
1214
)
1315

@@ -264,6 +266,36 @@ func (cm *ConfigLoader) ListConfigs() []string {
264266
return res
265267
}
266268

269+
func (cm *ConfigLoader) Preload(modelPath string) error {
270+
cm.Lock()
271+
defer cm.Unlock()
272+
273+
for i, config := range cm.configs {
274+
modelURL := config.PredictionOptions.Model
275+
modelURL = utils.ConvertURL(modelURL)
276+
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
277+
// md5 of model name
278+
md5Name := utils.MD5(modelURL)
279+
280+
// check if file exists
281+
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist {
282+
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) {
283+
log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent)
284+
})
285+
if err != nil {
286+
return err
287+
}
288+
}
289+
290+
cc := cm.configs[i]
291+
c := &cc
292+
c.PredictionOptions.Model = md5Name
293+
cm.configs[i] = *c
294+
}
295+
}
296+
return nil
297+
}
298+
267299
func (cm *ConfigLoader) LoadConfigs(path string) error {
268300
cm.Lock()
269301
defer cm.Unlock()

api/openai/chat.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
219219
c.Set("Transfer-Encoding", "chunked")
220220
}
221221

222-
templateFile := config.Model
222+
templateFile := ""
223+
224+
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
225+
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
226+
templateFile = config.Model
227+
}
223228

224229
if config.TemplateConfig.Chat != "" && !processFunctions {
225230
templateFile = config.TemplateConfig.Chat
@@ -229,18 +234,19 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
229234
templateFile = config.TemplateConfig.Functions
230235
}
231236

232-
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
233-
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
234-
SystemPrompt: config.SystemPrompt,
235-
SuppressSystemPrompt: suppressConfigSystemPrompt,
236-
Input: predInput,
237-
Functions: funcs,
238-
})
239-
if err == nil {
240-
predInput = templatedInput
241-
log.Debug().Msgf("Template found, input modified to: %s", predInput)
242-
} else {
243-
log.Debug().Msgf("Template failed loading: %s", err.Error())
237+
if templateFile != "" {
238+
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
239+
SystemPrompt: config.SystemPrompt,
240+
SuppressSystemPrompt: suppressConfigSystemPrompt,
241+
Input: predInput,
242+
Functions: funcs,
243+
})
244+
if err == nil {
245+
predInput = templatedInput
246+
log.Debug().Msgf("Template found, input modified to: %s", predInput)
247+
} else {
248+
log.Debug().Msgf("Template failed loading: %s", err.Error())
249+
}
244250
}
245251

246252
log.Debug().Msgf("Prompt (after templating): %s", predInput)

api/openai/completion.go

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
8181
c.Set("Transfer-Encoding", "chunked")
8282
}
8383

84-
templateFile := config.Model
84+
templateFile := ""
85+
86+
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
87+
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
88+
templateFile = config.Model
89+
}
8590

8691
if config.TemplateConfig.Completion != "" {
8792
templateFile = config.TemplateConfig.Completion
@@ -94,13 +99,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
9499

95100
predInput := config.PromptStrings[0]
96101

97-
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
98-
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
99-
Input: predInput,
100-
})
101-
if err == nil {
102-
predInput = templatedInput
103-
log.Debug().Msgf("Template found, input modified to: %s", predInput)
102+
if templateFile != "" {
103+
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
104+
Input: predInput,
105+
})
106+
if err == nil {
107+
predInput = templatedInput
108+
log.Debug().Msgf("Template found, input modified to: %s", predInput)
109+
}
104110
}
105111

106112
responses := make(chan schema.OpenAIResponse)
@@ -145,14 +151,16 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
145151
totalTokenUsage := backend.TokenUsage{}
146152

147153
for k, i := range config.PromptStrings {
148-
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
149-
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
150-
SystemPrompt: config.SystemPrompt,
151-
Input: i,
152-
})
153-
if err == nil {
154-
i = templatedInput
155-
log.Debug().Msgf("Template found, input modified to: %s", i)
154+
if templateFile != "" {
155+
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
156+
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
157+
SystemPrompt: config.SystemPrompt,
158+
Input: i,
159+
})
160+
if err == nil {
161+
i = templatedInput
162+
log.Debug().Msgf("Template found, input modified to: %s", i)
163+
}
156164
}
157165

158166
r, tokenUsage, err := ComputeChoices(

api/openai/edit.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
3030

3131
log.Debug().Msgf("Parameter Config: %+v", config)
3232

33-
templateFile := config.Model
33+
templateFile := ""
34+
35+
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
36+
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
37+
templateFile = config.Model
38+
}
3439

3540
if config.TemplateConfig.Edit != "" {
3641
templateFile = config.TemplateConfig.Edit
@@ -40,15 +45,16 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
4045
totalTokenUsage := backend.TokenUsage{}
4146

4247
for _, i := range config.InputStrings {
43-
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
44-
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
45-
Input: i,
46-
Instruction: input.Instruction,
47-
SystemPrompt: config.SystemPrompt,
48-
})
49-
if err == nil {
50-
i = templatedInput
51-
log.Debug().Msgf("Template found, input modified to: %s", i)
48+
if templateFile != "" {
49+
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
50+
Input: i,
51+
Instruction: input.Instruction,
52+
SystemPrompt: config.SystemPrompt,
53+
})
54+
if err == nil {
55+
i = templatedInput
56+
log.Debug().Msgf("Template found, input modified to: %s", i)
57+
}
5258
}
5359

5460
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {

pkg/gallery/models.go

Lines changed: 2 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66
"hash"
77
"io"
8-
"net/http"
98
"os"
109
"path/filepath"
1110
"strconv"
@@ -115,89 +114,8 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
115114
// Create file path
116115
filePath := filepath.Join(basePath, file.Filename)
117116

118-
// Check if the file already exists
119-
_, err := os.Stat(filePath)
120-
if err == nil {
121-
// File exists, check SHA
122-
if file.SHA256 != "" {
123-
// Verify SHA
124-
calculatedSHA, err := calculateSHA(filePath)
125-
if err != nil {
126-
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
127-
}
128-
if calculatedSHA == file.SHA256 {
129-
// SHA matches, skip downloading
130-
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
131-
continue
132-
}
133-
// SHA doesn't match, delete the file and download again
134-
err = os.Remove(filePath)
135-
if err != nil {
136-
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
137-
}
138-
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)
139-
140-
} else {
141-
// SHA is missing, skip downloading
142-
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
143-
continue
144-
}
145-
} else if !os.IsNotExist(err) {
146-
// Error occurred while checking file existence
147-
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
148-
}
149-
150-
log.Debug().Msgf("Downloading %q", file.URI)
151-
152-
// Download file
153-
resp, err := http.Get(file.URI)
154-
if err != nil {
155-
return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
156-
}
157-
defer resp.Body.Close()
158-
159-
// Create parent directory
160-
err = os.MkdirAll(filepath.Dir(filePath), 0755)
161-
if err != nil {
162-
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
163-
}
164-
165-
// Create and write file content
166-
outFile, err := os.Create(filePath)
167-
if err != nil {
168-
return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
169-
}
170-
defer outFile.Close()
171-
172-
progress := &progressWriter{
173-
fileName: file.Filename,
174-
total: resp.ContentLength,
175-
hash: sha256.New(),
176-
downloadStatus: downloadStatus,
177-
}
178-
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
179-
if err != nil {
180-
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
181-
}
182-
183-
if file.SHA256 != "" {
184-
// Verify SHA
185-
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
186-
if calculatedSHA != file.SHA256 {
187-
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
188-
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
189-
}
190-
} else {
191-
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
192-
}
193-
194-
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
195-
if utils.IsArchive(filePath) {
196-
log.Debug().Msgf("File %q is an archive, uncompressing to %s", file.Filename, basePath)
197-
if err := utils.ExtractArchive(filePath, basePath); err != nil {
198-
log.Debug().Msgf("Failed decompressing %q: %s", file.Filename, err.Error())
199-
return err
200-
}
117+
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
118+
return err
201119
}
202120
}
203121

pkg/model/loader.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,17 +247,19 @@ func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateN
247247
// skip any error here - we run anyway if a template does not exist
248248
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
249249

250-
if !ml.ExistsInModelPath(modelTemplateFile) {
251-
return nil
252-
}
253-
254-
dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
255-
if err != nil {
256-
return err
250+
dat := ""
251+
if ml.ExistsInModelPath(modelTemplateFile) {
252+
d, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
253+
if err != nil {
254+
return err
255+
}
256+
dat = string(d)
257+
} else {
258+
dat = templateName
257259
}
258260

259261
// Parse the template
260-
tmpl, err := template.New("prompt").Parse(string(dat))
262+
tmpl, err := template.New("prompt").Parse(dat)
261263
if err != nil {
262264
return err
263265
}

0 commit comments

Comments
 (0)