Skip to content

Commit 59fafc4

Browse files
committed
Created a first prototype
1 parent 9d2f55e commit 59fafc4

3 files changed

Lines changed: 243 additions & 0 deletions

File tree

main.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import time
2+
3+
import wiki
4+
import model
5+
6+
7+
def context(query, n_wiki_pages=5, n_top_chunks=8, min_summary_length=100, max_summary_length=200, verbose=False):
8+
"""
9+
provides context for the query by searching for relevant wikipedia pages, extracting the most relevant information
10+
and summarizing the facts
11+
:param query: query as a string for which the context should be provided
12+
:param n_wiki_pages: (optional) number of wikipedia pages that should be searched
13+
:param n_top_chunks: (optional) number of highest scoring chunks that should be summarized
14+
:param min_summary_length: (optional) minimum length of the summary (in tokens)
15+
:param max_summary_length: (optional) maximum length of the summary (in tokens)
16+
:param verbose: (optional) whether to print the progress
17+
:return: summarized facts from the wikipedia pages as a string
18+
"""
19+
# todo: finding optimal default values for the parameters
20+
# todo: remove time measurements
21+
if verbose:
22+
print("Query:", query)
23+
time1 = time.time()
24+
# create wikipedia search prompt
25+
wiki_search_prompt = model.create_wiki_search_prompt(query, verbose=verbose)
26+
time2 = time.time()
27+
print("Time taken to get wiki search prompt:", time2 - time1, "seconds")
28+
29+
# get relevant wikipedia pages
30+
page_titles = wiki.get_pages(wiki_search_prompt, n_results=n_wiki_pages)
31+
if verbose:
32+
print("Page titles:", page_titles)
33+
# get the content of the wikipedia pages and split it into chunks
34+
time3 = time.time()
35+
print("Time taken to get wiki pages:", time3 - time2, "seconds")
36+
wiki_chunks = wiki.get_text_chunks(page_titles, chunk_length=512, verbose=verbose)
37+
time4 = time.time()
38+
print("Time taken to get wiki chunks:", time4 - time3, "seconds")
39+
40+
# get the embeddings for the query and the wiki chunks
41+
query_embedding = model.get_embeddings([query])
42+
time5 = time.time()
43+
print("Time taken to get query embedding:", time5 - time4, "seconds")
44+
wiki_embeddings = model.get_embeddings(wiki_chunks)
45+
time6 = time.time()
46+
print("Time taken to get wiki embeddings:", time6 - time5, "seconds")
47+
# calculate the similarity between the query and the wiki chunks
48+
similarities = model.calculate_similarity(query_embedding, wiki_embeddings, top_k=n_top_chunks)
49+
time7 = time.time()
50+
print("Time taken to calculate similarity:", time7 - time6, "seconds")
51+
top_chunks = ""
52+
53+
for i, similarity in enumerate(similarities):
54+
top_chunks += "<" + str(i + 1) + "> " + wiki_chunks[similarity['corpus_id']] + " </" + str(i + 1) + ">\n\n"
55+
if verbose:
56+
print("Chunk" + str(i + 1) + ":", wiki_chunks[similarity['corpus_id']], "\t\t\tscore:", similarity['score'])
57+
58+
time8 = time.time()
59+
print("Time taken to get concatenated top chunk string:", time8 - time7, "seconds")
60+
# summarize facts from the top wiki chunks
61+
summarized_facts = model.summarize_facts(top_chunks, min_length=min_summary_length, max_length=max_summary_length)
62+
time9 = time.time()
63+
print("Time taken to summarize facts:", time9 - time8, "seconds")
64+
print("Total time taken:", time9 - time1, "seconds")
65+
return summarized_facts
66+
67+
68+
if __name__ == "__main__":
69+
user_query = "What are the names of Barack Obamas children?"
70+
context = context(user_query, verbose=True)
71+
print(context)

