Skip to content

Commit 5f18615

Browse files
authored
TableRead supports reading multiple splits parallely (#11)
1 parent 165ad9b commit 5f18615

9 files changed

Lines changed: 286 additions & 64 deletions

File tree

java_based_implementation/api_impl.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
################################################################################
1818

1919
from java_based_implementation.java_gateway import get_gateway
20+
from java_based_implementation.util import constants
2021
from java_based_implementation.util.java_utils import to_j_catalog_context, check_batch_write
2122
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
2223
write_builder, table_write, commit_message, table_commit)
@@ -27,31 +28,33 @@
2728

2829
class Catalog(catalog.Catalog):
2930

30-
def __init__(self, j_catalog):
31+
def __init__(self, j_catalog, catalog_options: dict):
3132
self._j_catalog = j_catalog
33+
self._catalog_options = catalog_options
3234

3335
@staticmethod
34-
def create(catalog_context: dict) -> 'Catalog':
35-
j_catalog_context = to_j_catalog_context(catalog_context)
36+
def create(catalog_options: dict) -> 'Catalog':
37+
j_catalog_context = to_j_catalog_context(catalog_options)
3638
gateway = get_gateway()
3739
j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context)
38-
return Catalog(j_catalog)
40+
return Catalog(j_catalog, catalog_options)
3941

4042
def get_table(self, identifier: str) -> 'Table':
4143
gateway = get_gateway()
4244
j_identifier = gateway.jvm.Identifier.fromString(identifier)
4345
j_table = self._j_catalog.getTable(j_identifier)
44-
return Table(j_table)
46+
return Table(j_table, self._catalog_options)
4547

4648

4749
class Table(table.Table):
4850

49-
def __init__(self, j_table):
51+
def __init__(self, j_table, catalog_options: dict):
5052
self._j_table = j_table
53+
self._catalog_options = catalog_options
5154

5255
def new_read_builder(self) -> 'ReadBuilder':
5356
j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
54-
return ReadBuilder(j_read_builder, self._j_table.rowType())
57+
return ReadBuilder(j_read_builder, self._j_table.rowType(), self._catalog_options)
5558

5659
def new_batch_write_builder(self) -> 'BatchWriteBuilder':
5760
check_batch_write(self._j_table)
@@ -61,9 +64,10 @@ def new_batch_write_builder(self) -> 'BatchWriteBuilder':
6164

6265
class ReadBuilder(read_builder.ReadBuilder):
6366

64-
def __init__(self, j_read_builder, j_row_type):
67+
def __init__(self, j_read_builder, j_row_type, catalog_options: dict):
6568
self._j_read_builder = j_read_builder
6669
self._j_row_type = j_row_type
70+
self._catalog_options = catalog_options
6771

6872
def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
6973
self._j_read_builder.withProjection(projection)
@@ -79,7 +83,7 @@ def new_scan(self) -> 'TableScan':
7983

8084
def new_read(self) -> 'TableRead':
8185
j_table_read = self._j_read_builder.newRead()
82-
return TableRead(j_table_read, self._j_row_type)
86+
return TableRead(j_table_read, self._j_row_type, self._catalog_options)
8387

8488

8589
class TableScan(table_scan.TableScan):
@@ -113,24 +117,40 @@ def to_j_split(self):
113117

114118
class TableRead(table_read.TableRead):
115119

116-
def __init__(self, j_table_read, j_row_type):
120+
def __init__(self, j_table_read, j_row_type, catalog_options):
117121
self._j_table_read = j_table_read
118-
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createBytesReader(
119-
j_table_read, j_row_type)
122+
self._j_row_type = j_row_type
123+
self._catalog_options = catalog_options
124+
self._j_bytes_reader = None
120125
self._arrow_schema = None
121126

122-
def create_reader(self, split: Split):
123-
self._j_bytes_reader.setSplit(split.to_j_split())
124-
# get schema
127+
def create_reader(self, splits):
128+
self._init()
129+
j_splits = list(map(lambda s: s.to_j_split(), splits))
130+
self._j_bytes_reader.setSplits(j_splits)
131+
batch_iterator = self._batch_generator()
132+
return RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)
133+
134+
def _init(self):
135+
if self._j_bytes_reader is None:
136+
# get thread num
137+
max_workers = self._catalog_options.get(constants.MAX_WORKERS)
138+
if max_workers is None:
139+
# default is sequential
140+
max_workers = 1
141+
else:
142+
max_workers = int(max_workers)
143+
if max_workers <= 0:
144+
raise ValueError("max_workers must be greater than 0")
145+
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
146+
self._j_table_read, self._j_row_type, max_workers)
147+
125148
if self._arrow_schema is None:
126149
schema_bytes = self._j_bytes_reader.serializeSchema()
127150
schema_reader = RecordBatchStreamReader(BufferReader(schema_bytes))
128151
self._arrow_schema = schema_reader.schema
129152
schema_reader.close()
130153

