Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import datafusion as dfn
import numpy as np
import pyarrow as pa
import pytest
from datafusion import col, lit
from datafusion import functions as F
Expand All @@ -29,6 +30,8 @@ def _doctest_namespace(doctest_namespace: dict) -> None:
"""Add common imports to the doctest namespace."""
doctest_namespace["dfn"] = dfn
doctest_namespace["np"] = np
doctest_namespace["pa"] = pa
Comment thread
timsaucer marked this conversation as resolved.
doctest_namespace["col"] = col
doctest_namespace["lit"] = lit
doctest_namespace["F"] = F
doctest_namespace["ctx"] = dfn.SessionContext()
58 changes: 57 additions & 1 deletion crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion::execution::context::{
};
use datafusion::execution::disk_manager::DiskManagerMode;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
use datafusion::execution::options::ReadOptions;
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::prelude::{
Expand Down Expand Up @@ -956,6 +956,39 @@ impl PySessionContext {
Ok(())
}

#[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
pub fn register_arrow(
&self,
name: &str,
path: &str,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<()> {
let mut options = ArrowReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_arrow(name, path, options);
wait_for_future(py, result)??;
Ok(())
}

pub fn register_batch(
&self,
name: &str,
batch: PyArrowType<RecordBatch>,
) -> PyDataFusionResult<()> {
self.ctx.register_batch(name, batch.0)?;
Ok(())
}

// Registers a PyArrow.Dataset
pub fn register_dataset(
&self,
Expand Down Expand Up @@ -1184,6 +1217,29 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

#[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
pub fn read_arrow(
&self,
path: &str,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
let mut options = ArrowReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.read_arrow(path, options);
let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}

pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
let session = self.clone().into_bound_py_any(table.py())?;
let table = PyTable::new(table, Some(session))?;
Expand Down
122 changes: 122 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,26 @@ def register_udtf(self, func: TableFunction) -> None:
"""Register a user defined table function."""
self.ctx.register_udtf(func._udtf)

def register_batch(self, name: str, batch: pa.RecordBatch) -> None:
"""Register a single :py:class:`pa.RecordBatch` as a table.

Args:
name: Name of the resultant table.
batch: Record batch to register as a table.

Examples:
>>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
>>> ctx.register_batch("batch_tbl", batch)
>>> ctx.sql("SELECT * FROM batch_tbl").collect()[0].column(0)
<pyarrow.lib.Int64Array object at ...>
[
1,
2,
3
]
"""
self.ctx.register_batch(name, batch)

def register_record_batches(
self, name: str, partitions: list[list[pa.RecordBatch]]
) -> None:
Expand Down Expand Up @@ -1092,6 +1112,49 @@ def register_avro(
name, str(path), schema, file_extension, table_partition_cols
)

def register_arrow(
self,
name: str,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".arrow",
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
) -> None:
"""Register an Arrow IPC file as a table.

The registered table can be referenced from SQL statements executed
against this context.

Args:
name: Name of the table to register.
path: Path to the Arrow IPC file.
schema: The data source schema.
file_extension: File extension to select.
table_partition_cols: Partition columns.

Examples:
Comment thread
timsaucer marked this conversation as resolved.
>>> import tempfile, os
>>> table = pa.table({"x": [10, 20, 30]})
>>> with tempfile.TemporaryDirectory() as tmpdir:
... path = os.path.join(tmpdir, "data.arrow")
... with pa.ipc.new_file(path, table.schema) as writer:
... writer.write_table(table)
... ctx.register_arrow("arrow_tbl", path)
... ctx.sql("SELECT * FROM arrow_tbl").collect()[0].column(0)
<pyarrow.lib.Int64Array object at ...>
[
10,
20,
30
]
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = _convert_table_partition_cols(table_partition_cols)
self.ctx.register_arrow(
name, str(path), schema, file_extension, table_partition_cols
)

def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None:
"""Register a :py:class:`pa.dataset.Dataset` as a table.

Expand Down Expand Up @@ -1328,6 +1391,65 @@ def read_avro(
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
)

def read_arrow(
self,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".arrow",
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading an Arrow IPC data source.

Args:
path: Path to the Arrow IPC file.
schema: The data source schema.
file_extension: File extension to select.
file_partition_cols: Partition columns.

Returns:
DataFrame representation of the read Arrow IPC file.

Examples:
Comment thread
timsaucer marked this conversation as resolved.
>>> import tempfile, os
>>> table = pa.table({"a": [1, 2, 3]})
>>> with tempfile.TemporaryDirectory() as tmpdir:
... path = os.path.join(tmpdir, "data.arrow")
... with pa.ipc.new_file(path, table.schema) as writer:
... writer.write_table(table)
... df = ctx.read_arrow(path)
... df.collect()[0].column(0)
<pyarrow.lib.Int64Array object at ...>
[
1,
2,
3
]
"""
if file_partition_cols is None:
file_partition_cols = []
file_partition_cols = _convert_table_partition_cols(file_partition_cols)
return DataFrame(
self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols)
)
Comment thread
timsaucer marked this conversation as resolved.

def read_empty(self) -> DataFrame:
"""Create an empty :py:class:`DataFrame` with no columns or rows.

This is an alias for :meth:`empty_table`.

Returns:
An empty DataFrame.

Examples:
>>> df = ctx.read_empty()
>>> result = df.collect()
>>> len(result)
1
>>> result[0].num_columns
0
"""
return self.empty_table()

def read_table(
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
) -> DataFrame:
Expand Down
62 changes: 62 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,68 @@ def test_read_avro(ctx):
assert avro_df is not None


def test_read_arrow(ctx, tmp_path):
# Write an Arrow IPC file, then read it back
table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]})
arrow_path = tmp_path / "test.arrow"
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
writer.write_table(table)

df = ctx.read_arrow(str(arrow_path))
result = df.collect()
assert result[0].column(0) == pa.array([1, 2, 3])
assert result[0].column(1) == pa.array(["x", "y", "z"])

# Also verify pathlib.Path works
df = ctx.read_arrow(arrow_path)
result = df.collect()
assert result[0].column(0) == pa.array([1, 2, 3])


def test_read_empty(ctx):
df = ctx.read_empty()
result = df.collect()
assert len(result) == 1
assert result[0].num_columns == 0

df = ctx.empty_table()
result = df.collect()
assert len(result) == 1
assert result[0].num_columns == 0
Comment thread
timsaucer marked this conversation as resolved.


def test_register_arrow(ctx, tmp_path):
# Write an Arrow IPC file, then register and query it
table = pa.table({"x": [10, 20, 30]})
arrow_path = tmp_path / "test.arrow"
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
writer.write_table(table)

ctx.register_arrow("arrow_tbl", str(arrow_path))
result = ctx.sql("SELECT * FROM arrow_tbl").collect()
assert result[0].column(0) == pa.array([10, 20, 30])

# Also verify pathlib.Path works
ctx.register_arrow("arrow_tbl_path", arrow_path)
result = ctx.sql("SELECT * FROM arrow_tbl_path").collect()
assert result[0].column(0) == pa.array([10, 20, 30])


def test_register_batch(ctx):
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
ctx.register_batch("batch_tbl", batch)
result = ctx.sql("SELECT * FROM batch_tbl").collect()
assert result[0].column(0) == pa.array([1, 2, 3])
assert result[0].column(1) == pa.array([4, 5, 6])


def test_register_batch_empty(ctx):
batch = pa.RecordBatch.from_pydict({"a": pa.array([], type=pa.int64())})
ctx.register_batch("empty_batch_tbl", batch)
result = ctx.sql("SELECT * FROM empty_batch_tbl").collect()
assert result[0].num_rows == 0


def test_create_sql_options():
SQLOptions()

Expand Down
Loading