sabazo commited on
Commit
8b3a6e5
·
1 Parent(s): 30148ce

added streaming output to Gradio

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -6,7 +6,7 @@ from botocore.client import Config
6
  from langchain.document_loaders import WebBaseLoader
7
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=350, chunk_overlap=10)
10
 
11
  from langchain.llms import HuggingFaceHub
12
  model_id = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={"temperature":0.1, "max_new_tokens":1024})
@@ -17,25 +17,16 @@ embeddings = HuggingFaceHubEmbeddings()
17
  from langchain.vectorstores import Chroma
18
 
19
  from langchain.chains import RetrievalQA
20
- from langchain.chains import RetrievalQAWithSourcesChain
21
-
22
- from langchain.prompts import ChatPromptTemplate
23
-
24
- #web_links = ["https://www.databricks.com/","https://help.databricks.com","https://docs.databricks.com","https://kb.databricks.com/","http://docs.databricks.com/getting-started/index.html","http://docs.databricks.com/introduction/index.html","http://docs.databricks.com/getting-started/tutorials/index.html","http://docs.databricks.com/machine-learning/index.html","http://docs.databricks.com/sql/index.html"]
25
- #loader = WebBaseLoader(web_links)
26
- #documents = loader.load()
27
 
28
  s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
29
  s3.download_file('rad-rag-demos', 'vectorstores/chroma.sqlite3', './chroma_db/chroma.sqlite3')
30
 
31
  db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
32
  db.get()
33
- #texts = text_splitter.split_documents(documents)
34
- #db = Chroma.from_documents(texts, embedding_function=embeddings)
35
  retriever = db.as_retriever()
36
 
37
  global qa
38
- qa = RetrievalQAWithSourcesChain.from_chain_type(llm=model_id, chain_type="stuff", retriever=retriever)
39
 
40
 
41
  def add_text(history, text):
@@ -44,14 +35,16 @@ def add_text(history, text):
44
 
45
  def bot(history):
46
  response = infer(history[-1][0])
47
- history[-1][1] = response['result']
48
- return history
 
 
 
49
 
50
  def infer(question):
51
 
52
- #query = question
53
- #result = qa({"query": query})
54
- result = qa({"question": question})
55
  return result
56
 
57
  css="""
 
6
  from langchain.document_loaders import WebBaseLoader
7
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
10
 
11
  from langchain.llms import HuggingFaceHub
12
  model_id = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={"temperature":0.1, "max_new_tokens":1024})
 
17
  from langchain.vectorstores import Chroma
18
 
19
  from langchain.chains import RetrievalQA
 
 
 
 
 
 
 
20
 
21
  s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
22
  s3.download_file('rad-rag-demos', 'vectorstores/chroma.sqlite3', './chroma_db/chroma.sqlite3')
23
 
24
  db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
25
  db.get()
 
 
26
  retriever = db.as_retriever()
27
 
28
  global qa
29
+ qa = RetrievalQA.from_chain_type(llm=model_id, chain_type="stuff", retriever=retriever)
30
 
31
 
32
  def add_text(history, text):
 
35
 
36
  def bot(history):
37
  response = infer(history[-1][0])
38
+ history[-1][1] = ""
39
+ for character in response['result']:
40
+ history[-1][1] += character
41
+ time.sleep(0.05)
42
+ yield history
43
 
44
  def infer(question):
45
 
46
+ query = question
47
+ result = qa({"query": query})
 
48
  return result
49
 
50
  css="""