Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class PyConnection {
py::object arrowTable);
std::unique_ptr<PyQueryResult> createArrowRelTable(const std::string& tableName,
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
const std::string& layout, py::object indptrTable);
const std::string& layout, py::object indptrTable, const std::string& dstColName = "to");
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);

static Value transformPythonValue(const py::handle& val);
Expand Down
7 changes: 4 additions & 3 deletions src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ void PyConnection::initialize(py::handle& m) {
py::arg("arrow_table"))
.def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"),
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"),
py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none())
py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none(),
py::arg("dst_col_name") = "to")
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
PyDateTime_IMPORT;
}
Expand Down Expand Up @@ -1070,7 +1071,7 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string&

std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::string& tableName,
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
const std::string& layout, py::object indptrTable) {
const std::string& layout, py::object indptrTable, const std::string& dstColName) {
auto& stateRef = refState();
py::gil_scoped_acquire acquire;

Expand All @@ -1097,7 +1098,7 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::stri
keepAlive.append(exportedIndptr.keepAlive);
result = ArrowTableSupport::createRelTableFromArrowCSR(stateRef.ref(), tableName,
srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays),
std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays));
std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays), dstColName);
} else {
throw RuntimeException("Arrow relationship table layout must be FLAT or CSR");
}
Expand Down
3 changes: 3 additions & 0 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def _setup_signatures() -> None:
ctypes.POINTER(_ArrowSchema),
ctypes.POINTER(_ArrowArray),
ctypes.c_uint64,
ctypes.c_char_p,
ctypes.POINTER(_LbugQueryResult),
]
_LIB.lbug_connection_create_arrow_rel_table_csr.restype = ctypes.c_int
Expand Down Expand Up @@ -2341,6 +2342,7 @@ def create_arrow_rel_table(
dst_table_name: str,
layout: Any = "FLAT",
indptr_dataframe: Any | None = None,
dst_col_name: str = "to",
) -> QueryResult:
layout_value = getattr(layout, "value", layout)
layout_value = str(layout_value).upper()
Expand Down Expand Up @@ -2385,6 +2387,7 @@ def create_arrow_rel_table(
ctypes.byref(indptr_schema),
indptr_arrays,
len(indptr_arrays),
dst_col_name.encode("utf-8"),
ctypes.byref(result),
)
if state != _LBUG_SUCCESS and not result._query_result:
Expand Down
11 changes: 9 additions & 2 deletions src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ def create_arrow_rel_table(
dst_table_name: str,
layout: ArrowRelTableLayout | str = ArrowRelTableLayout.FLAT,
indptr_dataframe: Any | None = None,
dst_col_name: str = "to",
) -> QueryResult:
"""
Create an Arrow memory-backed relationship table from a DataFrame.
Expand All @@ -908,13 +909,17 @@ def create_arrow_rel_table(
layout : ArrowRelTableLayout | str
Relationship layout. FLAT expects ``dataframe`` to contain ``from``
and ``to`` endpoint columns. CSR expects ``dataframe`` to contain a
``to`` destination offset column plus properties, and
``indptr_dataframe`` to contain source offsets.
destination offset column (named by ``dst_col_name``) plus
properties, and ``indptr_dataframe`` to contain source offsets.

indptr_dataframe : Any | None
A pandas DataFrame, polars DataFrame, or PyArrow table containing
CSR source offsets. Required when ``layout`` is CSR.

dst_col_name : str
Name of the destination offset column in the CSR indices table.
Defaults to ``"to"``. Only used when ``layout`` is CSR.

Returns
-------
QueryResult
Expand All @@ -936,6 +941,7 @@ def create_arrow_rel_table(
dst_table_name,
layout_value,
indptr_dataframe,
dst_col_name,
)
except NotImplementedError:
py_connection = self._get_pybind_connection()
Expand All @@ -949,6 +955,7 @@ def create_arrow_rel_table(
dst_table_name,
layout_value,
indptr_dataframe,
dst_col_name,
)
if not query_result_internal.isSuccess():
raise RuntimeError(query_result_internal.getErrorMessage())
Expand Down
58 changes: 57 additions & 1 deletion test/test_arrow_memory_backed_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,63 @@ def test_arrow_memory_backed_csr_arrow_rel_table(conn_db_empty: ConnDB) -> None:
conn.drop_arrow_table("arrow_csr_people")


def test_arrow_memory_backed_native_node_and_arrow_rel_table(
def test_arrow_memory_backed_csr_rel_table_custom_dst_col(
conn_db_empty: ConnDB,
) -> None:
"""Test Arrow CSR relationship table with a custom destination column name."""
conn, _ = conn_db_empty

import ladybug as lb

pa = pytest.importorskip("pyarrow")

people = pa.Table.from_arrays(
[pa.array([1, 2, 3], type=pa.int64())],
names=["id"],
)
conn.create_arrow_table("csr_custom_dst_people", people)

# Use "destination" instead of the default "to"
indices = pa.Table.from_arrays(
[
pa.array([1, 2, 2], type=pa.uint64()),
pa.array([10, 20, 30], type=pa.int64()),
],
names=["destination", "weight"],
)
indptr = pa.Table.from_arrays(
[pa.array([0, 2, 3, 3], type=pa.uint64())],
names=["indptr"],
)
conn.create_arrow_rel_table(
"csr_custom_dst_knows",
indices,
"csr_custom_dst_people",
"csr_custom_dst_people",
layout=lb.ArrowRelTableLayout.CSR,
indptr_dataframe=indptr,
dst_col_name="destination",
)

result = conn.execute(
"MATCH (a:csr_custom_dst_people)-[r:csr_custom_dst_knows]->(b:csr_custom_dst_people) "
"RETURN a.id, b.id, r.weight ORDER BY a.id, b.id"
)
rows = []
while result.has_next():
rows.append(result.get_next())

assert rows == [
[1, 2, 10],
[1, 3, 20],
[2, 3, 30],
]

conn.drop_arrow_table("csr_custom_dst_knows")
conn.drop_arrow_table("csr_custom_dst_people")


def test_arrow_memory_backed_rel_table_over_native_node_tables(
conn_db_empty: ConnDB,
) -> None:
"""Test an Arrow memory-backed relationship over native node tables."""
Expand Down
Loading