diff --git a/.gitignore b/.gitignore index 67e0a03..47bd8a3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ .idea/ .DS_Store +# Claude Code +.claude/ + # Test coverage output /*.out diff --git a/.golangci.yml b/.golangci.yml index e4d7468..92003a5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -22,6 +22,8 @@ linters: - ireturn # requires package-level variables created from errors.New() - err113 + # allow replacements in go.mod + - gomoddirectives issues: max-issues-per-linter: 0 max-same-issues: 0 diff --git a/README.md b/README.md index 331e507..2719f57 100644 --- a/README.md +++ b/README.md @@ -71,8 +71,16 @@ Configuration for connecting to StackRox Central. | Option | Environment Variable | Type | Required | Default | Description | |--------|---------------------|------|----------|---------|-------------| | `central.url` | `STACKROX_MCP__CENTRAL__URL` | string | Yes | central.stackrox:8443 | URL of StackRox Central instance | -| `central.insecure` | `STACKROX_MCP__CENTRAL__INSECURE` | bool | No | `false` | Skip TLS certificate verification | -| `central.force_http1` | `STACKROX_MCP__CENTRAL__FORCE_HTTP1` | bool | No | `false` | Force HTTP/1.1 instead of HTTP/2 | +| `central.auth_type` | `STACKROX_MCP__CENTRAL__AUTH_TYPE` | string | No | `passthrough` | Authentication type: `passthrough` (use token from MCP client headers) or `static` (use configured token) | +| `central.api_token` | `STACKROX_MCP__CENTRAL__API_TOKEN` | string | Conditional | - | API token for static authentication (required when `auth_type` is `static`, must not be set when `passthrough`) | +| `central.insecure_skip_tls_verify` | `STACKROX_MCP__CENTRAL__INSECURE_SKIP_TLS_VERIFY` | bool | No | `false` | Skip TLS certificate verification (use only for testing) | +| `central.force_http1` | `STACKROX_MCP__CENTRAL__FORCE_HTTP1` | bool | No | `false` | Route gRPC traffic through the HTTP/1 bridge (gRPC-Web/WebSockets) for environments that block HTTP/2 | +| `central.request_timeout` | `STACKROX_MCP__CENTRAL__REQUEST_TIMEOUT` | duration | No | `30s` | Maximum time to wait for a single request to complete (must be positive) | +| `central.max_retries` | `STACKROX_MCP__CENTRAL__MAX_RETRIES` | int | No | `3` | Maximum number of retry attempts (must be 0-10) | +| `central.initial_backoff` | `STACKROX_MCP__CENTRAL__INITIAL_BACKOFF` | duration | No | `1s` | Initial backoff duration for retries (must be positive) | +| `central.max_backoff` | `STACKROX_MCP__CENTRAL__MAX_BACKOFF` | duration | No | `10s` | Maximum backoff duration for retries (must be positive and >= initial_backoff) | + +When `central.force_http1` is enabled, the client uses the [StackRox gRPC-over-HTTP/1 bridge](https://github.com/stackrox/go-grpc-http1) to downgrade requests. This should only be turned on when Central is reached through an HTTP/1-only proxy or load balancer, as client-side streaming remains unsupported in downgrade mode. #### Global Configuration diff --git a/cmd/stackrox-mcp/main.go b/cmd/stackrox-mcp/main.go index 73f023e..4c67ce9 100644 --- a/cmd/stackrox-mcp/main.go +++ b/cmd/stackrox-mcp/main.go @@ -9,6 +9,7 @@ import ( "os/signal" "syscall" + "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stackrox/stackrox-mcp/internal/logging" "github.com/stackrox/stackrox-mcp/internal/server" @@ -18,9 +19,9 @@ import ( ) // getToolsets initializes and returns all available toolsets. -func getToolsets(cfg *config.Config) []toolsets.Toolset { +func getToolsets(cfg *config.Config, c *client.Client) []toolsets.Toolset { return []toolsets.Toolset{ - toolsetConfig.NewToolset(cfg), + toolsetConfig.NewToolset(cfg, c), toolsetVulnerability.NewToolset(cfg), } } @@ -37,15 +38,26 @@ func main() { logging.Fatal("Failed to load configuration", err) } - slog.Info("Configuration loaded successfully", "config", cfg) + // Log full configuration with sensitive data redacted. + slog.Info("Configuration loaded successfully", "config", cfg.Redacted()) - registry := toolsets.NewRegistry(cfg, getToolsets(cfg)) + stackroxClient, err := client.NewClient(&cfg.Central) + if err != nil { + logging.Fatal("Failed to create StackRox client", err) + } + + registry := toolsets.NewRegistry(cfg, getToolsets(cfg, stackroxClient)) srv := server.NewServer(cfg, registry) // Set up context with signal handling for graceful shutdown. ctx, cancel := context.WithCancel(context.Background()) defer cancel() + err = stackroxClient.Connect(ctx) + if err != nil { + logging.Fatal("Failed to connect to StackRox server", err) + } + sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) diff --git a/cmd/stackrox-mcp/main_test.go b/cmd/stackrox-mcp/main_test.go index 38ddda8..83b97a3 100644 --- a/cmd/stackrox-mcp/main_test.go +++ b/cmd/stackrox-mcp/main_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stackrox/stackrox-mcp/internal/server" "github.com/stackrox/stackrox-mcp/internal/testutil" @@ -17,43 +18,20 @@ import ( "github.com/stretchr/testify/require" ) -func getDefaultConfig() *config.Config { - return &config.Config{ - Global: config.GlobalConfig{ - ReadOnlyTools: false, - }, - Central: config.CentralConfig{ - URL: "central.example.com:8443", - }, - Server: config.ServerConfig{ - Address: "localhost", - Port: 8080, - }, - Tools: config.ToolsConfig{ - Vulnerability: config.ToolsetVulnerabilityConfig{ - Enabled: true, - }, - ConfigManager: config.ToolConfigManagerConfig{ - Enabled: false, - }, - }, - } -} - func TestGetToolsets(t *testing.T) { - cfg := getDefaultConfig() - cfg.Tools.ConfigManager.Enabled = true + allToolsets := getToolsets(&config.Config{}, &client.Client{}) - allToolsets := getToolsets(cfg) + toolsetNames := make(map[string]bool) + for _, toolset := range allToolsets { + toolsetNames[toolset.GetName()] = true + } - require.NotNil(t, allToolsets) - assert.Len(t, allToolsets, 2, "Should have 2 allToolsets") - assert.Equal(t, "config_manager", allToolsets[0].GetName()) - assert.Equal(t, "vulnerability", allToolsets[1].GetName()) + assert.Contains(t, toolsetNames, "config_manager") + assert.Contains(t, toolsetNames, "vulnerability") } func TestGracefulShutdown(t *testing.T) { - // Set up minimal valid config. + // Set up minimal valid config. config.LoadConfig() validates configuration. t.Setenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED", "true") cfg, err := config.LoadConfig("") @@ -61,7 +39,7 @@ func TestGracefulShutdown(t *testing.T) { require.NotNil(t, cfg) cfg.Server.Port = testutil.GetPortForTest(t) - registry := toolsets.NewRegistry(cfg, getToolsets(cfg)) + registry := toolsets.NewRegistry(cfg, getToolsets(cfg, &client.Client{})) srv := server.NewServer(cfg, registry) ctx, cancel := context.WithCancel(context.Background()) @@ -78,12 +56,10 @@ func TestGracefulShutdown(t *testing.T) { // Establish actual HTTP connection to verify server is responding. //nolint:gosec,noctx resp, err := http.Get(serverURL) - if err == nil { - _ = resp.Body.Close() - } - require.NoError(t, err, "Should be able to establish HTTP connection to server") + _ = resp.Body.Close() + // Simulate shutdown signal by canceling context. cancel() @@ -94,7 +70,7 @@ func TestGracefulShutdown(t *testing.T) { if err != nil && errors.Is(err, context.Canceled) { t.Errorf("Server returned unexpected error: %v", err) } - case <-time.After(5 * time.Second): + case <-time.After(server.ShutdownTimeout): t.Fatal("Server did not shut down within timeout period") } } diff --git a/examples/config-read-only.yaml b/examples/config-read-only.yaml index 24d45d9..dfe1f5b 100644 --- a/examples/config-read-only.yaml +++ b/examples/config-read-only.yaml @@ -21,14 +21,41 @@ central: # The URL of your StackRox Central instance url: central.stackrox:8443 - # Allow insecure TLS connection (optional, default: false) - # Set to true to skip TLS certificate verification - insecure: false + # Authentication type (optional, default: passthrough) + # Options: "passthrough" or "static" + # - passthrough: Use the API token from the MCP client request headers + # - static: Use a statically configured API token (specified in api_token field) + auth_type: passthrough - # Force HTTP1 (optional, default: false) - # Force HTTP/1.1 instead of HTTP/2 + # API token for static authentication (required only when auth_type is "static") + # Must not be set when auth_type is "passthrough" + # api_token: your-stackrox-api-token-here + + # Skip TLS certificate verification (optional, default: false) + # Set to true to disable TLS certificate validation + # Warning: Only use this for testing or in trusted environments + insecure_skip_tls_verify: false + + # Force HTTP1 bridge via gRPC-Web/WebSockets (optional, default: false) + # Enable only when Central is reachable through an HTTP/1-only proxy/load balancer force_http1: false + # Request timeout (optional, default: 30s) + # Maximum time to wait for a single request to complete + request_timeout: 30s + + # Maximum number of retry attempts (optional, default: 3) + # Must be between 0 and 10 + max_retries: 3 + + # Initial backoff duration for retries (optional, default: 1s) + # Must be positive + initial_backoff: 1s + + # Maximum backoff duration for retries (optional, default: 10s) + # Must be positive and >= initial_backoff + max_backoff: 10s + # Global MCP server configuration global: # Allow only read-only MCP tools (optional, default: true) diff --git a/go.mod b/go.mod index 86e7876..7e59d79 100644 --- a/go.mod +++ b/go.mod @@ -1,31 +1,54 @@ module github.com/stackrox/stackrox-mcp -go 1.24 +go 1.24.0 + +toolchain go1.24.7 require ( github.com/modelcontextprotocol/go-sdk v1.1.0 github.com/pkg/errors v0.9.1 github.com/spf13/viper v1.21.0 + github.com/stackrox/rox v0.0.0-20210914215712-9ac265932e28 github.com/stretchr/testify v1.11.1 + golang.stackrox.io/grpc-http1 v0.5.1 + google.golang.org/grpc v1.77.0 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/coder/websocket v1.8.14 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/golang/glog v1.2.5 // indirect github.com/google/jsonschema-go v0.3.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/stackrox/scanner v0.0.0-20240830165150-d133ba942d59 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/oauth2 v0.30.0 // indirect - golang.org/x/sys v0.29.0 // indirect - golang.org/x/text v0.28.0 // indirect + golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 // indirect + golang.org/x/oauth2 v0.33.0 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +// StackRox library - pinned to specific commit SHA. +// Additional two libraries have to be replaced, because go is not able to resolve version "v0.0.0" used for them. +replace ( + github.com/heroku/docker-registry-client => github.com/stackrox/docker-registry-client v0.2.1 + github.com/operator-framework/helm-operator-plugins => github.com/stackrox/helm-operator v0.8.1-0.20250929095149-d1ee3c386305 + + github.com/stackrox/rox => github.com/stackrox/stackrox v0.0.0-20251113103849-f9a0378795b1 +) diff --git a/go.sum b/go.sum index c2ac974..b7e7649 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,29 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/golang/glog v1.2.5 h1:DrW6hGnjIhtvhOIiAKT6Psh/Kd/ldepEa81DKeiRJ5I= +github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -20,10 +34,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca h1:ujRGEVWJEoaxQ+8+HMl8YEpGaDAgohgZxJ5S+d2TTFQ= +github.com/planetscale/vtprotobuf v0.6.1-0.20240409071808-615f978279ca/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= @@ -36,24 +52,54 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stackrox/scanner v0.0.0-20240830165150-d133ba942d59 h1:XrOPpgBpAnwTXGbyAYSOongfFeVJJBWPTdWEgYw+Uck= +github.com/stackrox/scanner v0.0.0-20240830165150-d133ba942d59/go.mod h1:xVs4A0Vur2djLSTZYtIj/5hgUORT1t405Fg0RX4/1kA= +github.com/stackrox/stackrox v0.0.0-20251113103849-f9a0378795b1 h1:GJ9vYov/zhKD5lYai8finf4QE56Yb1vqgouWwbrWf2w= +github.com/stackrox/stackrox v0.0.0-20251113103849-f9a0378795b1/go.mod h1:P+FAmKKLctUshb3eh1BNRuX7WOrET9LeeABFuAgdxos= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= -golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= +golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= +golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= +golang.stackrox.io/grpc-http1 v0.5.1 h1:V37kybMyETA7E3o4Ea73R3f3jw/L6BENE479Aw3JpYo= +golang.stackrox.io/grpc-http1 v0.5.1/go.mod h1:c2XHQF7Inb0pBvDx1A1bYW8MAHNFpU6blsuXEwCZ8lU= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/client/auth/auth.go b/internal/client/auth/auth.go new file mode 100644 index 0000000..e480c8f --- /dev/null +++ b/internal/client/auth/auth.go @@ -0,0 +1,30 @@ +package auth + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type contextKey string + +const ( + // requestContextKey is the context key for storing MCP request. + requestContextKey contextKey = "mcp-request" +) + +// WithMCPRequestContext returns a new context with the MCP request. +func WithMCPRequestContext(ctx context.Context, mcpReq *mcp.CallToolRequest) context.Context { + return context.WithValue(ctx, requestContextKey, mcpReq) +} + +// mcpRequestFromContext extracts the MCP request from the context. +// Returns nil if no MCP request is found. +func mcpRequestFromContext(ctx context.Context) *mcp.CallToolRequest { + mcpReq, ok := ctx.Value(requestContextKey).(*mcp.CallToolRequest) + if !ok { + return nil + } + + return mcpReq +} diff --git a/internal/client/auth/auth_test.go b/internal/client/auth/auth_test.go new file mode 100644 index 0000000..cf1691d --- /dev/null +++ b/internal/client/auth/auth_test.go @@ -0,0 +1,40 @@ +package auth + +import ( + "context" + "net/http" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" +) + +func TestWithMCPRequestContext(t *testing.T) { + mcpReq := &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{ + Header: http.Header{ + "Authorization": []string{"Bearer test-token"}, + }, + }, + } + + ctxWithReq := WithMCPRequestContext(context.Background(), mcpReq) + assert.Equal(t, mcpReq, ctxWithReq.Value(requestContextKey)) +} + +func TestMCPRequestFromContext_WithoutRequest(t *testing.T) { + ctx := context.Background() + assert.Nil(t, mcpRequestFromContext(ctx)) +} + +func TestMCPRequestFromContext_WithWrongType(t *testing.T) { + ctx := context.WithValue(context.Background(), requestContextKey, "not a request") + assert.Nil(t, mcpRequestFromContext(ctx)) +} + +func TestMCPRequestFromContext_WithNilRequest(t *testing.T) { + var nilReq *mcp.CallToolRequest + + ctx := WithMCPRequestContext(context.Background(), nilReq) + assert.Nil(t, mcpRequestFromContext(ctx)) +} diff --git a/internal/client/auth/passthrough.go b/internal/client/auth/passthrough.go new file mode 100644 index 0000000..60a75a8 --- /dev/null +++ b/internal/client/auth/passthrough.go @@ -0,0 +1,93 @@ +// Package auth handles tokens required for StackRox Central API communication. +package auth + +import ( + "context" + "net/http" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/pkg/errors" + "google.golang.org/grpc/credentials" +) + +const ( + authorizationHeader = "Authorization" + bearerPrefix = "Bearer " +) + +// extractBearerToken returns the bearer token provided in the MCP call metadata header. +// Returns an error if the header or token are missing or malformed. +func extractBearerToken(mcpReq *mcp.CallToolRequest) (string, error) { + if mcpReq == nil { + return "", errors.New("MCP request is nil") + } + + extra := mcpReq.GetExtra() + if extra == nil { + return "", errors.New("MCP request metadata is missing") + } + + token, err := tokenFromHeader(extra.Header) + if err != nil { + return "", err + } + + return token, nil +} + +func tokenFromHeader(header http.Header) (string, error) { + if header == nil { + return "", errors.New("headers are missing") + } + + raw := header.Get(authorizationHeader) + if raw == "" { + return "", errors.New("authorization header is missing") + } + + if !strings.HasPrefix(raw, bearerPrefix) { + return "", errors.New("authorization header must contain a bearer token") + } + + token := strings.TrimSpace(raw[len(bearerPrefix):]) + if token == "" { + return "", errors.New("authorization token is empty") + } + + return token, nil +} + +// passthroughTokenCredentials implements credentials.PerRPCCredentials for context-based API token authentication. +// It reads the API token from the request context using the tokenContextKey. +// This is used when auth_type is "passthrough" in the configuration, allowing tools +// to provide their own tokens on a per-request basis. +type passthroughTokenCredentials struct{} + +// NewPassthroughTokenCredentials creates a new passthroughTokenCredentials. +func NewPassthroughTokenCredentials() credentials.PerRPCCredentials { + return &passthroughTokenCredentials{} +} + +// GetRequestMetadata implements credentials.PerRPCCredentials. +// It reads the token from the context and returns the authorization metadata. +func (t *passthroughTokenCredentials) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { + mcpReq := mcpRequestFromContext(ctx) + if mcpReq == nil { + return nil, errors.New("MCP request is not found in context") + } + + token, err := extractBearerToken(mcpReq) + if err != nil { + return nil, errors.Wrap(err, "failed to extract bearer token from MCP request") + } + + return map[string]string{ + "authorization": "Bearer " + token, + }, nil +} + +// RequireTransportSecurity implements credentials.PerRPCCredentials. +func (t *passthroughTokenCredentials) RequireTransportSecurity() bool { + return true +} diff --git a/internal/client/auth/passthrough_test.go b/internal/client/auth/passthrough_test.go new file mode 100644 index 0000000..93f9dc2 --- /dev/null +++ b/internal/client/auth/passthrough_test.go @@ -0,0 +1,111 @@ +package auth + +import ( + "context" + "net/http" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractBearerToken_Failures(t *testing.T) { + tests := map[string]*mcp.CallToolRequest{ + "nil request": nil, + "request extra is nil": &mcp.CallToolRequest{}, + "request header is nil": &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{}, + }, + "missing auth header": &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{ + Header: http.Header{}, + }, + }, + "wrong bearer prefix": &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{ + Header: http.Header{ + "Authorization": []string{"Not-Bearer test"}, + }, + }, + }, + "empty token": &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{ + Header: http.Header{ + "Authorization": []string{"Bearer "}, + }, + }, + }, + } + + for testName, testMcpReq := range tests { + t.Run(testName, func(t *testing.T) { + token, err := extractBearerToken(testMcpReq) + + require.Error(t, err) + assert.Empty(t, token) + }) + } +} + +func TestExtractBearerToken_Success(t *testing.T) { + req := &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{ + Header: http.Header{ + "Authorization": []string{"Bearer my-token"}, + }, + }, + } + + token, err := extractBearerToken(req) + require.NoError(t, err) + assert.Equal(t, "my-token", token) +} + +func TestPassthroughTokenCredentials_Success(t *testing.T) { + mcpReq := &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{ + Header: http.Header{ + "Authorization": []string{"Bearer token-123"}, + }, + }, + } + + ctx := WithMCPRequestContext(context.Background(), mcpReq) + tokenCredentials := NewPassthroughTokenCredentials() + + meta, err := tokenCredentials.GetRequestMetadata(ctx) + require.NoError(t, err) + + assert.Equal(t, "Bearer token-123", meta["authorization"]) +} + +func TestPassthroughTokenCredentials_NoAuthHeader(t *testing.T) { + mcpReq := &mcp.CallToolRequest{ + Extra: &mcp.RequestExtra{ + Header: http.Header{}, + }, + } + + ctx := WithMCPRequestContext(context.Background(), mcpReq) + tokenCredentials := NewPassthroughTokenCredentials() + + _, err := tokenCredentials.GetRequestMetadata(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to extract bearer token from MCP request") +} + +func TestPassthroughTokenCredentials_MissingMCPRequest(t *testing.T) { + ctx := context.Background() + tokenCredentials := NewPassthroughTokenCredentials() + + _, err := tokenCredentials.GetRequestMetadata(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "MCP request is not found in context") +} + +func TestPassthroughTokenCredentials_RequireTransportSecurity(t *testing.T) { + tokenCredentials := NewPassthroughTokenCredentials() + + assert.True(t, tokenCredentials.RequireTransportSecurity()) +} diff --git a/internal/client/auth/static.go b/internal/client/auth/static.go new file mode 100644 index 0000000..35d0850 --- /dev/null +++ b/internal/client/auth/static.go @@ -0,0 +1,39 @@ +package auth + +import ( + "context" + "errors" + + "google.golang.org/grpc/credentials" +) + +// staticTokenCredentials implements credentials.PerRPCCredentials for static API token authentication. +// It adds a fixed API token as a Bearer token in the authorization header for each RPC call. +// This is used when auth_type is "static" in the configuration. +type staticTokenCredentials struct { + token string +} + +// NewStaticTokenCredentials creates a new staticTokenCredentials with the given API token. +func NewStaticTokenCredentials(token string) credentials.PerRPCCredentials { + return &staticTokenCredentials{ + token: token, + } +} + +// GetRequestMetadata implements credentials.PerRPCCredentials. +// It returns the authorization metadata to be added to each RPC request. +func (t *staticTokenCredentials) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + if t.token == "" { + return nil, errors.New("API token is empty") + } + + return map[string]string{ + "authorization": "Bearer " + t.token, + }, nil +} + +// RequireTransportSecurity implements credentials.PerRPCCredentials. +func (t *staticTokenCredentials) RequireTransportSecurity() bool { + return true +} diff --git a/internal/client/auth/static_test.go b/internal/client/auth/static_test.go new file mode 100644 index 0000000..b2c721f --- /dev/null +++ b/internal/client/auth/static_test.go @@ -0,0 +1,31 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStaticTokenCredentials_Success(t *testing.T) { + tokenCredentials := NewStaticTokenCredentials("static-token") + + meta, err := tokenCredentials.GetRequestMetadata(context.Background()) + require.NoError(t, err) + assert.Equal(t, "Bearer static-token", meta["authorization"]) +} + +func TestStaticTokenCredentials_EmptyToken(t *testing.T) { + tokenCredentials := NewStaticTokenCredentials("") + + _, err := tokenCredentials.GetRequestMetadata(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "API token is empty") +} + +func TestStaticTokenCredentials_RequireTransportSecurity(t *testing.T) { + tokenCredentials := NewStaticTokenCredentials("static-token") + + assert.True(t, tokenCredentials.RequireTransportSecurity()) +} diff --git a/internal/client/client.go b/internal/client/client.go new file mode 100644 index 0000000..07c4af2 --- /dev/null +++ b/internal/client/client.go @@ -0,0 +1,240 @@ +// Package client holds implementation of StachRock Central API client. +package client + +import ( + "context" + "crypto/tls" + "fmt" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/stackrox/rox/pkg/grpc/alpn" + "github.com/stackrox/stackrox-mcp/internal/client/auth" + "github.com/stackrox/stackrox-mcp/internal/config" + http1client "golang.stackrox.io/grpc-http1/client" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" +) + +const ( + minConnectTimeout = 5 * time.Second + backoffJitter = 0.2 +) + +// Client provides gRPC connection to StackRox Central API. +type Client struct { + config *config.CentralConfig + + mu sync.RWMutex + conn *grpc.ClientConn + connected bool +} + +// NewClient creates a new client with the given configuration and options. +func NewClient(config *config.CentralConfig) (*Client, error) { + if config == nil { + return nil, errors.New("config cannot be nil") + } + + return &Client{ + config: config, + connected: false, + }, nil +} + +// Connect establishes a connection to StackRox Central. +// Must be called before any API requests. +func (c *Client) Connect(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.shouldRedialLocked() { + return nil + } + + c.resetConnectionLocked() + + dialOpts, err := c.buildDialOptions() + if err != nil { + return err + } + + tlsConfig, err := c.tlsConfig() + if err != nil { + return err + } + + var conn *grpc.ClientConn + if c.config.ForceHTTP1 { + conn, err = c.connectHTTP1(ctx, dialOpts, tlsConfig) + } else { + transportDailOpt := grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) + dialOpts = append([]grpc.DialOption{transportDailOpt}, dialOpts...) + conn, err = grpc.NewClient(c.config.URL, dialOpts...) + } + + if err != nil { + return NewError(err, "Connect") + } + + c.conn = conn + c.connected = true + + return nil +} + +// Close gracefully closes the connection to StackRox Central. +// Safe to call multiple times. +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.connected || c.conn == nil { + return nil + } + + err := c.conn.Close() + c.conn = nil + c.connected = false + + return errors.Wrap(err, "failed to close connection") +} + +// IsConnected returns true if the client is connected to Central. +func (c *Client) IsConnected() bool { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.connected && c.conn != nil +} + +// Conn returns the underlying gRPC connection for creating service clients. +// Tools use this to instantiate their own typed service clients. +// +// Example usage: +// +// conn := client.Conn() +// deploymentClient := v1.NewDeploymentServiceClient(conn) +// +// Returns nil if client is not connected. +func (c *Client) Conn() *grpc.ClientConn { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.conn +} + +// ReadyConn ensures the connection to Central is healthy and returns it. +func (c *Client) ReadyConn(ctx context.Context) (*grpc.ClientConn, error) { + if err := c.Connect(ctx); err != nil { + return nil, err + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.conn == nil { + return nil, errors.New("client is not connected to StackRox Central") + } + + return c.conn, nil +} + +func (c *Client) shouldRedialLocked() bool { + if !c.connected || c.conn == nil { + return true + } + + state := c.conn.GetState() + //nolint:exhaustive + switch state { + case connectivity.TransientFailure, connectivity.Shutdown: + return true + default: + return false + } +} + +func (c *Client) resetConnectionLocked() { + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } + + c.connected = false +} + +func (c *Client) buildDialOptions() ([]grpc.DialOption, error) { + retryPolicy := NewRetryPolicy(c.config) + + dialOpts := []grpc.DialOption{ + grpc.WithChainUnaryInterceptor( + createLoggingInterceptor(), + createRetryInterceptor(retryPolicy, c.config.RequestTimeout), + ), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.Config{ + BaseDelay: c.config.InitialBackoff, + Multiplier: backoffMultiplier, + Jitter: backoffJitter, + MaxDelay: c.config.MaxBackoff, + }, + MinConnectTimeout: minConnectTimeout, + }), + } + + authOpt, err := c.perRPCCredentialsOption() + if err != nil { + return nil, err + } + + dialOpts = append(dialOpts, authOpt) + + return dialOpts, nil +} + +func (c *Client) perRPCCredentialsOption() (grpc.DialOption, error) { + switch c.config.AuthType { + case config.AuthTypeStatic: + return grpc.WithPerRPCCredentials(auth.NewStaticTokenCredentials(c.config.APIToken)), nil + case config.AuthTypePassthrough: + return grpc.WithPerRPCCredentials(auth.NewPassthroughTokenCredentials()), nil + default: + return nil, fmt.Errorf("unsupported auth type: %s", c.config.AuthType) + } +} + +func (c *Client) tlsConfig() (*tls.Config, error) { + // Extract hostname for TLS verification. + // This is especially important for force_http1 mode where the gRPC-HTTP/1 bridge + // needs explicit ServerName for certificate validation. + hostname, err := c.config.GetURLHostname() + if err != nil { + return nil, errors.Wrap(err, "failed to get central URL hostname") + } + + return &tls.Config{ + InsecureSkipVerify: c.config.InsecureSkipTLSVerify, //nolint:gosec + MinVersion: tls.VersionTLS12, + ServerName: hostname, + }, nil +} + +func (c *Client) connectHTTP1( + ctx context.Context, + dialOpts []grpc.DialOption, + tlsConfig *tls.Config, +) (*grpc.ClientConn, error) { + connectOpts := []http1client.ConnectOption{ + http1client.ForceDowngrade(true), + http1client.ExtraH2ALPNs(alpn.PureGRPCALPNString), + http1client.DialOpts(dialOpts...), + } + + http1Client, err := http1client.ConnectViaProxy(ctx, c.config.URL, tlsConfig, connectOpts...) + + return http1Client, errors.Wrap(err, "unable to connect via http1") +} diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 0000000..be7e835 --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,140 @@ +package client + +import ( + "context" + "crypto/tls" + "net" + "strconv" + "testing" + "time" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +func TestClientReconnectsAfterServerRestart(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + listener, server := startTestGRPCServer(t, "127.0.0.1:"+strconv.Itoa(testutil.GetPortForTest(t))) + + cfg := &config.CentralConfig{ + URL: listener.Addr().String(), + AuthType: config.AuthTypeStatic, + APIToken: "dummy", + InsecureSkipTLSVerify: true, + RequestTimeout: time.Second, + MaxRetries: 3, + InitialBackoff: time.Millisecond, + MaxBackoff: 5 * time.Millisecond, + } + + client, err := NewClient(cfg) + require.NoError(t, err) + + defer func() { assert.NoError(t, client.Close()) }() + + require.NoError(t, client.Connect(ctx)) + initialConn := client.Conn() + require.NotNil(t, initialConn) + + // Simulate server failure. + server.Stop() + // Note: server.Stop() already closes the listener, so we don't need to close it explicitly. + + waitCtx, waitCancel := context.WithTimeout(ctx, time.Second) + initialConn.WaitForStateChange(waitCtx, initialConn.GetState()) + waitCancel() + + // Restart server on the same address. + _, server2 := startTestGRPCServer(t, cfg.URL) + + defer server2.Stop() + + require.NoError(t, client.Connect(ctx)) + reconnected := client.Conn() + require.NotNil(t, reconnected) +} + +func startTestGRPCServer(t *testing.T, addr string) (net.Listener, *grpc.Server) { + t.Helper() + + //nolint:noctx + lis, err := net.Listen("tcp", addr) + require.NoError(t, err) + + srv := grpc.NewServer() + + go func() { + _ = srv.Serve(lis) + }() + + return lis, srv +} + +func TestClient_tlsConfig(t *testing.T) { + tests := map[string]struct { + url string + expectedServer string + }{ + "hostname": { + url: "central.stackrox.io:8443", + expectedServer: "central.stackrox.io", + }, + "https scheme": { + url: "https://central.stackrox.io:8443", + expectedServer: "central.stackrox.io", + }, + "IP address": { + url: "192.168.1.100:8443", + expectedServer: "192.168.1.100", + }, + "service name": { + url: "central.stackrox:443", + expectedServer: "central.stackrox", + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + client := &Client{ + config: &config.CentralConfig{ + URL: testCase.url, + }, + } + + tlsCfg, err := client.tlsConfig() + + require.NoError(t, err) + require.NotNil(t, tlsCfg) + assert.Equal(t, testCase.expectedServer, tlsCfg.ServerName) + assert.Equal(t, uint16(tls.VersionTLS12), tlsCfg.MinVersion) // TLS 1.2 + }) + } +} + +func TestClient_tlsConfig_insecureSkipVerify(t *testing.T) { + client := &Client{ + config: &config.CentralConfig{ + URL: "central.stackrox.io:8443", + InsecureSkipTLSVerify: true, + }, + } + + tlsCfg, err := client.tlsConfig() + + require.NoError(t, err) + require.NotNil(t, tlsCfg) + assert.True(t, tlsCfg.InsecureSkipVerify) + + client.config.InsecureSkipTLSVerify = false + + tlsCfg, err = client.tlsConfig() + + require.NoError(t, err) + require.NotNil(t, tlsCfg) + assert.False(t, tlsCfg.InsecureSkipVerify) +} diff --git a/internal/client/error.go b/internal/client/error.go new file mode 100644 index 0000000..0a9d497 --- /dev/null +++ b/internal/client/error.go @@ -0,0 +1,120 @@ +package client + +import ( + "fmt" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Error provides detailed error information for gRPC errors. +// Includes retry classification and human-readable error messages. +type Error struct { + // Error classification + Code codes.Code // gRPC status code + Retriable bool // Whether error should be retried + + // Human-readable information + Message string // Human-readable, actionable error message + OriginalErr error // Original gRPC error + + // Context + Operation string // Operation that failed (e.g., "Connect", "GetDeployment") +} + +// NewError creates an Error from a gRPC error. +// This function maps gRPC status codes to human-readable messages and determines if the error is retriable. +func NewError(err error, operation string) *Error { + if err == nil { + return nil + } + + grpcStatus, ok := status.FromError(err) + if !ok { + // Not a gRPC error. + return &Error{ + Code: codes.Unknown, + Retriable: false, + Message: fmt.Sprintf("Unknown error: %v", err), + OriginalErr: err, + Operation: operation, + } + } + + code := grpcStatus.Code() + message := formatMessage(code, grpcStatus.Message(), operation) + + return &Error{ + Code: code, + Retriable: IsRetriableGRPCError(err), + Message: message, + OriginalErr: err, + Operation: operation, + } +} + +// Error returns the error message (implements error interface). +func (e *Error) Error() string { + return e.Message +} + +// IsRetriable returns true if the error should be retried. +func (e *Error) IsRetriable() bool { + return e.Retriable +} + +// IsRetriableGRPCError determines if an error should be retried based on gRPC status code. +// Retriable errors are transient and may succeed on retry. +func IsRetriableGRPCError(err error) bool { + grpcStatus, ok := status.FromError(err) + if !ok { + return false + } + + return grpcStatus.Code() == codes.Unavailable || grpcStatus.Code() == codes.DeadlineExceeded +} + +// formatMessage generates a human-readable, actionable error message +// based on the gRPC status code. +// +//nolint:lll,cyclop +func formatMessage(code codes.Code, detail string, operation string) string { + baseMsg := fmt.Sprintf("Operation '%s' failed", operation) + if operation == "" { + baseMsg = "Operation failed" + } + + //nolint:exhaustive + switch code { + case codes.Unauthenticated: + return fmt.Sprintf("%s: Authentication failed - invalid or expired API token. Please check your configuration. %s", baseMsg, detail) + case codes.PermissionDenied: + return fmt.Sprintf("%s: Permission denied - your API token does not have sufficient permissions for this operation. %s", baseMsg, detail) + case codes.NotFound: + return fmt.Sprintf("%s: Resource not found - the requested resource does not exist. %s", baseMsg, detail) + case codes.InvalidArgument: + return fmt.Sprintf("%s: Invalid argument - the request contains invalid parameters. %s", baseMsg, detail) + case codes.Unavailable: + return fmt.Sprintf("%s: StackRox Central is temporarily unavailable. The request will be retried automatically. %s", baseMsg, detail) + case codes.DeadlineExceeded: + return fmt.Sprintf("%s: Request timed out after 30 seconds. StackRox Central may be overloaded. The request will be retried automatically. %s", baseMsg, detail) + case codes.ResourceExhausted: + return fmt.Sprintf("%s: Resource exhausted - rate limit exceeded or server overloaded. The request will be retried automatically. %s", baseMsg, detail) + case codes.Aborted: + return fmt.Sprintf("%s: Operation aborted due to concurrency conflict. The request will be retried automatically. %s", baseMsg, detail) + case codes.AlreadyExists: + return fmt.Sprintf("%s: Resource already exists. %s", baseMsg, detail) + case codes.FailedPrecondition: + return fmt.Sprintf("%s: Failed precondition - the system is not in the correct state for this operation. %s", baseMsg, detail) + case codes.Unimplemented: + return fmt.Sprintf("%s: Operation not implemented - this method is not available on the StackRox Central server. %s", baseMsg, detail) + case codes.Canceled: + return fmt.Sprintf("%s: Operation was cancelled. %s", baseMsg, detail) + case codes.Unknown: + return fmt.Sprintf("%s: Unknown error occurred. %s", baseMsg, detail) + case codes.Internal: + return fmt.Sprintf("%s: Internal server error - an error occurred on the StackRox Central server. %s", baseMsg, detail) + default: + return fmt.Sprintf("%s: Error code %s. %s", baseMsg, code.String(), detail) + } +} diff --git a/internal/client/error_test.go b/internal/client/error_test.go new file mode 100644 index 0000000..893a766 --- /dev/null +++ b/internal/client/error_test.go @@ -0,0 +1,220 @@ +package client + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestNewError_WithGRPCError(t *testing.T) { + grpcErr := status.Error(codes.Unauthenticated, "invalid token") + err := NewError(grpcErr, "GetDeployment") + + require.NotNil(t, err) + assert.Equal(t, codes.Unauthenticated, err.Code) + assert.False(t, err.Retriable) + assert.Contains(t, err.Message, "Authentication failed") + assert.Contains(t, err.Message, "GetDeployment") + assert.Equal(t, grpcErr, err.OriginalErr) + assert.Equal(t, "GetDeployment", err.Operation) +} + +func TestNewError_WithNonGRPCError(t *testing.T) { + nonGRPCErr := errors.New("connection refused") + err := NewError(nonGRPCErr, "Connect") + + require.NotNil(t, err) + assert.Equal(t, codes.Unknown, err.Code) + assert.False(t, err.Retriable) + assert.Contains(t, err.Message, "Unknown error") + assert.Contains(t, err.Message, "connection refused") + assert.Equal(t, nonGRPCErr, err.OriginalErr) + assert.Equal(t, "Connect", err.Operation) +} + +func TestNewError_WithNilError(t *testing.T) { + err := NewError(nil, "SomeOperation") + + assert.Nil(t, err) +} + +func TestError_ErrorMethod(t *testing.T) { + grpcErr := status.Error(codes.NotFound, "deployment not found") + err := NewError(grpcErr, "GetDeployment") + + assert.Equal(t, err.Message, err.Error()) +} + +func TestIsRetriableGRPCError(t *testing.T) { + tests := map[string]struct { + err error + expected bool + }{ + "Unavailable is retriable": { + err: status.Error(codes.Unavailable, "service unavailable"), + expected: true, + }, + "DeadlineExceeded is retriable": { + err: status.Error(codes.DeadlineExceeded, "timeout"), + expected: true, + }, + "Unauthenticated is not retriable": { + err: status.Error(codes.Unauthenticated, "invalid credentials"), + expected: false, + }, + "NotFound is not retriable": { + err: status.Error(codes.NotFound, "not found"), + expected: false, + }, + "PermissionDenied is not retriable": { + err: status.Error(codes.PermissionDenied, "permission denied"), + expected: false, + }, + "Non-gRPC error is not retriable": { + err: errors.New("regular error"), + expected: false, + }, + } + + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + result := IsRetriableGRPCError(tt.err) + assert.Equal(t, tt.expected, result) + }) + } +} + +//nolint:funlen +func TestFormatMessage_AllCodes(t *testing.T) { + tests := map[string]struct { + code codes.Code + detail string + operation string + expectedMessage string + }{ + "Unauthenticated": { + code: codes.Unauthenticated, + detail: "token expired", + operation: "GetDeployment", + expectedMessage: "Authentication failed", + }, + "PermissionDenied": { + code: codes.PermissionDenied, + detail: "insufficient permissions", + operation: "DeleteDeployment", + expectedMessage: "Permission denied", + }, + "NotFound": { + code: codes.NotFound, + detail: "deployment not found", + operation: "GetDeployment", + expectedMessage: "Resource not found", + }, + "InvalidArgument": { + code: codes.InvalidArgument, + detail: "invalid deployment ID", + operation: "GetDeployment", + expectedMessage: "Invalid argument", + }, + "Unavailable": { + code: codes.Unavailable, + detail: "service unavailable", + operation: "GetDeployment", + expectedMessage: "StackRox Central is temporarily unavailable", + }, + "DeadlineExceeded": { + code: codes.DeadlineExceeded, + detail: "timeout", + operation: "GetDeployment", + expectedMessage: "Request timed out after 30 seconds", + }, + "ResourceExhausted": { + code: codes.ResourceExhausted, + detail: "rate limit exceeded", + operation: "GetDeployment", + expectedMessage: "Resource exhausted", + }, + "Aborted": { + code: codes.Aborted, + detail: "transaction aborted", + operation: "UpdateDeployment", + expectedMessage: "Operation aborted", + }, + "AlreadyExists": { + code: codes.AlreadyExists, + detail: "deployment already exists", + operation: "CreateDeployment", + expectedMessage: "Resource already exists", + }, + "FailedPrecondition": { + code: codes.FailedPrecondition, + detail: "system not ready", + operation: "CreateDeployment", + expectedMessage: "Failed precondition", + }, + "Unimplemented": { + code: codes.Unimplemented, + detail: "method not supported", + operation: "GetDeployment", + expectedMessage: "Operation not implemented", + }, + "Canceled": { + code: codes.Canceled, + detail: "operation cancelled", + operation: "GetDeployment", + expectedMessage: "Operation was cancelled", + }, + "Unknown": { + code: codes.Unknown, + detail: "unknown error", + operation: "GetDeployment", + expectedMessage: "Unknown error occurred", + }, + "Internal": { + code: codes.Internal, + detail: "internal server error", + operation: "GetDeployment", + expectedMessage: "Internal server error", + }, + "Empty operation": { + code: codes.Unknown, + detail: "test", + operation: "", + expectedMessage: "Operation failed", + }, + "Default case": { + code: codes.DataLoss, + detail: "data loss", + operation: "GetDeployment", + expectedMessage: "Error code DataLoss", + }, + } + + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + message := formatMessage(tt.code, tt.detail, tt.operation) + assert.Contains(t, message, tt.expectedMessage) + assert.Contains(t, message, tt.detail) + }) + } +} + +func TestFormatMessage_ContainsOperationName(t *testing.T) { + message := formatMessage(codes.Unknown, "test error", "GetDeployment") + assert.Contains(t, message, "GetDeployment") +} + +func TestError_Fields(t *testing.T) { + grpcErr := status.Error(codes.Unavailable, "service down") + err := NewError(grpcErr, "Connect") + + assert.Equal(t, codes.Unavailable, err.Code) + assert.True(t, err.Retriable) + assert.NotEmpty(t, err.Message) + assert.Equal(t, grpcErr, err.OriginalErr) + assert.Equal(t, "Connect", err.Operation) +} diff --git a/internal/client/interceptors.go b/internal/client/interceptors.go new file mode 100644 index 0000000..b9ef4fc --- /dev/null +++ b/internal/client/interceptors.go @@ -0,0 +1,102 @@ +package client + +import ( + "context" + "log/slog" + "time" + + "google.golang.org/grpc" +) + +// createRetryInterceptor creates a retry interceptor closure that captures the retry policy and timeout. +func createRetryInterceptor(policy *RetryPolicy, requestTimeout time.Duration) grpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req, reply any, + clientConn *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + var lastErr error + + for attempt := range policy.GetMaxRetries() { + // Create context with timeout for this attempt. + attemptCtx, cancel := context.WithTimeout(ctx, requestTimeout) + err := invoker(attemptCtx, method, req, reply, clientConn, opts...) + + cancel() + + if err == nil { + if attempt > 0 { + slog.Info("Request succeeded after retry", "method", method, "attempt", attempt+1) + } + + return nil + } + + if !IsRetriableGRPCError(err) { + return err + } + + lastErr = err + + if !policy.ShouldRetry(attempt) { + break + } + + backoff := policy.NextBackoff(attempt) + + slog.Info("Request failed, retrying", + "method", method, + "attempt", attempt+1, + "backoff", backoff, + "error", err, + ) + + // Wait for backoff duration or context cancellation. + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + // Continue to next attempt. + } + } + + slog.Warn("Request failed after all retries", + "method", method, + "attempts", policy.GetMaxRetries(), + "error", lastErr, + ) + + return lastErr + } +} + +// createLoggingInterceptor creates a logging interceptor closure. +func createLoggingInterceptor() grpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req, reply any, + clientConn *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + slog.Debug("API request started", "method", method) + + startTime := time.Now() + err := invoker(ctx, method, req, reply, clientConn, opts...) + duration := time.Since(startTime) + + if err != nil { + slog.Error("API request failed", "method", method, "duration", duration, "error", err) + + return err + } + + slog.Debug("API request completed", "method", method, "duration", duration) + + return nil + } +} diff --git a/internal/client/interceptors_test.go b/internal/client/interceptors_test.go new file mode 100644 index 0000000..50c8281 --- /dev/null +++ b/internal/client/interceptors_test.go @@ -0,0 +1,396 @@ +package client + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// mockInvoker is a helper to create mock invokers for testing. +type mockInvoker struct { + calls int + responses []error +} + +func (m *mockInvoker) invoke( + _ context.Context, + _ string, + _, _ any, + _ *grpc.ClientConn, + _ ...grpc.CallOption, +) error { + if m.calls >= len(m.responses) { + return errors.New("no more mock responses available") + } + + err := m.responses[m.calls] + m.calls++ + + return err +} + +func TestCreateRetryInterceptor_Success(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 3, + InitialBackoff: time.Nanosecond, + MaxBackoff: time.Nanosecond, + } + policy := NewRetryPolicy(centralConfig) + interceptor := createRetryInterceptor(policy, time.Second) + + mock := &mockInvoker{ + responses: []error{nil}, // Success on first attempt. + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + + require.NoError(t, err) + assert.Equal(t, 1, mock.calls) +} + +func TestCreateRetryInterceptor_RetryableError(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 3, + InitialBackoff: time.Nanosecond, + MaxBackoff: time.Nanosecond, + } + policy := NewRetryPolicy(centralConfig) + interceptor := createRetryInterceptor(policy, time.Second) + + mock := &mockInvoker{ + responses: []error{ + status.Error(codes.Unavailable, "service unavailable"), // Retry 1. + status.Error(codes.DeadlineExceeded, "deadline exceeded"), // Retry 2. + nil, // Success on third attempt. + }, + } + + startTime := time.Now() + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + duration := time.Since(startTime) + + require.NoError(t, err) + assert.Equal(t, 3, mock.calls) + + // Should have at least some delay due to backoff (allow for jitter and variability). + assert.Greater(t, duration, centralConfig.InitialBackoff) +} + +func TestCreateRetryInterceptor_NonRetriableError(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 3, + InitialBackoff: time.Nanosecond, + MaxBackoff: time.Nanosecond, + } + policy := NewRetryPolicy(centralConfig) + interceptor := createRetryInterceptor(policy, time.Second) + + mock := &mockInvoker{ + responses: []error{ + status.Error(codes.Unauthenticated, "invalid token"), // Non-retriable. + }, + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + + require.Error(t, err) + assert.Equal(t, codes.Unauthenticated, status.Code(err)) + + // Should not retry. + assert.Equal(t, 1, mock.calls) +} + +func TestCreateRetryInterceptor_MaxAttemptsReached(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 3, + InitialBackoff: time.Nanosecond, + MaxBackoff: time.Nanosecond, + } + policy := NewRetryPolicy(centralConfig) + interceptor := createRetryInterceptor(policy, time.Second) + + mock := &mockInvoker{ + responses: []error{ + status.Error(codes.Unavailable, "service unavailable"), + status.Error(codes.Unavailable, "service unavailable"), + status.Error(codes.Unavailable, "service unavailable"), + }, + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + // All attempts used. + assert.Equal(t, 3, mock.calls) +} + +func TestCreateRetryInterceptor_ContextCancellation(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 3, + InitialBackoff: 5 * time.Millisecond, + MaxBackoff: 5 * time.Millisecond, + } + policy := NewRetryPolicy(centralConfig) + interceptor := createRetryInterceptor(policy, time.Second) + + ctx, cancel := context.WithCancel(context.Background()) + + mock := &mockInvoker{ + responses: []error{ + status.Error(codes.Unavailable, "service unavailable"), + status.Error(codes.Unavailable, "service unavailable"), + }, + } + + // Cancel context after first attempt. + go func() { + time.Sleep(time.Millisecond) + cancel() + }() + + err := interceptor( + ctx, + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + + require.Error(t, err) + assert.Equal(t, context.Canceled, err) + // Should have made first attempt, then cancelled during backoff. + assert.Equal(t, 1, mock.calls) +} + +func TestCreateRetryInterceptor_BackoffProgression(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 4, + InitialBackoff: time.Millisecond, + MaxBackoff: time.Millisecond, + } + policy := NewRetryPolicy(centralConfig) + interceptor := createRetryInterceptor(policy, time.Second) + + mock := &mockInvoker{ + responses: []error{ + status.Error(codes.Unavailable, "service unavailable"), + status.Error(codes.Unavailable, "service unavailable"), + status.Error(codes.Unavailable, "service unavailable"), + nil, // Success on 4th attempt. + }, + } + + startTime := time.Now() + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + duration := time.Since(startTime) + + require.NoError(t, err) + assert.Equal(t, 4, mock.calls) + // With exponential backoff and jitter, we should see some delay but exact timing is variable. + assert.Greater(t, duration, 2*centralConfig.InitialBackoff) +} + +func TestCreateRetryInterceptor_RequestTimeout(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 3, + InitialBackoff: 5 * time.Millisecond, + MaxBackoff: 5 * time.Millisecond, + } + policy := NewRetryPolicy(centralConfig) + + // Very short request timeout. + interceptor := createRetryInterceptor(policy, time.Millisecond) + + attemptCount := 0 + slowInvoker := func( + ctx context.Context, + _ string, + _, _ any, + _ *grpc.ClientConn, + _ ...grpc.CallOption, + ) error { + attemptCount++ + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(2 * centralConfig.MaxBackoff): + return nil + } + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + slowInvoker, + ) + + // Should timeout on each attempt. + require.Error(t, err) + assert.GreaterOrEqual(t, attemptCount, 1) +} + +func TestCreateRetryInterceptor_ZeroRetries(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 1, // No retries. + InitialBackoff: time.Nanosecond, + MaxBackoff: time.Nanosecond, + } + policy := NewRetryPolicy(centralConfig) + interceptor := createRetryInterceptor(policy, time.Second) + + mock := &mockInvoker{ + responses: []error{ + status.Error(codes.Unavailable, "service unavailable"), + }, + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + + require.Error(t, err) + assert.Equal(t, codes.Unavailable, status.Code(err)) + assert.Equal(t, 1, mock.calls) +} + +func TestCreateRetryInterceptor_ContextTimeoutDuringAttempt(t *testing.T) { + centralConfig := &config.CentralConfig{ + MaxRetries: 3, + InitialBackoff: time.Nanosecond, + MaxBackoff: time.Nanosecond, + } + policy := NewRetryPolicy(centralConfig) + + // Request timeout of 10ms. + interceptor := createRetryInterceptor(policy, time.Millisecond) + + callCount := 0 + slowInvoker := func( + _ context.Context, + _ string, + _, _ any, + _ *grpc.ClientConn, + _ ...grpc.CallOption, + ) error { + callCount++ + // First attempt times out, second succeeds quickly. + if callCount == 1 { + time.Sleep(5 * time.Millisecond) + + return status.Error(codes.DeadlineExceeded, "timeout") + } + + return nil + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + slowInvoker, + ) + + require.NoError(t, err) + assert.Equal(t, 2, callCount) +} + +func TestCreateLoggingInterceptor_Success(t *testing.T) { + interceptor := createLoggingInterceptor() + + mock := &mockInvoker{ + responses: []error{ + nil, + }, + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + + require.NoError(t, err) + assert.Equal(t, 1, mock.calls) +} + +func TestCreateLoggingInterceptor_Error(t *testing.T) { + interceptor := createLoggingInterceptor() + + expectedErr := status.Error(codes.NotFound, "not found") + mock := &mockInvoker{ + responses: []error{ + expectedErr, + }, + } + + err := interceptor( + context.Background(), + "test.Method", + nil, + nil, + nil, + mock.invoke, + ) + + require.Error(t, err) + assert.Equal(t, expectedErr, err) +} diff --git a/internal/client/retry.go b/internal/client/retry.go new file mode 100644 index 0000000..7917f60 --- /dev/null +++ b/internal/client/retry.go @@ -0,0 +1,52 @@ +package client + +import ( + "math" + "time" + + "github.com/stackrox/stackrox-mcp/internal/config" +) + +const ( + backoffMultiplier = 2.0 +) + +// RetryPolicy holds configuration for retry behavior with exponential backoff. +type RetryPolicy struct { + config *config.CentralConfig +} + +// NewRetryPolicy creates a new RetryPolicy with default values. +func NewRetryPolicy(config *config.CentralConfig) *RetryPolicy { + return &RetryPolicy{ + config: config, + } +} + +// NextBackoff calculates the next backoff duration based on the attempt number. +// Uses exponential backoff with jitter to prevent thundering herd. +func (rp *RetryPolicy) NextBackoff(attempt int) time.Duration { + if attempt < 0 { + attempt = 0 + } + + // Calculate exponential backoff: initial * (multiplier ^ attempt). + backoff := float64(rp.config.InitialBackoff) * math.Pow(backoffMultiplier, float64(attempt)) + + // Apply max backoff cap. + if backoff > float64(rp.config.MaxBackoff) { + backoff = float64(rp.config.MaxBackoff) + } + + return time.Duration(backoff) +} + +// GetMaxRetries maximum allowed retries. +func (rp *RetryPolicy) GetMaxRetries() int { + return rp.config.MaxRetries +} + +// ShouldRetry returns true if attempts are not exhausted. +func (rp *RetryPolicy) ShouldRetry(attempt int) bool { + return attempt < rp.config.MaxRetries-1 +} diff --git a/internal/client/retry_test.go b/internal/client/retry_test.go new file mode 100644 index 0000000..067e5a0 --- /dev/null +++ b/internal/client/retry_test.go @@ -0,0 +1,86 @@ +package client + +import ( + "testing" + "time" + + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestRetryPolicy_ShouldRetry(t *testing.T) { + policy := NewRetryPolicy(&config.CentralConfig{ + MaxRetries: 3, + }) + + assert.True(t, policy.ShouldRetry(-1)) + assert.True(t, policy.ShouldRetry(0)) + assert.True(t, policy.ShouldRetry(policy.config.MaxRetries-2)) + assert.False(t, policy.ShouldRetry(policy.config.MaxRetries-1)) +} + +func TestRetryPolicy_NextBackoff(t *testing.T) { + tests := map[string]struct { + initialBackoff time.Duration + maxBackoff time.Duration + attempt int + expected time.Duration + }{ + "zero attempt": { + initialBackoff: 500 * time.Millisecond, + maxBackoff: 5 * time.Second, + attempt: 0, + expected: 500 * time.Millisecond, // 500ms * (2^0) = 500ms + }, + "first attempt": { + initialBackoff: time.Second, + maxBackoff: 10 * time.Second, + attempt: 0, + expected: time.Second, // 1s * (2^0) = 1s + }, + "second attempt": { + initialBackoff: time.Second, + maxBackoff: 10 * time.Second, + attempt: 1, + expected: 2 * time.Second, // 1s * (2^1) = 2s + }, + "third attempt": { + initialBackoff: time.Second, + maxBackoff: 10 * time.Second, + attempt: 2, + expected: 4 * time.Second, // 1s * (2^2) = 4s + }, + "capped at max backoff": { + initialBackoff: time.Second, + maxBackoff: 5 * time.Second, + attempt: 3, + expected: 5 * time.Second, // 8s capped to 5s + }, + "negative attempt treated as zero": { + initialBackoff: time.Second, + maxBackoff: 10 * time.Second, + attempt: -1, + expected: time.Second, // negative becomes 0 + }, + "max equals initial backoff": { + initialBackoff: time.Second, + maxBackoff: time.Second, + attempt: 5, + expected: time.Second, // always capped at 1s + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + policy := &RetryPolicy{ + config: &config.CentralConfig{ + InitialBackoff: testCase.initialBackoff, + MaxBackoff: testCase.maxBackoff, + }, + } + + result := policy.NextBackoff(testCase.attempt) + assert.Equal(t, testCase.expected, result) + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 6796d43..90caba1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,13 +2,23 @@ package config import ( + "fmt" + "net/url" "strings" + "time" "github.com/pkg/errors" "github.com/spf13/viper" ) -const defaultPort = 8080 +const ( + defaultPort = 8080 + + defaultRequestTimeout = 30 * time.Second + defaultMaxRetries = 3 + defaultInitialBackoff = time.Second + defaultMaxBackoff = 10 * time.Second +) // Config represents the complete application configuration. type Config struct { @@ -18,11 +28,31 @@ type Config struct { Tools ToolsConfig `mapstructure:"tools"` } +type authType string + +const ( + // AuthTypePassthrough defines auth flow where API token, used to communicate with MCP server, + // is passed and used in a communication with StackRox Central API. + AuthTypePassthrough authType = "passthrough" + + // AuthTypeStatic defines auth flow where API token is statically configured and + // defined in configuration or environment variable. + AuthTypeStatic authType = "static" +) + // CentralConfig contains StackRox Central connection configuration. type CentralConfig struct { - URL string `mapstructure:"url"` - Insecure bool `mapstructure:"insecure"` - ForceHTTP1 bool `mapstructure:"force_http1"` + URL string `mapstructure:"url"` + AuthType authType `mapstructure:"auth_type"` + APIToken string `mapstructure:"api_token"` + InsecureSkipTLSVerify bool `mapstructure:"insecure_skip_tls_verify"` + ForceHTTP1 bool `mapstructure:"force_http1"` + + // Timeouts and retry settings + RequestTimeout time.Duration `mapstructure:"request_timeout"` + MaxRetries int `mapstructure:"max_retries"` + InitialBackoff time.Duration `mapstructure:"initial_backoff"` + MaxBackoff time.Duration `mapstructure:"max_backoff"` } // GlobalConfig contains global MCP server configuration. @@ -64,6 +94,9 @@ func LoadConfig(configPath string) (*Config, error) { // Set up environment variable support. // Note: SetEnvPrefix adds a single underscore, so "STACKROX_MCP_" becomes the prefix. // We want double underscores between sections, so we use "__" in the replacer. + // + // For environment variable mapping to the config to work, we need to define a default for that config option. + // Every configuration option that can be set via environment variables must be defined in setDefaults(). viperInstance.SetEnvPrefix("STACKROX_MCP_") viperInstance.SetEnvKeyReplacer(strings.NewReplacer(".", "__")) viperInstance.AutomaticEnv() @@ -91,9 +124,16 @@ func LoadConfig(configPath string) (*Config, error) { // setDefaults sets default values for configuration. func setDefaults(viper *viper.Viper) { viper.SetDefault("central.url", "central.stackrox:8443") - viper.SetDefault("central.insecure", false) + viper.SetDefault("central.auth_type", "passthrough") + viper.SetDefault("central.api_token", "") + viper.SetDefault("central.insecure_skip_tls_verify", false) viper.SetDefault("central.force_http1", false) + viper.SetDefault("central.request_timeout", defaultRequestTimeout) + viper.SetDefault("central.max_retries", defaultMaxRetries) + viper.SetDefault("central.initial_backoff", defaultInitialBackoff) + viper.SetDefault("central.max_backoff", defaultMaxBackoff) + viper.SetDefault("global.read_only_tools", true) viper.SetDefault("server.address", "localhost") @@ -103,28 +143,115 @@ func setDefaults(viper *viper.Viper) { viper.SetDefault("tools.config_manager.enabled", false) } -var ( - errURLRequired = errors.New("central.url is required") - errAtLeastOneTool = errors.New("at least one tool has to be enabled") -) +// GetURLHostname returns URL hostname. +func (cc *CentralConfig) GetURLHostname() (string, error) { + parsedURL, err := url.Parse(cc.URL) + if err == nil && parsedURL.Hostname() != "" { + return parsedURL.Hostname(), nil + } -// Validate validates the configuration. -func (c *Config) Validate() error { - if c.Central.URL == "" { - return errURLRequired + // Many StackRox configurations use hostname:port format without a scheme, + // so we add a scheme if missing to ensure proper parsing. + parsedURL, err = url.Parse("https://" + cc.URL) + if err != nil { + return "", errors.Wrapf(err, "failed to parse URL %q", cc.URL) + } + + return parsedURL.Hostname(), nil +} + +//nolint:cyclop +func (cc *CentralConfig) validate() error { + if cc.URL == "" { + return errors.New("central.url is required") + } + + _, err := cc.GetURLHostname() + if err != nil { + return errors.Wrap(err, "central.url is not a valid URL") + } + + if cc.AuthType != AuthTypePassthrough && cc.AuthType != AuthTypeStatic { + return errors.New("central.auth_type must be either passthrough or static") + } + + if cc.AuthType == AuthTypeStatic && cc.APIToken == "" { + return fmt.Errorf("central.api_token is required for %q auth type", AuthTypeStatic) + } + + if cc.AuthType == AuthTypePassthrough && cc.APIToken != "" { + return fmt.Errorf("central.api_token can not be set for %q auth type", AuthTypePassthrough) + } + + if cc.RequestTimeout <= 0 { + return errors.New("central.request_timeout must be positive") + } + + if cc.MaxRetries < 0 || cc.MaxRetries > 10 { + return errors.New("central.max_retries must be between 0 and 10") + } + + if cc.InitialBackoff <= 0 { + return errors.New("central.initial_backoff must be positive") + } + + if cc.MaxBackoff <= 0 { + return errors.New("central.max_backoff must be positive") + } + + if cc.MaxBackoff < cc.InitialBackoff { + return errors.New("central.max_backoff has to be greater than or equal to central.initial_backoff") } - if c.Server.Address == "" { + return nil +} + +func (sc *ServerConfig) validate() error { + if sc.Address == "" { return errors.New("server.address is required") } - if c.Server.Port < 1 || c.Server.Port > 65535 { + if sc.Port < 1 || sc.Port > 65535 { return errors.New("server.port must be between 1 and 65535") } + return nil +} + +// Validate validates the configuration. +func (c *Config) Validate() error { + if err := c.Central.validate(); err != nil { + return err + } + + if err := c.Server.validate(); err != nil { + return err + } + if !c.Tools.Vulnerability.Enabled && !c.Tools.ConfigManager.Enabled { - return errAtLeastOneTool + return errors.New("at least one tool has to be enabled") } return nil } + +const redacted = "***REDACTED***" + +// Redacted returns a copy of the configuration with sensitive data redacted. +// This is useful for logging configuration without exposing secrets. +func (c *Config) Redacted() *Config { + redactedConfig := *c + redactedConfig.Central = c.Central.redacted() + + return &redactedConfig +} + +// redacted returns a copy of CentralConfig with sensitive data redacted. +func (cc *CentralConfig) redacted() CentralConfig { + redactedCentral := *cc + if cc.APIToken != "" { + redactedCentral.APIToken = redacted + } + + return redactedCentral +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b59bce2..dcc522b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "os" "testing" + "time" "github.com/stackrox/stackrox-mcp/internal/testutil" "github.com/stretchr/testify/assert" @@ -13,7 +14,18 @@ import ( func getDefaultConfig() *Config { return &Config{ Central: CentralConfig{ - URL: "central.example.com:8443", + URL: "central.example.com:8443", + AuthType: AuthTypeStatic, + APIToken: "test-token", + InsecureSkipTLSVerify: false, + ForceHTTP1: false, + RequestTimeout: defaultRequestTimeout, + MaxRetries: defaultMaxRetries, + InitialBackoff: defaultInitialBackoff, + MaxBackoff: defaultMaxBackoff, + }, + Global: GlobalConfig{ + ReadOnlyTools: false, }, Server: ServerConfig{ Address: "localhost", @@ -23,6 +35,9 @@ func getDefaultConfig() *Config { Vulnerability: ToolsetVulnerabilityConfig{ Enabled: true, }, + ConfigManager: ToolConfigManagerConfig{ + Enabled: false, + }, }, } } @@ -31,7 +46,7 @@ func TestLoadConfig_FromYAML(t *testing.T) { yamlContent := ` central: url: central.example.com:8443 - insecure: true + insecure_skip_tls_verify: true force_http1: true global: read_only_tools: false @@ -51,7 +66,7 @@ tools: require.NotNil(t, cfg) assert.Equal(t, "central.example.com:8443", cfg.Central.URL) - assert.True(t, cfg.Central.Insecure) + assert.True(t, cfg.Central.InsecureSkipTLSVerify) assert.True(t, cfg.Central.ForceHTTP1) assert.False(t, cfg.Global.ReadOnlyTools) assert.True(t, cfg.Tools.Vulnerability.Enabled) @@ -84,7 +99,9 @@ tools: func TestLoadConfig_EnvVarOnly(t *testing.T) { t.Setenv("STACKROX_MCP__CENTRAL__URL", "env.example.com:8443") - t.Setenv("STACKROX_MCP__CENTRAL__INSECURE", "true") + t.Setenv("STACKROX_MCP__CENTRAL__AUTH_TYPE", string(AuthTypeStatic)) + t.Setenv("STACKROX_MCP__CENTRAL__API_TOKEN", "test-token") + t.Setenv("STACKROX_MCP__CENTRAL__INSECURE_SKIP_TLS_VERIFY", "true") t.Setenv("STACKROX_MCP__CENTRAL__FORCE_HTTP1", "true") t.Setenv("STACKROX_MCP__GLOBAL__READ_ONLY_TOOLS", "false") t.Setenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED", "true") @@ -95,7 +112,9 @@ func TestLoadConfig_EnvVarOnly(t *testing.T) { require.NotNil(t, cfg) assert.Equal(t, "env.example.com:8443", cfg.Central.URL) - assert.True(t, cfg.Central.Insecure) + assert.Equal(t, AuthTypeStatic, cfg.Central.AuthType) + assert.Equal(t, "test-token", cfg.Central.APIToken) + assert.True(t, cfg.Central.InsecureSkipTLSVerify) assert.True(t, cfg.Central.ForceHTTP1) assert.False(t, cfg.Global.ReadOnlyTools) assert.True(t, cfg.Tools.Vulnerability.Enabled) @@ -111,7 +130,7 @@ func TestLoadConfig_Defaults(t *testing.T) { require.NotNil(t, cfg) assert.Equal(t, "central.stackrox:8443", cfg.Central.URL) - assert.False(t, cfg.Central.Insecure) + assert.False(t, cfg.Central.InsecureSkipTLSVerify) assert.False(t, cfg.Central.ForceHTTP1) assert.True(t, cfg.Global.ReadOnlyTools) assert.False(t, cfg.Tools.Vulnerability.Enabled) @@ -225,8 +244,8 @@ func TestValidate_InvalidServerPort(t *testing.T) { "port too high": {port: 65536}, } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { cfg := getDefaultConfig() cfg.Server.Port = tt.port @@ -248,3 +267,252 @@ func TestLoadConfig_ServerDefaults(t *testing.T) { assert.Equal(t, "localhost", cfg.Server.Address) assert.Equal(t, 8080, cfg.Server.Port) } + +func TestValidate_AuthType_Invalid(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.AuthType = "bad-type" + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "central.auth_type must be either passthrough or static") +} + +func TestValidate_AuthTypeStatic_RequiresAPIToken(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.AuthType = AuthTypeStatic + cfg.Central.APIToken = "" + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "central.api_token is required for") + assert.Contains(t, err.Error(), "static") +} + +func TestValidate_AuthTypePassthrough_ForbidsAPIToken(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.AuthType = AuthTypePassthrough + cfg.Central.APIToken = "some-token" + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "central.api_token can not be set for") + assert.Contains(t, err.Error(), "passthrough") +} + +func TestValidate_AuthTypePassthrough_Success(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.AuthType = AuthTypePassthrough + cfg.Central.APIToken = "" + + err := cfg.Validate() + assert.NoError(t, err) +} + +func TestValidate_RequestTimeout_MustBePositive(t *testing.T) { + tests := map[string]struct { + timeout time.Duration + }{ + "zero timeout": { + timeout: 0, + }, + "negative timeout": { + timeout: -1 * time.Second, + }, + } + + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.RequestTimeout = tt.timeout + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "central.request_timeout must be positive") + }) + } +} + +func TestValidate_RequestTimeout_PositiveValue(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.RequestTimeout = 10 * time.Second + + err := cfg.Validate() + assert.NoError(t, err) +} + +func TestValidate_MaxRetries_Range(t *testing.T) { + tests := map[string]struct { + maxRetries int + shouldFail bool + }{ + "negative retries": { + maxRetries: -1, + shouldFail: true, + }, + "zero retries": { + maxRetries: 0, + shouldFail: false, + }, + "valid retries": { + maxRetries: 5, + shouldFail: false, + }, + "max retries": { + maxRetries: 10, + shouldFail: false, + }, + "over max retries": { + maxRetries: 11, + shouldFail: true, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.MaxRetries = testCase.maxRetries + + err := cfg.Validate() + if testCase.shouldFail { + require.Error(t, err) + assert.Contains(t, err.Error(), "central.max_retries must be between 0 and 10") + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidate_InitialBackoff_MustBePositive(t *testing.T) { + tests := map[string]struct { + backoff time.Duration + }{ + "zero backoff": { + backoff: 0, + }, + "negative backoff": { + backoff: -1 * time.Second, + }, + } + + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.InitialBackoff = tt.backoff + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "central.initial_backoff must be positive") + }) + } +} + +func TestValidate_InitialBackoff_PositiveValue(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.InitialBackoff = 2 * time.Second + + err := cfg.Validate() + assert.NoError(t, err) +} + +func TestValidate_MaxBackoff_MustBePositive(t *testing.T) { + tests := map[string]struct { + backoff time.Duration + }{ + "zero backoff": { + backoff: 0, + }, + "negative backoff": { + backoff: -1 * time.Second, + }, + } + + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.MaxBackoff = tt.backoff + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "central.max_backoff must be positive") + }) + } +} + +func TestValidate_MaxBackoff_PositiveValue(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.MaxBackoff = 30 * time.Second + + err := cfg.Validate() + assert.NoError(t, err) +} + +func TestValidate_MaxBackoff_MustBeGreaterThanInitialBackoff(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.InitialBackoff = 10 * time.Second + cfg.Central.MaxBackoff = 5 * time.Second + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "central.max_backoff has to be greater than or equal to central.initial_backoff") +} + +func TestValidate_MaxBackoff_EqualToInitialBackoff(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.InitialBackoff = 5 * time.Second + cfg.Central.MaxBackoff = 5 * time.Second + + err := cfg.Validate() + assert.NoError(t, err) +} + +func TestLoadConfig_TimeoutAndRetryDefaults(t *testing.T) { + t.Setenv("STACKROX_MCP__TOOLS__CONFIG_MANAGER__ENABLED", "true") + + cfg, err := LoadConfig("") + require.NoError(t, err) + require.NotNil(t, cfg) + + assert.Equal(t, defaultRequestTimeout, cfg.Central.RequestTimeout) + assert.Equal(t, defaultMaxRetries, cfg.Central.MaxRetries) + assert.Equal(t, defaultInitialBackoff, cfg.Central.InitialBackoff) + assert.Equal(t, defaultMaxBackoff, cfg.Central.MaxBackoff) +} + +func TestConfig_Redacted(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.APIToken = "super-secret-token" + + redactedConfig := cfg.Redacted() + + // Verify sensitive data is redacted. + assert.Equal(t, "***REDACTED***", redactedConfig.Central.APIToken) + + // Verify non-sensitive data is preserved. + assert.Equal(t, cfg.Central.URL, redactedConfig.Central.URL) + assert.Equal(t, cfg.Central.AuthType, redactedConfig.Central.AuthType) + assert.Equal(t, cfg.Central.InsecureSkipTLSVerify, redactedConfig.Central.InsecureSkipTLSVerify) + assert.Equal(t, cfg.Central.ForceHTTP1, redactedConfig.Central.ForceHTTP1) + assert.Equal(t, cfg.Central.RequestTimeout, redactedConfig.Central.RequestTimeout) + assert.Equal(t, cfg.Central.MaxRetries, redactedConfig.Central.MaxRetries) + assert.Equal(t, cfg.Central.InitialBackoff, redactedConfig.Central.InitialBackoff) + assert.Equal(t, cfg.Central.MaxBackoff, redactedConfig.Central.MaxBackoff) + + // Verify other config sections are preserved. + assert.Equal(t, cfg.Global, redactedConfig.Global) + assert.Equal(t, cfg.Server, redactedConfig.Server) + assert.Equal(t, cfg.Tools, redactedConfig.Tools) + + // Verify original config is unchanged. + assert.Equal(t, "super-secret-token", cfg.Central.APIToken) +} + +func TestConfig_Redacted_EmptyToken(t *testing.T) { + cfg := getDefaultConfig() + cfg.Central.APIToken = "" + + redactedConfig := cfg.Redacted() + + // Empty token should remain empty, not be replaced with redacted marker. + assert.Empty(t, redactedConfig.Central.APIToken) +} diff --git a/internal/server/server.go b/internal/server/server.go index 3d581b3..f522210 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -16,7 +16,9 @@ import ( ) const ( - shutdownTimeout = 5 * time.Second + // ShutdownTimeout represents allowed timeout for graceful shutdown to finish. + ShutdownTimeout = 5 * time.Second + readHeaderTimeout = 5 * time.Second ) @@ -71,7 +73,7 @@ func (s *Server) Start(ctx context.Context) error { errChan := make(chan error, 1) go func() { - if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if err := httpServer.ListenAndServe(); err != nil && errors.Is(err, http.ErrServerClosed) { errChan <- errors.Wrap(err, "HTTP server error") } }() @@ -81,7 +83,7 @@ func (s *Server) Start(ctx context.Context) error { case <-ctx.Done(): slog.Info("Shutting down HTTP server") // Create a context with timeout for graceful shutdown. - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout) + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), ShutdownTimeout) defer shutdownCancel() //nolint:contextcheck return errors.Wrap(httpServer.Shutdown(shutdownCtx), "server shutting down failed") diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 199f004..ca11dfa 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -162,7 +162,7 @@ func TestServer_Start(t *testing.T) { if err != nil && errors.Is(err, context.Canceled) { t.Errorf("Server returned unexpected error: %v", err) } - case <-time.After(shutdownTimeout): + case <-time.After(ShutdownTimeout): t.Fatal("Server did not shut down within timeout period") } } diff --git a/internal/toolsets/config/tools.go b/internal/toolsets/config/tools.go index ab7c95b..2b0c45f 100644 --- a/internal/toolsets/config/tools.go +++ b/internal/toolsets/config/tools.go @@ -2,9 +2,13 @@ package config import ( "context" - "errors" + "fmt" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/pkg/errors" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/client/auth" "github.com/stackrox/stackrox-mcp/internal/toolsets" ) @@ -18,13 +22,15 @@ type listClustersOutput struct { // listClustersTool implements the list_clusters tool. type listClustersTool struct { - name string + name string + client *client.Client } // NewListClustersTool creates a new list_clusters tool. -func NewListClustersTool() toolsets.Tool { +func NewListClustersTool(c *client.Client) toolsets.Tool { return &listClustersTool{ - name: "list_clusters", + name: "list_clusters", + client: c, } } @@ -53,9 +59,52 @@ func (t *listClustersTool) RegisterWith(server *mcp.Server) { // handle is the placeholder handler for list_clusters tool. func (t *listClustersTool) handle( - _ context.Context, - _ *mcp.CallToolRequest, + ctx context.Context, + req *mcp.CallToolRequest, _ listClustersInput, ) (*mcp.CallToolResult, *listClustersOutput, error) { - return nil, nil, errors.New("list_clusters tool is not yet implemented") + conn, err := t.client.ReadyConn(ctx) + if err != nil { + return nil, nil, errors.Wrap(err, "unable to connect to server") + } + + callCtx := auth.WithMCPRequestContext(ctx, req) + + // Create ClustersService client + clustersClient := v1.NewClustersServiceClient(conn) + + // Call GetClusters + resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{}) + if err != nil { + // Convert gRPC error to client error + clientErr := client.NewError(err, "GetClusters") + + return nil, nil, clientErr + } + + // Extract cluster information + clusters := make([]string, 0, len(resp.GetClusters())) + for _, cluster := range resp.GetClusters() { + // Format: "ID: , Name: , Type: " + clusterInfo := fmt.Sprintf("ID: %s, Name: %s, Type: %s", + cluster.GetId(), + cluster.GetName(), + cluster.GetType().String()) + clusters = append(clusters, clusterInfo) + } + + output := &listClustersOutput{ + Clusters: clusters, + } + + // Return result with text content + result := &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Found %d cluster(s)", len(clusters)), + }, + }, + } + + return result, output, nil } diff --git a/internal/toolsets/config/tools_test.go b/internal/toolsets/config/tools_test.go index 9712d95..f6475ca 100644 --- a/internal/toolsets/config/tools_test.go +++ b/internal/toolsets/config/tools_test.go @@ -4,25 +4,26 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewListClustersTool(t *testing.T) { - tool := NewListClustersTool() + tool := NewListClustersTool(&client.Client{}) require.NotNil(t, tool) assert.Equal(t, "list_clusters", tool.GetName()) } func TestListClustersTool_IsReadOnly(t *testing.T) { - tool := NewListClustersTool() + tool := NewListClustersTool(&client.Client{}) assert.True(t, tool.IsReadOnly(), "list_clusters should be read-only") } func TestListClustersTool_GetTool(t *testing.T) { - tool := NewListClustersTool() + tool := NewListClustersTool(&client.Client{}) mcpTool := tool.GetTool() @@ -32,7 +33,7 @@ func TestListClustersTool_GetTool(t *testing.T) { } func TestListClustersTool_RegisterWith(t *testing.T) { - tool := NewListClustersTool() + tool := NewListClustersTool(&client.Client{}) server := mcp.NewServer( &mcp.Implementation{ Name: "test-server", diff --git a/internal/toolsets/config/toolset.go b/internal/toolsets/config/toolset.go index 4c5ce03..092873a 100644 --- a/internal/toolsets/config/toolset.go +++ b/internal/toolsets/config/toolset.go @@ -2,6 +2,7 @@ package config import ( + "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stackrox/stackrox-mcp/internal/toolsets" ) @@ -13,11 +14,11 @@ type Toolset struct { } // NewToolset creates a new config management toolset. -func NewToolset(cfg *config.Config) *Toolset { +func NewToolset(cfg *config.Config, c *client.Client) *Toolset { return &Toolset{ cfg: cfg, tools: []toolsets.Tool{ - NewListClustersTool(), + NewListClustersTool(c), }, } } diff --git a/internal/toolsets/config/toolset_test.go b/internal/toolsets/config/toolset_test.go index b062e23..f12b3d2 100644 --- a/internal/toolsets/config/toolset_test.go +++ b/internal/toolsets/config/toolset_test.go @@ -3,13 +3,14 @@ package config import ( "testing" + "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewToolset(t *testing.T) { - toolset := NewToolset(&config.Config{}) + toolset := NewToolset(&config.Config{}, &client.Client{}) require.NotNil(t, toolset) assert.Equal(t, "config_manager", toolset.GetName()) } @@ -23,7 +24,7 @@ func TestToolset_IsEnabled_True(t *testing.T) { }, } - toolset := NewToolset(cfg) + toolset := NewToolset(cfg, &client.Client{}) assert.True(t, toolset.IsEnabled()) tools := toolset.GetTools() @@ -41,7 +42,7 @@ func TestToolset_IsEnabled_False(t *testing.T) { }, } - toolset := NewToolset(cfg) + toolset := NewToolset(cfg, &client.Client{}) assert.False(t, toolset.IsEnabled()) tools := toolset.GetTools()