1616# limitations under the License.
1717################################################################################
1818
19+ import itertools
20+
1921from java_based_implementation .java_gateway import get_gateway
2022from java_based_implementation .util .java_utils import to_j_catalog_context
2123from paimon_python_api import (catalog , table , read_builder , table_scan , split , table_read ,
2224 write_builder , table_write , commit_message , table_commit )
23- from pyarrow import RecordBatchReader , RecordBatch
24- from typing import List
25+ from pyarrow import (RecordBatch , BufferOutputStream , RecordBatchStreamWriter ,
26+ RecordBatchStreamReader , BufferReader , RecordBatchReader )
27+ from typing import List , Iterator
2528
2629
2730class Catalog (catalog .Catalog ):
@@ -49,18 +52,19 @@ def __init__(self, j_table):
4952 self ._j_table = j_table
5053
5154 def new_read_builder (self ) -> 'ReadBuilder' :
52- j_read_builder = self . _j_table . newReadBuilder ( )
53- return ReadBuilder (j_read_builder )
55+ j_read_builder = get_gateway (). jvm . InvocationUtil . getReadBuilder ( self . _j_table )
56+ return ReadBuilder (j_read_builder , self . _j_table . rowType () )
5457
5558 def new_batch_write_builder (self ) -> 'BatchWriteBuilder' :
56- j_batch_write_builder = self . _j_table . newBatchWriteBuilder ( )
57- return BatchWriteBuilder (j_batch_write_builder )
59+ j_batch_write_builder = get_gateway (). jvm . InvocationUtil . getBatchWriteBuilder ( self . _j_table )
60+ return BatchWriteBuilder (j_batch_write_builder , self . _j_table . rowType () )
5861
5962
6063class ReadBuilder (read_builder .ReadBuilder ):
6164
62- def __init__ (self , j_read_builder ):
65+ def __init__ (self , j_read_builder , j_row_type ):
6366 self ._j_read_builder = j_read_builder
67+ self ._j_row_type = j_row_type
6468
6569 def with_projection (self , projection : List [List [int ]]) -> 'ReadBuilder' :
6670 self ._j_read_builder .withProjection (projection )
@@ -75,8 +79,8 @@ def new_scan(self) -> 'TableScan':
7579 return TableScan (j_table_scan )
7680
7781 def new_read (self ) -> 'TableRead' :
78- # TODO
79- pass
82+ j_table_read = self . _j_read_builder . newRead ()
83+ return TableRead ( j_table_read , self . _j_row_type )
8084
8185
8286class TableScan (table_scan .TableScan ):
@@ -110,23 +114,56 @@ def to_j_split(self):
110114
111115class TableRead (table_read .TableRead ):
112116
113- def create_reader (self , split : Split ) -> RecordBatchReader :
114- # TODO
115- pass
117+ def __init__ (self , j_table_read , j_row_type ):
118+ self ._j_table_read = j_table_read
119+ self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createBytesReader (
120+ j_table_read , j_row_type )
121+ self ._arrow_schema = None
122+
123+ def create_reader (self , split : Split ):
124+ self ._j_bytes_reader .setSplit (split .to_j_split ())
125+ batch_iterator = self ._batch_generator ()
126+ # to init arrow schema
127+ try :
128+ first_batch = next (batch_iterator )
129+ except StopIteration :
130+ return self ._empty_batch_reader ()
131+
132+ batches = itertools .chain ((b for b in [first_batch ]), batch_iterator )
133+ return RecordBatchReader .from_batches (self ._arrow_schema , batches )
134+
135+ def _batch_generator (self ) -> Iterator [RecordBatch ]:
136+ while True :
137+ next_bytes = self ._j_bytes_reader .next ()
138+ if next_bytes is None :
139+ break
140+ else :
141+ stream_reader = RecordBatchStreamReader (BufferReader (next_bytes ))
142+ if self ._arrow_schema is None :
143+ self ._arrow_schema = stream_reader .schema
144+ yield from stream_reader
145+
146+ def _empty_batch_reader (self ):
147+ import pyarrow as pa
148+ schema = pa .schema ([])
149+ empty_batch = pa .RecordBatch .from_arrays ([], schema = schema )
150+ empty_reader = pa .RecordBatchReader .from_batches (schema , [empty_batch ])
151+ return empty_reader
116152
117153
118154class BatchWriteBuilder (write_builder .BatchWriteBuilder ):
119155
120- def __init__ (self , j_batch_write_builder ):
156+ def __init__ (self , j_batch_write_builder , j_row_type ):
121157 self ._j_batch_write_builder = j_batch_write_builder
158+ self ._j_row_type = j_row_type
122159
123160 def with_overwrite (self , static_partition : dict ) -> 'BatchWriteBuilder' :
124161 self ._j_batch_write_builder .withOverwrite (static_partition )
125162 return self
126163
127164 def new_write (self ) -> 'BatchTableWrite' :
128165 j_batch_table_write = self ._j_batch_write_builder .newWrite ()
129- return BatchTableWrite (j_batch_table_write )
166+ return BatchTableWrite (j_batch_table_write , self . _j_row_type )
130167
131168 def new_commit (self ) -> 'BatchTableCommit' :
132169 j_batch_table_commit = self ._j_batch_write_builder .newCommit ()
@@ -135,17 +172,27 @@ def new_commit(self) -> 'BatchTableCommit':
135172
136173class BatchTableWrite (table_write .BatchTableWrite ):
137174
138- def __init__ (self , j_batch_table_write ):
175+ def __init__ (self , j_batch_table_write , j_row_type ):
139176 self ._j_batch_table_write = j_batch_table_write
177+ self ._j_bytes_writer = get_gateway ().jvm .InvocationUtil .createBytesWriter (
178+ j_batch_table_write , j_row_type )
140179
141180 def write (self , record_batch : RecordBatch ):
142- # TODO
143- pass
181+ stream = BufferOutputStream ()
182+ with RecordBatchStreamWriter (stream , record_batch .schema ) as writer :
183+ writer .write (record_batch )
184+ writer .close ()
185+ arrow_bytes = stream .getvalue ().to_pybytes ()
186+ self ._j_bytes_writer .write (arrow_bytes )
144187
145188 def prepare_commit (self ) -> List ['CommitMessage' ]:
146189 j_commit_messages = self ._j_batch_table_write .prepareCommit ()
147190 return list (map (lambda cm : CommitMessage (cm ), j_commit_messages ))
148191
192+ def close (self ):
193+ self ._j_batch_table_write .close ()
194+ self ._j_bytes_writer .close ()
195+
149196
150197class CommitMessage (commit_message .CommitMessage ):
151198
@@ -164,3 +211,6 @@ def __init__(self, j_batch_table_commit):
164211 def commit (self , commit_messages : List [CommitMessage ]):
165212 j_commit_messages = list (map (lambda cm : cm .to_j_commit_message (), commit_messages ))
166213 self ._j_batch_table_commit .commit (j_commit_messages )
214+
215+ def close (self ):
216+ self ._j_batch_table_commit .close ()
0 commit comments