Skip to content

Commit d9928c1

Browse files
committed
Added topic modeling example [skip ci]
1 parent 27d3370 commit d9928c1

3 files changed

Lines changed: 40 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Or check out some examples:
2525

2626
- [Embeddings](examples/openai_embeddings.rb) with OpenAI
2727
- [Binary embeddings](examples/cohere_embeddings.rb) with Cohere
28+
- [Topic modeling](examples/topic_modeling.rb) with tomoto.rb
2829
- [User-based recommendations](examples/disco_user_recs.rb) with Disco
2930
- [Item-based recommendations](examples/disco_item_recs.rb) with Disco
3031
- [Bulk loading](examples/bulk_loading.rb) with `COPY`

examples/Gemfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ gem "disco"
66
gem "numo-narray"
77
gem "pg"
88
gem "sequel"
9+
gem "tomoto"

examples/topic_modeling.rb

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
require "pg"
2+
require "pgvector"
3+
require "tomoto"
4+
5+
conn = PG.connect(dbname: "pgvector_example")
6+
conn.exec("CREATE EXTENSION IF NOT EXISTS vector")
7+
8+
conn.exec("DROP TABLE IF EXISTS documents")
9+
conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(20))")
10+
11+
def generate_embeddings(input)
12+
model = Tomoto::LDA.new(k: 20)
13+
stop_words = Set.new(["the", "is"])
14+
input.each do |text|
15+
model.add_doc(text.downcase.split.reject { |w| stop_words.include?(w) })
16+
end
17+
model.train(100) # iterations
18+
input.map.with_index do |_, i|
19+
model.docs[i].topics.values
20+
end
21+
end
22+
23+
input = [
24+
"The dog is barking",
25+
"The cat is purring",
26+
"The bear is growling"
27+
]
28+
embeddings = generate_embeddings(input)
29+
30+
input.zip(embeddings) do |content, embedding|
31+
conn.exec_params("INSERT INTO documents (content, embedding) VALUES ($1, $2)", [content, embedding])
32+
end
33+
34+
document_id = 1
35+
result = conn.exec_params("SELECT content FROM documents WHERE id != $1 ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = $1) LIMIT 5", [document_id])
36+
result.each do |row|
37+
puts row["content"]
38+
end

0 commit comments

Comments
 (0)