diff --git a/router-tests/modules/router_on_request_test.go b/router-tests/modules/router_on_request_test.go index 371b0b5cc9..eb8197cf96 100644 --- a/router-tests/modules/router_on_request_test.go +++ b/router-tests/modules/router_on_request_test.go @@ -1,11 +1,16 @@ package module_test import ( + "context" "encoding/json" - "github.com/wundergraph/cosmo/router-tests/modules/router-on-request" - "go.uber.org/zap/zapcore" "net/http" + "sync/atomic" "testing" + "time" + + "github.com/hasura/go-graphql-client" + router_on_request "github.com/wundergraph/cosmo/router-tests/modules/router-on-request" + "go.uber.org/zap/zapcore" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -101,4 +106,113 @@ func TestRouterOnRequestHook(t *testing.T) { assert.Len(t, retryRequestLog.All(), 2) }) }) + + t.Run("Test RouterOnRequest hook is called with subscriptions over sse", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "routerOnRequestModule": router_on_request.RouterOnRequestModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&router_on_request.RouterOnRequestModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var counter atomic.Uint32 + + go xEnv.GraphQLSubscriptionOverSSE(ctx, testenv.GraphQLRequest{ + OperationName: []byte(`CurrentTime`), + Query: `subscription CurrentTime { currentTime { unixTime timeStamp }}`, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Accept": {"text/event-stream"}, + "Connection": {"keep-alive"}, + "Cache-Control": {"no-cache"}, + }, + }, func(data string) { + counter.Add(1) + }) + + require.Eventually(t, func() bool { + return counter.Load() > 0 + }, time.Second*5, time.Millisecond*100) + + requestLog := xEnv.Observer().FilterMessage("RouterOnRequest Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test RouterOnRequest hook is called with subscriptions over ws", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "routerOnRequestModule": router_on_request.RouterOnRequestModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&router_on_request.RouterOnRequestModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + var counter atomic.Uint32 + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + var subscriptionOne struct { + currentTime struct { + unixTime float64 `graphql:"unixTime"` + timeStamp float64 `graphql:"timeStamp"` + } `graphql:"currentTime"` + } + + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + counter.Add(1) + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + require.Eventually(t, func() bool { + return counter.Load() > 0 + }, time.Second*5, time.Millisecond*100) + + err = client.Unsubscribe(subscriptionOneID) + require.NoError(t, err) + + // Close the client + client.Close() + err = <-clientRunCh + require.NoError(t, err) + + requestLog := xEnv.Observer().FilterMessage("RouterOnRequest Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) }