Skip to content

Commit 17ec640

Browse files
committed
fix warmup
1 parent 43e4367 commit 17ec640

File tree

4 files changed

+64
-30
lines changed

4 files changed

+64
-30
lines changed

src/Chat.js

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@ export function Chat() {
2525
let isGenerating;
2626
let error;
2727
let generatedText;
28+
let warmup;
29+
2830
if (INTERFACE === 'IMAGE') {
29-
({ generateText, isGenerating, error, generatedText } =
31+
({ generateText, isGenerating, error, generatedText, warmup } =
3032
useLLMVisionGeneration(LLM_VISION_MODEL_CONFIG));
3133
} else {
32-
({ generateCode: generateText, isGenerating, error, generatedCode: generatedText } =
34+
({ generateCode: generateText, isGenerating, error, generatedCode: generatedText, warmup } =
3335
useLLMHtmlGeneration(LLM_HTML_MODEL_CONFIG));
3436
}
3537
const [currentMessageId, setCurrentMessageId] = useState(null);
@@ -84,6 +86,16 @@ export function Chat() {
8486
}
8587
}, [generatedText, currentMessageId]);
8688

89+
const onChangeInput = (e) => {
90+
setInput(e.target.value);
91+
if (showWarning) {
92+
const proceed = window.confirm("Warning: Using this chat will download AI models larger than 1GB in size. Do you want to continue?");
93+
if (!proceed) return;
94+
warmup();
95+
setShowWarning(false);
96+
}
97+
}
98+
8799
const handleImageUpload = (event) => {
88100
const file = event.target.files[0];
89101
if (file) {
@@ -100,12 +112,6 @@ export function Chat() {
100112
e.preventDefault();
101113
if ((!input.trim() && !selectedImage) || isGenerating) return;
102114

103-
if (showWarning) {
104-
const proceed = window.confirm("Warning: Using this chat will download AI models larger than 1GB in size. Do you want to continue?");
105-
if (!proceed) return;
106-
setShowWarning(false);
107-
}
108-
109115
const userMessage = {
110116
role: "user",
111117
content: input,
@@ -128,17 +134,6 @@ export function Chat() {
128134
h(
129135
"div",
130136
{ className: "fixed top-4 right-4 z-20 flex items-center gap-2" },
131-
hasCache && h(
132-
"button",
133-
{
134-
onClick: clearModelCache,
135-
136-
className: "bg-gray-900/80 backdrop-blur-sm rounded-full p-2 text-gray-300 hover:text-white hover:bg-gray-800 transition-colors flex items-center gap-2",
137-
title: `Clear cached models (${cacheSize}MB)`
138-
},
139-
h(XCircle, { className: "w-5 h-5" }),
140-
`Delete Cache ${cacheSize}MB`
141-
),
142137
h(
143138
"a",
144139
{
@@ -149,7 +144,17 @@ export function Chat() {
149144
},
150145
h(Github, { className: "w-5 h-5" }),
151146
"GitHub"
152-
)
147+
), hasCache && h(
148+
"button",
149+
{
150+
onClick: clearModelCache,
151+
152+
className: "bg-gray-900/80 backdrop-blur-sm rounded-full p-2 text-gray-300 hover:text-white hover:bg-gray-800 transition-colors flex items-center gap-2",
153+
title: `Clear cached models (${cacheSize}MB)`
154+
},
155+
h(XCircle, { className: "w-5 h-5" }),
156+
`Delete Cache ${cacheSize}MB`
157+
),
153158
),
154159
!isGenerating && !messages.length
155160
? h(
@@ -298,7 +303,7 @@ export function Chat() {
298303
h("input", {
299304
type: "text",
300305
value: input,
301-
onChange: (e) => setInput(e.target.value),
306+
onChange: (e) => onChangeInput(e),
302307
placeholder: "Describe what you want to create...",
303308
className:
304309
"flex-1 bg-transparent px-4 py-3 focus:outline-none placeholder-gray-500",

src/constants/chat.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ export const LLM_HTML_MODEL_CONFIG = {
2222
temperature: 0.3,
2323
top_p: 0.9,
2424
},
25+
warmup: {
26+
max_tokens: 1,
27+
}
2528
},
2629
huggingface: {
2730
modelId: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
@@ -35,6 +38,9 @@ export const LLM_HTML_MODEL_CONFIG = {
3538
top_p: 0.9,
3639
do_sample: true,
3740
},
41+
warmup: {
42+
max_new_tokens: 1,
43+
}
3844
},
3945
},
4046
backend: "webllm",

src/hooks/useLLMGeneration.js

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,19 +374,28 @@ class Qwen2VLBackend {
374374

375375
class WebLLMBackend {
376376
constructor(modelId, config) {
377-
this.modelId = modelId
378-
this.config = config
377+
this.modelId = modelId;
378+
this.config = config;
379379
}
380380

381-
async generate(prompt, systemPrompt, callbacks) {
382-
const engine = await CreateMLCEngine(this.modelId)
381+
async warmup() {
382+
if (!this.warmupPromise) {
383+
this.warmupPromise = (async () => {
384+
console.log('Creating new engine');
385+
this.engine = await CreateMLCEngine(this.modelId);
386+
return this.engine;
387+
})();
388+
}
389+
return this.warmupPromise;
390+
}
383391

392+
async generate(prompt, systemPrompt, callbacks) {
384393
const messages = [
385394
{ role: "system", content: systemPrompt },
386395
{ role: "user", content: prompt }
387396
]
388397

389-
const asyncChunkGenerator = await engine.chat.completions.create({
398+
const asyncChunkGenerator = await this.engine.chat.completions.create({
390399
messages,
391400
stream: true,
392401
...this.config
@@ -491,8 +500,17 @@ export function useLLMGeneration(
491500
}
492501
}, [backend, modelConfig])
493502

503+
const warmup = React.useCallback(async () => {
504+
const callbacks = {
505+
onToken: () => { },
506+
onComplete: () => { },
507+
onError: () => { }
508+
};
509+
await backendRef.current.warmup();
510+
});
511+
494512
const generate = React.useCallback(
495-
async (prompt, extras) => {
513+
async (prompt, extras = {}, config = {}) => {
496514
if (!backendRef.current) {
497515
throw new Error(`No backend configured for ${backend}`)
498516
}
@@ -516,7 +534,10 @@ export function useLLMGeneration(
516534
}
517535
}
518536

537+
backendRef.config = { ...backendRef.config, ...config };
538+
519539
try {
540+
await backendRef.current.warmup();
520541
await backendRef.current.generate(prompt, systemPrompt, callbacks, extras)
521542
} catch (err) {
522543
callbacks.onError(err)
@@ -528,6 +549,7 @@ export function useLLMGeneration(
528549

529550
return {
530551
generate,
552+
warmup,
531553
isGenerating,
532554
error,
533555
partialText,

src/hooks/useLLMHtmlGeneration.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ export function useLLMHtmlGeneration({
8484
const [generatedCode, setGeneratedCode] = React.useState("");
8585
const lastGeneratedCode = React.useRef("");
8686

87-
const { generate, isGenerating, error, partialText } = useLLMGeneration(
87+
const { generate, warmup, isGenerating, error, partialText } = useLLMGeneration(
8888
modelConfig,
8989
systemPrompt,
9090
backend
@@ -100,12 +100,12 @@ export function useLLMHtmlGeneration({
100100
}, [partialText]);
101101

102102
const generateCode = React.useCallback(
103-
async (prompt) => {
103+
async (prompt, extras, config) => {
104104
const fullPrompt = lastGeneratedCode.current
105105
? `Current HTML: \n${lastGeneratedCode.current}\n\nRequest: ${prompt}`
106106
: `Generate the HTML for: ${prompt}`;
107107

108-
await generate(fullPrompt);
108+
await generate(fullPrompt, extras, config);
109109
},
110110
[generate]
111111
);
@@ -115,5 +115,6 @@ export function useLLMHtmlGeneration({
115115
isGenerating,
116116
error,
117117
generatedCode,
118+
warmup,
118119
};
119120
}

0 commit comments

Comments
 (0)