diff --git a/src/api/providers/fetchers/__tests__/lmstudio.test.ts b/src/api/providers/fetchers/__tests__/lmstudio.test.ts index 59b4388785..98fe5db32e 100644 --- a/src/api/providers/fetchers/__tests__/lmstudio.test.ts +++ b/src/api/providers/fetchers/__tests__/lmstudio.test.ts @@ -1,6 +1,6 @@ import axios from "axios" import { vi, describe, it, expect, beforeEach } from "vitest" -import { LMStudioClient, LLM, LLMInstanceInfo } from "@lmstudio/sdk" // LLMInfo is a type +import { LMStudioClient, LLM, LLMInstanceInfo, LLMInfo } from "@lmstudio/sdk" import { getLMStudioModels, parseLMStudioModel } from "../lmstudio" import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types" // ModelInfo is a type @@ -11,12 +11,16 @@ const mockedAxios = axios as any // Mock @lmstudio/sdk const mockGetModelInfo = vi.fn() const mockListLoaded = vi.fn() +const mockListDownloadedModels = vi.fn() vi.mock("@lmstudio/sdk", () => { return { LMStudioClient: vi.fn().mockImplementation(() => ({ llm: { listLoaded: mockListLoaded, }, + system: { + listDownloadedModels: mockListDownloadedModels, + }, })), } }) @@ -28,6 +32,7 @@ describe("LMStudio Fetcher", () => { MockedLMStudioClientConstructor.mockClear() mockListLoaded.mockClear() mockGetModelInfo.mockClear() + mockListDownloadedModels.mockClear() }) describe("parseLMStudioModel", () => { @@ -88,8 +93,40 @@ describe("LMStudio Fetcher", () => { trainedForToolUse: false, // Added } - it("should fetch and parse models successfully", async () => { + it("should fetch downloaded models using system.listDownloadedModels", async () => { + const mockLLMInfo: LLMInfo = { + type: "llm" as const, + modelKey: "mistralai/devstral-small-2505", + format: "safetensors", + displayName: "Devstral Small 2505", + path: "mistralai/devstral-small-2505", + sizeBytes: 13277565112, + architecture: "mistral", + vision: false, + trainedForToolUse: false, + maxContextLength: 131072, + } + + mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } }) + mockListDownloadedModels.mockResolvedValueOnce([mockLLMInfo]) + + const result = await getLMStudioModels(baseUrl) + + expect(mockedAxios.get).toHaveBeenCalledTimes(1) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`) + expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1) + expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl }) + expect(mockListDownloadedModels).toHaveBeenCalledTimes(1) + expect(mockListDownloadedModels).toHaveBeenCalledWith("llm") + expect(mockListLoaded).not.toHaveBeenCalled() + + const expectedParsedModel = parseLMStudioModel(mockLLMInfo) + expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel }) + }) + + it("should fall back to listLoaded when listDownloadedModels fails", async () => { mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } }) + mockListDownloadedModels.mockRejectedValueOnce(new Error("Method not available")) mockListLoaded.mockResolvedValueOnce([{ getModelInfo: mockGetModelInfo }]) mockGetModelInfo.mockResolvedValueOnce(mockRawModel) @@ -99,6 +136,7 @@ describe("LMStudio Fetcher", () => { expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`) expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1) expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl }) + expect(mockListDownloadedModels).toHaveBeenCalledTimes(1) expect(mockListLoaded).toHaveBeenCalledTimes(1) const expectedParsedModel = parseLMStudioModel(mockRawModel) diff --git a/src/api/providers/fetchers/lmstudio.ts b/src/api/providers/fetchers/lmstudio.ts index ea1a590f1e..4b7ece71ea 100644 --- a/src/api/providers/fetchers/lmstudio.ts +++ b/src/api/providers/fetchers/lmstudio.ts @@ -2,14 +2,17 @@ import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types" import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk" import axios from "axios" -export const parseLMStudioModel = (rawModel: LLMInstanceInfo): ModelInfo => { +export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => { + // Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models) + const contextLength = "contextLength" in rawModel ? rawModel.contextLength : rawModel.maxContextLength + const modelInfo: ModelInfo = Object.assign({}, lMStudioDefaultModelInfo, { description: `${rawModel.displayName} - ${rawModel.path}`, - contextWindow: rawModel.contextLength, + contextWindow: contextLength, supportsPromptCache: true, supportsImages: rawModel.vision, supportsComputerUse: false, - maxTokens: rawModel.contextLength, + maxTokens: contextLength, }) return modelInfo @@ -33,12 +36,25 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom await axios.get(`${baseUrl}/v1/models`) const client = new LMStudioClient({ baseUrl: lmsUrl }) - const response = (await client.llm.listLoaded().then((models: LLM[]) => { - return Promise.all(models.map((m) => m.getModelInfo())) - })) as Array - for (const lmstudioModel of response) { - models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel) + // First, try to get all downloaded models + try { + const downloadedModels = await client.system.listDownloadedModels("llm") + for (const model of downloadedModels) { + // Use the model path as the key since that's what users select + models[model.path] = parseLMStudioModel(model) + } + } catch (error) { + console.warn("Failed to list downloaded models, falling back to loaded models only") + + // Fall back to listing only loaded models + const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => { + return Promise.all(models.map((m) => m.getModelInfo())) + })) as Array + + for (const lmstudioModel of loadedModels) { + models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel) + } } } catch (error) { if (error.code === "ECONNREFUSED") { diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index f689196d79..8bf2f6b95a 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -448,6 +448,9 @@ export const webviewMessageHandler = async ( // Specific handler for Ollama models only const { apiConfiguration: ollamaApiConfig } = await provider.getState() try { + // Flush cache first to ensure fresh models + await flushModels("ollama") + const ollamaModels = await getModels({ provider: "ollama", baseUrl: ollamaApiConfig.ollamaBaseUrl, @@ -469,6 +472,9 @@ export const webviewMessageHandler = async ( // Specific handler for LM Studio models only const { apiConfiguration: lmStudioApiConfig } = await provider.getState() try { + // Flush cache first to ensure fresh models + await flushModels("lmstudio") + const lmStudioModels = await getModels({ provider: "lmstudio", baseUrl: lmStudioApiConfig.lmStudioBaseUrl, diff --git a/webview-ui/src/components/settings/providers/LMStudio.tsx b/webview-ui/src/components/settings/providers/LMStudio.tsx index 17af44871b..a907e43e1b 100644 --- a/webview-ui/src/components/settings/providers/LMStudio.tsx +++ b/webview-ui/src/components/settings/providers/LMStudio.tsx @@ -1,4 +1,4 @@ -import { useCallback, useState, useMemo } from "react" +import { useCallback, useState, useMemo, useEffect } from "react" import { useEvent } from "react-use" import { Trans } from "react-i18next" import { Checkbox } from "vscrui" @@ -9,6 +9,7 @@ import type { ProviderSettings } from "@roo-code/types" import { useAppTranslation } from "@src/i18n/TranslationContext" import { ExtensionMessage } from "@roo/ExtensionMessage" import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" +import { vscode } from "@src/utils/vscode" import { inputEventTransform } from "../transforms" @@ -49,6 +50,12 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi useEvent("message", onMessage) + // Refresh models on mount + useEffect(() => { + // Request fresh models - the handler now flushes cache automatically + vscode.postMessage({ type: "requestLmStudioModels" }) + }, []) + // Check if the selected model exists in the fetched models const modelNotAvailable = useMemo(() => { const selectedModel = apiConfiguration?.lmStudioModelId diff --git a/webview-ui/src/components/settings/providers/Ollama.tsx b/webview-ui/src/components/settings/providers/Ollama.tsx index e118f68b46..263c3892f2 100644 --- a/webview-ui/src/components/settings/providers/Ollama.tsx +++ b/webview-ui/src/components/settings/providers/Ollama.tsx @@ -1,4 +1,4 @@ -import { useState, useCallback, useMemo } from "react" +import { useState, useCallback, useMemo, useEffect } from "react" import { useEvent } from "react-use" import { VSCodeTextField, VSCodeRadioGroup, VSCodeRadio } from "@vscode/webview-ui-toolkit/react" @@ -8,6 +8,7 @@ import { ExtensionMessage } from "@roo/ExtensionMessage" import { useAppTranslation } from "@src/i18n/TranslationContext" import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" +import { vscode } from "@src/utils/vscode" import { inputEventTransform } from "../transforms" @@ -48,6 +49,12 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro useEvent("message", onMessage) + // Refresh models on mount + useEffect(() => { + // Request fresh models - the handler now flushes cache automatically + vscode.postMessage({ type: "requestOllamaModels" }) + }, []) + // Check if the selected model exists in the fetched models const modelNotAvailable = useMemo(() => { const selectedModel = apiConfiguration?.ollamaModelId