|
| 1 | +require "pg" |
| 2 | +require "pgvector" |
| 3 | +require "transformers-rb" |
| 4 | + |
| 5 | +conn = PG.connect(dbname: "pgvector_example") |
| 6 | +conn.exec("CREATE EXTENSION IF NOT EXISTS vector") |
| 7 | + |
| 8 | +conn.exec("DROP TABLE IF EXISTS documents") |
| 9 | +conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))") |
| 10 | +conn.exec("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))") |
| 11 | + |
| 12 | +model = Transformers::SentenceTransformer.new("sentence-transformers/multi-qa-MiniLM-L6-cos-v1") |
| 13 | + |
| 14 | +input = [ |
| 15 | + "The dog is barking", |
| 16 | + "The cat is purring", |
| 17 | + "The bear is growling" |
| 18 | +] |
| 19 | +embeddings = model.encode(input) |
| 20 | +input.zip(embeddings) do |content, embedding| |
| 21 | + conn.exec_params("INSERT INTO documents (content, embedding) VALUES ($1, $2)", [content, embedding]) |
| 22 | +end |
| 23 | + |
| 24 | +sql = <<~SQL |
| 25 | +WITH semantic_search AS ( |
| 26 | + SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank |
| 27 | + FROM documents |
| 28 | + ORDER BY embedding <=> $2 |
| 29 | + LIMIT 20 |
| 30 | +), |
| 31 | +keyword_search AS ( |
| 32 | + SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC) |
| 33 | + FROM documents, plainto_tsquery('english', $1) query |
| 34 | + WHERE to_tsvector('english', content) @@ query |
| 35 | + ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC |
| 36 | + LIMIT 20 |
| 37 | +) |
| 38 | +SELECT |
| 39 | + COALESCE(semantic_search.id, keyword_search.id) AS id, |
| 40 | + COALESCE(1.0 / ($3 + semantic_search.rank), 0.0) + |
| 41 | + COALESCE(1.0 / ($3 + keyword_search.rank), 0.0) AS score |
| 42 | +FROM semantic_search |
| 43 | +FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id |
| 44 | +ORDER BY score DESC |
| 45 | +LIMIT 5 |
| 46 | +SQL |
| 47 | +query = "growling bear" |
| 48 | +query_embedding = model.encode(query) |
| 49 | +k = 60 |
| 50 | +result = conn.exec_params(sql, [query, query_embedding, k]) |
| 51 | +result.each do |row| |
| 52 | + puts "document: #{row["id"]}, RRF score: #{row["score"]}" |
| 53 | +end |
0 commit comments