-
Notifications
You must be signed in to change notification settings - Fork 53
/
ensemble.py
46 lines (32 loc) · 1.38 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_core.output_parsers import StrOutputParser
from basic_chain import get_model
from rag_chain import make_rag_chain
from remote_loader import load_web_page
from splitter import split_documents
from vector_store import create_vector_db
from dotenv import load_dotenv
def ensemble_retriever_from_docs(docs, embeddings=None):
texts = split_documents(docs)
vs = create_vector_db(texts, embeddings)
vs_retriever = vs.as_retriever()
bm25_retriever = BM25Retriever.from_texts([t.page_content for t in texts])
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vs_retriever],
weights=[0.5, 0.5])
return ensemble_retriever
def main():
load_dotenv()
problems_of_philosophy_by_russell = "https://www.gutenberg.org/ebooks/5827.html.images"
docs = load_web_page(problems_of_philosophy_by_russell)
ensemble_retriever = ensemble_retriever_from_docs(docs)
model = get_model("ChatGPT")
chain = make_rag_chain(model, ensemble_retriever) | StrOutputParser()
result = chain.invoke("What are the key problems of philosophy according to Russell?")
print(result)
if __name__ == "__main__":
# this is to quite parallel tokenizers warning.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
main()