Skip to content

Commit 0354924

Browse files
committed
feat: Add provider filtering support to router models backend
Allows frontend to request specific subset of router models instead of fetching all providers. This significantly reduces payload sizes and memory usage when only specific providers are needed. - Honor message.values.providers filter in requestRouterModels handler - Fetch only requested providers when filter is present - Maintain backward compatibility with existing aggregate behavior - Add comprehensive test coverage for filtering logic
1 parent ff0c65a commit 0354924

File tree

2 files changed

+211
-37
lines changed

2 files changed

+211
-37
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import { webviewMessageHandler } from "../webviewMessageHandler"
3+
import type { ClineProvider } from "../ClineProvider"
4+
5+
// Mock vscode (minimal)
6+
vi.mock("vscode", () => ({
7+
window: {
8+
showErrorMessage: vi.fn(),
9+
showWarningMessage: vi.fn(),
10+
showInformationMessage: vi.fn(),
11+
},
12+
workspace: {
13+
workspaceFolders: undefined,
14+
getConfiguration: vi.fn(() => ({
15+
get: vi.fn(),
16+
update: vi.fn(),
17+
})),
18+
},
19+
env: {
20+
clipboard: { writeText: vi.fn() },
21+
openExternal: vi.fn(),
22+
},
23+
commands: {
24+
executeCommand: vi.fn(),
25+
},
26+
Uri: {
27+
parse: vi.fn((s: string) => ({ toString: () => s })),
28+
file: vi.fn((p: string) => ({ fsPath: p })),
29+
},
30+
ConfigurationTarget: {
31+
Global: 1,
32+
Workspace: 2,
33+
WorkspaceFolder: 3,
34+
},
35+
}))
36+
37+
// Mock modelCache getModels/flushModels used by the handler
38+
const getModelsMock = vi.fn()
39+
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
40+
getModels: (...args: any[]) => getModelsMock(...args),
41+
flushModels: vi.fn(),
42+
}))
43+
44+
describe("webviewMessageHandler - requestRouterModels providers filter", () => {
45+
let mockProvider: ClineProvider & {
46+
postMessageToWebview: ReturnType<typeof vi.fn>
47+
getState: ReturnType<typeof vi.fn>
48+
contextProxy: any
49+
log: ReturnType<typeof vi.fn>
50+
}
51+
52+
beforeEach(() => {
53+
vi.clearAllMocks()
54+
55+
mockProvider = {
56+
// Only methods used by this code path
57+
postMessageToWebview: vi.fn(),
58+
getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }),
59+
contextProxy: {
60+
getValue: vi.fn(),
61+
setValue: vi.fn(),
62+
globalStorageUri: { fsPath: "/mock/storage" },
63+
},
64+
log: vi.fn(),
65+
} as any
66+
67+
// Default mock: return distinct model maps per provider so we can verify keys
68+
getModelsMock.mockImplementation(async (options: any) => {
69+
switch (options?.provider) {
70+
case "roo":
71+
return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } }
72+
case "openrouter":
73+
return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } }
74+
case "requesty":
75+
return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } }
76+
case "deepinfra":
77+
return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } }
78+
case "glama":
79+
return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } }
80+
case "unbound":
81+
return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } }
82+
case "vercel-ai-gateway":
83+
return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } }
84+
case "io-intelligence":
85+
return { "io/model": { contextWindow: 8192, supportsPromptCache: false } }
86+
case "litellm":
87+
return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } }
88+
default:
89+
return {}
90+
}
91+
})
92+
})
93+
94+
it("fetches only requested provider when values.providers is present (['roo'])", async () => {
95+
await webviewMessageHandler(
96+
mockProvider as any,
97+
{
98+
type: "requestRouterModels",
99+
values: { providers: ["roo"] },
100+
} as any,
101+
)
102+
103+
// Should post a single routerModels message
104+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
105+
expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }),
106+
)
107+
108+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
109+
(c: any[]) => c[0]?.type === "routerModels",
110+
)
111+
expect(call).toBeTruthy()
112+
const payload = call[0]
113+
const routerModels = payload.routerModels as Record<string, Record<string, any>>
114+
115+
// Only "roo" key should be present
116+
const keys = Object.keys(routerModels)
117+
expect(keys).toEqual(["roo"])
118+
expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet")
119+
120+
// getModels should have been called exactly once for roo
121+
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
122+
expect(providersCalled).toEqual(["roo"])
123+
})
124+
125+
it("defaults to aggregate fetching when no providers filter is sent", async () => {
126+
await webviewMessageHandler(
127+
mockProvider as any,
128+
{
129+
type: "requestRouterModels",
130+
} as any,
131+
)
132+
133+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
134+
(c: any[]) => c[0]?.type === "routerModels",
135+
)
136+
expect(call).toBeTruthy()
137+
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
138+
139+
// Aggregate handler initializes many known routers - ensure a few expected keys exist
140+
expect(routerModels).toHaveProperty("openrouter")
141+
expect(routerModels).toHaveProperty("roo")
142+
expect(routerModels).toHaveProperty("requesty")
143+
})
144+
145+
it("supports filtering another single provider (['openrouter'])", async () => {
146+
await webviewMessageHandler(
147+
mockProvider as any,
148+
{
149+
type: "requestRouterModels",
150+
values: { providers: ["openrouter"] },
151+
} as any,
152+
)
153+
154+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
155+
(c: any[]) => c[0]?.type === "routerModels",
156+
)
157+
expect(call).toBeTruthy()
158+
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
159+
const keys = Object.keys(routerModels)
160+
161+
expect(keys).toEqual(["openrouter"])
162+
expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5")
163+
164+
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
165+
expect(providersCalled).toEqual(["openrouter"])
166+
})
167+
})

