11import warnings
2- from typing import List , Set
2+ from typing import Any , List , Set , Tuple
33
44from pandas .core .frame import DataFrame
55
@@ -40,15 +40,18 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
4040 "{readConcurrency: $read_concurrency, parameters: { nodes: $nodes, relationships: $relationships }})"
4141 )
4242
43+ node_query , nodes = self ._node_query (node_dfs [0 ])
44+ relationship_query , relationships = self ._relationship_query (relationship_dfs [0 ])
45+
4346 self ._query_runner .run_query (
4447 query ,
4548 {
4649 "graph_name" : self ._graph_name ,
47- "node_query" : self . _node_query ( node_dfs [ 0 ]) ,
48- "relationship_query" : self . _relationship_query ( relationship_dfs [ 0 ]) ,
50+ "node_query" : node_query ,
51+ "relationship_query" : relationship_query ,
4952 "read_concurrency" : self ._concurrency ,
50- "nodes" : node_dfs [ 0 ]. values . tolist () ,
51- "relationships" : relationship_dfs [ 0 ]. values . tolist () ,
53+ "nodes" : nodes ,
54+ "relationships" : relationships ,
5255 },
5356 )
5457
@@ -59,23 +62,32 @@ def _is_enterprise(self) -> bool:
5962
6063 return license == "Licensed"
6164
62- def _node_query (self , node_df : DataFrame ) -> str :
65+ def _node_query (self , node_df : DataFrame ) -> Tuple [str , List [List [Any ]]]:
66+ node_list = node_df .values .tolist ()
6367 node_id_index = node_df .columns .get_loc ("nodeId" )
6468
6569 label_query = ""
6670 if "labels" in node_df .keys ():
6771 label_index = node_df .columns .get_loc ("labels" )
6872 label_query = f", node[{ label_index } ] as labels"
6973
74+ # Make sure every node has a list of labels
75+ for node in node_list :
76+ labels = node [label_index ]
77+ if isinstance (labels , List ):
78+ continue
79+ node [label_index ] = [labels ]
80+
7081 property_query = ""
7182 property_columns : Set [str ] = set (node_df .keys ()) - {"nodeId" , "labels" }
7283 if len (property_columns ) > 0 :
7384 property_queries = (f", node[{ node_df .columns .get_loc (col )} ] as { col } " for col in property_columns )
7485 property_query = "" .join (property_queries )
7586
76- return f"UNWIND $nodes as node RETURN node[{ node_id_index } ] as id{ label_query } { property_query } "
87+ return f"UNWIND $nodes as node RETURN node[{ node_id_index } ] as id{ label_query } { property_query } " , node_list
7788
78- def _relationship_query (self , rel_df : DataFrame ) -> str :
89+ def _relationship_query (self , rel_df : DataFrame ) -> Tuple [str , List [List [Any ]]]:
90+ rel_list = rel_df .values .tolist ()
7991 source_id_index = rel_df .columns .get_loc ("sourceNodeId" )
8092 target_id_index = rel_df .columns .get_loc ("targetNodeId" )
8193
@@ -93,5 +105,6 @@ def _relationship_query(self, rel_df: DataFrame) -> str:
93105 return (
94106 "UNWIND $relationships as relationship "
95107 f"RETURN relationship[{ source_id_index } ] as source, relationship[{ target_id_index } ] as target"
96- f"{ type_query } { property_query } "
108+ f"{ type_query } { property_query } " ,
109+ rel_list ,
97110 )
0 commit comments