diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ade8a59..7011cdd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: stable hooks: - id: black - language_version: python3.7 + language_version: python3.8 - repo: https://gitlab.com/pycqa/flake8 rev: "" hooks: diff --git a/python_graphql_client/graphql_client.py b/python_graphql_client/graphql_client.py index 2787971..a6164a5 100644 --- a/python_graphql_client/graphql_client.py +++ b/python_graphql_client/graphql_client.py @@ -47,7 +47,9 @@ def execute( ) result = requests.post( - self.endpoint, json=request_body, headers=self.__request_headers(headers), + self.endpoint, + json=request_body, + headers=self.__request_headers(headers), ) result.raise_for_status() @@ -92,7 +94,9 @@ async def subscribe( ) async with websockets.connect( - self.endpoint, subprotocols=["graphql-ws"] + self.endpoint, + subprotocols=["graphql-ws"], + extra_headers=self.__request_headers(headers), ) as websocket: await websocket.send(connection_init_message) await websocket.send(request_message) @@ -100,5 +104,7 @@ async def subscribe( response_body = json.loads(response_message) if response_body["type"] == "connection_ack": logging.info("the server accepted the connection") + elif response_body["type"] == "ka": + logging.info("the server sent a keep alive message") else: handle(response_body["payload"]) diff --git a/tests/test_graphql_client.py b/tests/test_graphql_client.py index a35980e..a3d408f 100644 --- a/tests/test_graphql_client.py +++ b/tests/test_graphql_client.py @@ -248,3 +248,62 @@ async def test_subscribe(self, mock_connect): call({"data": {"messageAdded": "two"}}), ] ) + + @patch("logging.info") + @patch("websockets.connect") + async def test_does_not_crash_with_keep_alive(self, mock_connect, mock_info): + """Subsribe a GraphQL subscription.""" + mock_websocket = mock_connect.return_value.__aenter__.return_value + mock_websocket.send = AsyncMock() + mock_websocket.__aiter__.return_value = [ + '{"type": "ka"}', + ] + + client = GraphqlClient(endpoint="ws://www.test-api.com/graphql") + query = """ + subscription onMessageAdded { + messageAdded + } + """ + + await client.subscribe(query=query, handle=MagicMock()) + + mock_info.assert_has_calls([call("the server sent a keep alive message")]) + + @patch("websockets.connect") + async def test_headers_passed_to_websocket_connect(self, mock_connect): + """Subsribe a GraphQL subscription.""" + mock_websocket = mock_connect.return_value.__aenter__.return_value + mock_websocket.send = AsyncMock() + mock_websocket.__aiter__.return_value = [ + '{"type": "data", "id": "1", "payload": {"data": {"messageAdded": "one"}}}', + ] + + expected_endpoint = "ws://www.test-api.com/graphql" + client = GraphqlClient(endpoint=expected_endpoint) + + query = """ + subscription onMessageAdded { + messageAdded + } + """ + + mock_handle = MagicMock() + + expected_headers = {"some": "header"} + + await client.subscribe( + query=query, handle=mock_handle, headers=expected_headers + ) + + mock_connect.assert_called_with( + expected_endpoint, + subprotocols=["graphql-ws"], + extra_headers=expected_headers, + ) + + mock_handle.assert_has_calls( + [ + call({"data": {"messageAdded": "one"}}), + ] + )