src/core/webview/webviewMessageHandler.ts

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -757,20 +757,38 @@ export const webviewMessageHandler = async (
757757
case "requestRouterModels":
758758
const { apiConfiguration } = await provider.getState()
759759

760-
const routerModels: Record<RouterName, ModelRecord> = {
761-
openrouter: {},
762-
"vercel-ai-gateway": {},
763-
huggingface: {},
764-
litellm: {},
765-
deepinfra: {},
766-
"io-intelligence": {},
767-
requesty: {},
768-
unbound: {},
769-
glama: {},
770-
ollama: {},
771-
lmstudio: {},
772-
roo: {},
773-
}
760+
// Optional providers filter coming from the webview
761+
const providersFilterRaw = Array.isArray(message?.values?.providers) ? message.values.providers : undefined
762+
const requestedProviders = providersFilterRaw
763+
?.filter((p: unknown) => typeof p === "string")
764+
.map((p: string) => {
765+
try {
766+
return toRouterName(p)
767+
} catch {
768+
return undefined
769+
}
770+
})
771+
.filter((p): p is RouterName => !!p)
772+
773+
const hasFilter = !!requestedProviders && requestedProviders.length > 0
774+
const requestedSet = new Set<RouterName>(requestedProviders || [])
775+
776+
const routerModels: Record<RouterName, ModelRecord> = hasFilter
777+
? ({} as Record<RouterName, ModelRecord>)
778+
: {
779+
openrouter: {},
780+
"vercel-ai-gateway": {},
781+
huggingface: {},
782+
litellm: {},
783+
deepinfra: {},
784+
"io-intelligence": {},
785+
requesty: {},
786+
unbound: {},
787+
glama: {},
788+
ollama: {},
789+
lmstudio: {},
790+
roo: {},
791+
}
774792

775793
const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
776794
try {
@@ -785,7 +803,8 @@ export const webviewMessageHandler = async (
785803
}
786804
}
787805

788-
const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [
806+
// Base candidates (only those handled by this aggregate fetcher)
807+
const candidates: { key: RouterName; options: GetModelsOptions }[] = [
789808
{ key: "openrouter", options: { provider: "openrouter" } },
790809
{
791810
key: "requesty",
@@ -818,29 +837,28 @@ export const webviewMessageHandler = async (
818837
},
819838
]
820839

821-
// Add IO Intelligence if API key is provided.
822-
const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey
823-
824-
if (ioIntelligenceApiKey) {
825-
modelFetchPromises.push({
840+
// IO Intelligence is conditional on api key
841+
if (apiConfiguration.ioIntelligenceApiKey) {
842+
candidates.push({
826843
key: "io-intelligence",
827-
options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey },
844+
options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey },
828845
})
829846
}
830847

831-
// Don't fetch Ollama and LM Studio models by default anymore.
832-
// They have their own specific handlers: requestOllamaModels and requestLmStudioModels.
833-
848+
// LiteLLM is conditional on baseUrl+apiKey
834849
const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey
835850
const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl
836851

837852
if (litellmApiKey && litellmBaseUrl) {
838-
modelFetchPromises.push({
853+
candidates.push({
839854
key: "litellm",
840855
options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },
841856
})
842857
}
843858

859+
// Apply providers filter (if any)
860+
const modelFetchPromises = candidates.filter(({ key }) => (!hasFilter ? true : requestedSet.has(key)))
861+
844862
const results = await Promise.allSettled(
845863
modelFetchPromises.map(async ({ key, options }) => {
846864
const models = await safeGetModels(options)
@@ -854,18 +872,7 @@ export const webviewMessageHandler = async (
854872
if (result.status === "fulfilled") {
855873
routerModels[routerName] = result.value.models
856874

857-
// Ollama and LM Studio settings pages still need these events.
858-
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
859-
provider.postMessageToWebview({
860-
type: "ollamaModels",
861-
ollamaModels: result.value.models,
862-
})
863-
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
864-
provider.postMessageToWebview({
865-
type: "lmStudioModels",
866-
lmStudioModels: result.value.models,
867-
})
868-
}
875+
// Ollama and LM Studio settings pages still need these events. They are not fetched here.
869876
} else {
870877
// Handle rejection: Post a specific error message for this provider.
871878
const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)

0 commit comments

Comments
 (0)