Skip to content

Commit 405fb73

Browse files
committed
Try converting SD to purego
1 parent 09457b9 commit 405fb73

File tree

6 files changed

+84
-94
lines changed

6 files changed

+84
-94
lines changed

backend/go/stablediffusion-ggml/Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ else
128128
$(CXX) $(CXXFLAGS) gosd.cpp -o gosd.o -c
129129
endif
130130

131+
gosd.so: sources/stablediffusion-ggml.cpp build/libstable-diffusion.a
132+
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
133+
+bash -c "source $(ONEAPI_VARS); \
134+
$(CXX) $(CXXFLAGS) gosd.cpp -o gosd.so -shared -lggmlall -lstable-diffusion -L./ -L./build"
135+
else
136+
$(CXX) $(CXXFLAGS) gosd.cpp -o gosd.so -shared -lggmlall -lstable-diffusion -L./ -L./build
137+
endif
138+
139+
131140
## stablediffusion (ggml)
132141
sources/stablediffusion-ggml.cpp:
133142
git clone --recursive $(STABLEDIFFUSION_GGML_REPO) sources/stablediffusion-ggml.cpp && \
@@ -144,6 +153,9 @@ stablediffusion-ggml: libsd.a
144153
CC="$(CC)" CXX="$(CXX)" CGO_CXXFLAGS="$(CGO_CXXFLAGS)" \
145154
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o stablediffusion-ggml ./
146155

156+
stablediffusion-ggml-pure: main.go gosd.go gosd.so
157+
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o stablediffusion-ggml-pure ./
158+
147159
package:
148160
bash package.sh
149161

backend/go/stablediffusion-ggml/gosd.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <cstdint>
12
#define GGML_MAX_NAME 128
23

34
#include <stdio.h>
@@ -226,7 +227,7 @@ int load_model(char *model, char *model_path, char* options[], int threads, int
226227
return 0;
227228
}
228229

