diff --git a/router-tests/graphql_over_get_test.go b/router-tests/graphql_over_get_test.go index 38b84eab65..0a7f6527e7 100644 --- a/router-tests/graphql_over_get_test.go +++ b/router-tests/graphql_over_get_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "sync" + "sync/atomic" "testing" "time" @@ -387,4 +388,54 @@ func TestSubscriptionOverGET(t *testing.T) { wg.Wait() }) }) + + t.Run("subscription over sse send heartbeat", func(t *testing.T) { + t.Parallel() + + type currentTimePayload struct { + Data struct { + CurrentTime struct { + UnixTime float64 `json:"unixTime"` + Timestamp string `json:"timestamp"` + } `json:"currentTime"` + } `json:"data"` + } + + twentySecs := 20 * time.Second + config := config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + KeepAliveProbeInterval: &twentySecs, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithSubgraphTransportOptions(core.NewSubgraphTransportOptions(config)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var checkHeartbeat atomic.Uint32 + + go xEnv.GraphQLSubscriptionOverSSERaw(ctx, testenv.GraphQLRequest{ + OperationName: []byte(`CountEmp`), + Query: `subscription CountEmp { countEmp(max: 1, intervalMilliseconds: 6000) }`, + Header: map[string][]string{ + "Content-Type": {"application/json"}, + "Accept": {"text/event-stream"}, + "Connection": {"keep-alive"}, + "Cache-Control": {"no-cache"}, + }, + }, func(data string) { + if data == ":heartbeat\n" { + checkHeartbeat.Add(1) + } + }) + + require.Eventually(t, func() bool { + return checkHeartbeat.Load() > 0 + }, 20*time.Second, 6*time.Second, "did not receive any heartbeat") + }) + }) } diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 0cba295266..1df5914fa2 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -2356,6 +2356,31 @@ func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query u return conn } +func (e *Environment) GraphQLSubscriptionOverSSERaw(ctx context.Context, request GraphQLRequest, handler func(data string)) { + req, err := e.newGraphQLRequestOverGET(e.GraphQLRequestURL(), request) + if err != nil { + e.t.Fatalf("could not create request: %s", err) + } + + resp, err := e.RouterClient.Do(req) + if err != nil { + e.t.Fatalf("could not make request: %s", err) + } + defer resp.Body.Close() + + require.Equal(e.t, "text/event-stream", resp.Header.Get("Content-Type")) + require.Equal(e.t, "no-cache", resp.Header.Get("Cache-Control")) + require.Equal(e.t, "keep-alive", resp.Header.Get("Connection")) + require.Equal(e.t, "no", resp.Header.Get("X-Accel-Buffering")) + + // Check for the correct response status code + if resp.StatusCode != http.StatusOK { + e.t.Fatalf("expected status code 200, got %d", resp.StatusCode) + } + + e.ReadSSE(ctx, resp.Body, true, handler) +} + func (e *Environment) GraphQLSubscriptionOverSSE(ctx context.Context, request GraphQLRequest, handler func(data string)) { req, err := e.newGraphQLRequestOverGET(e.GraphQLRequestURL(), request) if err != nil { @@ -2378,7 +2403,7 @@ func (e *Environment) GraphQLSubscriptionOverSSE(ctx context.Context, request Gr e.t.Fatalf("expected status code 200, got %d", resp.StatusCode) } - e.ReadSSE(ctx, resp.Body, handler) + e.ReadSSE(ctx, resp.Body, false, handler) } func (e *Environment) GraphQLSubscriptionOverSSEWithQueryParam(ctx context.Context, request GraphQLRequest, handler func(data string)) { @@ -2403,10 +2428,10 @@ func (e *Environment) GraphQLSubscriptionOverSSEWithQueryParam(ctx context.Conte e.t.Fatalf("expected status code 200, got %d", resp.StatusCode) } - e.ReadSSE(ctx, resp.Body, handler) + e.ReadSSE(ctx, resp.Body, false, handler) } -func (e *Environment) ReadSSE(ctx context.Context, body io.ReadCloser, handler func(data string)) { +func (e *Environment) ReadSSE(ctx context.Context, body io.ReadCloser, raw bool, handler func(data string)) { reader := bufio.NewReader(body) // Process incoming events @@ -2425,8 +2450,10 @@ func (e *Environment) ReadSSE(ctx context.Context, body io.ReadCloser, handler f return } - // SSE lines typically start with "event", "data", etc. - if strings.HasPrefix(line, "data: ") { + if raw { + handler(line) + } else if strings.HasPrefix(line, "data: ") { + // SSE lines typically start with "event", "data", etc. data := strings.TrimPrefix(line, "data: ") handler(data) }