Skip to content

Commit cfca7e3

Browse files
committed
Added Citus example [skip ci]
1 parent d9928c1 commit cfca7e3

2 files changed

Lines changed: 51 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Or check out some examples:
2828
- [Topic modeling](examples/topic_modeling.rb) with tomoto.rb
2929
- [User-based recommendations](examples/disco_user_recs.rb) with Disco
3030
- [Item-based recommendations](examples/disco_item_recs.rb) with Disco
31+
- [Horizontal scaling](examples/citus.rb) with Citus
3132
- [Bulk loading](examples/bulk_loading.rb) with `COPY`
3233

3334
## pg

examples/citus.rb

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
require "numo/narray"
2+
require "pg"
3+
require "pgvector"
4+
5+
# generate random data
6+
rows = 1000000
7+
dimensions = 128
8+
embeddings = Numo::SFloat.new(rows, dimensions).rand
9+
categories = Numo::Int64.new(rows, dimensions).rand(100)
10+
queries = Numo::SFloat.new(10, dimensions).rand
11+
12+
# enable extensions
13+
conn = PG.connect(dbname: "pgvector_example")
14+
conn.exec("CREATE EXTENSION IF NOT EXISTS citus")
15+
conn.exec("CREATE EXTENSION IF NOT EXISTS vector")
16+
17+
# GUC variables set on the session do not propagate to Citus workers
18+
# https://github.com/citusdata/citus/issues/462
19+
# you can either:
20+
# 1. set them on the system, user, or database and reconnect
21+
# 2. set them for a transaction with SET LOCAL
22+
conn.exec("ALTER DATABASE pgvector_citus SET maintenance_work_mem = '512MB'")
23+
conn.exec("ALTER DATABASE pgvector_citus SET hnsw.ef_search = 20")
24+
conn.close
25+
26+
# reconnect for updated GUC variables to take effect
27+
conn = PG.connect(dbname: "pgvector_example")
28+
29+
puts "Creating distributed table"
30+
conn.exec("DROP TABLE IF EXISTS items")
31+
conn.exec("CREATE TABLE items (id bigserial, embedding vector(#{dimensions}), category_id bigint, PRIMARY KEY (id, category_id))")
32+
conn.exec("SET citus.shard_count = 4")
33+
conn.exec("SELECT create_distributed_table('items', 'category_id')")
34+
35+
puts "Loading data in parallel"
36+
coder = PG::BinaryEncoder::CopyRow.new
37+
conn.copy_data("COPY items (embedding, category_id) FROM STDIN WITH (FORMAT BINARY)", coder) do
38+
embeddings.each_over_axis(0).with_index do |embedding, i|
39+
conn.put_copy_data([Pgvector::Vector.new(embedding).to_binary, [categories[i]].pack("q>")])
40+
end
41+
end
42+
43+
puts "Creating index in parallel"
44+
conn.exec("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)")
45+
46+
puts "Running distributed queries"
47+
queries.each_over_axis(0) do |query|
48+
items = conn.exec_params("SELECT id FROM items ORDER BY embedding <-> $1 LIMIT 10", [Pgvector::Vector.new(query)])
49+
p items.map { |v| v["id"].to_i }
50+
end

0 commit comments

Comments
 (0)