Skip to content

Commit 6e47a8f

Browse files
authored
Merge pull request #138 from soerenreichardt/stream-topology-endpoint
Stream topology endpoint
2 parents e825beb + b4626fa commit 6e47a8f

File tree

5 files changed

+60
-1
lines changed

5 files changed

+60
-1
lines changed

graphdatascience/graph/graph_entity_ops_runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ def drop(
151151

152152
return self._query_runner.run_query(query, params).squeeze() # type: ignore
153153

154+
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
155+
def stream(self, G: Graph, relationship_types: List[str] = ["*"], **config: Any) -> DataFrame:
156+
self._namespace += ".stream"
157+
query = f"CALL {self._namespace}($graph_name, $relationship_types, $config)"
158+
159+
params = {"graph_name": G.name(), "relationship_types": relationship_types, "config": config}
160+
161+
return self._query_runner.run_query(query, params)
162+
154163

155164
class GraphPropertyRunner(CallerBase, UncallableNamespace, IllegalAttrChecker):
156165
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from .arrow_graph_constructor import ArrowGraphConstructor
1111
from .graph_constructor import GraphConstructor
1212
from .query_runner import QueryRunner
13+
from graphdatascience.server_version.compatible_with import (
14+
IncompatibleServerVersionError,
15+
)
1316
from graphdatascience.server_version.server_version import ServerVersion
1417

1518

@@ -107,6 +110,19 @@ def run_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Data
107110
endpoint,
108111
{"relationship_properties": property_names, "relationship_types": relationship_types},
109112
)
113+
elif "gds.beta.graph.relationships.stream" in query:
114+
graph_name = params["graph_name"]
115+
relationship_types = params["relationship_types"]
116+
117+
if self._server_version < new_endpoint_server_version:
118+
raise IncompatibleServerVersionError(
119+
f"The call gds.beta.graph.relationships.stream with parameters {params} via Arrow requires GDS "
120+
f"server version >= 2.2.0. The current version is {self._server_version}"
121+
)
122+
else:
123+
endpoint = "gds.beta.graph.relationships.stream"
124+
125+
return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types})
110126

111127
return self._fallback_query_runner.run_query(query, params)
112128

graphdatascience/tests/integration/test_coverage.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
IGNORED_ENDPOINTS = {
99
"gds.alpha.graph.removeGraphProperty", # Exists but undocumented for GDS 2.1
1010
"gds.alpha.graph.streamGraphProperty", # Exists but undocumented for GDS 2.1
11-
"gds.beta.graph.relationships.stream", # FIXME: Add support
1211
"gds.alpha.pipeline.linkPrediction.addMLP",
1312
"gds.alpha.pipeline.linkPrediction.addRandomForest",
1413
"gds.beta.pipeline.linkPrediction.addFeature",

graphdatascience/tests/integration/test_graph_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,32 @@ def test_graph_relationshipProperties_stream_without_arrow_separate_property_col
549549
assert {e for e in result["relY"]} == {5, 6, 7}
550550

551551

552+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
553+
def test_graph_relationships_stream_without_arrow(gds_without_arrow: GraphDataScience) -> None:
554+
G, _ = gds_without_arrow.graph.project(GRAPH_NAME, "*", "REL")
555+
556+
result = gds_without_arrow.beta.graph.relationships.stream(G, ["REL"])
557+
558+
expected = gds_without_arrow.run_cypher("MATCH (n)-[REL]->(m) RETURN id(n) AS src_id, id(m) AS trg_id")
559+
560+
assert list(result.keys()) == ["sourceNodeId", "targetNodeId", "relationshipType"]
561+
assert {e for e in result["sourceNodeId"]} == {i for i in expected["src_id"]}
562+
assert {e for e in result["targetNodeId"]} == {i for i in expected["trg_id"]}
563+
564+
565+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
566+
def test_graph_relationships_stream_with_arrow(gds: GraphDataScience) -> None:
567+
G, _ = gds.graph.project(GRAPH_NAME, "*", "REL")
568+
569+
result = gds.beta.graph.relationships.stream(G, ["REL"])
570+
571+
expected = gds.run_cypher("MATCH (n)-[REL]->(m) RETURN id(n) AS src_id, id(m) AS trg_id")
572+
573+
assert list(result.keys()) == ["sourceNodeId", "targetNodeId", "relationshipType"]
574+
assert {e for e in result["sourceNodeId"]} == {i for i in expected["src_id"]}
575+
assert {e for e in result["targetNodeId"]} == {i for i in expected["trg_id"]}
576+
577+
552578
def test_graph_writeNodeProperties(gds: GraphDataScience) -> None:
553579
G, _ = gds.graph.project(GRAPH_NAME, "*", "*")
554580

graphdatascience/tests/unit/test_graph_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,15 @@ def test_graph_property_drop(runner: CollectingQueryRunner, gds: GraphDataScienc
533533
assert runner.last_params() == {"graph_name": "g", "graph_property": "prop", "config": {}}
534534

535535

536+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 2, 0)])
537+
def test_graph_relationships_stream(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
538+
G, _ = gds.graph.project("g", "*", "*")
539+
540+
gds.beta.graph.relationships.stream(G, ["REL_A"])
541+
assert runner.last_query() == "CALL gds.beta.graph.relationships.stream($graph_name, $relationship_types, $config)"
542+
assert runner.last_params() == {"graph_name": "g", "relationship_types": ["REL_A"], "config": {}}
543+
544+
536545
def test_graph_generate(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
537546
gds.beta.graph.generate("g", 1337, 42, orientation="NATURAL")
538547

0 commit comments

Comments
 (0)