pminervini commited on
Commit
6cbdb81
1 Parent(s): 3e8dc72
Files changed (1) hide show
  1. app.py +37 -25
app.py CHANGED
@@ -3,6 +3,8 @@ import gradio as gr
3
 
4
  import torch
5
  from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
 
 
6
  from elasticsearch import Elasticsearch
7
 
8
 
@@ -16,22 +18,20 @@ def search(query, index="pubmed", num_docs=3):
16
  Search the Elasticsearch index for the most relevant documents.
17
  """
18
 
19
- print(f'Running query: {query}')
20
-
21
- es_request_body = {
22
- "query": {
23
- "match": {
24
- "content": query # Assuming documents have a 'content' field
25
- }
26
- }, "size": num_docs
27
- }
28
-
29
- response = es.options(request_timeout=60).search(index=index, body=es_request_body)
30
-
31
- # Extract and return the documents
32
- docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
33
-
34
- print(f'Received {len(docs)} documents from index {index}')
35
 
36
  return docs
37
 
@@ -67,25 +67,37 @@ def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/z
67
  # Define the generation_kwargs with stopping criteria
68
  generation_kwargs = {
69
  "max_new_tokens": 128,
70
- "generation_kwargs": {"stopping_criteria": stopping_criteria},
71
  "return_full_text": False
72
  }
73
 
74
- # Generate response using the LLM
75
- response = generator(messages, **generation_kwargs)
76
-
77
- print('RESPONSE', response)
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # Return the generated text and the documents
80
- return response[0]['generated_text'], joined_docs
81
 
82
  # Create the Gradio interface
83
  iface = gr.Interface(fn=rag_pipeline,
84
  inputs=[
85
  gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
86
- gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks"], value="pubmed"),
87
- gr.Number(label="Number of Documents", value=3, step=1, minimum=1, maximum=10),
88
- gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf"], value="HuggingFaceH4/zephyr-7b-beta")
89
  ],
90
  outputs=[
91
  gr.Textbox(label="Generated Text"),
 
3
 
4
  import torch
5
  from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
6
+ import openai
7
+
8
  from elasticsearch import Elasticsearch
9
 
10
 
 
18
  Search the Elasticsearch index for the most relevant documents.
19
  """
20
 
21
+ docs = []
22
+ if num_docs > 0:
23
+ print(f'Running query: {query}')
24
+ es_request_body = {
25
+ "query": {
26
+ "match": {
27
+ "content": query # Assuming documents have a 'content' field
28
+ }
29
+ }, "size": num_docs
30
+ }
31
+ response = es.options(request_timeout=60).search(index=index, body=es_request_body)
32
+ # Extract and return the documents
33
+ docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
34
+ print(f'Received {len(docs)} documents from index {index}')
 
 
35
 
36
  return docs
37
 
 
67
  # Define the generation_kwargs with stopping criteria
68
  generation_kwargs = {
69
  "max_new_tokens": 128,
70
+ # "generation_kwargs": {"stopping_criteria": stopping_criteria},
71
  "return_full_text": False
72
  }
73
 
74
+ if model_name.startswith('openai/'):
75
+ openai_model_name = model_name.split('/')[1]
76
+ openai_prompt = '\n\n'.join([m['content'] for m in messages])
77
+
78
+ openai_res = openai.Completion.create(model=openai_model_name,
79
+ prompt=openai_prompt,
80
+ max_tokens=generation_kwargs["max_new_tokens"],
81
+ n=1,
82
+ stop=None,
83
+ temperature=0)
84
+ response = openai_res.choices[0].text.strip()
85
+ else:
86
+ # Generate response using the HF LLM
87
+ hf_response = generator(messages, **generation_kwargs)
88
+ print('HF_RESPONSE', hf_response)
89
+ response = hf_response[0]['generated_text']
90
 
91
  # Return the generated text and the documents
92
+ return response, joined_docs
93
 
94
  # Create the Gradio interface
95
  iface = gr.Interface(fn=rag_pipeline,
96
  inputs=[
97
  gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
98
+ gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks"], value="pubmed"),
99
+ gr.Number(label="Number of Documents", value=3, step=1, minimum=0, maximum=10),
100
+ gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
101
  ],
102
  outputs=[
103
  gr.Textbox(label="Generated Text"),