Skip to content

Commit 3c0134f

Browse files
committed
Improved sparse search example [skip ci]
1 parent dd29e35 commit 3c0134f

1 file changed

Lines changed: 19 additions & 15 deletions

File tree

examples/sparse_search.rb

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,37 @@
1212
conn.exec("DROP TABLE IF EXISTS documents")
1313
conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding sparsevec(30522))")
1414

15-
model_id = "opensearch-project/opensearch-neural-sparse-encoding-v1"
16-
model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id)
17-
tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
18-
special_token_ids = tokenizer.special_tokens_map.map { |_, token| tokenizer.vocab[token] }
19-
20-
fetch_embeddings = lambda do |input|
21-
feature = tokenizer.(input, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false)
22-
output = model.(**feature)[0]
23-
24-
values, _ = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1)
25-
values = Torch.log(1 + Torch.relu(values))
26-
values[0.., special_token_ids] = 0
27-
values.to_a
15+
class EmbeddingModel
16+
def initialize(model_id)
17+
@model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id)
18+
@tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
19+
@special_token_ids = @tokenizer.special_tokens_map.map { |_, token| @tokenizer.vocab[token] }
20+
end
21+
22+
def embed(input)
23+
feature = @tokenizer.(input, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false)
24+
output = @model.(**feature)[0]
25+
values = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1)[0]
26+
values = Torch.log(1 + Torch.relu(values))
27+
values[0.., @special_token_ids] = 0
28+
values.to_a
29+
end
2830
end
2931

32+
model = EmbeddingModel.new("opensearch-project/opensearch-neural-sparse-encoding-v1")
33+
3034
input = [
3135
"The dog is barking",
3236
"The cat is purring",
3337
"The bear is growling"
3438
]
35-
embeddings = fetch_embeddings.(input)
39+
embeddings = model.embed(input)
3640
input.zip(embeddings) do |content, embedding|
3741
conn.exec_params("INSERT INTO documents (content, embedding) VALUES ($1, $2)", [content, Pgvector::SparseVector.new(embedding)])
3842
end
3943

4044
query = "forest"
41-
query_embedding = fetch_embeddings.([query])[0]
45+
query_embedding = model.embed([query])[0]
4246
result = conn.exec_params("SELECT content FROM documents ORDER BY embedding <#> $1 LIMIT 5", [Pgvector::SparseVector.new(query_embedding)])
4347
result.each do |row|
4448
puts row["content"]

0 commit comments

Comments
 (0)