@@ -41,11 +41,12 @@ def to_df(r):
4141
4242
4343class StreamingResult :
44- def __init__ (self , c_result , conn , result_func ):
44+ def __init__ (self , c_result , conn , result_func , supports_record_batch ):
4545 self ._result = c_result
4646 self ._result_func = result_func
4747 self ._conn = conn
4848 self ._exhausted = False
49+ self ._supports_record_batch = supports_record_batch
4950
5051 def fetch (self ):
5152 """Fetch next chunk of streaming results"""
@@ -80,15 +81,182 @@ def __enter__(self):
8081 return self
8182
8283 def __exit__ (self , exc_type , exc_val , exc_tb ):
83- pass
84+ self .cancel ()
85+
86+ def close (self ):
87+ self .cancel ()
8488
8589 def cancel (self ):
86- self ._exhausted = True
90+ if not self ._exhausted :
91+ self ._exhausted = True
92+ try :
93+ self ._conn .streaming_cancel_query (self ._result )
94+ except Exception as e :
95+ raise RuntimeError (f"Failed to cancel streaming query: { str (e )} " ) from e
8796
88- try :
89- self ._conn .streaming_cancel_query (self ._result )
90- except Exception as e :
91- raise RuntimeError (f"Failed to cancel streaming query: { str (e )} " ) from e
97+ def record_batch (self , rows_per_batch : int = 1000000 ) -> pa .RecordBatchReader :
98+ """
99+ Create a PyArrow RecordBatchReader from this StreamingResult.
100+
101+ This method requires that the StreamingResult was created with arrow format.
102+ It wraps the streaming result with ChdbRecordBatchReader to provide efficient
103+ batching with configurable batch sizes.
104+
105+ Args:
106+ rows_per_batch (int): Number of rows per batch. Defaults to 1000000.
107+
108+ Returns:
109+ pa.RecordBatchReader: PyArrow RecordBatchReader for efficient streaming
110+
111+ Raises:
112+ ValueError: If the StreamingResult was not created with arrow format
113+ """
114+ if not self ._supports_record_batch :
115+ raise ValueError (
116+ "record_batch() can only be used with arrow format. "
117+ "Please use format='Arrow' when calling send_query."
118+ )
119+
120+ chdb_reader = ChdbRecordBatchReader (self , rows_per_batch )
121+ return pa .RecordBatchReader .from_batches (chdb_reader .schema (), chdb_reader )
122+
123+
124+ class ChdbRecordBatchReader :
125+ """
126+ A PyArrow RecordBatchReader wrapper for chdb StreamingResult.
127+
128+ This class provides an efficient way to read large result sets as PyArrow RecordBatches
129+ with configurable batch sizes to optimize memory usage and performance.
130+ """
131+
132+ def __init__ (self , chdb_stream_result , batch_size_rows ):
133+ self ._stream_result = chdb_stream_result
134+ self ._schema = None
135+ self ._closed = False
136+ self ._pending_batches = []
137+ self ._accumulator = []
138+ self ._batch_size_rows = batch_size_rows
139+ self ._current_rows = 0
140+ self ._first_batch = None
141+ self ._first_batch_consumed = True
142+ self ._schema = self .schema ()
143+
144+ def schema (self ):
145+ if self ._schema is None :
146+ # Get the first chunk to determine schema
147+ chunk = self ._stream_result .fetch ()
148+ if chunk is not None :
149+ arrow_bytes = chunk .bytes ()
150+ reader = pa .RecordBatchFileReader (arrow_bytes )
151+ self ._schema = reader .schema
152+
153+ table = reader .read_all ()
154+ if table .num_rows > 0 :
155+ batches = table .to_batches ()
156+ self ._first_batch = batches [0 ]
157+ if len (batches ) > 1 :
158+ self ._pending_batches = batches [1 :]
159+ self ._first_batch_consumed = False
160+ else :
161+ self ._first_batch = None
162+ self ._first_batch_consumed = True
163+ else :
164+ self ._schema = pa .schema ([])
165+ self ._first_batch = None
166+ self ._first_batch_consumed = True
167+ self ._closed = True
168+ return self ._schema
169+
170+ def read_next_batch (self ):
171+ if self ._accumulator :
172+ result = self ._accumulator .pop (0 )
173+ return result
174+
175+ if self ._closed :
176+ raise StopIteration
177+
178+ while True :
179+ batch = None
180+
181+ # 1. Return the first batch if not consumed yet
182+ if not self ._first_batch_consumed :
183+ self ._first_batch_consumed = True
184+ batch = self ._first_batch
185+
186+ # 2. Check pending batches from current chunk
187+ elif self ._pending_batches :
188+ batch = self ._pending_batches .pop (0 )
189+
190+ # 3. Fetch new chunk from chdb stream
191+ else :
192+ chunk = self ._stream_result .fetch ()
193+ if chunk is None :
194+ # No more data - return accumulated batches if any
195+ break
196+
197+ arrow_bytes = chunk .bytes ()
198+ if not arrow_bytes :
199+ continue
200+
201+ reader = pa .RecordBatchFileReader (arrow_bytes )
202+ table = reader .read_all ()
203+
204+ if table .num_rows > 0 :
205+ batches = table .to_batches ()
206+ batch = batches [0 ]
207+ if len (batches ) > 1 :
208+ self ._pending_batches = batches [1 :]
209+ else :
210+ continue
211+
212+ # Process the batch if we got one
213+ if batch is not None :
214+ self ._accumulator .append (batch )
215+ self ._current_rows += batch .num_rows
216+
217+ # If accumulated enough rows, return combined batch
218+ if self ._current_rows >= self ._batch_size_rows :
219+ if len (self ._accumulator ) == 1 :
220+ result = self ._accumulator .pop (0 )
221+ else :
222+ if hasattr (pa , 'concat_batches' ):
223+ result = pa .concat_batches (self ._accumulator )
224+ self ._accumulator = []
225+ else :
226+ result = self ._accumulator .pop (0 )
227+
228+ self ._current_rows = 0
229+ return result
230+
231+ # End of stream - return any accumulated batches
232+ if self ._accumulator :
233+ if len (self ._accumulator ) == 1 :
234+ result = self ._accumulator .pop (0 )
235+ else :
236+ if hasattr (pa , 'concat_batches' ):
237+ result = pa .concat_batches (self ._accumulator )
238+ self ._accumulator = []
239+ else :
240+ result = self ._accumulator .pop (0 )
241+
242+ self ._current_rows = 0
243+ self ._closed = True
244+ return result
245+
246+ # No more data
247+ self ._closed = True
248+ raise StopIteration
249+
250+ def close (self ):
251+ if not self ._closed :
252+ self ._stream_result .close ()
253+ self ._closed = True
254+
255+ def __iter__ (self ):
256+ return self
257+
258+ def __next__ (self ):
259+ return self .read_next_batch ()
92260
93261
94262class Connection :
@@ -112,12 +280,13 @@ def query(self, query: str, format: str = "CSV") -> Any:
112280
113281 def send_query (self , query : str , format : str = "CSV" ) -> StreamingResult :
114282 lower_output_format = format .lower ()
283+ supports_record_batch = lower_output_format == "arrow"
115284 result_func = _process_result_format_funs .get (lower_output_format , lambda x : x )
116285 if lower_output_format in _arrow_format :
117286 format = "Arrow"
118287
119288 c_stream_result = self ._conn .send_query (query , format )
120- return StreamingResult (c_stream_result , self ._conn , result_func )
289+ return StreamingResult (c_stream_result , self ._conn , result_func , supports_record_batch )
121290
122291 def close (self ) -> None :
123292 # print("close")
0 commit comments