pminervini commited on
Commit
4467ed0
1 Parent(s): 3e44f8a
Files changed (1) hide show
  1. app.py +38 -7
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import gradio as gr
3
 
 
 
4
  import torch
5
  from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer
6
  from openai import OpenAI
@@ -28,11 +30,6 @@ class MultiTokenEOSCriteria(StoppingCriteria):
28
  return False not in self.done_tracker
29
 
30
 
31
- # Connect to Elasticsearch
32
- es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
33
- basic_auth=('elastic', os.environ['ES_PASSWORD']),
34
- verify_certs=False, ssl_show_warn=False)
35
-
36
  def search(query, index="pubmed", num_docs=3):
37
  """
38
  Search the Elasticsearch index for the most relevant documents.
@@ -48,6 +45,12 @@ def search(query, index="pubmed", num_docs=3):
48
  }
49
  }, "size": num_docs
50
  }
 
 
 
 
 
 
51
  response = es.options(request_timeout=60).search(index=index, body=es_request_body)
52
  # Extract and return the documents
53
  docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
@@ -55,6 +58,31 @@ def search(query, index="pubmed", num_docs=3):
55
 
56
  return docs
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/zephyr-7b-beta"):
59
  """
60
  A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
@@ -118,8 +146,10 @@ def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/z
118
  print('HF_RESPONSE', hf_response)
119
  response = hf_response[0]['generated_text']
120
 
 
 
121
  # Return the generated text and the documents
122
- return response, joined_docs
123
 
124
  # Create the Gradio interface
125
  iface = gr.Interface(fn=rag_pipeline,
@@ -130,7 +160,8 @@ iface = gr.Interface(fn=rag_pipeline,
130
  gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
131
  ],
132
  outputs=[
133
- gr.Textbox(label="Generated Text"),
 
134
  gr.Textbox(label="Retrieved Documents")
135
  ],
136
  description="Retrieval-Augmented Generation Pipeline")
 
1
  import os
2
  import gradio as gr
3
 
4
+ import vllm
5
+
6
  import torch
7
  from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer
8
  from openai import OpenAI
 
30
  return False not in self.done_tracker
31
 
32
 
 
 
 
 
 
33
  def search(query, index="pubmed", num_docs=3):
34
  """
35
  Search the Elasticsearch index for the most relevant documents.
 
45
  }
46
  }, "size": num_docs
47
  }
48
+
49
+ # Connect to Elasticsearch
50
+ es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
51
+ basic_auth=('elastic', os.environ['ES_PASSWORD']),
52
+ verify_certs=False, ssl_show_warn=False)
53
+
54
  response = es.options(request_timeout=60).search(index=index, body=es_request_body)
55
  # Extract and return the documents
56
  docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
 
58
 
59
  return docs
60
 
61
+ def analyse(text: str) -> str:
62
+ model = vllm.LLM(model="fava-uw/fava-model")
63
+ sampling_params = vllm.SamplingParams(temperature=0, top_p=1.0, max_tokens=500)
64
+ outputs = model.generate(text, sampling_params)
65
+ outputs = [it.outputs[0].text for it in outputs]
66
+ output = outputs[0].replace("<mark>", "<span style='color: green; font-weight: bold;'> ")
67
+ output = output.replace("</mark>", " </span>")
68
+ output = output.replace("<delete>", "<span style='color: red; text-decoration: line-through;'>")
69
+ output = output.replace("</delete>", "</span>")
70
+ output = output.replace("<entity>", "<span style='background-color: #E9A2D9; border-bottom: 1px dotted;'>entity</span>")
71
+ output = output.replace("<relation>", "<span style='background-color: #F3B78B; border-bottom: 1px dotted;'>relation</span>")
72
+ output = output.replace("<contradictory>", "<span style='background-color: #FFFF9B; border-bottom: 1px dotted;'>contradictory</span>")
73
+ output = output.replace("<unverifiable>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>unverifiable</span><u>")
74
+ output = output.replace("<invented>", "<span style='background-color: #BFE9B9; border-bottom: 1px dotted;'>invented</span>")
75
+ output = output.replace("<subjective>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>subjective</span><u>")
76
+ output = output.replace("</entity>", "")
77
+ output = output.replace("</relation>", "")
78
+ output = output.replace("</contradictory>", "")
79
+ output = output.replace("</unverifiable>", "</u>")
80
+ output = output.replace("</invented>", "")
81
+ output = output.replace("</subjective>", "</u>")
82
+ output = output.replace("Edited:", "")
83
+ return f'<div style="font-weight: normal;">{output}</div>'
84
+
85
+
86
  def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/zephyr-7b-beta"):
87
  """
88
  A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
 
146
  print('HF_RESPONSE', hf_response)
147
  response = hf_response[0]['generated_text']
148
 
149
+ analysed_response = analyse(response)
150
+
151
  # Return the generated text and the documents
152
+ return response, analysed_response, joined_docs
153
 
154
  # Create the Gradio interface
155
  iface = gr.Interface(fn=rag_pipeline,
 
160
  gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
161
  ],
162
  outputs=[
163
+ gr.Textbox(label="Generated Answer"),
164
+ gr.Textbox(label="Analysed Answer"),
165
  gr.Textbox(label="Retrieved Documents")
166
  ],
167
  description="Retrieval-Augmented Generation Pipeline")