Skip to content

Commit 1ace0ae

Browse files
committed
Address GitHub review comments
1 parent f1d148d commit 1ace0ae

File tree

11 files changed

+341
-354
lines changed

11 files changed

+341
-354
lines changed

cmd/stackrox-mcp/main_test.go

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,43 @@ import (
1111
"github.com/rs/zerolog"
1212
"github.com/stackrox/stackrox-mcp/internal/config"
1313
"github.com/stackrox/stackrox-mcp/internal/server"
14+
"github.com/stackrox/stackrox-mcp/internal/testutil"
1415
"github.com/stackrox/stackrox-mcp/internal/toolsets"
1516
"github.com/stretchr/testify/assert"
1617
"github.com/stretchr/testify/require"
1718
)
1819

19-
// waitForServerReady polls the server until it's ready to accept connections
20-
func waitForServerReady(address string, timeout time.Duration) error {
21-
deadline := time.Now().Add(timeout)
22-
client := &http.Client{Timeout: 100 * time.Millisecond}
23-
24-
for time.Now().Before(deadline) {
25-
resp, err := client.Get(address)
26-
if err == nil {
27-
_ = resp.Body.Close()
28-
return nil
29-
}
30-
time.Sleep(100 * time.Millisecond)
31-
}
32-
33-
return fmt.Errorf("server did not become ready within %v", timeout)
34-
}
35-
36-
func TestSetupLogging(t *testing.T) {
37-
setupLogging()
38-
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
39-
}
40-
41-
func TestGetToolsets(t *testing.T) {
42-
cfg := &config.Config{
20+
func getDefaultConfig() *config.Config {
21+
return &config.Config{
22+
Global: config.GlobalConfig{
23+
ReadOnlyTools: false,
24+
},
4325
Central: config.CentralConfig{
4426
URL: "central.example.com:8443",
4527
},
28+
Server: config.ServerConfig{
29+
Address: "localhost",
30+
Port: 8080,
31+
},
4632
Tools: config.ToolsConfig{
4733
Vulnerability: config.VulnerabilityConfig{
4834
Enabled: true,
4935
},
5036
ConfigManager: config.ConfigManagerConfig{
51-
Enabled: true,
37+
Enabled: false,
5238
},
5339
},
5440
}
41+
}
42+
43+
func TestSetupLogging(t *testing.T) {
44+
setupLogging()
45+
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
46+
}
47+
48+
func TestGetToolsets(t *testing.T) {
49+
cfg := getDefaultConfig()
50+
cfg.Tools.ConfigManager.Enabled = true
5551

5652
allToolsets := getToolsets(cfg)
5753
require.NotNil(t, allToolsets)
@@ -68,9 +64,7 @@ func TestGracefulShutdown(t *testing.T) {
6864
cfg, err := config.LoadConfig("")
6965
require.NoError(t, err)
7066
require.NotNil(t, cfg)
71-
72-
// Use a different port to avoid conflicts
73-
cfg.Server.Port = 9999
67+
cfg.Server.Port = testutil.GetPortForTest(t)
7468

7569
registry := toolsets.NewRegistry(cfg, getToolsets(cfg))
7670
srv := server.NewServer(cfg, registry)
@@ -84,7 +78,7 @@ func TestGracefulShutdown(t *testing.T) {
8478

8579
// Wait for server to be ready by polling
8680
serverURL := fmt.Sprintf("http://%s:%d", cfg.Server.Address, cfg.Server.Port)
87-
err = waitForServerReady(serverURL, 3*time.Second)
81+
err = testutil.WaitForServerReady(serverURL, 3*time.Second)
8882
require.NoError(t, err, "Server should start within timeout")
8983

9084
// Establish actual HTTP connection to verify server is responding

internal/config/config_test.go

Lines changed: 42 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,32 @@ package config
22

33
import (
44
"os"
5-
"path/filepath"
65
"testing"
76

7+
"github.com/stackrox/stackrox-mcp/internal/testutil"
88
"github.com/stretchr/testify/assert"
99
"github.com/stretchr/testify/require"
1010
)
1111

12-
func TestLoadConfig_FromYAML(t *testing.T) {
13-
tmpDir := t.TempDir()
14-
configPath := filepath.Join(tmpDir, "config.yaml")
12+
// getDefaultConfig returns a default config for testing validation logic.
13+
func getDefaultConfig() *Config {
14+
return &Config{
15+
Central: CentralConfig{
16+
URL: "central.example.com:8443",
17+
},
18+
Server: ServerConfig{
19+
Address: "localhost",
20+
Port: 8080,
21+
},
22+
Tools: ToolsConfig{
23+
Vulnerability: VulnerabilityConfig{
24+
Enabled: true,
25+
},
26+
},
27+
}
28+
}
1529

30+
func TestLoadConfig_FromYAML(t *testing.T) {
1631
yamlContent := `
1732
central:
1833
url: central.example.com:8443
@@ -26,10 +41,7 @@ tools:
2641
config_manager:
2742
enabled: False
2843
`
29-
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
30-
require.NoError(t, err)
31-
defer func() { assert.NoError(t, os.Remove(configPath)) }()
32-
44+
configPath := testutil.WriteYAMLFile(t, yamlContent)
3345
cfg, err := LoadConfig(configPath)
3446
require.NoError(t, err)
3547
require.NotNil(t, cfg)
@@ -50,12 +62,7 @@ tools:
5062
vulnerability:
5163
enabled: false
5264
`
53-
54-
tmpDir := t.TempDir()
55-
configPath := filepath.Join(tmpDir, "config.yaml")
56-
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
57-
require.NoError(t, err)
58-
defer func() { assert.NoError(t, os.Remove(configPath)) }()
65+
configPath := testutil.WriteYAMLFile(t, yamlContent)
5966

6067
assert.NoError(t, os.Setenv("STACKROX_MCP__CENTRAL__URL", "override.example.com:443"))
6168
assert.NoError(t, os.Setenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED", "true"))
@@ -136,46 +143,30 @@ func TestLoadConfig_MissingFile(t *testing.T) {
136143
}
137144

138145
func TestLoadConfig_InvalidYAML(t *testing.T) {
139-
tmpDir := t.TempDir()
140-
configPath := filepath.Join(tmpDir, "config.yaml")
141-
142146
invalidYAML := `
143147
central:
144148
url: central.example.com:8443
145149
invalid yaml syntax here: [[[
146150
`
151+
configPath := testutil.WriteYAMLFile(t, invalidYAML)
147152

148-
err := os.WriteFile(configPath, []byte(invalidYAML), 0644)
149-
require.NoError(t, err)
150-
defer func() { assert.NoError(t, os.Remove(configPath)) }()
151-
152-
_, err = LoadConfig(configPath)
153+
_, err := LoadConfig(configPath)
153154
assert.Error(t, err)
154155
}
155156

156157
func TestLoadConfig_UnmarshalFailure(t *testing.T) {
157-
tmpDir := t.TempDir()
158-
configPath := filepath.Join(tmpDir, "config.yaml")
159-
160158
// YAML with type mismatch - port should be int
161159
invalidTypeYAML := `
162160
server:
163161
port: "not-a-number"
164162
`
165-
166-
err := os.WriteFile(configPath, []byte(invalidTypeYAML), 0644)
167-
require.NoError(t, err)
168-
defer func() { assert.NoError(t, os.Remove(configPath)) }()
169-
170-
_, err = LoadConfig(configPath)
163+
configPath := testutil.WriteYAMLFile(t, invalidTypeYAML)
164+
_, err := LoadConfig(configPath)
171165
require.Error(t, err)
172166
assert.Contains(t, err.Error(), "failed to unmarshal config")
173167
}
174168

175169
func TestLoadConfig_ValidationFailure(t *testing.T) {
176-
tmpDir := t.TempDir()
177-
configPath := filepath.Join(tmpDir, "config.yaml")
178-
179170
// Valid YAML but fails on central URL validation (no URL)
180171
validYAMLInvalidConfig := `
181172
central:
@@ -188,124 +179,60 @@ tools:
188179
enabled: true
189180
`
190181

191-
err := os.WriteFile(configPath, []byte(validYAMLInvalidConfig), 0644)
192-
require.NoError(t, err)
193-
defer func() { assert.NoError(t, os.Remove(configPath)) }()
194-
195-
_, err = LoadConfig(configPath)
182+
configPath := testutil.WriteYAMLFile(t, validYAMLInvalidConfig)
183+
_, err := LoadConfig(configPath)
196184
require.Error(t, err)
197185
assert.Contains(t, err.Error(), "invalid configuration")
198186
assert.Contains(t, err.Error(), "central.url is required")
199187
}
200188

201189
func TestValidate_MissingURL(t *testing.T) {
202-
cfg := &Config{
203-
Central: CentralConfig{
204-
URL: "",
205-
},
206-
Tools: ToolsConfig{
207-
Vulnerability: VulnerabilityConfig{
208-
Enabled: true,
209-
},
210-
},
211-
}
190+
cfg := getDefaultConfig()
191+
cfg.Central.URL = ""
212192

213193
err := cfg.Validate()
214194
require.Error(t, err)
215195
assert.Contains(t, err.Error(), "central.url is required")
216196
}
217197

218198
func TestValidate_AtLeastOneTool(t *testing.T) {
219-
cfg := &Config{
220-
Central: CentralConfig{
221-
URL: "central.example.com:8443",
222-
},
223-
Server: ServerConfig{
224-
Address: "localhost",
225-
Port: 8080,
226-
},
227-
}
199+
cfg := getDefaultConfig()
200+
cfg.Tools.Vulnerability.Enabled = false
228201

229202
err := cfg.Validate()
230203
require.Error(t, err)
231204
assert.Contains(t, err.Error(), "at least one tool has to be enabled")
232205
}
233206

234207
func TestValidate_ValidConfig(t *testing.T) {
235-
cfg := &Config{
236-
Central: CentralConfig{
237-
URL: "central.example.com:8443",
238-
Insecure: false,
239-
ForceHTTP1: false,
240-
},
241-
Global: GlobalConfig{
242-
ReadOnlyTools: true,
243-
},
244-
Server: ServerConfig{
245-
Address: "localhost",
246-
Port: 8080,
247-
},
248-
Tools: ToolsConfig{
249-
Vulnerability: VulnerabilityConfig{
250-
Enabled: true,
251-
},
252-
ConfigManager: ConfigManagerConfig{
253-
Enabled: false,
254-
},
255-
},
256-
}
208+
cfg := getDefaultConfig()
257209

258210
err := cfg.Validate()
259211
assert.NoError(t, err)
260212
}
261213

262214
func TestValidate_MissingServerAddress(t *testing.T) {
263-
cfg := &Config{
264-
Central: CentralConfig{
265-
URL: "central.example.com:8443",
266-
},
267-
Server: ServerConfig{
268-
Address: "",
269-
Port: 8080,
270-
},
271-
Tools: ToolsConfig{
272-
Vulnerability: VulnerabilityConfig{
273-
Enabled: true,
274-
},
275-
},
276-
}
215+
cfg := getDefaultConfig()
216+
cfg.Server.Address = ""
277217

278218
err := cfg.Validate()
279219
require.Error(t, err)
280220
assert.Contains(t, err.Error(), "server.address is required")
281221
}
282222

283223
func TestValidate_InvalidServerPort(t *testing.T) {
284-
tests := []struct {
285-
name string
224+
tests := map[string]struct {
286225
port int
287226
}{
288-
{"zero port", 0},
289-
{"negative port", -1},
290-
{"port too high", 65536},
227+
"zero port": {port: 0},
228+
"negative port": {port: -1},
229+
"port too high": {port: 65536},
291230
}
292231

293-
for _, tt := range tests {
294-
t.Run(tt.name, func(t *testing.T) {
295-
cfg := &Config{
296-
Central: CentralConfig{
297-
URL: "central.example.com:8443",
298-
},
299-
Server: ServerConfig{
300-
Address: "localhost",
301-
Port: tt.port,
302-
},
303-
Tools: ToolsConfig{
304-
Vulnerability: VulnerabilityConfig{
305-
Enabled: true,
306-
},
307-
},
308-
}
232+
for name, tt := range tests {
233+
t.Run(name, func(t *testing.T) {
234+
cfg := getDefaultConfig()
235+
cfg.Server.Port = tt.port
309236

310237
err := cfg.Validate()
311238
require.Error(t, err)

0 commit comments

Comments
 (0)