Skip to content

Commit 583bf96

Browse files
authored
Fix: BytesReader flush ArrowFormatWriter and pass schema (#9)
1 parent c505c28 commit 583bf96

3 files changed

Lines changed: 20 additions & 24 deletions

File tree

java_based_implementation/api_impl.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
# limitations under the License.
1717
################################################################################
1818

19-
import itertools
20-
2119
from java_based_implementation.java_gateway import get_gateway
2220
from java_based_implementation.util.java_utils import to_j_catalog_context, check_batch_write
2321
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
@@ -123,15 +121,15 @@ def __init__(self, j_table_read, j_row_type):
123121

124122
def create_reader(self, split: Split):
125123
self._j_bytes_reader.setSplit(split.to_j_split())
126-
batch_iterator = self._batch_generator()
127-
# to init arrow schema
128-
try:
129-
first_batch = next(batch_iterator)
130-
except StopIteration:
131-
return self._empty_batch_reader()
124+
# get schema
125+
if self._arrow_schema is None:
126+
schema_bytes = self._j_bytes_reader.serializeSchema()
127+
schema_reader = RecordBatchStreamReader(BufferReader(schema_bytes))
128+
self._arrow_schema = schema_reader.schema
129+
schema_reader.close()
132130

133-
batches = itertools.chain((b for b in [first_batch]), batch_iterator)
134-
return RecordBatchReader.from_batches(self._arrow_schema, batches)
131+
batch_iterator = self._batch_generator()
132+
return RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)
135133

136134
def _batch_generator(self) -> Iterator[RecordBatch]:
137135
while True:
@@ -140,17 +138,8 @@ def _batch_generator(self) -> Iterator[RecordBatch]:
140138
break
141139
else:
142140
stream_reader = RecordBatchStreamReader(BufferReader(next_bytes))
143-
if self._arrow_schema is None:
144-
self._arrow_schema = stream_reader.schema
145141
yield from stream_reader
146142

147-
def _empty_batch_reader(self):
148-
import pyarrow as pa
149-
schema = pa.schema([])
150-
empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
151-
empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
152-
return empty_reader
153-
154143

155144
class BatchWriteBuilder(write_builder.BatchWriteBuilder):
156145

java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public class BytesReader {
4747

4848
public BytesReader(TableRead tableRead, RowType rowType) {
4949
this.tableRead = tableRead;
50-
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE);
50+
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
5151
}
5252

5353
public void setSplit(Split split) throws IOException {
@@ -56,6 +56,13 @@ public void setSplit(Split split) throws IOException {
5656
nextRow();
5757
}
5858

59+
public byte[] serializeSchema() {
60+
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
61+
ByteArrayOutputStream out = new ByteArrayOutputStream();
62+
ArrowUtils.serializeToIpc(vsr, out);
63+
return out.toByteArray();
64+
}
65+
5966
@Nullable
6067
public byte[] next() throws Exception {
6168
if (nextRow == null) {
@@ -68,6 +75,7 @@ public byte[] next() throws Exception {
6875
rowCount++;
6976
}
7077

78+
arrowFormatWriter.flush();
7179
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
7280
vsr.setRowCount(rowCount);
7381
ByteArrayOutputStream out = new ByteArrayOutputStream();

java_based_implementation/tests/test_write_and_read.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def testReadEmptyPkTable(self):
9292
for split in splits
9393
for batch in table_read.create_reader(split)
9494
]
95-
result = pd.concat(data_frames)
96-
self.assertEqual(result.shape, (0, 0))
95+
self.assertEqual(len(data_frames), 0)
9796

9897
def testWriteReadAppendTable(self):
9998
create_simple_table(self.warehouse, 'default', 'simple_append_table', False)
@@ -135,8 +134,8 @@ def testWriteReadAppendTable(self):
135134
]
136135
result = pd.concat(data_frames)
137136

138-
# check data
139-
pd.testing.assert_frame_equal(result, df)
137+
# check data (ignore index)
138+
pd.testing.assert_frame_equal(result.reset_index(drop=True), df.reset_index(drop=True))
140139

141140
def testWriteWrongSchema(self):
142141
create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)

0 commit comments

Comments
 (0)