chahah commited on
Commit
df96d22
·
verified ·
1 Parent(s): c7748d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://python.langchain.com/docs/tutorials/rag/
2
+ import gradio as gr
3
+ from langchain import hub
4
+ from langchain_chroma import Chroma
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from langchain_mistralai import MistralAIEmbeddings
8
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from langchain_mistralai import ChatMistralAI
11
+ from langchain_community.document_loaders import PyPDFLoader
12
+ import requests
13
+ from pathlib import Path
14
+ from langchain_community.document_loaders import WebBaseLoader
15
+ import bs4
16
+ from langchain_core.rate_limiters import InMemoryRateLimiter
17
+ from urllib.parse import urljoin
18
+
19
+ rate_limiter = InMemoryRateLimiter(
20
+ requests_per_second=0.1, # <-- MistralAI free. We can only make a request once every second
21
+ check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
22
+ max_bucket_size=10, # Controls the maximum burst size.
23
+ )
24
+
25
+ # get data
26
+ urlsfile = open("urls.txt")
27
+ urls = urlsfile.readlines()
28
+ urls = [url.replace("\n","") for url in urls]
29
+ urlsfile.close()
30
+
31
+ # Load, chunk and index the contents of the blog.
32
+ loader = WebBaseLoader(urls)
33
+ docs = loader.load()
34
+
35
+ def format_docs(docs):
36
+ return "\n\n".join(doc.page_content for doc in docs)
37
+
38
+ def RAG(llm, docs, embeddings):
39
+
40
+ # Split text
41
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
42
+ splits = text_splitter.split_documents(docs)
43
+
44
+ # Create vector store
45
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
46
+
47
+ # Retrieve and generate using the relevant snippets of the documents
48
+ retriever = vectorstore.as_retriever()
49
+
50
+ # Prompt basis example for RAG systems
51
+ prompt = hub.pull("rlm/rag-prompt")
52
+
53
+ # Create the chain
54
+ rag_chain = (
55
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
56
+ | prompt
57
+ | llm
58
+ | StrOutputParser()
59
+ )
60
+
61
+ return rag_chain
62
+
63
+ # LLM model
64
+ llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
65
+
66
+ # Embeddings
67
+ embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
68
+ # embed_model = "nvidia/NV-Embed-v2"
69
+ embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
70
+ # embeddings = MistralAIEmbeddings()
71
+
72
+ # RAG chain
73
+ rag_chain = RAG(llm, docs, embeddings)
74
+
75
+ def handle_prompt(message, history):
76
+ try:
77
+ # Stream output
78
+ out=""
79
+ for chunk in rag_chain.stream(message):
80
+ out += chunk
81
+ yield out
82
+ except:
83
+ raise gr.Error("Requests rate limit exceeded")
84
+
85
+ greetingsmessage = "Hi, I'm ChangBot, a chat bot here to assist you with any question related to Chang's research"
86
+ example_questions = [
87
+ "What is the DESI BGS?",
88
+ "What is Quijote?",
89
+ "What is a galaxy bispectrum?",
90
+ "Tell me more about SimBIG"
91
+ ]
92
+
93
+ demo = gr.ChatInterface(handle_prompt, type="messages", title="ChangBot", examples=example_questions, theme=gr.themes.Soft(), description=greetingsmessage)#, chatbot=chatbot)
94
+
95
+ demo.launch()