Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/api/providers/fetchers/__tests__/lmstudio.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import axios from "axios"
import { vi, describe, it, expect, beforeEach } from "vitest"
import { LMStudioClient, LLM, LLMInstanceInfo } from "@lmstudio/sdk" // LLMInfo is a type
import { LMStudioClient, LLM, LLMInstanceInfo, LLMInfo } from "@lmstudio/sdk"
import { getLMStudioModels, parseLMStudioModel } from "../lmstudio"
import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types" // ModelInfo is a type

Expand All @@ -11,12 +11,16 @@ const mockedAxios = axios as any
// Mock @lmstudio/sdk
const mockGetModelInfo = vi.fn()
const mockListLoaded = vi.fn()
const mockListDownloadedModels = vi.fn()
vi.mock("@lmstudio/sdk", () => {
return {
LMStudioClient: vi.fn().mockImplementation(() => ({
llm: {
listLoaded: mockListLoaded,
},
system: {
listDownloadedModels: mockListDownloadedModels,
},
})),
}
})
Expand All @@ -28,6 +32,7 @@ describe("LMStudio Fetcher", () => {
MockedLMStudioClientConstructor.mockClear()
mockListLoaded.mockClear()
mockGetModelInfo.mockClear()
mockListDownloadedModels.mockClear()
})