model.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from sentence_transformers.util import semantic_search
2+
from sentence_transformers import SentenceTransformer
3+
from transformers import pipeline
4+
5+
# Loading models
6+
device = "cpu" # todo only for cpu testing, can be removed to automatically choose the device
7+
gist_embedding = SentenceTransformer("avsolatorio/GIST-small-Embedding-v0", device=device)
8+
bart_summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
9+
flan_t5 = pipeline("text2text-generation", model="google/flan-t5-base", device=device)
10+
11+
12+
# ------------------------------------------------ Embedding Model ----------------------------------------------------
13+
14+
15+
def get_embeddings(texts):
16+
"""
17+
gets the embeddings for the texts using the sentence transformer embedding model
18+
:param texts: list of texts for which the embeddings should be calculated
19+
:return: embeddings
20+
"""
21+
# todo: check out arguments of the encode method for example 'prompt' or 'precision'
22+
return gist_embedding.encode(texts)
23+
24+
25+
def calculate_similarity(query_embedding, wiki_embeddings, top_k=10):
26+
"""
27+
calculates the similarity between the query_embedding and the wiki_embeddings which can be used to filter the
28+
wiki content
29+
for the most relevant information
30+
:param query_embedding: embedding of the query
31+
:param wiki_embeddings: list of chunked embeddings of the wikipedia content
32+
:param top_k: number of most similar chunks that should be returned
33+
:return: list of dictionary's with similarity scores ['score'] between the query_embedding and each embedding chunk
34+
and the index of the chunk ['corpus_id']
35+
"""
36+
return semantic_search(query_embedding, wiki_embeddings, top_k=top_k)[0]
37+
38+
39+
# --------------------------------------------------- Flan T5 ---------------------------------------------------------
40+
41+
def create_wiki_search_prompt(query, verbose=False):
42+
"""
43+
extracts the most relevant keywords from the query and returns it as a prompt for the wikipedia search
44+
:param query: query for which the keywords should be extracted
45+
:param verbose: whether to print the wiki search prompt
46+
:return: keywords for the wikipedia search
47+
"""
48+
prompt = ("I will give you a query and you have to create a list of keywords separated by commas to search in the "
49+
"internet for additional information. "
50+
"Example Query 1: What is the capital of France? "
51+
"Keywords: capital, France"
52+
"Example Query 2: Person that won the Nobel Prize in Literature in 2020 "
53+
"Keywords: Nobel Prize, Literature, 2020"
54+
"Example Query 3: What variation of house music was produced by artists such as Madonna and Kylie Minogue? "
55+
"Keywords: house music, Madonna, Kylie Minogue"
56+
"Now it's your turn!"
57+
f"Query: {query} Keywords:")
58+
59+
keywords = flan_t5(prompt, max_length=50, do_sample=False)[0]['generated_text']
60+
if verbose:
61+
print("wiki search prompt:", keywords)
62+
return keywords
63+
64+
65+
# todo: try out to look at different titles and let the model decide which will be the most promising ones
66+
67+
# ------------------------------------------------ Bart Large CNN -----------------------------------------------------
68+
69+
def summarize_facts(top_chunks, min_length, max_length):
70+
"""
71+
summarizes the facts from the wiki_content
72+
:param top_chunks: chunks of the wiki content with the highest similarity to the query
73+
:param min_length: minimum length of the summary (in tokens)
74+
:param max_length: maximum length of the summary (in tokens)
75+
:return: summarized facts from the wiki content as a string
76+
"""
77+
summary = bart_summarizer(top_chunks, min_length=min_length, max_length=max_length, do_sample=False)
78+
summary = summary[0]['summary_text']
79+
if summary.startswith(" "):
80+
summary = summary[1:]
81+
return summary