131-
batch_iterator = self._batch_generator()
132-
return RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)
133-
134154
def _batch_generator(self) -> Iterator[RecordBatch]:
135155
while True:
136156
next_bytes = self._j_bytes_reader.next()

java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/InvocationUtil.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ public static ReadBuilder getReadBuilder(Table table) {
3939
return table.newReadBuilder();
4040
}
4141

42-
public static BytesReader createBytesReader(TableRead tableRead, RowType rowType) {
43-
return new BytesReader(tableRead, rowType);
42+
public static ParallelBytesReader createParallelBytesReader(
43+
TableRead tableRead, RowType rowType, int threadNum) {
44+
return new ParallelBytesReader(tableRead, rowType, threadNum);
4445
}
4546

4647
public static BytesWriter createBytesWriter(TableWrite tableWrite, RowType rowType) {
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.paimon.python;
20+
21+
import org.apache.paimon.arrow.ArrowUtils;
22+
import org.apache.paimon.arrow.vector.ArrowFormatWriter;
23+
import org.apache.paimon.data.InternalRow;
24+
import org.apache.paimon.reader.RecordReader;
25+
import org.apache.paimon.reader.RecordReaderIterator;
26+
import org.apache.paimon.table.source.Split;
27+
import org.apache.paimon.table.source.TableRead;
28+
import org.apache.paimon.types.RowType;
29+
import org.apache.paimon.utils.ThreadPoolUtils;
30+
31+
import org.apache.paimon.shade.guava30.com.google.common.collect.Iterators;
32+
33+
import org.apache.arrow.vector.VectorSchemaRoot;
34+
35+
import javax.annotation.Nullable;
36+
37+
import java.io.ByteArrayOutputStream;
38+
import java.io.IOException;
39+
import java.util.ArrayDeque;
40+
import java.util.ArrayList;
41+
import java.util.Collection;
42+
import java.util.Iterator;
43+
import java.util.List;
44+
import java.util.Queue;
45+
import java.util.concurrent.ConcurrentLinkedQueue;
46+
import java.util.concurrent.ExecutionException;
47+
import java.util.concurrent.ExecutorService;
48+
import java.util.concurrent.Future;
49+
import java.util.concurrent.ThreadPoolExecutor;
50+
import java.util.function.Function;
51+
52+
/** Parallely read Arrow bytes from multiple splits. */
53+
public class ParallelBytesReader {
54+
55+
private static final String THREAD_NAME_PREFIX = "PARALLEL_SPLITS_READER";
56+
private static final int DEFAULT_WRITE_BATCH_SIZE = 1024;
57+
58+
private final TableRead tableRead;
59+
private final RowType rowType;
60+
private final int threadNum;
61+
62+
private final ConcurrentLinkedQueue<RecordReaderIterator<InternalRow>> iterators;
63+
private final ConcurrentLinkedQueue<ArrowFormatWriter> arrowFormatWriters;
64+
65+
private ThreadPoolExecutor executor;
66+
private Iterator<byte[]> bytesIterator;
67+
68+
public ParallelBytesReader(TableRead tableRead, RowType rowType, int threadNum) {
69+
this.tableRead = tableRead;
70+
this.rowType = rowType;
71+
this.threadNum = threadNum;
72+
this.iterators = new ConcurrentLinkedQueue<>();
73+
this.arrowFormatWriters = new ConcurrentLinkedQueue<>();
74+
}
75+
76+
public void setSplits(List<Split> splits) {
77+
bytesIterator = randomlyExecute(getExecutor(), makeProcessor(), splits);
78+
}
79+
80+
public byte[] serializeSchema() {
81+
ArrowFormatWriter arrowFormatWriter = newWriter();
82+
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
83+
ByteArrayOutputStream out = new ByteArrayOutputStream();
84+
ArrowUtils.serializeToIpc(vsr, out);
85+
arrowFormatWriter.close();
86+
return out.toByteArray();
87+
}
88+
89+
@Nullable
90+
public byte[] next() {
91+
if (bytesIterator.hasNext()) {
92+
return bytesIterator.next();
93+
} else {
94+
closeResources();
95+
return null;
96+
}
97+
}
98+
99+
private ThreadPoolExecutor getExecutor() {
100+
if (executor == null) {
101+
executor = ThreadPoolUtils.createCachedThreadPool(threadNum, THREAD_NAME_PREFIX);
102+
}
103+
return executor;
104+
}
105+
106+
private Function<Split, Iterator<byte[]>> makeProcessor() {
107+
return split -> {
108+
try {
109+
RecordReader<InternalRow> recordReader = tableRead.createReader(split);
110+
RecordReaderIterator<InternalRow> iterator =
111+
new RecordReaderIterator<>(recordReader);
112+
iterators.add(iterator);
113+
ArrowFormatWriter arrowFormatWriter = newWriter();
114+
arrowFormatWriters.add(arrowFormatWriter);
115+
return new RecordBytesIterator(iterator, arrowFormatWriter);
116+
} catch (IOException e) {
117+
throw new RuntimeException(e);
118+
}
119+
};
120+
}
121+
122+
private <U, T> Iterator<T> randomlyExecute(
123+
ExecutorService executor, Function<U, Iterator<T>> processor, Collection<U> input) {
124+
List<Future<Iterator<T>>> futures = new ArrayList<>(input.size());
125+
for (U u : input) {
126+
futures.add(executor.submit(() -> processor.apply(u)));
127+
}
128+
return futuresToIterIter(futures);
129+
}
130+
131+
private <T> Iterator<T> futuresToIterIter(List<Future<Iterator<T>>> futures) {
132+
final Queue<Future<Iterator<T>>> queue = new ArrayDeque<>(futures);
133+
return Iterators.concat(
134+
new Iterator<Iterator<T>>() {
135+
public boolean hasNext() {
136+
return !queue.isEmpty();
137+
}
138+
139+
public Iterator<T> next() {
140+
try {
141+
return queue.poll().get();
142+
} catch (InterruptedException e) {
143+
Thread.currentThread().interrupt();
144+
throw new RuntimeException(e);
145+
} catch (ExecutionException e) {
146+
throw new RuntimeException(e);
147+
}
148+
}
149+
});
150+
}
151+
152+
private void closeResources() {
153+
for (RecordReaderIterator<InternalRow> iterator : iterators) {
154+
try {
155+
iterator.close();
156+
} catch (Exception e) {
157+
throw new RuntimeException(e);
158+
}
159+
}
160+
iterators.clear();
161+
162+
for (ArrowFormatWriter arrowFormatWriter : arrowFormatWriters) {
163+
arrowFormatWriter.close();
164+
}
165+
arrowFormatWriters.clear();
166+
}
167+
168+
private ArrowFormatWriter newWriter() {
169+
return new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
170+
}
171+
}

java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java renamed to java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/RecordBytesIterator.java

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,54 +21,33 @@
2121
import org.apache.paimon.arrow.ArrowUtils;
2222
import org.apache.paimon.arrow.vector.ArrowFormatWriter;
2323
import org.apache.paimon.data.InternalRow;
24-
import org.apache.paimon.reader.RecordReader;
2524
import org.apache.paimon.reader.RecordReaderIterator;
26-
import org.apache.paimon.table.source.Split;
27-
import org.apache.paimon.table.source.TableRead;
28-
import org.apache.paimon.types.RowType;
2925

3026
import org.apache.arrow.vector.VectorSchemaRoot;
3127

32-
import javax.annotation.Nullable;
33-
3428
import java.io.ByteArrayOutputStream;
35-
import java.io.IOException;
36-
37-
/** Read Arrow bytes from split. */
38-
public class BytesReader {
29+
import java.util.Iterator;
3930

40-
private static final int DEFAULT_WRITE_BATCH_SIZE = 2048;
31+
/** Cast a {@link RecordReaderIterator} to bytes iterator. */
32+
public class RecordBytesIterator implements Iterator<byte[]> {
4133

42-
private final TableRead tableRead;
34+
private final RecordReaderIterator<InternalRow> iterator;
4335
private final ArrowFormatWriter arrowFormatWriter;
4436

45-
private RecordReaderIterator<InternalRow> iterator;
4637
private InternalRow nextRow;
4738

48-
public BytesReader(TableRead tableRead, RowType rowType) {
49-
this.tableRead = tableRead;
50-
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
51-
}
52-
53-
public void setSplit(Split split) throws IOException {
54-
RecordReader<InternalRow> recordReader = tableRead.createReader(split);
55-
iterator = new RecordReaderIterator<InternalRow>(recordReader);
39+
public RecordBytesIterator(
40+
RecordReaderIterator<InternalRow> iterator, ArrowFormatWriter arrowFormatWriter) {
41+
this.iterator = iterator;
42+
this.arrowFormatWriter = arrowFormatWriter;
5643
nextRow();
5744
}
5845

59-
public byte[] serializeSchema() {
60-
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
61-
ByteArrayOutputStream out = new ByteArrayOutputStream();
62-
ArrowUtils.serializeToIpc(vsr, out);
63-
return out.toByteArray();
46+
public boolean hasNext() {
47+
return nextRow != null;
6448
}
6549

66-
@Nullable
67-
public byte[] next() throws Exception {
68-
if (nextRow == null) {
69-
return null;
70-
}
71-
50+
public byte[] next() {
7251
int rowCount = 0;
7352
while (nextRow != null && arrowFormatWriter.write(nextRow)) {
7453
nextRow();
@@ -80,11 +59,6 @@ public byte[] next() throws Exception {
8059
vsr.setRowCount(rowCount);
8160
ByteArrayOutputStream out = new ByteArrayOutputStream();
8261
ArrowUtils.serializeToIpc(vsr, out);
83-
if (nextRow == null) {
84-
// close resource
85-
arrowFormatWriter.close();
86-
iterator.close();
87-
}
8862
return out.toByteArray();
8963
}
9064

0 commit comments

Comments
 (0)