Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions python/python/benchmarks/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,49 @@ def test_ann_with_refine(test_dataset, benchmark):
assert result.num_rows > 0


N_BATCH_QUERIES = 32


@pytest.mark.benchmark(group="query_ann_batch")
def test_batch_ann_search(test_dataset, benchmark):
# One request carrying all query vectors: the index shares each partition's
# scan across the batch (issue #6822).
queries = np.random.randn(N_BATCH_QUERIES, N_DIMS).astype(np.float32)
result = benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=queries,
k=100,
nprobes=10,
),
)
assert result.num_rows > 0


@pytest.mark.benchmark(group="query_ann_batch")
def test_repeated_single_ann_search(test_dataset, benchmark):
# Baseline: the same query vectors issued one indexed search at a time.
queries = np.random.randn(N_BATCH_QUERIES, N_DIMS).astype(np.float32)

def run():
for q in queries:
test_dataset.to_table(
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
k=100,
nprobes=10,
),
)

benchmark(run)


@pytest.mark.benchmark(group="query_ann")
@pytest.mark.parametrize("selectivity", (0.25, 0.75))
@pytest.mark.parametrize("prefilter", (False, True))
Expand Down
40 changes: 40 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,46 @@ def test_batch_flat_query_matches_repeated_single_queries(dataset, queries):
)


@pytest.mark.parametrize("metric", ["l2", "cosine"])
@pytest.mark.parametrize("query_count", [3, 1], ids=["three_queries", "single_query"])
def test_batch_indexed_query_matches_repeated_single_queries(
dataset, metric, query_count
):
indexed = dataset.create_index(
"vector",
index_type="IVF_PQ",
num_partitions=4,
num_sub_vectors=16,
metric=metric,
)
# Give the query vectors deliberately different magnitudes: a cosine batch
# that normalized the whole concatenated key by one global norm would scale
# them unequally and diverge from per-query single search.
scales = np.linspace(0.1, 10.0, query_count).reshape(-1, 1)
queries = (np.random.randn(query_count, 128) * scales).astype(np.float32)
k = 5

# nprobes covers every partition so the shared-scan batch path and the
# repeated single-query path search the same partitions deterministically.
nearest_kwargs = {"use_index": True, "nprobes": 4}
batch = indexed.to_table(
columns=["id"],
nearest={"column": "vector", "q": queries, "k": k, **nearest_kwargs},
)

assert batch.column_names == ["query_index", "id", "_distance"]
assert batch["query_index"].to_pylist() == sum(
[[i] * k for i in range(query_count)], []
)

_assert_batch_matches_single_queries(
indexed,
queries,
k=k,
nearest_kwargs=nearest_kwargs,
)


def _assert_batch_matches_single_queries(ds, queries, k, nearest_kwargs):
batch = ds.to_table(
columns=["id"],
Expand Down
48 changes: 48 additions & 0 deletions rust/lance-index/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,54 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index {
)))
}

/// Whether this index can search multiple query vectors in a single pass
/// via [`VectorIndex::search_partitions_batch`], reading each partition's
/// storage once and scoring every query that probes it.
///
/// Defaults to `false`; the batch scan planner falls back to repeated
/// single-query search for indices that return `false`.
fn supports_batch_partition_search(&self) -> bool {
false
}

/// Search a batch of query vectors against a shared set of partitions.
///
/// `query.key` holds all query vectors concatenated (length
/// `query_count * dim`, where `query_count == partitions_per_query.len()`).
/// `partitions_per_query[i]` / `q_c_dists_per_query[i]` are the ranked
/// partition ids and query-to-centroid distances for query `i`.
///
/// Returns one [RecordBatch] per query (in query order) with the
/// [`VECTOR_RESULT_SCHEMA`] (`_distance`, `_rowid`) and at most `query.k`
/// rows each. Implementations should read each distinct partition's storage
/// only once and score every query assigned to it against the loaded data.
///
/// The default implementation returns an error; callers must gate on
/// [`VectorIndex::supports_batch_partition_search`].
#[allow(clippy::too_many_arguments)]
async fn search_partitions_batch(
self: Arc<Self>,
query: Query,
partitions_per_query: Vec<Arc<UInt32Array>>,
q_c_dists_per_query: Vec<Arc<Float32Array>>,
pre_filter: Arc<dyn PreFilter>,
metrics: Arc<dyn MetricsCollector>,
) -> Result<Vec<RecordBatch>>
where
Self: 'static,
{
let _ = (
query,
partitions_per_query,
q_c_dists_per_query,
pre_filter,
metrics,
);
Err(Error::not_supported(
"batch partition search is not supported for this index",
))
}

/// If the index is loadable by IVF, so it can be a sub-index that
/// is loaded on demand by IVF.
fn is_loadable(&self) -> bool;
Expand Down
Loading