Skip to content

Commit 79209f2

Browse files
Add Groq provider (ragapp#77)
--------- Co-authored-by: leehuwuj <[email protected]>
1 parent 0342480 commit 79209f2

File tree

8 files changed

+845
-3
lines changed

8 files changed

+845
-3
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"ragbox": patch
3+
---
4+
5+
Add Groq provider

admin/client/providers/groq.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { z } from "zod";
2+
import { BaseConfigSchema } from "./base";
3+
4+
export const GroqConfigSchema = BaseConfigSchema.extend({
5+
model_provider: z.literal("groq"),
6+
groq_api_key: z
7+
.string()
8+
.nullable()
9+
.optional()
10+
.refine(
11+
(value) => value && value.trim() !== "",
12+
"Groq API Key is required",
13+
),
14+
});
15+
16+
export const DEFAULT_GROQ_CONFIG: z.input<typeof GroqConfigSchema> = {
17+
model_provider: "groq",
18+
model: "llama3-8b",
19+
embedding_model: "all-mpnet-base-v2",
20+
embedding_dim: 768,
21+
groq_api_key: "",
22+
};

admin/client/providers/index.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { z } from "zod";
22
import { getBaseURL } from "../utils";
33
import { AzureOpenAIConfigSchema, DEFAULT_AZURE_OPENAI_CONFIG } from "./azure";
44
import { DEFAULT_GEMINI_CONFIG, GeminiConfigSchema } from "./gemini";
5+
import { DEFAULT_GROQ_CONFIG, GroqConfigSchema } from "./groq";
56
import { DEFAULT_OLLAMA_CONFIG, OllamaConfigSchema } from "./ollama";
67
import { DEFAULT_OPENAI_CONFIG, OpenAIConfigSchema } from "./openai";
78

@@ -11,6 +12,7 @@ export const ModelConfigSchema = z
1112
GeminiConfigSchema,
1213
OllamaConfigSchema,
1314
AzureOpenAIConfigSchema,
15+
GroqConfigSchema,
1416
])
1517
.refine((data) => {
1618
switch (data.model_provider) {
@@ -22,6 +24,8 @@ export const ModelConfigSchema = z
2224
return OllamaConfigSchema.parse(data);
2325
case "azure-openai":
2426
return AzureOpenAIConfigSchema.parse(data);
27+
case "groq":
28+
return GroqConfigSchema.parse(data);
2529
default:
2630
return true;
2731
}
@@ -46,6 +50,10 @@ export const supportedProviders = [
4650
name: "Azure OpenAI",
4751
value: "azure-openai",
4852
},
53+
{
54+
name: "Groq",
55+
value: "groq",
56+
},
4957
];
5058

5159
export const getDefaultProviderConfig = (provider: string) => {
@@ -58,6 +66,8 @@ export const getDefaultProviderConfig = (provider: string) => {
5866
return DEFAULT_GEMINI_CONFIG;
5967
case "azure-openai":
6068
return DEFAULT_AZURE_OPENAI_CONFIG;
69+
case "groq":
70+
return DEFAULT_GROQ_CONFIG;
6171
default:
6272
throw new Error(`Provider ${provider} not supported`);
6373
}

admin/sections/config/model.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import { useForm } from "react-hook-form";
2828
import { useMutation, useQuery } from "react-query";
2929
import { AzureOpenAIForm } from "./providers/azureOpenai";
3030
import { GeminiForm } from "./providers/gemini";
31+
import { GroqForm } from "./providers/groq";
3132
import { OllamaForm } from "./providers/ollama";
3233
import { OpenAIForm } from "./providers/openai";
3334

@@ -97,6 +98,8 @@ export const ModelConfig = ({
9798
switch (defaultValues.model_provider ?? "") {
9899
case "openai":
99100
return <OpenAIForm form={form} defaultValues={defaultValues} />;
101+
case "groq":
102+
return <GroqForm form={form} defaultValues={defaultValues} />;
100103
case "ollama":
101104
return <OllamaForm form={form} defaultValues={defaultValues} />;
102105
case "gemini":
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import {
2+
FormControl,
3+
FormDescription,
4+
FormField,
5+
FormItem,
6+
FormLabel,
7+
FormMessage,
8+
} from "@/components/ui/form";
9+
import { PasswordInput } from "@/components/ui/password-input";
10+
import { UseFormReturn } from "react-hook-form";
11+
import { ModelForm } from "./shared";
12+
13+
export const GroqForm = ({
14+
form,
15+
defaultValues,
16+
}: {
17+
form: UseFormReturn;
18+
defaultValues: any;
19+
}) => {
20+
const supportingModels = ["llama3-8b", "llama3-70b", "mixtral-8x7b"];
21+
22+
return (
23+
<>
24+
<FormField
25+
control={form.control}
26+
name="groq_api_key"
27+
render={({ field }) => (
28+
<FormItem>
29+
<FormLabel>Groq API Key (*)</FormLabel>
30+
<FormControl>
31+
<PasswordInput
32+
placeholder={defaultValues.openai_api_key ?? "sk-xxx"}
33+
showCopy
34+
{...field}
35+
/>
36+
</FormControl>
37+
<FormDescription>
38+
Get your API key from{" "}
39+
<a href="https://console.groq.com/keys" target="_blank">
40+
https://console.groq.com/keys
41+
</a>
42+
</FormDescription>
43+
<FormMessage />
44+
</FormItem>
45+
)}
46+
/>
47+
<ModelForm
48+
form={form}
49+
defaultValues={defaultValues}
50+
supportedModels={supportingModels}
51+
/>
52+
</>
53+
);
54+
};

poetry.lock

Lines changed: 733 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ e2b-code-interpreter = "^0.0.10"
3434
llama-index-tools-openapi = "^0.1.3"
3535
llama-index-tools-requests = "^0.1.3"
3636
jsonschema = "^4.22.0"
37+
llama-index-embeddings-huggingface = "^0.2.2"
38+
llama-index-llms-groq = "^0.1.4"
3739

3840

3941
[tool.poetry.group.dev.dependencies]

src/models/model_config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ class GeminiConfig(BaseModel):
2020
)
2121

2222

23+
class GroqConfig(BaseModel):
24+
groq_api_key: str | None = Field(
25+
default=None,
26+
description="The Groq API key to use",
27+
env="GROQ_API_KEY",
28+
)
29+
30+
2331
class OllamaConfig(BaseModel):
2432
ollama_base_url: str | None = Field(
2533
default=None,
@@ -64,7 +72,12 @@ class AzureOpenAIConfig(BaseModel):
6472
# We're using inheritance to flatten all the fields into a single class
6573
# Todo: Refactor API to nested structure
6674
class ModelConfig(
67-
BaseEnvConfig, OpenAIConfig, GeminiConfig, OllamaConfig, AzureOpenAIConfig
75+
BaseEnvConfig,
76+
OpenAIConfig,
77+
GroqConfig,
78+
GeminiConfig,
79+
OllamaConfig,
80+
AzureOpenAIConfig,
6881
):
6982
model_provider: str | None = Field(
7083
default=None,
@@ -104,6 +117,8 @@ def configured(self) -> bool:
104117
return True
105118
elif self.model_provider == "azure-openai":
106119
return True
120+
elif self.model_provider == "groq":
121+
return self.groq_api_key is not None
107122
return False
108123

109124
@classmethod

0 commit comments

Comments
 (0)