diff --git a/src_cpp/include/py_connection.h b/src_cpp/include/py_connection.h index 61a971e..8b6f8fe 100644 --- a/src_cpp/include/py_connection.h +++ b/src_cpp/include/py_connection.h @@ -57,7 +57,7 @@ class PyConnection { py::object arrowTable); std::unique_ptr createArrowRelTable(const std::string& tableName, py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName, - const std::string& layout, py::object indptrTable); + const std::string& layout, py::object indptrTable, const std::string& dstColName = "to"); std::unique_ptr dropArrowTable(const std::string& tableName); static Value transformPythonValue(const py::handle& val); diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index ffa1a56..660196c 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -55,7 +55,8 @@ void PyConnection::initialize(py::handle& m) { py::arg("arrow_table")) .def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"), py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"), - py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none()) + py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none(), + py::arg("dst_col_name") = "to") .def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name")); PyDateTime_IMPORT; } @@ -1070,7 +1071,7 @@ std::unique_ptr PyConnection::createArrowTable(const std::string& std::unique_ptr PyConnection::createArrowRelTable(const std::string& tableName, py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName, - const std::string& layout, py::object indptrTable) { + const std::string& layout, py::object indptrTable, const std::string& dstColName) { auto& stateRef = refState(); py::gil_scoped_acquire acquire; @@ -1097,7 +1098,7 @@ std::unique_ptr PyConnection::createArrowRelTable(const std::stri keepAlive.append(exportedIndptr.keepAlive); result = ArrowTableSupport::createRelTableFromArrowCSR(stateRef.ref(), tableName, srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays), - std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays)); + std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays), dstColName); } else { throw RuntimeException("Arrow relationship table layout must be FLAT or CSR"); } diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index 9d16f0f..c20c4b3 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -343,6 +343,7 @@ def _setup_signatures() -> None: ctypes.POINTER(_ArrowSchema), ctypes.POINTER(_ArrowArray), ctypes.c_uint64, + ctypes.c_char_p, ctypes.POINTER(_LbugQueryResult), ] _LIB.lbug_connection_create_arrow_rel_table_csr.restype = ctypes.c_int @@ -2341,6 +2342,7 @@ def create_arrow_rel_table( dst_table_name: str, layout: Any = "FLAT", indptr_dataframe: Any | None = None, + dst_col_name: str = "to", ) -> QueryResult: layout_value = getattr(layout, "value", layout) layout_value = str(layout_value).upper() @@ -2385,6 +2387,7 @@ def create_arrow_rel_table( ctypes.byref(indptr_schema), indptr_arrays, len(indptr_arrays), + dst_col_name.encode("utf-8"), ctypes.byref(result), ) if state != _LBUG_SUCCESS and not result._query_result: diff --git a/src_py/connection.py b/src_py/connection.py index 80822ad..2c9f9a0 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -887,6 +887,7 @@ def create_arrow_rel_table( dst_table_name: str, layout: ArrowRelTableLayout | str = ArrowRelTableLayout.FLAT, indptr_dataframe: Any | None = None, + dst_col_name: str = "to", ) -> QueryResult: """ Create an Arrow memory-backed relationship table from a DataFrame. @@ -908,13 +909,17 @@ def create_arrow_rel_table( layout : ArrowRelTableLayout | str Relationship layout. FLAT expects ``dataframe`` to contain ``from`` and ``to`` endpoint columns. CSR expects ``dataframe`` to contain a - ``to`` destination offset column plus properties, and - ``indptr_dataframe`` to contain source offsets. + destination offset column (named by ``dst_col_name``) plus + properties, and ``indptr_dataframe`` to contain source offsets. indptr_dataframe : Any | None A pandas DataFrame, polars DataFrame, or PyArrow table containing CSR source offsets. Required when ``layout`` is CSR. + dst_col_name : str + Name of the destination offset column in the CSR indices table. + Defaults to ``"to"``. Only used when ``layout`` is CSR. + Returns ------- QueryResult @@ -936,6 +941,7 @@ def create_arrow_rel_table( dst_table_name, layout_value, indptr_dataframe, + dst_col_name, ) except NotImplementedError: py_connection = self._get_pybind_connection() @@ -949,6 +955,7 @@ def create_arrow_rel_table( dst_table_name, layout_value, indptr_dataframe, + dst_col_name, ) if not query_result_internal.isSuccess(): raise RuntimeError(query_result_internal.getErrorMessage()) diff --git a/test/test_arrow_memory_backed_table.py b/test/test_arrow_memory_backed_table.py index d2edd0f..1cefa92 100644 --- a/test/test_arrow_memory_backed_table.py +++ b/test/test_arrow_memory_backed_table.py @@ -401,7 +401,63 @@ def test_arrow_memory_backed_csr_arrow_rel_table(conn_db_empty: ConnDB) -> None: conn.drop_arrow_table("arrow_csr_people") -def test_arrow_memory_backed_native_node_and_arrow_rel_table( +def test_arrow_memory_backed_csr_rel_table_custom_dst_col( + conn_db_empty: ConnDB, +) -> None: + """Test Arrow CSR relationship table with a custom destination column name.""" + conn, _ = conn_db_empty + + import ladybug as lb + + pa = pytest.importorskip("pyarrow") + + people = pa.Table.from_arrays( + [pa.array([1, 2, 3], type=pa.int64())], + names=["id"], + ) + conn.create_arrow_table("csr_custom_dst_people", people) + + # Use "destination" instead of the default "to" + indices = pa.Table.from_arrays( + [ + pa.array([1, 2, 2], type=pa.uint64()), + pa.array([10, 20, 30], type=pa.int64()), + ], + names=["destination", "weight"], + ) + indptr = pa.Table.from_arrays( + [pa.array([0, 2, 3, 3], type=pa.uint64())], + names=["indptr"], + ) + conn.create_arrow_rel_table( + "csr_custom_dst_knows", + indices, + "csr_custom_dst_people", + "csr_custom_dst_people", + layout=lb.ArrowRelTableLayout.CSR, + indptr_dataframe=indptr, + dst_col_name="destination", + ) + + result = conn.execute( + "MATCH (a:csr_custom_dst_people)-[r:csr_custom_dst_knows]->(b:csr_custom_dst_people) " + "RETURN a.id, b.id, r.weight ORDER BY a.id, b.id" + ) + rows = [] + while result.has_next(): + rows.append(result.get_next()) + + assert rows == [ + [1, 2, 10], + [1, 3, 20], + [2, 3, 30], + ] + + conn.drop_arrow_table("csr_custom_dst_knows") + conn.drop_arrow_table("csr_custom_dst_people") + + +def test_arrow_memory_backed_rel_table_over_native_node_tables( conn_db_empty: ConnDB, ) -> None: """Test an Arrow memory-backed relationship over native node tables."""