Skip to content

Commit 7e49b66

Browse files
authored
Catalog supports to create table from pyarrow schema (#12)
1 parent a7b752a commit 7e49b66

8 files changed

Lines changed: 248 additions & 86 deletions

File tree

paimon_python_api/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .table_commit import BatchTableCommit
2525
from .table_write import BatchTableWrite
2626
from .write_builder import BatchWriteBuilder
27-
from .table import Table
27+
from .table import Table, Schema
2828
from .catalog import Catalog
2929

3030
__all__ = [
@@ -38,5 +38,6 @@
3838
'BatchTableWrite',
3939
'BatchWriteBuilder',
4040
'Table',
41+
'Schema',
4142
'Catalog'
4243
]

paimon_python_api/catalog.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
#################################################################################
1818

1919
from abc import ABC, abstractmethod
20-
from paimon_python_api import Table
20+
from typing import Optional
21+
from paimon_python_api import Table, Schema
2122

2223

2324
class Catalog(ABC):
@@ -34,3 +35,11 @@ def create(catalog_options: dict) -> 'Catalog':
3435
@abstractmethod
3536
def get_table(self, identifier: str) -> Table:
3637
"""Get paimon table identified by the given Identifier."""
38+
39+
@abstractmethod
40+
def create_database(self, name: str, ignore_if_exists: bool, properties: Optional[dict] = None):
41+
"""Create a database with properties."""
42+
43+
@abstractmethod
44+
def create_table(self, identifier: str, schema: Schema, ignore_if_exists: bool):
45+
"""Create table."""

paimon_python_api/table.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
# limitations under the License.
1717
#################################################################################
1818

19+
import pyarrow as pa
20+
1921
from abc import ABC, abstractmethod
2022
from paimon_python_api import ReadBuilder, BatchWriteBuilder
23+
from typing import Optional, List
2124

2225

2326
class Table(ABC):
@@ -30,3 +33,19 @@ def new_read_builder(self) -> ReadBuilder:
3033
@abstractmethod
3134
def new_batch_write_builder(self) -> BatchWriteBuilder:
3235
"""Returns a builder for building batch table write and table commit."""
36+
37+
38+
class Schema:
39+
"""Schema of a table."""
40+
41+
def __init__(self,
42+
pa_schema: pa.Schema,
43+
partition_keys: Optional[List[str]] = None,
44+
primary_keys: Optional[List[str]] = None,
45+
options: Optional[dict] = None,
46+
comment: Optional[str] = None):
47+
self.pa_schema = pa_schema
48+
self.partition_keys = partition_keys
49+
self.primary_keys = primary_keys
50+
self.options = options
51+
self.comment = comment

paimon_python_java/pypaimon.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from paimon_python_java.java_gateway import get_gateway
2222
from paimon_python_java.util import java_utils, constants
2323
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
24-
write_builder, table_write, commit_message, table_commit)
25-
from typing import List, Iterator
24+
write_builder, table_write, commit_message, table_commit, Schema)
25+
from typing import List, Iterator, Optional
2626

2727

2828
class Catalog(catalog.Catalog):
@@ -39,11 +39,20 @@ def create(catalog_options: dict) -> 'Catalog':
3939
return Catalog(j_catalog, catalog_options)
4040

4141
def get_table(self, identifier: str) -> 'Table':
42-
gateway = get_gateway()
43-
j_identifier = gateway.jvm.Identifier.fromString(identifier)
42+
j_identifier = java_utils.to_j_identifier(identifier)
4443
j_table = self._j_catalog.getTable(j_identifier)
4544
return Table(j_table, self._catalog_options)
4645

46+
def create_database(self, name: str, ignore_if_exists: bool, properties: Optional[dict] = None):
47+
if properties is None:
48+
properties = {}
49+
self._j_catalog.createDatabase(name, ignore_if_exists, properties)
50+
51+
def create_table(self, identifier: str, schema: Schema, ignore_if_exists: bool):
52+
j_identifier = java_utils.to_j_identifier(identifier)
53+
j_schema = java_utils.to_paimon_schema(schema)
54+
self._j_catalog.createTable(j_identifier, j_schema, ignore_if_exists)
55+
4756

