|
| 1 | +require "json" |
| 2 | +require "net/http" |
| 3 | +require "pg" |
| 4 | + |
| 5 | +conn = PG.connect(dbname: "pgvector_example") |
| 6 | +conn.exec("CREATE EXTENSION IF NOT EXISTS vector") |
| 7 | + |
| 8 | +conn.exec("DROP TABLE IF EXISTS documents") |
| 9 | +conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding bit(1024))") |
| 10 | + |
| 11 | +# https://docs.cohere.com/reference/embed |
| 12 | +def fetch_embeddings(texts, input_type) |
| 13 | + url = "https://api.cohere.com/v1/embed" |
| 14 | + headers = { |
| 15 | + "Authorization" => "Bearer #{ENV.fetch("CO_API_KEY")}", |
| 16 | + "Content-Type" => "application/json" |
| 17 | + } |
| 18 | + data = { |
| 19 | + texts: texts, |
| 20 | + model: "embed-english-v3.0", |
| 21 | + input_type: input_type, |
| 22 | + embedding_types: ["ubinary"] |
| 23 | + } |
| 24 | + |
| 25 | + response = Net::HTTP.post(URI(url), data.to_json, headers).tap(&:value) |
| 26 | + JSON.parse(response.body)["embeddings"]["ubinary"].map { |e| e.map { |v| v.chr.unpack1("B*") }.join } |
| 27 | +end |
| 28 | + |
| 29 | +input = [ |
| 30 | + "The dog is barking", |
| 31 | + "The cat is purring", |
| 32 | + "The bear is growling" |
| 33 | +] |
| 34 | +embeddings = fetch_embeddings(input, "search_document") |
| 35 | +input.zip(embeddings) do |content, embedding| |
| 36 | + conn.exec_params("INSERT INTO documents (content, embedding) VALUES ($1, $2)", [content, embedding]) |
| 37 | +end |
| 38 | + |
| 39 | +query = "forest" |
| 40 | +query_embedding = fetch_embeddings([query], "search_query")[0] |
| 41 | +result = conn.exec_params("SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5", [query_embedding]) |
| 42 | +result.each do |row| |
| 43 | + puts row["content"] |
| 44 | +end |
0 commit comments