|
12 | 12 | conn.exec("DROP TABLE IF EXISTS documents") |
13 | 13 | conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding sparsevec(30522))") |
14 | 14 |
|
15 | | -model_id = "opensearch-project/opensearch-neural-sparse-encoding-v1" |
16 | | -model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id) |
17 | | -tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id) |
18 | | -special_token_ids = tokenizer.special_tokens_map.map { |_, token| tokenizer.vocab[token] } |
19 | | - |
20 | | -fetch_embeddings = lambda do |input| |
21 | | - feature = tokenizer.(input, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false) |
22 | | - output = model.(**feature)[0] |
23 | | - |
24 | | - values, _ = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1) |
25 | | - values = Torch.log(1 + Torch.relu(values)) |
26 | | - values[0.., special_token_ids] = 0 |
27 | | - values.to_a |
| 15 | +class EmbeddingModel |
| 16 | + def initialize(model_id) |
| 17 | + @model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id) |
| 18 | + @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id) |
| 19 | + @special_token_ids = @tokenizer.special_tokens_map.map { |_, token| @tokenizer.vocab[token] } |
| 20 | + end |
| 21 | + |
| 22 | + def embed(input) |
| 23 | + feature = @tokenizer.(input, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false) |
| 24 | + output = @model.(**feature)[0] |
| 25 | + values = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1)[0] |
| 26 | + values = Torch.log(1 + Torch.relu(values)) |
| 27 | + values[0.., @special_token_ids] = 0 |
| 28 | + values.to_a |
| 29 | + end |
28 | 30 | end |
29 | 31 |
|
| 32 | +model = EmbeddingModel.new("opensearch-project/opensearch-neural-sparse-encoding-v1") |
| 33 | + |
30 | 34 | input = [ |
31 | 35 | "The dog is barking", |
32 | 36 | "The cat is purring", |
33 | 37 | "The bear is growling" |
34 | 38 | ] |
35 | | -embeddings = fetch_embeddings.(input) |
| 39 | +embeddings = model.embed(input) |
36 | 40 | input.zip(embeddings) do |content, embedding| |
37 | 41 | conn.exec_params("INSERT INTO documents (content, embedding) VALUES ($1, $2)", [content, Pgvector::SparseVector.new(embedding)]) |
38 | 42 | end |
39 | 43 |
|
40 | 44 | query = "forest" |
41 | | -query_embedding = fetch_embeddings.([query])[0] |
| 45 | +query_embedding = model.embed([query])[0] |
42 | 46 | result = conn.exec_params("SELECT content FROM documents ORDER BY embedding <#> $1 LIMIT 5", [Pgvector::SparseVector.new(query_embedding)]) |
43 | 47 | result.each do |row| |
44 | 48 | puts row["content"] |
|
0 commit comments