wiki.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import concurrent.futures
2+
import wikipedia
3+
4+
5+
def get_pages(search_prompt, n_results=5):
6+
"""
7+
gets the wikipedia pages for the search prompt using the wikipedia api
8+
:param search_prompt: search prompt for the wikipedia search
9+
:param n_results: number of page titles that should be returned
10+
:return: page titles
11+
"""
12+
return wikipedia.search(search_prompt, results=n_results)
13+
14+
15+
def get_text_chunks(page_titles, chunk_length=512, verbose=False):
16+
"""
17+
gets the content of the wikipedia pages using multiple threads (API calls take time) and splits it into chunks
18+
:param page_titles: list of page titles for which the content should be extracted
19+
:param chunk_length: length of characters that a chunk should have
20+
:param verbose: whether to print the progress
21+
:return: list of wiki text chunks
22+
"""
23+
wiki_chunks = []
24+
with concurrent.futures.ThreadPoolExecutor() as executor:
25+
future_to_page = {executor.submit(get_page_content, page_title): page_title for page_title in page_titles}
26+
for future in concurrent.futures.as_completed(future_to_page):
27+
page_title = future_to_page[future]
28+
try:
29+
wiki_content = future.result()
30+
wiki_content = preprocess_and_chunk_wiki_content(wiki_content, chunk_length=chunk_length)
31+
if verbose:
32+
print(f"getting content of page {page_title}")
33+
wiki_chunks.extend(wiki_content)
34+
except wikipedia.exceptions.PageError or wikipedia.exceptions.DisambiguationError as e:
35+
if verbose:
36+
print(f"page {page_title} not found, {e}")
37+
continue # skip the page if it is not available
38+
return wiki_chunks
39+
40+
41+
def get_page_content(page_title):
42+
"""
43+
gets the content of the wikipedia page using the wikipedia api
44+
:param page_title: page_title of the wikipedia page from which the content should be extracted
45+
:return: content of the wikipedia page
46+
"""
47+
return wikipedia.page(page_title).content
48+
49+
50+
def preprocess_and_chunk_wiki_content(wiki_content, chunk_length=512):
51+
"""
52+
preprocesses the wiki content:
53+
- splits the content into paragraphs
54+
- removes headings and empty lines
55+
- splits too long paragraphs into smaller chunks without cutting sentences
56+
:param wiki_content: content of the wikipedia page
57+
:param chunk_length: length of characters that a chunk should have
58+
:return: list of wiki text chunks
59+
"""
60+
# remove everything from the references section onwards
61+
wiki_content = wiki_content.split("== References ==")[0]
62+
# split into paragraphs
63+
chunks = wiki_content.split("\n")
64+
# remove headings, empty chunks and too short chunks
65+
chunks = [chunk for chunk in chunks if not ((chunk.startswith("=") and chunk.endswith("=")) or len(chunk) < 50)]
66+
# split too long chunks
67+
additional_chunks = []
68+
for i, chunk in enumerate(chunks):
69+
if len(chunk) > chunk_length:
70+
# split into sentences
71+
sentences = chunk.split(". ")
72+
# split into sub-chunks without cutting sentences
73+
sub_chunks = [""]
74+
sub_chunk_index = 0
75+
for sentence in sentences:
76+
if len(sentence) > chunk_length:
77+
# cut the sentence if it's longer than the chunk_length (this should happen rarely)
78+
sentence = sentence[:chunk_length]
79+
# if the sentence fits into the current sub-chunk
80+
if len(sub_chunks[sub_chunk_index]) + len(sentence) < chunk_length:
81+
sub_chunks[sub_chunk_index] += sentence + ". "
82+
else:
83+
sub_chunk_index += 1
84+
sub_chunks.append(sentence + ". ")
85+
# replace original chunk with first sub-chunk
86+
chunks[i] = sub_chunks[0]
87+
# add the other sub-chunks to the list
88+
for j in range(1, len(sub_chunks)):
89+
additional_chunks.append(sub_chunks[j])
90+
chunks.extend(additional_chunks)
91+
return chunks

0 commit comments

Comments
 (0)