1515# See the License for the specific language governing permissions and
1616# limitations under the License.
1717################################################################################
18+ import os
1819
1920# pypaimon.api implementation based on Java code & py4j lib
2021
3031 table_commit , Schema , predicate )
3132from typing import List , Iterator , Optional , Any , TYPE_CHECKING
3233
34+ from pypaimon .pynative .common .exception import PyNativeNotImplementedError
35+ from pypaimon .pynative .common .predicate import PyNativePredicate
36+ from pypaimon .pynative .common .row .internal_row import InternalRow
37+ from pypaimon .pynative .util .reader_converter import ReaderConverter
38+
3339if TYPE_CHECKING :
3440 import ray
3541 from duckdb .duckdb import DuckDBPyConnection
@@ -72,7 +78,12 @@ def __init__(self, j_table, catalog_options: dict):
7278
7379 def new_read_builder (self ) -> 'ReadBuilder' :
7480 j_read_builder = get_gateway ().jvm .InvocationUtil .getReadBuilder (self ._j_table )
75- return ReadBuilder (j_read_builder , self ._j_table .rowType (), self ._catalog_options )
81+ if self ._j_table .primaryKeys ().isEmpty ():
82+ primary_keys = None
83+ else :
84+ primary_keys = [str (key ) for key in self ._j_table .primaryKeys ()]
85+ return ReadBuilder (j_read_builder , self ._j_table .rowType (), self ._catalog_options ,
86+ primary_keys )
7687
7788 def new_batch_write_builder (self ) -> 'BatchWriteBuilder' :
7889 java_utils .check_batch_write (self ._j_table )
@@ -82,16 +93,21 @@ def new_batch_write_builder(self) -> 'BatchWriteBuilder':
8293
8394class ReadBuilder (read_builder .ReadBuilder ):
8495
85- def __init__ (self , j_read_builder , j_row_type , catalog_options : dict ):
96+ def __init__ (self , j_read_builder , j_row_type , catalog_options : dict , primary_keys : List [ str ] ):
8697 self ._j_read_builder = j_read_builder
8798 self ._j_row_type = j_row_type
8899 self ._catalog_options = catalog_options
100+ self ._primary_keys = primary_keys
101+ self ._predicate = None
102+ self ._projection = None
89103
90104 def with_filter (self , predicate : 'Predicate' ):
105+ self ._predicate = predicate
91106 self ._j_read_builder .withFilter (predicate .to_j_predicate ())
92107 return self
93108
94109 def with_projection (self , projection : List [str ]) -> 'ReadBuilder' :
110+ self ._projection = projection
95111 field_names = list (map (lambda field : field .name (), self ._j_row_type .getFields ()))
96112 int_projection = list (map (lambda p : field_names .index (p ), projection ))
97113 gateway = get_gateway ()
@@ -111,7 +127,8 @@ def new_scan(self) -> 'TableScan':
111127
112128 def new_read (self ) -> 'TableRead' :
113129 j_table_read = self ._j_read_builder .newRead ().executeFilter ()
114- return TableRead (j_table_read , self ._j_read_builder .readType (), self ._catalog_options )
130+ return TableRead (j_table_read , self ._j_read_builder .readType (), self ._catalog_options ,
131+ self ._predicate , self ._projection , self ._primary_keys )
115132
116133 def new_predicate_builder (self ) -> 'PredicateBuilder' :
117134 return PredicateBuilder (self ._j_row_type )
@@ -185,14 +202,29 @@ def file_paths(self) -> List[str]:
185202
186203class TableRead (table_read .TableRead ):
187204
188- def __init__ (self , j_table_read , j_read_type , catalog_options ):
205+ def __init__ (self , j_table_read , j_read_type , catalog_options , predicate , projection ,
206+ primary_keys : List [str ]):
207+ self ._j_table_read = j_table_read
208+ self ._j_read_type = j_read_type
209+ self ._catalog_options = catalog_options
210+
211+ self ._predicate = predicate
212+ self ._projection = projection
213+ self ._primary_keys = primary_keys
214+
189215 self ._arrow_schema = java_utils .to_arrow_schema (j_read_type )
190216 self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createParallelBytesReader (
191217 j_table_read , j_read_type , TableRead ._get_max_workers (catalog_options ))
192218
193- def to_arrow (self , splits ):
194- record_batch_reader = self .to_arrow_batch_reader (splits )
195- return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
219+ def to_arrow (self , splits : List ['Split' ]) -> pa .Table :
220+ record_generator = self .to_record_generator (splits )
221+
222+ # If necessary, set the env constants.IMPLEMENT_MODE to 'py4j' to forcibly use py4j reader
223+ if os .environ .get (constants .IMPLEMENT_MODE , '' ) != 'py4j' and record_generator is not None :
224+ return TableRead ._iterator_to_pyarrow_table (record_generator , self ._arrow_schema )
225+ else :
226+ record_batch_reader = self .to_arrow_batch_reader (splits )
227+ return pa .Table .from_batches (record_batch_reader , schema = self ._arrow_schema )
196228
197229 def to_arrow_batch_reader (self , splits ):
198230 j_splits = list (map (lambda s : s .to_j_split (), splits ))
@@ -219,6 +251,60 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
219251
220252 return ray .data .from_arrow (self .to_arrow (splits ))
221253
254+ def to_record_generator (self , splits : List ['Split' ]) -> Optional [Iterator [Any ]]:
255+ """
256+ Returns a generator for iterating over records in the table.
257+ If pynative reader is not available, returns None.
258+ """
259+ try :
260+ j_splits = list (s .to_j_split () for s in splits )
261+ j_reader = get_gateway ().jvm .InvocationUtil .createReader (self ._j_table_read , j_splits )
262+ converter = ReaderConverter (self ._predicate , self ._projection , self ._primary_keys )
263+ pynative_reader = converter .convert_java_reader (j_reader )
264+
265+ def _record_generator ():
266+ try :
267+ batch = pynative_reader .read_batch ()
268+ while batch is not None :
269+ record = batch .next ()
270+ while record is not None :
271+ yield record
272+ record = batch .next ()
273+ batch .release_batch ()
274+ batch = pynative_reader .read_batch ()
275+ finally :
276+ pynative_reader .close ()
277+
278+ return _record_generator ()
279+
280+ except PyNativeNotImplementedError as e :
281+ print (f"Generating pynative reader failed, will use py4j reader instead, "
282+ f"error message: { str (e )} " )
283+ return None
284+
285+ @staticmethod
286+ def _iterator_to_pyarrow_table (record_generator , arrow_schema ):
287+ """
288+ Converts a record generator into a pyarrow Table using the provided Arrow schema.
289+ """
290+ record_batches = []
291+ current_batch = []
292+ batch_size = 1024 # Can be adjusted according to needs for batch size
293+
294+ for record in record_generator :
295+ record_dict = {field : record .get_field (i ) for i , field in enumerate (arrow_schema .names )}
296+ current_batch .append (record_dict )
297+ if len (current_batch ) >= batch_size :
298+ batch = pa .RecordBatch .from_pylist (current_batch , schema = arrow_schema )
299+ record_batches .append (batch )
300+ current_batch = []
301+
302+ if current_batch :
303+ batch = pa .RecordBatch .from_pylist (current_batch , schema = arrow_schema )
304+ record_batches .append (batch )
305+
306+ return pa .Table .from_batches (record_batches , schema = arrow_schema )
307+
222308 @staticmethod
223309 def _get_max_workers (catalog_options ):
224310 # default is sequential
@@ -317,12 +403,16 @@ def close(self):
317403
318404class Predicate (predicate .Predicate ):
319405
320- def __init__ (self , j_predicate_bytes ):
406+ def __init__ (self , py_predicate : PyNativePredicate , j_predicate_bytes ):
407+ self .py_predicate = py_predicate
321408 self ._j_predicate_bytes = j_predicate_bytes
322409
323410 def to_j_predicate (self ):
324411 return deserialize_java_object (self ._j_predicate_bytes )
325412
413+ def test (self , record : InternalRow ) -> bool :
414+ return self .py_predicate .test (record )
415+
326416
327417class PredicateBuilder (predicate .PredicateBuilder ):
328418
@@ -350,7 +440,8 @@ def _build(self, method: str, field: str, literals: Optional[List[Any]] = None):
350440 index ,
351441 literals
352442 )
353- return Predicate (serialize_java_object (j_predicate ))
443+ return Predicate (PyNativePredicate (method , index , field , literals ),
444+ serialize_java_object (j_predicate ))
354445
355446 def equal (self , field : str , literal : Any ) -> Predicate :
356447 return self ._build ('equal' , field , [literal ])
@@ -396,11 +487,13 @@ def between(self, field: str, included_lower_bound: Any, included_upper_bound: A
396487 return self ._build ('between' , field , [included_lower_bound , included_upper_bound ])
397488
398489 def and_predicates (self , predicates : List [Predicate ]) -> Predicate :
399- predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
400- j_predicate = get_gateway ().jvm .PredicationUtil .buildAnd (predicates )
401- return Predicate (serialize_java_object (j_predicate ))
490+ j_predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
491+ j_predicate = get_gateway ().jvm .PredicationUtil .buildAnd (j_predicates )
492+ return Predicate (PyNativePredicate ('and' , None , None , predicates ),
493+ serialize_java_object (j_predicate ))
402494
403495 def or_predicates (self , predicates : List [Predicate ]) -> Predicate :
404- predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
405- j_predicate = get_gateway ().jvm .PredicationUtil .buildOr (predicates )
406- return Predicate (serialize_java_object (j_predicate ))
496+ j_predicates = list (map (lambda p : p .to_j_predicate (), predicates ))
497+ j_predicate = get_gateway ().jvm .PredicationUtil .buildOr (j_predicates )
498+ return Predicate (PyNativePredicate ('or' , None , None , predicates ),
499+ serialize_java_object (j_predicate ))
0 commit comments