4857
class Table(table.Table):
4958

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
19+
import os
20+
import random
21+
import shutil
22+
import string
23+
import tempfile
24+
import pyarrow as pa
25+
import unittest
26+
27+
28+
from paimon_python_api import Schema
29+
from paimon_python_java import Catalog
30+
from paimon_python_java.util import java_utils
31+
from setup_utils import java_setuputils
32+
33+
34+
class DataTypesTest(unittest.TestCase):
35+
36+
@classmethod
37+
def setUpClass(cls):
38+
java_setuputils.setup_java_bridge()
39+
cls.warehouse = tempfile.mkdtemp()
40+
cls.simple_pa_schema = pa.schema([
41+
('f0', pa.int32()),
42+
('f1', pa.string())
43+
])
44+
cls.catalog = Catalog.create({'warehouse': cls.warehouse})
45+
cls.catalog.create_database('default', False)
46+
47+
@classmethod
48+
def tearDownClass(cls):
49+
java_setuputils.clean()
50+
if os.path.exists(cls.warehouse):
51+
shutil.rmtree(cls.warehouse)
52+
53+
def test_int(self):
54+
pa_schema = pa.schema([
55+
('_int8', pa.int8()),
56+
('_int16', pa.int16()),
57+
('_int32', pa.int32()),
58+
('_int64', pa.int64())
59+
])
60+
expected_types = ['TINYINT', 'SMALLINT', 'INT', 'BIGINT']
61+
self._test_impl(pa_schema, expected_types)
62+
63+
def test_float(self):
64+
pa_schema = pa.schema([
65+
('_float16', pa.float16()),
66+
('_float32', pa.float32()),
67+
('_float64', pa.float64())
68+
])
69+
expected_types = ['FLOAT', 'FLOAT', 'DOUBLE']
70+
self._test_impl(pa_schema, expected_types)
71+
72+
def test_string(self):
73+
pa_schema = pa.schema([
74+
('_string', pa.string()),
75+
('_utf8', pa.utf8())
76+
])
77+
expected_types = ['STRING', 'STRING']
78+
self._test_impl(pa_schema, expected_types)
79+
80+
def test_bool(self):
81+
pa_schema = pa.schema([('_bool', pa.bool_())])
82+
expected_types = ['BOOLEAN']
83+
self._test_impl(pa_schema, expected_types)
84+
85+
def test_null(self):
86+
pa_schema = pa.schema([('_null', pa.null())])
87+
expected_types = ['STRING']
88+
self._test_impl(pa_schema, expected_types)
89+
90+
def test_unsupported_type(self):
91+
pa_schema = pa.schema([('_array', pa.list_(pa.int32()))])
92+
schema = Schema(pa_schema)
93+
with self.assertRaises(ValueError) as e:
94+
java_utils.to_paimon_schema(schema)
95+
self.assertEqual(
96+
str(e.exception), 'Found unsupported data type list<item: int32> for field _array.')
97+
98+
def _test_impl(self, pa_schema, expected_types):
99+
scheme = Schema(pa_schema)
100+
letters = string.ascii_letters
101+
identifier = 'default.' + ''.join(random.choice(letters) for _ in range(10))
102+
self.catalog.create_table(identifier, scheme, False)
103+
table = self.catalog.get_table(identifier)
104+
field_types = table._j_table.rowType().getFieldTypes()
105+
actual_types = list(map(lambda t: t.toString(), field_types))
106+
self.assertListEqual(actual_types, expected_types)

paimon_python_java/tests/test_write_and_read.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
# limitations under the License.
1717
################################################################################
1818

19+
import os
20+
import shutil
1921
import tempfile
2022
import unittest
2123
import pandas as pd
2224
import pyarrow as pa
2325
import setup_utils.java_setuputils as setuputils
2426

25-
from paimon_python_java import Catalog, Table
27+
from paimon_python_api import Schema
28+
from paimon_python_java import Catalog
2629
from paimon_python_java.java_gateway import get_gateway
27-
from paimon_python_java.tests.utils import create_simple_table
2830
from paimon_python_java.util import java_utils
2931
from py4j.protocol import Py4JJavaError
3032

@@ -35,15 +37,23 @@ class TableWriteReadTest(unittest.TestCase):
3537
def setUpClass(cls):
3638
setuputils.setup_java_bridge()
3739
cls.warehouse = tempfile.mkdtemp()
40+
cls.simple_pa_schema = pa.schema([
41+
('f0', pa.int32()),
42+
('f1', pa.string())
43+
])
44+
cls.catalog = Catalog.create({'warehouse': cls.warehouse})
45+
cls.catalog.create_database('default', False)
3846

3947
@classmethod
4048
def tearDownClass(cls):
4149
setuputils.clean()
50+
if os.path.exists(cls.warehouse):
51+
shutil.rmtree(cls.warehouse)
4252

4353
def testReadEmptyAppendTable(self):
44-
create_simple_table(self.warehouse, 'default', 'empty_append_table', False)
45-
catalog = Catalog.create({'warehouse': self.warehouse})
46-
table = catalog.get_table('default.empty_append_table')
54+
schema = Schema(self.simple_pa_schema)
55+
self.catalog.create_table('default.empty_append_table', schema, False)
56+
table = self.catalog.get_table('default.empty_append_table')
4757

4858
# read data
4959
read_builder = table.new_read_builder()
@@ -53,7 +63,10 @@ def testReadEmptyAppendTable(self):
5363
self.assertTrue(len(splits) == 0)
5464

5565
def testReadEmptyPkTable(self):
56-
create_simple_table(self.warehouse, 'default', 'empty_pk_table', True)
66+
schema = Schema(self.simple_pa_schema, primary_keys=['f0'], options={'bucket': '1'})
67+
self.catalog.create_table('default.empty_pk_table', schema, False)
68+
69+
# use Java API to generate data
5770
gateway = get_gateway()
5871
j_catalog_context = java_utils.to_j_catalog_context({'warehouse': self.warehouse})
5972
j_catalog = gateway.jvm.CatalogFactory.createCatalog(j_catalog_context)
@@ -84,7 +97,7 @@ def testReadEmptyPkTable(self):
8497
table_commit.close()
8598

