Skip to content

Commit ab2774e

Browse files
committed
Added hybrid search example [skip ci]
1 parent 1aac1f1 commit ab2774e

3 files changed

Lines changed: 74 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Or check out some examples:
3333

3434
- [Embeddings](examples/openai/example.js) with OpenAI
3535
- [Sentence embeddings](examples/transformers/example.js) with Transformers.js
36+
- [Hybrid search](examples/hybrid-search/example.js) with Transformers.js
3637
- [Recommendations](examples/disco/example.js) with Disco
3738
- [Bulk loading](examples/loading/example.js) with `COPY`
3839

examples/hybrid-search/example.js

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import { pipeline } from '@xenova/transformers';
2+
import pg from 'pg';
3+
import pgvector from 'pgvector/pg';
4+
5+
const client = new pg.Client({database: 'pgvector_example'});
6+
await client.connect();
7+
8+
await client.query('CREATE EXTENSION IF NOT EXISTS vector');
9+
await pgvector.registerTypes(client);
10+
11+
await client.query('DROP TABLE IF EXISTS documents');
12+
await client.query('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))');
13+
await client.query("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))");
14+
15+
const input = [
16+
'The dog is barking',
17+
'The cat is purring',
18+
'The bear is growling'
19+
];
20+
21+
const extractor = await pipeline('feature-extraction', 'Xenova/multi-qa-MiniLM-L6-cos-v1');
22+
23+
async function generateEmbedding(content) {
24+
const output = await extractor(content, {pooling: 'mean', normalize: true});
25+
return Array.from(output.data);
26+
}
27+
28+
for (let [i, content] of input.entries()) {
29+
const embedding = await generateEmbedding(content);
30+
await client.query('INSERT INTO documents (content, embedding) VALUES ($1, $2)', [content, pgvector.toSql(embedding)]);
31+
}
32+
33+
const sql = `
34+
WITH semantic_search AS (
35+
SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
36+
FROM documents
37+
ORDER BY embedding <=> $2
38+
LIMIT 20
39+
),
40+
keyword_search AS (
41+
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
42+
FROM documents, plainto_tsquery('english', $1) query
43+
WHERE to_tsvector('english', content) @@ query
44+
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
45+
LIMIT 20
46+
)
47+
SELECT
48+
COALESCE(semantic_search.id, keyword_search.id) AS id,
49+
COALESCE(1.0 / ($3 + semantic_search.rank), 0.0) +
50+
COALESCE(1.0 / ($3 + keyword_search.rank), 0.0) AS score
51+
FROM semantic_search
52+
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
53+
ORDER BY score DESC
54+
LIMIT 5
55+
`;
56+
const query = 'growling bear'
57+
const embedding = await generateEmbedding(query);
58+
const k = 60
59+
const { rows } = await client.query(sql, [query, pgvector.toSql(embedding), k]);
60+
for (let row of rows) {
61+
console.log(row);
62+
}
63+
64+
await client.end();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"private": true,
3+
"type": "module",
4+
"dependencies": {
5+
"@xenova/transformers": "^2.6.0",
6+
"pg": "^8.11.3",
7+
"pgvector": "file:../.."
8+
}
9+
}

0 commit comments

Comments
 (0)