229-
int gen_image(char *text, char *negativeText, int width, int height, int steps, int seed , char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
230+
int gen_image(char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
230231

231232
sd_image_t* results;
232233

@@ -252,14 +253,14 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
252253
// Handle input image for img2img
253254
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
254255
bool has_mask_image = (mask_image != NULL && strlen(mask_image) > 0);
255-
256+
256257
uint8_t* input_image_buffer = NULL;
257258
uint8_t* mask_image_buffer = NULL;
258259
std::vector<uint8_t> default_mask_image_vec;
259-
260+
260261
if (has_input_image) {
261262
fprintf(stderr, "Loading input image: %s\n", src_image);
262-
263+
263264
int c = 0;
264265
int img_width = 0;
265266
int img_height = 0;
@@ -273,29 +274,29 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
273274
free(input_image_buffer);
274275
return 1;
275276
}
276-
277+
277278
// Resize input image if dimensions don't match
278279
if (img_width != width || img_height != height) {
279280
fprintf(stderr, "Resizing input image from %dx%d to %dx%d\n", img_width, img_height, width, height);
280-
281+
281282
uint8_t* resized_image_buffer = (uint8_t*)malloc(height * width * 3);
282283
if (resized_image_buffer == NULL) {
283284
fprintf(stderr, "Failed to allocate memory for resized image\n");
284285
free(input_image_buffer);
285286
return 1;
286287
}
287-
288+
288289
stbir_resize(input_image_buffer, img_width, img_height, 0,
289290
resized_image_buffer, width, height, 0, STBIR_TYPE_UINT8,
290291
3, STBIR_ALPHA_CHANNEL_NONE, 0,
291292
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
292293
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
293294
STBIR_COLORSPACE_SRGB, nullptr);
294-
295+
295296
free(input_image_buffer);
296297
input_image_buffer = resized_image_buffer;
297298
}
298-
299+
299300
p.init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
300301
p.strength = strength;
301302
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
@@ -304,11 +305,11 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
304305
p.init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
305306
p.strength = 0.0f;
306307
}
307-
308+
308309
// Handle mask image for inpainting
309310
if (has_mask_image) {
310311
fprintf(stderr, "Loading mask image: %s\n", mask_image);
311-
312+
312313
int c = 0;
313314
int mask_width = 0;
314315
int mask_height = 0;
@@ -318,30 +319,30 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
318319
if (input_image_buffer) free(input_image_buffer);
319320
return 1;
320321
}
321-
322+
322323
// Resize mask if dimensions don't match
323324
if (mask_width != width || mask_height != height) {
324325
fprintf(stderr, "Resizing mask image from %dx%d to %dx%d\n", mask_width, mask_height, width, height);
325-
326+
326327
uint8_t* resized_mask_buffer = (uint8_t*)malloc(height * width);
327328
if (resized_mask_buffer == NULL) {
328329
fprintf(stderr, "Failed to allocate memory for resized mask\n");
329330
free(mask_image_buffer);
330331
if (input_image_buffer) free(input_image_buffer);
331332
return 1;
332333
}
333-
334+
334335
stbir_resize(mask_image_buffer, mask_width, mask_height, 0,
335336
resized_mask_buffer, width, height, 0, STBIR_TYPE_UINT8,
336337
1, STBIR_ALPHA_CHANNEL_NONE, 0,
337338
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
338339
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
339340
STBIR_COLORSPACE_SRGB, nullptr);
340-
341+
341342
free(mask_image_buffer);
342343
mask_image_buffer = resized_mask_buffer;
343344
}
344-
345+
345346
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
346347
fprintf(stderr, "Using inpainting with mask\n");
347348
} else {
@@ -353,17 +354,17 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
353354
// Handle reference images
354355
std::vector<sd_image_t> ref_images_vec;
355356
std::vector<uint8_t*> ref_image_buffers;
356-
357+
357358
if (ref_images_count > 0 && ref_images != NULL) {
358359
fprintf(stderr, "Loading %d reference images\n", ref_images_count);
359-
360+
360361
for (int i = 0; i < ref_images_count; i++) {
361362
if (ref_images[i] == NULL || strlen(ref_images[i]) == 0) {
362363
continue;
363364
}
364-
365+
365366
fprintf(stderr, "Loading reference image %d: %s\n", i + 1, ref_images[i]);
366-
367+
367368
int c = 0;
368369
int ref_width = 0;
369370
int ref_height = 0;
@@ -377,33 +378,33 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
377378
free(ref_image_buffer);
378379
continue;
379380
}
380-
381+
381382
// Resize reference image if dimensions don't match
382383
if (ref_width != width || ref_height != height) {
383384
fprintf(stderr, "Resizing reference image from %dx%d to %dx%d\n", ref_width, ref_height, width, height);
384-
385+
385386
uint8_t* resized_ref_buffer = (uint8_t*)malloc(height * width * 3);
386387
if (resized_ref_buffer == NULL) {
387388
fprintf(stderr, "Failed to allocate memory for resized reference image\n");
388389
free(ref_image_buffer);
389390
continue;
390391
}
391-
392+
392393
stbir_resize(ref_image_buffer, ref_width, ref_height, 0,
393394
resized_ref_buffer, width, height, 0, STBIR_TYPE_UINT8,
394395
3, STBIR_ALPHA_CHANNEL_NONE, 0,
395396
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
396397
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
397398
STBIR_COLORSPACE_SRGB, nullptr);
398-
399+
399400
free(ref_image_buffer);
400401
ref_image_buffer = resized_ref_buffer;
401402
}
402-
403+
403404
ref_image_buffers.push_back(ref_image_buffer);
404405
ref_images_vec.push_back({(uint32_t)width, (uint32_t)height, 3, ref_image_buffer});
405406
}
406-
407+
407408
if (!ref_images_vec.empty()) {
408409
p.ref_images = ref_images_vec.data();
409410
p.ref_images_count = ref_images_vec.size();

backend/go/stablediffusion-ggml/gosd.go

Lines changed: 22 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
package main
22

3-
// #cgo CXXFLAGS: -I${SRCDIR}/sources/stablediffusion-ggml.cpp/thirdparty -I${SRCDIR}/sources/stablediffusion-ggml.cpp -I${SRCDIR}/sources/stablediffusion-ggml.cpp/ggml/include
4-
// #cgo LDFLAGS: -L${SRCDIR}/ -lsd -lstdc++ -lm -lggmlall -lgomp
5-
// #include <gosd.h>
6-
// #include <stdlib.h>
7-
import "C"
8-
93
import (
104
"fmt"
115
"os"
@@ -25,25 +19,19 @@ type SDGGML struct {
2519
cfgScale float32
2620
}
2721

22+
var (
23+
LoadModel func(model, model_apth string, options []string, threads int32, diff int) int
24+
GenImage func(text, negativeText string, width, height, steps int, seed int64, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int
25+
)
26+
2827
func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
2928

3029
sd.threads = int(opts.Threads)
3130

3231
modelPath := opts.ModelPath
3332

34-
modelFile := C.CString(opts.ModelFile)
35-
defer C.free(unsafe.Pointer(modelFile))
36-
37-
modelPathC := C.CString(modelPath)
38-
defer C.free(unsafe.Pointer(modelPathC))
39-
40-
var options **C.char
41-
// prepare the options array to pass to C
42-
43-
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
44-
length := C.size_t(len(opts.Options))
45-
options = (**C.char)(C.malloc((length + 1) * size))
46-
view := (*[1 << 30]*C.char)(unsafe.Pointer(options))[0 : len(opts.Options)+1 : len(opts.Options)+1]
33+
modelFile := opts.ModelFile
34+
modelPathC := modelPath
4735

4836
var diffusionModel int
4937

@@ -68,14 +56,12 @@ func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
6856

6957
fmt.Fprintf(os.Stderr, "Options: %+v\n", oo)
7058

71-
for i, x := range oo {
72-
view[i] = C.CString(x)
73-
}
74-
view[len(oo)] = nil
59+
options := make([]string, len(oo), len(oo) + 1)
60+
*(*uintptr)(unsafe.Add(unsafe.Pointer(&options), uintptr(len(oo)))) = 0
7561

7662
sd.cfgScale = opts.CFGScale
7763

78-
ret := C.load_model(modelFile, modelPathC, options, C.int(opts.Threads), C.int(diffusionModel))
64+
ret := LoadModel(modelFile, modelPathC, options, opts.Threads, diffusionModel)
7965
if ret != 0 {
8066
return fmt.Errorf("could not load model")
8167
}
@@ -84,65 +70,33 @@ func (sd *SDGGML) Load(opts *pb.ModelOptions) error {
8470
}
8571

8672
func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
87-
t := C.CString(opts.PositivePrompt)
88-
defer C.free(unsafe.Pointer(t))
73+
t := opts.PositivePrompt
74+
dst := opts.Dst
75+
negative := opts.NegativePrompt
76+
srcImage := opts.Src
8977

90-
dst := C.CString(opts.Dst)
91-
defer C.free(unsafe.Pointer(dst))
92-
93-
negative := C.CString(opts.NegativePrompt)
94-
defer C.free(unsafe.Pointer(negative))
95-
96-
// Handle source image path
97-
var srcImage *C.char
98-
if opts.Src != "" {
99-
srcImage = C.CString(opts.Src)
100-
defer C.free(unsafe.Pointer(srcImage))
101-
}
102-
103-
// Handle mask image path
104-
var maskImage *C.char
78+
var maskImage string
10579
if opts.EnableParameters != "" {
106-
// Parse EnableParameters for mask path if provided
107-
// This is a simple approach - in a real implementation you might want to parse JSON
10880
if strings.Contains(opts.EnableParameters, "mask:") {
10981
parts := strings.Split(opts.EnableParameters, "mask:")
11082
if len(parts) > 1 {
11183
maskPath := strings.TrimSpace(parts[1])
11284
if maskPath != "" {
113-
maskImage = C.CString(maskPath)
114-
defer C.free(unsafe.Pointer(maskImage))
85+
maskImage = maskPath
11586
}
11687
}
11788
}
11889
}
11990

120-
// Handle reference images
121-
var refImages **C.char
122-
var refImagesCount C.int
123-
if len(opts.RefImages) > 0 {
124-
refImagesCount = C.int(len(opts.RefImages))
125-
// Allocate array of C strings
126-
size := C.size_t(unsafe.Sizeof((*C.char)(nil)))
127-
refImages = (**C.char)(C.malloc((C.size_t(len(opts.RefImages)) + 1) * size))
128-
view := (*[1 << 30]*C.char)(unsafe.Pointer(refImages))[0 : len(opts.RefImages)+1 : len(opts.RefImages)+1]
129-
130-
for i, refImagePath := range opts.RefImages {
131-
view[i] = C.CString(refImagePath)
132-
defer C.free(unsafe.Pointer(view[i]))
133-
}
134-
view[len(opts.RefImages)] = nil
135-
}
91+
refImagesCount := len(opts.RefImages)
92+
refImages := make([]string, refImagesCount, refImagesCount + 1)
93+
copy(refImages, opts.RefImages)
94+
*(*uintptr)(unsafe.Add(unsafe.Pointer(&refImages), refImagesCount)) = 0
13695

13796
// Default strength for img2img (0.75 is a good default)
138-
strength := C.float(0.75)
139-
if opts.Src != "" {
140-
// If we have a source image, use img2img mode
141-
// You could also parse strength from EnableParameters if needed
142-
strength = C.float(0.75)
143-
}
97+
strength := float32(0.75)
14498

145-
ret := C.gen_image(t, negative, C.int(opts.Width), C.int(opts.Height), C.int(opts.Step), C.int(opts.Seed), dst, C.float(sd.cfgScale), srcImage, strength, maskImage, refImages, refImagesCount)
99+
ret := GenImage(t, negative, int(opts.Width), int(opts.Height), int(opts.Step), int64(opts.Seed), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
146100
if ret != 0 {
147101
return fmt.Errorf("inference failed")
148102
}

backend/go/stablediffusion-ggml/main.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,35 @@ package main
33
// Note: this is started internally by LocalAI and a server is allocated for each model
44
import (
55
"flag"
6+
"fmt"
7+
"runtime"
68

9+
"github.com/ebitengine/purego"
710
grpc "github.com/mudler/LocalAI/pkg/grpc"
811
)
912

1013
var (
1114
addr = flag.String("addr", "localhost:50051", "the address to connect to")
1215
)
1316

17+
func getLibrary() string {
18+
switch runtime.GOOS {
19+
case "linux":
20+
return "./gosd.so"
21+
default:
22+
panic(fmt.Errorf("GOOS=%s is not supported", runtime.GOOS))
23+
}
24+
}
25+
1426
func main() {
27+
gosd, err := purego.Dlopen(getLibrary(), purego.RTLD_NOW|purego.RTLD_GLOBAL)
28+
if err != nil {
29+
panic(err)
30+
}
31+
32+
purego.RegisterLibFunc(&LoadModel, gosd, "load_model")
33+
purego.RegisterLibFunc(&GenImage, gosd, "gen_image")
34+
1535
flag.Parse()
1636

1737
if err := grpc.StartServer(*addr, &SDGGML{}); err != nil {

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ require (
1212
github.com/chasefleming/elem-go v0.26.0
1313
github.com/containerd/containerd v1.7.19
1414
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2
15+
github.com/ebitengine/purego v0.8.4
1516
github.com/fsnotify/fsnotify v1.7.0
1617
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20240626202019-c118733a29ad
1718
github.com/go-audio/wav v1.1.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdf
134134
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
135135
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
136136
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
137+
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
138+
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
137139
github.com/elastic/gosigar v0.12.0/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=
138140
github.com/elastic/gosigar v0.14.3 h1:xwkKwPia+hSfg9GqrCUKYdId102m9qTJIIr7egmK/uo=
139141
github.com/elastic/gosigar v0.14.3/go.mod h1:iXRIGg2tLnu7LBdpqzyQfGDEidKCfWcCMS0WKyPWoMs=

0 commit comments

Comments
 (0)