8699
# read data
87-
table = Table(j_table, {})
100+
table = self.catalog.get_table('default.empty_pk_table')
88101
read_builder = table.new_read_builder()
89102
table_scan = read_builder.new_scan()
90103
table_read = read_builder.new_read()
@@ -98,19 +111,17 @@ def testReadEmptyPkTable(self):
98111
self.assertEqual(len(data_frames), 0)
99112

100113
def testWriteReadAppendTable(self):
101-
create_simple_table(self.warehouse, 'default', 'simple_append_table', False)
102-
103-
catalog = Catalog.create({'warehouse': self.warehouse})
104-
table = catalog.get_table('default.simple_append_table')
114+
schema = Schema(self.simple_pa_schema)
115+
self.catalog.create_table('default.simple_append_table', schema, False)
116+
table = self.catalog.get_table('default.simple_append_table')
105117

106118
# prepare data
107119
data = {
108120
'f0': [1, 2, 3],
109121
'f1': ['a', 'b', 'c'],
110122
}
111123
df = pd.DataFrame(data)
112-
df['f0'] = df['f0'].astype('int32')
113-
record_batch = pa.RecordBatch.from_pandas(df)
124+
record_batch = pa.RecordBatch.from_pandas(df, schema=self.simple_pa_schema)
114125

115126
# write and commit data
116127
write_builder = table.new_batch_write_builder()
@@ -138,13 +149,15 @@ def testWriteReadAppendTable(self):
138149
result = pd.concat(data_frames)
139150

140151
# check data (ignore index)
141-
pd.testing.assert_frame_equal(result.reset_index(drop=True), df.reset_index(drop=True))
152+
expected = df
153+
expected['f0'] = df['f0'].astype('int32')
154+
pd.testing.assert_frame_equal(
155+
result.reset_index(drop=True), expected.reset_index(drop=True))
142156

143157
def testWriteWrongSchema(self):
144-
create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)
145-
146-
catalog = Catalog.create({'warehouse': self.warehouse})
147-
table = catalog.get_table('default.test_wrong_schema')
158+
schema = Schema(self.simple_pa_schema)
159+
self.catalog.create_table('default.test_wrong_schema', schema, False)
160+
table = self.catalog.get_table('default.test_wrong_schema')
148161

149162
data = {
150163
'f0': [1, 2, 3],
@@ -155,7 +168,7 @@ def testWriteWrongSchema(self):
155168
('f0', pa.int64()),
156169
('f1', pa.string())
157170
])
158-
record_batch = pa.RecordBatch.from_pandas(df, schema)
171+
record_batch = pa.RecordBatch.from_pandas(df, schema=schema)
159172

160173
write_builder = table.new_batch_write_builder()
161174
table_write = write_builder.new_write()
@@ -169,16 +182,9 @@ def testWriteWrongSchema(self):
169182
\tInput schema is: [f0: Int(64, true), f1: Utf8]''')
170183

171184
def testCannotWriteDynamicBucketTable(self):
172-
create_simple_table(
173-
self.warehouse,
174-
'default',
175-
'test_dynamic_bucket',
176-
True,
177-
{'bucket': '-1'}
178-
)
179-
180-
catalog = Catalog.create({'warehouse': self.warehouse})
181-
table = catalog.get_table('default.test_dynamic_bucket')
185+
schema = Schema(self.simple_pa_schema, primary_keys=['f0'])
186+
self.catalog.create_table('default.test_dynamic_bucket', schema, False)
187+
table = self.catalog.get_table('default.test_dynamic_bucket')
182188

183189
with self.assertRaises(TypeError) as e:
184190
table.new_batch_write_builder()
@@ -187,9 +193,9 @@ def testCannotWriteDynamicBucketTable(self):
187193
"Doesn't support writing dynamic bucket or cross partition table.")
188194

189195
def testParallelRead(self):
190-
create_simple_table(self.warehouse, 'default', 'test_parallel_read', False)
191-
192196
catalog = Catalog.create({'warehouse': self.warehouse, 'max-workers': '2'})
197+
schema = Schema(self.simple_pa_schema)
198+
catalog.create_table('default.test_parallel_read', schema, False)
193199
table = catalog.get_table('default.test_parallel_read')
194200

195201
# prepare data
@@ -207,8 +213,7 @@ def testParallelRead(self):
207213
expected_data['f1'].append(str(i * 2))
208214

209215
df = pd.DataFrame(data)
210-
df['f0'] = df['f0'].astype('int32')
211-
record_batch = pa.RecordBatch.from_pandas(df)
216+
record_batch = pa.RecordBatch.from_pandas(df, schema=self.simple_pa_schema)
212217

213218
# write and commit data
214219
write_builder = table.new_batch_write_builder()

0 commit comments

Comments
 (0)