diff --git a/.changes/unreleased/Minor-20251117-145956.yaml b/.changes/unreleased/Minor-20251117-145956.yaml new file mode 100644 index 00000000..3c054182 --- /dev/null +++ b/.changes/unreleased/Minor-20251117-145956.yaml @@ -0,0 +1,3 @@ +kind: Minor +body: The server now performs comprehensive pre-flight checks to verify your environment, including Neo4j connection, query capabilities, APOC installation, and will gracefully start without GDS-specific tools if the GDS library is not found. +time: 2025-11-17T14:59:56.826192Z diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 84964ece..7cf69a16 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -162,10 +162,12 @@ func MyToolHandler(deps *ToolDependencies) mcp.ToolHandler { ) } ``` - **Note:** WithReadOnlyHintAnnotation marks a tool with a read-only hint is used for filtering. - When set to true, the tool will be considered read-only and included when selecting - tools for read-only mode. If the annotation is not present or set to false, - the tool is treated as a write-capable tool (i.e., not considered read-only). + + **Note:** WithReadOnlyHintAnnotation marks a tool with a read-only hint is used for filtering. + When set to true, the tool will be considered read-only and included when selecting + tools for read-only mode. If the annotation is not present or set to false, + the tool is treated as a write-capable tool (i.e., not considered read-only). + 2. **Implement tool handler**: ```go @@ -180,8 +182,12 @@ func MyToolHandler(deps *ToolDependencies) mcp.ToolHandler { ```go { - Tool: NewMyToolSpec(), - Handler: NewMyToolHandler(deps), + category: cypherCategory, + definition: server.ServerTool{ + Tool: cypher.GetSchemaSpec(), + Handler: cypher.GetSchemaHandler(deps), + }, + readonly: true, }, ``` diff --git a/README.md b/README.md index d364085d..36dec5e6 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,20 @@ BETA - Active development; not yet suitable for production. - APOC plugin installed in the Neo4j instance. - Any MCP-compatible client (e.g. [VSCode](https://code.visualstudio.com/) with [MCP support](https://code.visualstudio.com/docs/copilot/customization/mcp-servers)) +## Startup Checks & Adaptive Operation + +The server performs several pre-flight checks at startup to ensure your environment is correctly configured. + +**Mandatory Requirements** +The server verifies the following core requirements. If any of these checks fail (e.g., due to an invalid configuration, incorrect credentials, or a missing APOC installation), the server will not start: + +- A valid connection to your Neo4j instance. +- The ability to execute queries. +- The presence of the APOC plugin. + +**Optional Requirements** +If an optional dependency is missing, the server will start in an adaptive mode. For instance, if the Graph Data Science (GDS) library is not detected in your Neo4j installation, the server will still launch but will automatically disable all GDS-related tools, such as `list-gds-procedures`. All other tools will remain available. + ## Installation (Binary) Releases: https://github.com/neo4j/mcp/releases diff --git a/cmd/neo4j-mcp/main.go b/cmd/neo4j-mcp/main.go index 041608b9..6129cf8c 100644 --- a/cmd/neo4j-mcp/main.go +++ b/cmd/neo4j-mcp/main.go @@ -41,12 +41,6 @@ func main() { } }() - // Verify database connectivity - if err := driver.VerifyConnectivity(ctx); err != nil { - log.Printf("Failed to verify database connectivity: %v", err) - return - } - // Create database service dbService, err := database.NewNeo4jService(driver, cfg.Database) if err != nil { diff --git a/internal/database/interfaces.go b/internal/database/interfaces.go index 189ff239..2ef41177 100644 --- a/internal/database/interfaces.go +++ b/internal/database/interfaces.go @@ -27,8 +27,13 @@ type RecordFormatter interface { Neo4jRecordsToJSON(records []*neo4j.Record) (string, error) } +type Helpers interface { + VerifyConnectivity(ctx context.Context) error +} + // Service combines query execution and record formatting type Service interface { QueryExecutor RecordFormatter + Helpers } diff --git a/internal/database/mocks/mock_database.go b/internal/database/mocks/mock_database.go index d096e337..ac06b0ef 100644 --- a/internal/database/mocks/mock_database.go +++ b/internal/database/mocks/mock_database.go @@ -100,3 +100,17 @@ func (mr *MockServiceMockRecorder) Neo4jRecordsToJSON(records any) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Neo4jRecordsToJSON", reflect.TypeOf((*MockService)(nil).Neo4jRecordsToJSON), records) } + +// VerifyConnectivity mocks base method. +func (m *MockService) VerifyConnectivity(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VerifyConnectivity", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// VerifyConnectivity indicates an expected call of VerifyConnectivity. +func (mr *MockServiceMockRecorder) VerifyConnectivity(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyConnectivity", reflect.TypeOf((*MockService)(nil).VerifyConnectivity), ctx) +} diff --git a/internal/database/service.go b/internal/database/service.go index 57a28a33..0595b4ca 100644 --- a/internal/database/service.go +++ b/internal/database/service.go @@ -28,9 +28,18 @@ func NewNeo4jService(driver neo4j.DriverWithContext, database string) (*Neo4jSer }, nil } +// VerifyConnectivity checks the driver can establish a valid connection with a Neo4j instance; +func (s *Neo4jService) VerifyConnectivity(ctx context.Context) error { + // Verify database connectivity + if err := s.driver.VerifyConnectivity(ctx); err != nil { + log.Printf("Failed to verify database connectivity: %s", err.Error()) + return err + } + return nil +} + // ExecuteReadQuery executes a read-only Cypher query and returns raw records func (s *Neo4jService) ExecuteReadQuery(ctx context.Context, cypher string, params map[string]any) ([]*neo4j.Record, error) { - res, err := neo4j.ExecuteQuery(ctx, s.driver, cypher, params, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithDatabase(s.database), neo4j.ExecuteQueryWithReadersRouting()) if err != nil { wrappedErr := fmt.Errorf("failed to execute read query: %w", err) diff --git a/internal/server/server.go b/internal/server/server.go index 45955e42..22acb1e7 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "log" @@ -12,11 +13,12 @@ import ( // Neo4jMCPServer represents the MCP server instance type Neo4jMCPServer struct { - MCPServer *server.MCPServer - config *config.Config - dbService database.Service - version string - anService analytics.Service + MCPServer *server.MCPServer + config *config.Config + dbService database.Service + version string + anService analytics.Service + gdsInstalled bool } // NewNeo4jMCPServer creates a new MCP server instance @@ -31,23 +33,28 @@ func NewNeo4jMCPServer(version string, cfg *config.Config, dbService database.Se ) return &Neo4jMCPServer{ - MCPServer: mcpServer, - config: cfg, - dbService: dbService, - version: version, - anService: anService, + MCPServer: mcpServer, + config: cfg, + dbService: dbService, + version: version, + anService: anService, + gdsInstalled: false, } } // Start initializes and starts the MCP server using stdio transport func (s *Neo4jMCPServer) Start() error { log.Println("Starting Neo4j MCP Server...") + err := s.verifyRequirements() + if err != nil { + return err + } // track startup event s.anService.EmitEvent(s.anService.NewStartupEvent()) // Register tools - if err := s.RegisterTools(); err != nil { + if err := s.registerTools(); err != nil { return fmt.Errorf("failed to register tools: %w", err) } log.Println("Started Neo4j MCP Server. Now listening for input...") @@ -55,6 +62,62 @@ func (s *Neo4jMCPServer) Start() error { return server.ServeStdio(s.MCPServer) } +// verifyRequirements check the Neo4j requirements: +// - A valid connection with a Neo4j instance. +// - The ability to perform a read query (database name is correctly defined). +// - Required plugin installed: APOC (specifically apoc.meta.schema as it's used for get-schema) +// - In case GDS is not installed a flag is set in the server and tools will be registered accordingly +func (s *Neo4jMCPServer) verifyRequirements() error { + err := s.dbService.VerifyConnectivity(context.Background()) + if err != nil { + return fmt.Errorf("impossible to verify connectivity with the Neo4j instance: %w", err) + } + // Perform a dummy query to verify correctness of the connection, VerifyConnectivity is not exhaustive. + records, err := s.dbService.ExecuteReadQuery(context.Background(), "RETURN 1 as first", map[string]any{}) + + if err != nil { + return fmt.Errorf("impossible to verify connectivity with the Neo4j instance: %w", err) + } + if len(records) != 1 || len(records[0].Values) != 1 { + return fmt.Errorf("failed to verify connectivity with the Neo4j instance: unexpected response from test query") + } + one, ok := records[0].Values[0].(int64) + if !ok || one != 1 { + return fmt.Errorf("failed to verify connectivity with the Neo4j instance: unexpected response from test query") + } + // Check for apoc.meta.schema procedure + checkApocMetaSchemaQuery := "SHOW PROCEDURES YIELD name WHERE name = 'apoc.meta.schema' RETURN count(name) > 0 AS apocMetaSchemaAvailable" + + // Check for apoc.meta.schema availability + records, err = s.dbService.ExecuteReadQuery(context.Background(), checkApocMetaSchemaQuery, nil) + if err != nil { + return fmt.Errorf("failed to check for APOC availability: %w", err) + } + if len(records) != 1 || len(records[0].Values) != 1 { + return fmt.Errorf("failed to verify APOC availability: unexpected response from test query") + } + apocMetaSchemaAvailable, ok := records[0].Values[0].(bool) + if !ok || !apocMetaSchemaAvailable { + return fmt.Errorf("please ensure the APOC plugin is installed and includes the 'meta' component") + } + // Call gds.version procedure to determine if GDS is installed + records, err = s.dbService.ExecuteReadQuery(context.Background(), "RETURN gds.version() as gdsVersion", nil) + if err != nil { + // GDS is optional, so we log a warning and continue, assuming it's not installed. + log.Print("Impossible to verify GDS installation.") + s.gdsInstalled = false + return nil + } + if len(records) == 1 && len(records[0].Values) == 1 { + _, ok := records[0].Values[0].(string) + if ok { + s.gdsInstalled = true + } + } + + return nil +} + // Stop gracefully stops the server func (s *Neo4jMCPServer) Stop() error { log.Println("Stopping Neo4j MCP Server...") diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 11570585..abfecb1f 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,12 +1,14 @@ package server_test import ( + "fmt" "testing" analytics "github.com/neo4j/mcp/internal/analytics/mocks" "github.com/neo4j/mcp/internal/config" - db_mock "github.com/neo4j/mcp/internal/database/mocks" + db "github.com/neo4j/mcp/internal/database/mocks" "github.com/neo4j/mcp/internal/server" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" "go.uber.org/mock/gomock" ) @@ -20,32 +22,75 @@ func TestNewNeo4jMCPServer(t *testing.T) { Password: "password", Database: "neo4j", } - - mockDB := db_mock.NewMockService(ctrl) analyticsService := analytics.NewMockService(ctrl) analyticsService.EXPECT().EmitEvent(gomock.Any()).AnyTimes() analyticsService.EXPECT().NewStartupEvent().AnyTimes() - t.Run("creates server successfully", func(t *testing.T) { + + t.Run("starts server successfully", func(t *testing.T) { + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).Times(1) + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), "RETURN 1 as first", gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"first"}, + Values: []any{ + int64(1), + }, + }, + }, nil) + checkApocMetaSchemaQuery := "SHOW PROCEDURES YIELD name WHERE name = 'apoc.meta.schema' RETURN count(name) > 0 AS apocMetaSchemaAvailable" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), checkApocMetaSchemaQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"apocMetaSchemaAvailable"}, + Values: []any{ + bool(true), + }, + }, + }, nil) + gdsVersionQuery := "RETURN gds.version() as gdsVersion" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), gdsVersionQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"gdsVersion"}, + Values: []any{ + string("2.22.0"), + }, + }, + }, nil) + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) if s == nil { t.Errorf("NewNeo4jMCPServer() expected non-nil server, got nil") } + + err := s.Start() + if err != nil { + t.Errorf("Start() unexpected error = %v", err) + } }) - t.Run("starts server successfully", func(t *testing.T) { + t.Run("starts server should fails when no connection can be established", func(t *testing.T) { + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).Times(1).Return(fmt.Errorf("connection error")) s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) if s == nil { t.Errorf("NewNeo4jMCPServer() expected non-nil server, got nil") } + err := s.Start() - if err != nil { - t.Errorf("Start() unexpected error = %v", err) + if err == nil { + t.Errorf("Start() expected an error, got nil") } }) - - t.Run("stops server successfully", func(t *testing.T) { + t.Run("starts server should fail when test query returns unexpected result", func(t *testing.T) { + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).Return(nil) + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), "RETURN 1 as first", gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"first"}, + Values: []any{int64(2)}, // Return a value other than 1 + }, + }, nil) s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) if s == nil { @@ -53,26 +98,49 @@ func TestNewNeo4jMCPServer(t *testing.T) { } err := s.Start() - if err != nil { - t.Errorf("Start() unexpected error = %v", err) + if err == nil { + t.Errorf("Start() expected an error for unexpected query result, got nil") } }) t.Run("server creates successfully with all required components", func(t *testing.T) { + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).Times(1) + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), "RETURN 1 as first", gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"first"}, + Values: []any{ + int64(1), + }, + }, + }, nil) + checkApocMetaSchemaQuery := "SHOW PROCEDURES YIELD name WHERE name = 'apoc.meta.schema' RETURN count(name) > 0 AS apocMetaSchemaAvailable" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), checkApocMetaSchemaQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"apocMetaSchemaAvailable"}, + Values: []any{ + bool(true), + }, + }, + }, nil) + gdsVersionQuery := "RETURN gds.version() as gdsVersion" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), gdsVersionQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"gdsVersion"}, + Values: []any{ + string("2.22.0"), + }, + }, + }, nil) + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) if s == nil { t.Fatal("NewNeo4jMCPServer() returned nil") } - // Register tools should work without errors - err := s.RegisterTools() - if err != nil { - t.Errorf("RegisterTools() unexpected error = %v", err) - } - // Start should work without errors - err = s.Start() + err := s.Start() if err != nil { t.Errorf("Start() unexpected error = %v", err) } @@ -84,6 +152,82 @@ func TestNewNeo4jMCPServer(t *testing.T) { } }) + t.Run("starts server successfully if GDS is not found", func(t *testing.T) { + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).Times(1) + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), "RETURN 1 as first", gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"first"}, + Values: []any{ + int64(1), + }, + }, + }, nil) + checkApocMetaSchemaQuery := "SHOW PROCEDURES YIELD name WHERE name = 'apoc.meta.schema' RETURN count(name) > 0 AS apocMetaSchemaAvailable" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), checkApocMetaSchemaQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"apocMetaSchemaAvailable"}, + Values: []any{ + bool(true), + }, + }, + }, nil) + gdsVersionQuery := "RETURN gds.version() as gdsVersion" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), gdsVersionQuery, gomock.Any()).Times(1).Return(nil, fmt.Errorf("Unknown function 'gds.version'")) + + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) + + if s == nil { + t.Errorf("NewNeo4jMCPServer() expected non-nil server, got nil") + } + err := s.Start() + if err != nil { + t.Errorf("Start() unexpected error = %v", err) + } + }) + + t.Run("stops server successfully", func(t *testing.T) { + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).Times(1) + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), "RETURN 1 as first", gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"first"}, + Values: []any{ + int64(1), + }, + }, + }, nil) + checkApocMetaSchemaQuery := "SHOW PROCEDURES YIELD name WHERE name = 'apoc.meta.schema' RETURN count(name) > 0 AS apocMetaSchemaAvailable" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), checkApocMetaSchemaQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"apocMetaSchemaAvailable"}, + Values: []any{ + bool(true), + }, + }, + }, nil) + gdsVersionQuery := "RETURN gds.version() as gdsVersion" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), gdsVersionQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"gdsVersion"}, + Values: []any{ + string("2.22.0"), + }, + }, + }, nil) + + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) + + if s == nil { + t.Errorf("NewNeo4jMCPServer() expected non-nil server, got nil") + } + + err := s.Start() + if err != nil { + t.Errorf("Start() unexpected error = %v", err) + } + }) + } func TestNewNeo4jMCPServerEvents(t *testing.T) { @@ -97,7 +241,34 @@ func TestNewNeo4jMCPServerEvents(t *testing.T) { Database: "neo4j", } - mockDB := db_mock.NewMockService(ctrl) + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).AnyTimes() + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), "RETURN 1 as first", gomock.Any()).AnyTimes().Return([]*neo4j.Record{ + { + Keys: []string{"first"}, + Values: []any{ + int64(1), + }, + }, + }, nil) + checkApocMetaSchemaQuery := "SHOW PROCEDURES YIELD name WHERE name = 'apoc.meta.schema' RETURN count(name) > 0 AS apocMetaSchemaAvailable" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), checkApocMetaSchemaQuery, gomock.Any()).AnyTimes().Return([]*neo4j.Record{ + { + Keys: []string{"apocMetaSchemaAvailable"}, + Values: []any{ + bool(true), + }, + }, + }, nil) + gdsVersionQuery := "RETURN gds.version() as gdsVersion" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), gdsVersionQuery, gomock.Any()).AnyTimes().Return([]*neo4j.Record{ + { + Keys: []string{"gdsVersion"}, + Values: []any{ + string("2.22.0"), + }, + }, + }, nil) analyticsService := analytics.NewMockService(ctrl) t.Run("emits startup and OSInfoEvent and StartupEvent events on start", func(t *testing.T) { diff --git a/internal/server/tool_register_test.go b/internal/server/tool_register_test.go index 5aa1a7c9..1ece1e11 100644 --- a/internal/server/tool_register_test.go +++ b/internal/server/tool_register_test.go @@ -1,13 +1,14 @@ package server_test import ( + "fmt" "testing" - "github.com/neo4j/mcp/internal/analytics" - analytics_mock "github.com/neo4j/mcp/internal/analytics/mocks" + analytics "github.com/neo4j/mcp/internal/analytics/mocks" "github.com/neo4j/mcp/internal/config" - db_mock "github.com/neo4j/mcp/internal/database/mocks" + db "github.com/neo4j/mcp/internal/database/mocks" "github.com/neo4j/mcp/internal/server" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" "go.uber.org/mock/gomock" ) @@ -15,27 +16,28 @@ func TestToolRegister(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockDB := db_mock.NewMockService(ctrl) - mockClient := analytics_mock.NewMockHTTPClient(ctrl) - analyticsService := analytics.NewAnalyticsWithClient("test-token", "http://localhost", mockClient, "bolt://localhost:7687") + aService := analytics.NewMockService(ctrl) + aService.EXPECT().EmitEvent(gomock.Any()).AnyTimes() + aService.EXPECT().NewStartupEvent().AnyTimes() t.Run("verifies expected tools are registered", func(t *testing.T) { + mockDB := getMockedDBService(ctrl, true) cfg := &config.Config{ URI: "bolt://test-host:7687", Username: "neo4j", Password: "password", Database: "neo4j", } - s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, aService) // Expected tools that should be registered // update this number when a tool is added or removed. expectedTotalToolsCount := 4 - // Register tools - err := s.RegisterTools() + // Start server and register tools + err := s.Start() if err != nil { - t.Fatalf("RegisterTools() failed: %v", err) + t.Fatalf("Start() failed: %v", err) } registeredTools := len(s.MCPServer.ListTools()) @@ -45,6 +47,7 @@ func TestToolRegister(t *testing.T) { }) t.Run("should register only readonly tools when readonly", func(t *testing.T) { + mockDB := getMockedDBService(ctrl, true) cfg := &config.Config{ URI: "bolt://test-host:7687", Username: "neo4j", @@ -52,16 +55,16 @@ func TestToolRegister(t *testing.T) { Database: "neo4j", ReadOnly: "true", } - s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, aService) // Expected tools that should be registered // update this number when a tool is added or removed. expectedTotalToolsCount := 3 - // Register tools - err := s.RegisterTools() + // Start server and register tools + err := s.Start() if err != nil { - t.Fatalf("RegisterTools() failed: %v", err) + t.Fatalf("Start() failed: %v", err) } registeredTools := len(s.MCPServer.ListTools()) @@ -69,7 +72,8 @@ func TestToolRegister(t *testing.T) { t.Errorf("Expected %d tools, but test configuration shows %d", expectedTotalToolsCount, registeredTools) } }) - t.Run("should not register only readonly tools when readonly is set to false", func(t *testing.T) { + t.Run("should register also not write tools when readonly is set to false", func(t *testing.T) { + mockDB := getMockedDBService(ctrl, true) cfg := &config.Config{ URI: "bolt://test-host:7687", Username: "neo4j", @@ -77,16 +81,16 @@ func TestToolRegister(t *testing.T) { Database: "neo4j", ReadOnly: "false", } - s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, analyticsService) + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, aService) // Expected tools that should be registered // update this number when a tool is added or removed. expectedTotalToolsCount := 4 - // Register tools - err := s.RegisterTools() + // Start server and register tools + err := s.Start() if err != nil { - t.Fatalf("RegisterTools() failed: %v", err) + t.Fatalf("Start() failed: %v", err) } registeredTools := len(s.MCPServer.ListTools()) @@ -94,4 +98,69 @@ func TestToolRegister(t *testing.T) { t.Errorf("Expected %d tools, but test configuration shows %d", expectedTotalToolsCount, registeredTools) } }) + + t.Run("should remove GDS tools if GDS is not present", func(t *testing.T) { + mockDB := getMockedDBService(ctrl, false) + cfg := &config.Config{ + URI: "bolt://test-host:7687", + Username: "neo4j", + Password: "password", + Database: "neo4j", + ReadOnly: "false", + } + s := server.NewNeo4jMCPServer("test-version", cfg, mockDB, aService) + + // Expected tools that should be registered + // update this number when a tool is added or removed. + expectedTotalToolsCount := 3 + + // Start server and register tools + err := s.Start() + if err != nil { + t.Fatalf("Start() failed: %v", err) + } + registeredTools := len(s.MCPServer.ListTools()) + + if expectedTotalToolsCount != registeredTools { + t.Errorf("Expected %d tools, but test configuration shows %d", expectedTotalToolsCount, registeredTools) + } + }) +} + +// utility to mock the invocation required by VerifyRequirements +func getMockedDBService(ctrl *gomock.Controller, withGDS bool) *db.MockService { + mockDB := db.NewMockService(ctrl) + mockDB.EXPECT().VerifyConnectivity(gomock.Any()).Times(1) + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), "RETURN 1 as first", gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"first"}, + Values: []any{ + int64(1), + }, + }, + }, nil) + checkApocMetaSchemaQuery := "SHOW PROCEDURES YIELD name WHERE name = 'apoc.meta.schema' RETURN count(name) > 0 AS apocMetaSchemaAvailable" + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), checkApocMetaSchemaQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"apocMetaSchemaAvailable"}, + Values: []any{ + bool(true), + }, + }, + }, nil) + gdsVersionQuery := "RETURN gds.version() as gdsVersion" + if withGDS { + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), gdsVersionQuery, gomock.Any()).Times(1).Return([]*neo4j.Record{ + { + Keys: []string{"gdsVersion"}, + Values: []any{ + string("2.22.0"), + }, + }, + }, nil) + return mockDB + } + mockDB.EXPECT().ExecuteReadQuery(gomock.Any(), gdsVersionQuery, gomock.Any()).Times(1).Return(nil, fmt.Errorf("Unknown function 'gds.version'")) + + return mockDB } diff --git a/internal/server/tools_register.go b/internal/server/tools_register.go index c2169438..62711a25 100644 --- a/internal/server/tools_register.go +++ b/internal/server/tools_register.go @@ -7,55 +7,117 @@ import ( "github.com/neo4j/mcp/internal/tools/gds" ) -// RegisterTools registers all enabled MCP tools and adds them to the provided MCP server. +// registerTools registers all enabled MCP tools and adds them to the provided MCP server. // Tools are filtered according to the server configuration. For example, when the read-only // mode is enabled (e.g. via the NEO4J_READ_ONLY environment variable or the Config.ReadOnly flag), // any tool that performs state mutation will be excluded; only tools annotated as read-only will be registered. // Note: this read-only filtering relies on the tool annotation "readonly" (ReadOnlyHint). If the annotation // is not defined or is set to false, the tool will be added (i.e., only tools with readonly=true are filtered in read-only mode). -func (s *Neo4jMCPServer) RegisterTools() error { +func (s *Neo4jMCPServer) registerTools() error { + filteredTools := s.getEnabledTools() + s.MCPServer.AddTools(filteredTools...) + return nil +} + +type toolFilter func(tools []ToolDefinition) []ToolDefinition + +type toolCategory int + +const ( + cypherCategory toolCategory = 0 + gdsCategory toolCategory = 1 +) + +type ToolDefinition struct { + category toolCategory + definition server.ServerTool + readonly bool +} + +func (s *Neo4jMCPServer) getEnabledTools() []server.ServerTool { + filters := make([]toolFilter, 0) + + // If read-only mode is enabled, expose only tools annotated as read-only. + if s.config != nil && s.config.ReadOnly == "true" { + filters = append(filters, filterWriteTools) + } + // If GDS is not installed, disable GDS tools. + if !s.gdsInstalled { + filters = append(filters, filterGDSTools) + } deps := &tools.ToolDependencies{ DBService: s.dbService, AnalyticsService: s.anService, } + toolDefs := getAllToolsDefs(deps) + + for _, filter := range filters { + toolDefs = filter(toolDefs) + } + enabledTools := make([]server.ServerTool, 0) + for _, toolDef := range toolDefs { + enabledTools = append(enabledTools, toolDef.definition) + } + return enabledTools - all := getAllTools(deps) +} - // If read-only mode is enabled, expose only tools annotated as read-only. - if deps != nil && s.config != nil && s.config.ReadOnly == "true" { - readOnlyTools := make([]server.ServerTool, 0, len(all)) - for _, t := range all { - if t.Tool.Annotations.ReadOnlyHint != nil && *t.Tool.Annotations.ReadOnlyHint { - readOnlyTools = append(readOnlyTools, t) - } +func filterWriteTools(tools []ToolDefinition) []ToolDefinition { + readOnlyTools := make([]ToolDefinition, 0, len(tools)) + for _, t := range tools { + if t.readonly { + readOnlyTools = append(readOnlyTools, t) } - s.MCPServer.AddTools(readOnlyTools...) - return nil } + return readOnlyTools +} - s.MCPServer.AddTools(all...) - return nil +func filterGDSTools(tools []ToolDefinition) []ToolDefinition { + nonGDSTools := make([]ToolDefinition, 0, len(tools)) + for _, t := range tools { + if t.category != gdsCategory { + nonGDSTools = append(nonGDSTools, t) + } + } + return nonGDSTools } -// getAllTools returns all available tools with their specs and handlers -func getAllTools(deps *tools.ToolDependencies) []server.ServerTool { - return []server.ServerTool{ +// getAllToolsDefs returns all available tools with their specs and handlers +func getAllToolsDefs(deps *tools.ToolDependencies) []ToolDefinition { + + return []ToolDefinition{ { - Tool: cypher.GetSchemaSpec(), - Handler: cypher.GetSchemaHandler(deps), + category: cypherCategory, + definition: server.ServerTool{ + Tool: cypher.GetSchemaSpec(), + Handler: cypher.GetSchemaHandler(deps), + }, + readonly: true, }, { - Tool: cypher.ReadCypherSpec(), - Handler: cypher.ReadCypherHandler(deps), + category: cypherCategory, + definition: server.ServerTool{ + Tool: cypher.ReadCypherSpec(), + Handler: cypher.ReadCypherHandler(deps), + }, + readonly: true, }, { - Tool: cypher.WriteCypherSpec(), - Handler: cypher.WriteCypherHandler(deps), + category: cypherCategory, + definition: server.ServerTool{ + Tool: cypher.WriteCypherSpec(), + Handler: cypher.WriteCypherHandler(deps), + }, + readonly: false, }, // GDS Category/Section { - Tool: gds.ListGDSProceduresSpec(), - Handler: gds.ListGdsProceduresHandler(deps), + category: gdsCategory, + definition: server.ServerTool{ + Tool: gds.ListGDSProceduresSpec(), + Handler: gds.ListGdsProceduresHandler(deps), + }, + readonly: true, }, // Add other categories below... } diff --git a/test/integration/containerrunner/container_runner.go b/test/integration/containerrunner/container_runner.go index 073a8d05..dd87aada 100644 --- a/test/integration/containerrunner/container_runner.go +++ b/test/integration/containerrunner/container_runner.go @@ -19,6 +19,7 @@ import ( var ( container testcontainers.Container driver *neo4j.DriverWithContext + cfg *config.Config once sync.Once ) @@ -37,6 +38,17 @@ func GetDriver() *neo4j.DriverWithContext { return driver } +func GetDriverConf() *config.Config { + if cfg == nil { + log.Fatal("getDriverConf invoked before configuration is initialized.") + } + return &config.Config{ + URI: cfg.URI, + Username: cfg.Username, + Password: cfg.Password, + } +} + // startOnce start the testcontainer imaged func startOnce(ctx context.Context) { ctr, boltURI, err := createNeo4jContainer(ctx) @@ -45,7 +57,7 @@ func startOnce(ctx context.Context) { } container = ctr - cfg := &config.Config{ + cfg = &config.Config{ URI: boltURI, Username: config.GetEnvWithDefault("NEO4J_USERNAME", "neo4j"), Password: config.GetEnvWithDefault("NEO4J_PASSWORD", "password"), diff --git a/test/integration/dbservice/dbservice.go b/test/integration/dbservice/dbservice.go index 2d99ca33..1be1b266 100644 --- a/test/integration/dbservice/dbservice.go +++ b/test/integration/dbservice/dbservice.go @@ -11,31 +11,33 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j" ) -type DBService struct { +type dbService struct { driver *neo4j.DriverWithContext useContainer bool } -func NewDBService() *DBService { - return &DBService{ +func NewDBService() *dbService { + useContainer := config.GetEnvWithDefault("USE_CONTAINER", "true") == "true" + log.Printf("Testing using container: %t", useContainer) + return &dbService{ driver: nil, - useContainer: config.GetEnvWithDefault("USE_CONTAINER", "true") == "true", + useContainer: useContainer, } } -func (dbs *DBService) Start(ctx context.Context) { +func (dbs *dbService) Start(ctx context.Context) { if dbs.useContainer { containerrunner.Start(ctx) } } -func (dbs *DBService) Stop(ctx context.Context) { +func (dbs *dbService) Stop(ctx context.Context) { if dbs.useContainer { containerrunner.Close(ctx) } } -func (dbs *DBService) GetDriver() *neo4j.DriverWithContext { +func (dbs *dbService) GetDriver() *neo4j.DriverWithContext { if dbs.driver != nil { return dbs.driver } @@ -59,3 +61,17 @@ func (dbs *DBService) GetDriver() *neo4j.DriverWithContext { return dbs.driver } + +func (dbs *dbService) GetDriverConf() *config.Config { + if dbs.useContainer == true { + return containerrunner.GetDriverConf() + } + + cfg := &config.Config{ + URI: config.GetEnvWithDefault("NEO4J_URI", "bolt://localhost:7687"), + Username: config.GetEnvWithDefault("NEO4J_USERNAME", "neo4j"), + Password: config.GetEnvWithDefault("NEO4J_PASSWORD", "password"), + } + + return cfg +} diff --git a/test/integration/helpers/helpers.go b/test/integration/helpers/helpers.go index 42bb4420..4980ba60 100644 --- a/test/integration/helpers/helpers.go +++ b/test/integration/helpers/helpers.go @@ -13,7 +13,7 @@ import ( "github.com/google/uuid" "github.com/mark3labs/mcp-go/mcp" - analytics_mocks "github.com/neo4j/mcp/internal/analytics/mocks" + analytics "github.com/neo4j/mcp/internal/analytics/mocks" "github.com/neo4j/mcp/internal/database" "github.com/neo4j/mcp/internal/tools" "github.com/neo4j/neo4j-go-driver/v5/neo4j" @@ -30,12 +30,13 @@ func (ul UniqueLabel) String() string { // TestContext holds common test dependencies type TestContext struct { - ctx context.Context - t *testing.T - TestID string - Service database.Service - Deps *tools.ToolDependencies - createdLabels map[string]bool + ctx context.Context + t *testing.T + TestID string + Service database.Service + Deps *tools.ToolDependencies + createdLabels map[string]bool + AnalyticsService *analytics.MockService } // NewTestContext creates a new test context with automatic cleanup @@ -61,6 +62,7 @@ func NewTestContext(t *testing.T, driver *neo4j.DriverWithContext) *TestContext t.Fatalf("failed to create Neo4j service: %v", err) } analyticsService := getAnalyticsMock(t) + tc.AnalyticsService = analyticsService deps := &tools.ToolDependencies{DBService: databaseService, AnalyticsService: analyticsService} tc.Service = databaseService @@ -70,10 +72,10 @@ func NewTestContext(t *testing.T, driver *neo4j.DriverWithContext) *TestContext } // getAnalyticsMock is used to mock the analytics service, for integration test purpose. -func getAnalyticsMock(t *testing.T) *analytics_mocks.MockService { +func getAnalyticsMock(t *testing.T) *analytics.MockService { ctrl := gomock.NewController(t) defer ctrl.Finish() - analyticsService := analytics_mocks.NewMockService(ctrl) + analyticsService := analytics.NewMockService(ctrl) analyticsService.EXPECT().EmitEvent(gomock.Any()).AnyTimes() analyticsService.EXPECT().Disable().AnyTimes() analyticsService.EXPECT().Enable().AnyTimes() diff --git a/test/integration/server_test.go b/test/integration/server_test.go new file mode 100644 index 00000000..28ddc81a --- /dev/null +++ b/test/integration/server_test.go @@ -0,0 +1,156 @@ +//go:build integration + +package integration + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/neo4j/mcp/internal/config" + "github.com/neo4j/mcp/internal/database" + "github.com/neo4j/mcp/internal/server" + "github.com/neo4j/mcp/test/integration/helpers" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" +) + +func TestServerLifecycle(t *testing.T) { + t.Parallel() + testCFG := dbs.GetDriverConf() + testCases := []struct { + name string + config *config.Config + expectError bool + }{ + { + name: "Neo4jMCPServer should correctly start", + config: testCFG, + expectError: false, + }, + { + name: "Neo4jMCPServer should fail to start: invalid host", + config: &config.Config{ + URI: "bolt://not-a-valid-host:7687", + Username: testCFG.Username, + Password: testCFG.Password, + Database: testCFG.Database, + }, + expectError: true, + }, + { + name: "Neo4jMCPServer should fail to start: invalid database name", + config: &config.Config{ + URI: testCFG.URI, + Username: testCFG.Username, + Password: testCFG.Password, + Database: "not-a-valid-db-name", + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + driver, err := neo4j.NewDriverWithContext(tc.config.URI, neo4j.BasicAuth(tc.config.Username, tc.config.Password, "")) + if err != nil { + t.Fatalf("failed to create Neo4j driver: %s", err.Error()) + } + testContext := helpers.NewTestContext(t, &driver) + + ctx := context.Background() + defer func() { + if err := driver.Close(ctx); err != nil { + t.Fatalf("error closing driver: %s", err.Error()) + } + }() + + dbService, err := database.NewNeo4jService(driver, tc.config.Database) + if err != nil { + t.Fatalf("failed to create database service: %v", err) + return + } + + s := server.NewNeo4jMCPServer("test-version", tc.config, dbService, testContext.AnalyticsService) + + if s == nil { + t.Fatal("the NewNeo4jMCPServer() returned nil") + } + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + + var startErr error + go func() { + defer wg.Done() + startErr = s.Start() + }() + + for { + select { + case <-ctx.Done(): + if tc.expectError { + if startErr == nil { + t.Fatal("expected an error but got nil") + } + } else { + if startErr != nil { + t.Fatalf("Start returned an unexpected error: %s", startErr.Error()) + } + } + return + default: + time.Sleep(50 * time.Millisecond) + } + } + }) + } + + t.Run("server stop should return no errors", func(t *testing.T) { + + driver, err := neo4j.NewDriverWithContext(testCFG.URI, neo4j.BasicAuth(testCFG.Username, testCFG.Password, "")) + if err != nil { + t.Fatalf("failed to create Neo4j driver: %s", err.Error()) + } + testContext := helpers.NewTestContext(t, &driver) + ctx := context.Background() + defer func() { + if err := driver.Close(ctx); err != nil { + t.Fatalf("error closing driver: %s", err.Error()) + } + }() + + dbService, err := database.NewNeo4jService(driver, testCFG.Database) + if err != nil { + t.Fatalf("failed to create database service: %v", err) + } + + s := server.NewNeo4jMCPServer("test-version", testCFG, dbService, testContext.AnalyticsService) + if s == nil { + t.Fatal("NewNeo4jMCPServer() returned nil") + } + + var wg sync.WaitGroup + wg.Add(1) + + var startErr error + go func() { + defer wg.Done() + startErr = s.Start() + }() + + // Give the server a moment to start + time.Sleep(4 * time.Second) + + if startErr != nil { + t.Fatalf("Start() returned an unexpected error after stop: %v", startErr) + } + if err := s.Stop(); err != nil { + t.Fatalf("Stop() returned an unexpected error: %v", err) + } + }) +}