Skip to content

Commit 2ac3212

Browse files
committed
filtered a warning, decreased number of wiki pages that are not found by disabling auto_suggest, changed filter for too short paragraphs from 50 to 100 characters, documentation
1 parent 8399641 commit 2ac3212

3 files changed

Lines changed: 22 additions & 15 deletions

File tree

contextplus/main.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33
from contextplus import model, wiki
44

55

6-
def context(query, n_wiki_pages=5, n_top_chunks=8, min_summary_length=100, max_summary_length=200, verbose=False):
6+
def context(query, min_summary_length=100, max_summary_length=200, n_wiki_pages=5, n_top_chunks=8, verbose=False):
77
"""
88
provides context for the query by searching for relevant wikipedia pages, extracting the most relevant information
99
and summarizing the facts
1010
:param query: query as a string for which the context should be provided
11-
:param n_wiki_pages: (optional) number of wikipedia pages that should be searched
12-
:param n_top_chunks: (optional) number of highest scoring chunks that should be summarized
1311
:param min_summary_length: (optional) minimum length of the summary (in tokens)
1412
:param max_summary_length: (optional) maximum length of the summary (in tokens)
13+
:param n_wiki_pages: (optional, not recommended to change) number of wikipedia pages that should be searched
14+
:param n_top_chunks: (optional, not recommended to change) number of highest scoring chunks that should be summarized
1515
:param verbose: (optional) whether to print the progress
16-
:return: summarized facts from the wikipedia pages as a string
16+
:return: summarized facts from the wikipedia pages as a string, None if no wikipedia pages were found
1717
"""
18-
# todo: finding optimal default values for the parameters
18+
1919
time1, time2, time3, time4, time5, time6, time7, time8, time9 = 0, 0, 0, 0, 0, 0, 0, 0, 0
2020
if verbose:
2121
print("Query:", query)
2222
time1 = time.time()
23+
2324
# create wikipedia search prompt
2425
wiki_search_prompt = model.create_wiki_search_prompt(query, verbose=verbose)
2526
if verbose:
@@ -30,14 +31,16 @@ def context(query, n_wiki_pages=5, n_top_chunks=8, min_summary_length=100, max_s
3031
page_titles = wiki.get_pages(wiki_search_prompt, n_results=n_wiki_pages)
3132
if verbose:
3233
print("Page titles:", page_titles)
33-
# get the content of the wikipedia pages and split it into chunks
34-
if verbose:
3534
time3 = time.time()
3635
print("Time taken to get wiki pages:", time3 - time2, "seconds")
36+
37+
# get the content of the wikipedia pages and split it into chunks
3738
wiki_chunks = wiki.get_text_chunks(page_titles, chunk_length=512, verbose=verbose)
3839
if verbose:
3940
time4 = time.time()
4041
print("Time taken to get wiki chunks:", time4 - time3, "seconds")
42+
if not wiki_chunks:
43+
return None
4144

4245
# get the embeddings for the query and the wiki chunks
4346
query_embedding = model.get_embeddings([query])
@@ -48,21 +51,22 @@ def context(query, n_wiki_pages=5, n_top_chunks=8, min_summary_length=100, max_s
4851
if verbose:
4952
time6 = time.time()
5053
print("Time taken to get wiki embeddings:", time6 - time5, "seconds")
54+
5155
# calculate the similarity between the query and the wiki chunks
5256
similarities = model.calculate_similarity(query_embedding, wiki_embeddings, top_k=n_top_chunks)
5357
if verbose:
5458
time7 = time.time()
5559
print("Time taken to calculate similarity:", time7 - time6, "seconds")
56-
top_chunks = ""
5760

61+
top_chunks = ""
5862
for i, similarity in enumerate(similarities):
5963
top_chunks += "<" + str(i + 1) + "> " + wiki_chunks[similarity['corpus_id']] + " </" + str(i + 1) + ">\n\n"
6064
if verbose:
6165
print("Chunk" + str(i + 1) + ":", wiki_chunks[similarity['corpus_id']], "\t\t\tscore:", similarity['score'])
62-
6366
if verbose:
6467
time8 = time.time()
6568
print("Time taken to get concatenated top chunk string:", time8 - time7, "seconds")
69+
6670
# summarize facts from the top wiki chunks
6771
summarized_facts = model.summarize_facts(top_chunks, min_length=min_summary_length, max_length=max_summary_length)
6872
if verbose:
@@ -74,5 +78,5 @@ def context(query, n_wiki_pages=5, n_top_chunks=8, min_summary_length=100, max_s
7478

7579
if __name__ == "__main__":
7680
user_query = "What are the names of Barack Obamas children?"
77-
context = context(user_query, verbose=False)
81+
context = context(user_query, verbose=True)
7882
print(context)

contextplus/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def get_embeddings(texts):
1818
:param texts: list of texts for which the embeddings should be calculated
1919
:return: embeddings
2020
"""
21-
# todo: check out arguments of the encode method for example 'prompt' or 'precision'
2221
return gist_embedding.encode(texts)
2322

2423

@@ -62,8 +61,6 @@ def create_wiki_search_prompt(query, verbose=False):
6261
return keywords
6362

6463

65-
# todo: try out to look at different titles and let the model decide which will be the most promising ones
66-
6764
# ------------------------------------------------ Bart Large CNN -----------------------------------------------------
6865

6966
def summarize_facts(top_chunks, min_length, max_length):
@@ -74,6 +71,8 @@ def summarize_facts(top_chunks, min_length, max_length):
7471
:param max_length: maximum length of the summary (in tokens)
7572
:return: summarized facts from the wiki content as a string
7673
"""
74+
if len(top_chunks) > 3700:
75+
top_chunks = top_chunks[:3700]
7776
summary = bart_summarizer(top_chunks, min_length=min_length, max_length=max_length, do_sample=False)
7877
summary = summary[0]['summary_text']
7978
if summary.startswith(" "):

contextplus/wiki.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import concurrent.futures
2+
import warnings
23
import wikipedia
34

45

@@ -44,7 +45,10 @@ def get_page_content(page_title):
4445
:param page_title: page_title of the wikipedia page from which the content should be extracted
4546
:return: content of the wikipedia page
4647
"""
47-
return wikipedia.page(page_title).content
48+
with warnings.catch_warnings():
49+
warnings.filterwarnings("ignore", category=UserWarning)
50+
page_content = wikipedia.page(page_title, auto_suggest=False).content
51+
return page_content
4852

4953

5054
def preprocess_and_chunk_wiki_content(wiki_content, chunk_length=512):
@@ -62,7 +66,7 @@ def preprocess_and_chunk_wiki_content(wiki_content, chunk_length=512):
6266
# split into paragraphs
6367
chunks = wiki_content.split("\n")
6468
# remove headings, empty chunks and too short chunks
65-
chunks = [chunk for chunk in chunks if not ((chunk.startswith("=") and chunk.endswith("=")) or len(chunk) < 50)]
69+
chunks = [chunk for chunk in chunks if not ((chunk.startswith("=") and chunk.endswith("=")) or len(chunk) < 100)]
6670
# split too long chunks
6771
additional_chunks = []
6872
for i, chunk in enumerate(chunks):

0 commit comments

Comments
 (0)