describe("parseLMStudioModel", () => {
Expand Down Expand Up @@ -88,8 +93,40 @@ describe("LMStudio Fetcher", () => {
trainedForToolUse: false, // Added
}

it("should fetch and parse models successfully", async () => {
it("should fetch downloaded models using system.listDownloadedModels", async () => {
const mockLLMInfo: LLMInfo = {
type: "llm" as const,
modelKey: "mistralai/devstral-small-2505",
format: "safetensors",
displayName: "Devstral Small 2505",
path: "mistralai/devstral-small-2505",
sizeBytes: 13277565112,
architecture: "mistral",
vision: false,
trainedForToolUse: false,
maxContextLength: 131072,
}

mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
mockListDownloadedModels.mockResolvedValueOnce([mockLLMInfo])

const result = await getLMStudioModels(baseUrl)

expect(mockedAxios.get).toHaveBeenCalledTimes(1)
expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
expect(mockListDownloadedModels).toHaveBeenCalledWith("llm")
expect(mockListLoaded).not.toHaveBeenCalled()

const expectedParsedModel = parseLMStudioModel(mockLLMInfo)
expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel })
})

it("should fall back to listLoaded when listDownloadedModels fails", async () => {
mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
mockListDownloadedModels.mockRejectedValueOnce(new Error("Method not available"))
mockListLoaded.mockResolvedValueOnce([{ getModelInfo: mockGetModelInfo }])
mockGetModelInfo.mockResolvedValueOnce(mockRawModel)

Expand All @@ -99,6 +136,7 @@ describe("LMStudio Fetcher", () => {
expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
expect(mockListLoaded).toHaveBeenCalledTimes(1)

const expectedParsedModel = parseLMStudioModel(mockRawModel)
Expand Down
32 changes: 24 additions & 8 deletions src/api/providers/fetchers/lmstudio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types"
import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk"
import axios from "axios"

export const parseLMStudioModel = (rawModel: LLMInstanceInfo): ModelInfo => {
export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => {
// Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models)
const contextLength = "contextLength" in rawModel ? rawModel.contextLength : rawModel.maxContextLength

const modelInfo: ModelInfo = Object.assign({}, lMStudioDefaultModelInfo, {
description: `${rawModel.displayName} - ${rawModel.path}`,
contextWindow: rawModel.contextLength,
contextWindow: contextLength,
supportsPromptCache: true,
supportsImages: rawModel.vision,
supportsComputerUse: false,
maxTokens: rawModel.contextLength,
maxTokens: contextLength,
})

return modelInfo
Expand All @@ -33,12 +36,25 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom
await axios.get(`${baseUrl}/v1/models`)

const client = new LMStudioClient({ baseUrl: lmsUrl })
const response = (await client.llm.listLoaded().then((models: LLM[]) => {
return Promise.all(models.map((m) => m.getModelInfo()))
})) as Array<LLMInstanceInfo>

for (const lmstudioModel of response) {
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
// First, try to get all downloaded models
try {
const downloadedModels = await client.system.listDownloadedModels("llm")
for (const model of downloadedModels) {
// Use the model path as the key since that's what users select
models[model.path] = parseLMStudioModel(model)
}
} catch (error) {
console.warn("Failed to list downloaded models, falling back to loaded models only")

// Fall back to listing only loaded models
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
return Promise.all(models.map((m) => m.getModelInfo()))
})) as Array<LLMInstanceInfo>

for (const lmstudioModel of loadedModels) {
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
}
}
} catch (error) {
if (error.code === "ECONNREFUSED") {
Expand Down
6 changes: 6 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ export const webviewMessageHandler = async (
// Specific handler for Ollama models only
const { apiConfiguration: ollamaApiConfig } = await provider.getState()
try {
// Flush cache first to ensure fresh models
await flushModels("ollama")

const ollamaModels = await getModels({
provider: "ollama",
baseUrl: ollamaApiConfig.ollamaBaseUrl,
Expand All @@ -469,6 +472,9 @@ export const webviewMessageHandler = async (
// Specific handler for LM Studio models only
const { apiConfiguration: lmStudioApiConfig } = await provider.getState()
try {
// Flush cache first to ensure fresh models
await flushModels("lmstudio")

const lmStudioModels = await getModels({
provider: "lmstudio",
baseUrl: lmStudioApiConfig.lmStudioBaseUrl,
Expand Down
9 changes: 8 additions & 1 deletion webview-ui/src/components/settings/providers/LMStudio.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useCallback, useState, useMemo } from "react"
import { useCallback, useState, useMemo, useEffect } from "react"
import { useEvent } from "react-use"
import { Trans } from "react-i18next"
import { Checkbox } from "vscrui"
Expand All @@ -9,6 +9,7 @@ import type { ProviderSettings } from "@roo-code/types"
import { useAppTranslation } from "@src/i18n/TranslationContext"
import { ExtensionMessage } from "@roo/ExtensionMessage"
import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
import { vscode } from "@src/utils/vscode"

import { inputEventTransform } from "../transforms"

Expand Down Expand Up @@ -49,6 +50,12 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi

useEvent("message", onMessage)

// Refresh models on mount
useEffect(() => {
// Request fresh models - the handler now flushes cache automatically
vscode.postMessage({ type: "requestLmStudioModels" })
}, [])

// Check if the selected model exists in the fetched models
const modelNotAvailable = useMemo(() => {
const selectedModel = apiConfiguration?.lmStudioModelId
Expand Down
9 changes: 8 additions & 1 deletion webview-ui/src/components/settings/providers/Ollama.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useState, useCallback, useMemo } from "react"
import { useState, useCallback, useMemo, useEffect } from "react"
import { useEvent } from "react-use"
import { VSCodeTextField, VSCodeRadioGroup, VSCodeRadio } from "@vscode/webview-ui-toolkit/react"

Expand All @@ -8,6 +8,7 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"

import { useAppTranslation } from "@src/i18n/TranslationContext"
import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
import { vscode } from "@src/utils/vscode"

import { inputEventTransform } from "../transforms"

Expand Down Expand Up @@ -48,6 +49,12 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro

useEvent("message", onMessage)

// Refresh models on mount
useEffect(() => {
// Request fresh models - the handler now flushes cache automatically
vscode.postMessage({ type: "requestOllamaModels" })
}, [])

// Check if the selected model exists in the fetched models
const modelNotAvailable = useMemo(() => {
const selectedModel = apiConfiguration?.ollamaModelId
Expand Down
Loading