Spaces:
Sleeping
Sleeping
pminervini
commited on
Commit
•
6cbdb81
1
Parent(s):
3e8dc72
update
Browse files
app.py
CHANGED
@@ -3,6 +3,8 @@ import gradio as gr
|
|
3 |
|
4 |
import torch
|
5 |
from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
|
|
|
|
|
6 |
from elasticsearch import Elasticsearch
|
7 |
|
8 |
|
@@ -16,22 +18,20 @@ def search(query, index="pubmed", num_docs=3):
|
|
16 |
Search the Elasticsearch index for the most relevant documents.
|
17 |
"""
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
"
|
24 |
-
"
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
print(f'Received {len(docs)} documents from index {index}')
|
35 |
|
36 |
return docs
|
37 |
|
@@ -67,25 +67,37 @@ def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/z
|
|
67 |
# Define the generation_kwargs with stopping criteria
|
68 |
generation_kwargs = {
|
69 |
"max_new_tokens": 128,
|
70 |
-
"generation_kwargs": {"stopping_criteria": stopping_criteria},
|
71 |
"return_full_text": False
|
72 |
}
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Return the generated text and the documents
|
80 |
-
return response
|
81 |
|
82 |
# Create the Gradio interface
|
83 |
iface = gr.Interface(fn=rag_pipeline,
|
84 |
inputs=[
|
85 |
gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
|
86 |
-
|
87 |
-
|
88 |
-
gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf"], value="HuggingFaceH4/zephyr-7b-beta")
|
89 |
],
|
90 |
outputs=[
|
91 |
gr.Textbox(label="Generated Text"),
|
|
|
3 |
|
4 |
import torch
|
5 |
from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
|
6 |
+
import openai
|
7 |
+
|
8 |
from elasticsearch import Elasticsearch
|
9 |
|
10 |
|
|
|
18 |
Search the Elasticsearch index for the most relevant documents.
|
19 |
"""
|
20 |
|
21 |
+
docs = []
|
22 |
+
if num_docs > 0:
|
23 |
+
print(f'Running query: {query}')
|
24 |
+
es_request_body = {
|
25 |
+
"query": {
|
26 |
+
"match": {
|
27 |
+
"content": query # Assuming documents have a 'content' field
|
28 |
+
}
|
29 |
+
}, "size": num_docs
|
30 |
+
}
|
31 |
+
response = es.options(request_timeout=60).search(index=index, body=es_request_body)
|
32 |
+
# Extract and return the documents
|
33 |
+
docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
|
34 |
+
print(f'Received {len(docs)} documents from index {index}')
|
|
|
|
|
35 |
|
36 |
return docs
|
37 |
|
|
|
67 |
# Define the generation_kwargs with stopping criteria
|
68 |
generation_kwargs = {
|
69 |
"max_new_tokens": 128,
|
70 |
+
# "generation_kwargs": {"stopping_criteria": stopping_criteria},
|
71 |
"return_full_text": False
|
72 |
}
|
73 |
|
74 |
+
if model_name.startswith('openai/'):
|
75 |
+
openai_model_name = model_name.split('/')[1]
|
76 |
+
openai_prompt = '\n\n'.join([m['content'] for m in messages])
|
77 |
+
|
78 |
+
openai_res = openai.Completion.create(model=openai_model_name,
|
79 |
+
prompt=openai_prompt,
|
80 |
+
max_tokens=generation_kwargs["max_new_tokens"],
|
81 |
+
n=1,
|
82 |
+
stop=None,
|
83 |
+
temperature=0)
|
84 |
+
response = openai_res.choices[0].text.strip()
|
85 |
+
else:
|
86 |
+
# Generate response using the HF LLM
|
87 |
+
hf_response = generator(messages, **generation_kwargs)
|
88 |
+
print('HF_RESPONSE', hf_response)
|
89 |
+
response = hf_response[0]['generated_text']
|
90 |
|
91 |
# Return the generated text and the documents
|
92 |
+
return response, joined_docs
|
93 |
|
94 |
# Create the Gradio interface
|
95 |
iface = gr.Interface(fn=rag_pipeline,
|
96 |
inputs=[
|
97 |
gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
|
98 |
+
gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks"], value="pubmed"),
|
99 |
+
gr.Number(label="Number of Documents", value=3, step=1, minimum=0, maximum=10),
|
100 |
+
gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
|
101 |
],
|
102 |
outputs=[
|
103 |
gr.Textbox(label="Generated Text"),
|