Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions router-tests/graphql_over_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -387,4 +388,54 @@ func TestSubscriptionOverGET(t *testing.T) {
wg.Wait()
})
})

t.Run("subscription over sse send heartbeat", func(t *testing.T) {
t.Parallel()

type currentTimePayload struct {
Data struct {
CurrentTime struct {
UnixTime float64 `json:"unixTime"`
Timestamp string `json:"timestamp"`
} `json:"currentTime"`
} `json:"data"`
}

twentySecs := 20 * time.Second
config := config.TrafficShapingRules{
All: config.GlobalSubgraphRequestRule{
KeepAliveProbeInterval: &twentySecs,
},
}

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithSubgraphTransportOptions(core.NewSubgraphTransportOptions(config)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

var checkHeartbeat atomic.Uint32

go xEnv.GraphQLSubscriptionOverSSERaw(ctx, testenv.GraphQLRequest{
OperationName: []byte(`CountEmp`),
Query: `subscription CountEmp { countEmp(max: 1, intervalMilliseconds: 6000) }`,
Header: map[string][]string{
"Content-Type": {"application/json"},
"Accept": {"text/event-stream"},
"Connection": {"keep-alive"},
"Cache-Control": {"no-cache"},
},
}, func(data string) {
if data == ":heartbeat\n" {
checkHeartbeat.Add(1)
}
})

require.Eventually(t, func() bool {
return checkHeartbeat.Load() > 0
}, 20*time.Second, 6*time.Second, "did not receive any heartbeat")
})
})
}
37 changes: 32 additions & 5 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,31 @@ func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query u
return conn
}

func (e *Environment) GraphQLSubscriptionOverSSERaw(ctx context.Context, request GraphQLRequest, handler func(data string)) {
req, err := e.newGraphQLRequestOverGET(e.GraphQLRequestURL(), request)
if err != nil {
e.t.Fatalf("could not create request: %s", err)
}

resp, err := e.RouterClient.Do(req)
if err != nil {
e.t.Fatalf("could not make request: %s", err)
}
defer resp.Body.Close()

require.Equal(e.t, "text/event-stream", resp.Header.Get("Content-Type"))
require.Equal(e.t, "no-cache", resp.Header.Get("Cache-Control"))
require.Equal(e.t, "keep-alive", resp.Header.Get("Connection"))
require.Equal(e.t, "no", resp.Header.Get("X-Accel-Buffering"))

// Check for the correct response status code
if resp.StatusCode != http.StatusOK {
e.t.Fatalf("expected status code 200, got %d", resp.StatusCode)
}

e.ReadSSE(ctx, resp.Body, true, handler)
}

func (e *Environment) GraphQLSubscriptionOverSSE(ctx context.Context, request GraphQLRequest, handler func(data string)) {
req, err := e.newGraphQLRequestOverGET(e.GraphQLRequestURL(), request)
if err != nil {
Expand All @@ -2378,7 +2403,7 @@ func (e *Environment) GraphQLSubscriptionOverSSE(ctx context.Context, request Gr
e.t.Fatalf("expected status code 200, got %d", resp.StatusCode)
}

e.ReadSSE(ctx, resp.Body, handler)
e.ReadSSE(ctx, resp.Body, false, handler)
}

func (e *Environment) GraphQLSubscriptionOverSSEWithQueryParam(ctx context.Context, request GraphQLRequest, handler func(data string)) {
Expand All @@ -2403,10 +2428,10 @@ func (e *Environment) GraphQLSubscriptionOverSSEWithQueryParam(ctx context.Conte
e.t.Fatalf("expected status code 200, got %d", resp.StatusCode)
}

e.ReadSSE(ctx, resp.Body, handler)
e.ReadSSE(ctx, resp.Body, false, handler)
}

func (e *Environment) ReadSSE(ctx context.Context, body io.ReadCloser, handler func(data string)) {
func (e *Environment) ReadSSE(ctx context.Context, body io.ReadCloser, raw bool, handler func(data string)) {
reader := bufio.NewReader(body)

// Process incoming events
Expand All @@ -2425,8 +2450,10 @@ func (e *Environment) ReadSSE(ctx context.Context, body io.ReadCloser, handler f
return
}

// SSE lines typically start with "event", "data", etc.
if strings.HasPrefix(line, "data: ") {
if raw {
handler(line)
} else if strings.HasPrefix(line, "data: ") {
// SSE lines typically start with "event", "data", etc.
data := strings.TrimPrefix(line, "data: ")
handler(data)
}
Expand Down
Loading