Spaces:
Sleeping
Sleeping
File size: 9,810 Bytes
858ef78 482d0d4 5f90f73 482d0d4 238b842 28875eb 6cbdb81 858ef78 8400fbd 818a1c7 4cf3bf5 858ef78 6cbdb81 4467ed0 6cbdb81 858ef78 5f90f73 78b0064 5f90f73 eb1a71a 5f90f73 c4c0ebe 4467ed0 c4c0ebe 4467ed0 4cf3bf5 858ef78 4cf3bf5 858ef78 4cf3bf5 be47723 858ef78 be47723 78b0064 529ee77 78b0064 eaec240 be47723 a744ad5 eb1a71a 4467ed0 858ef78 eb1a71a 858ef78 bbe5900 7aff13a 6cbdb81 858ef78 eb1a71a 4467ed0 f064e36 858ef78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
import gradio as gr
import ray
import vllm
import torch
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, BitsAndBytesConfig
from openai import OpenAI
from elasticsearch import Elasticsearch
class MultiTokenEOSCriteria(StoppingCriteria):
def __init__(self, sequence: str, tokenizer: PreTrainedTokenizer, initial_decoder_input_length: int, batch_size: int = 1) -> None:
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def search(query, index="pubmed", num_docs=3):
"""
Search the Elasticsearch index for the most relevant documents.
"""
docs = []
if num_docs > 0:
print(f'Running query: {query}')
es_request_body = {
"query": {
"match": {
"content": query # Assuming documents have a 'content' field
}
}, "size": num_docs
}
# Connect to Elasticsearch
es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
basic_auth=('elastic', os.environ['ES_PASSWORD']),
verify_certs=False, ssl_show_warn=False)
response = es.options(request_timeout=60).search(index=index, body=es_request_body)
# Extract and return the documents
docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
print(f'Received {len(docs)} documents from index {index}')
return docs
@ray.remote(num_gpus=1, max_calls=1)
def generate(model_name: str, messages):
max_new_tokens = 1024
if model_name.startswith('openai/'):
openai_model_name = model_name.split('/')[1]
client = OpenAI()
openai_res = client.chat.completions.create(model=openai_model_name,
messages=messages,
max_tokens=max_new_tokens,
temperature=0)
print('OAI_RESPONSE', openai_res)
response = openai_res.choices[0].message.content.strip()
else:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", low_cpu_mem_usage=True, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load your language model from HuggingFace Transformers
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
tokenized_prompt = tokenizer.apply_chat_template(messages, tokenize=True)
# Define the stopping criteria using MaxTimeCriteria
stopping_criteria = StoppingCriteriaList([
# MaxTimeCriteria(32),
MultiTokenEOSCriteria("\n", tokenizer, len(tokenized_prompt))
])
# Define the generation_kwargs with stopping criteria
generation_kwargs = {
"max_new_tokens": max_new_tokens,
# "stopping_criteria": stopping_criteria,
"return_full_text": False
}
# Generate response using the HF LLM
hf_response = generator(messages, **generation_kwargs)
print('HF_RESPONSE', hf_response)
response = hf_response[0]['generated_text'].strip()
return response
@ray.remote(num_gpus=1, max_calls=1)
def analyse(reference: str, passage: str) -> str:
fava_input = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] "
prompt = [fava_input.format_map({"evidence": reference, "output": passage})]
model = vllm.LLM(model="fava-uw/fava-model")
sampling_params = vllm.SamplingParams(temperature=0, top_p=1.0, max_tokens=500)
outputs = model.generate(prompt, sampling_params)
outputs = [it.outputs[0].text for it in outputs]
output = outputs[0].replace("<mark>", "<span style='color: green; font-weight: bold;'> ")
output = output.replace("</mark>", " </span>")
output = output.replace("<delete>", "<span style='color: red; text-decoration: line-through;'>")
output = output.replace("</delete>", "</span>")
output = output.replace("<entity>", "<span style='background-color: #E9A2D9; border-bottom: 1px dotted;'>entity</span>")
output = output.replace("<relation>", "<span style='background-color: #F3B78B; border-bottom: 1px dotted;'>relation</span>")
output = output.replace("<contradictory>", "<span style='background-color: #FFFF9B; border-bottom: 1px dotted;'>contradictory</span>")
output = output.replace("<unverifiable>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>unverifiable</span><u>")
output = output.replace("<invented>", "<span style='background-color: #BFE9B9; border-bottom: 1px dotted;'>invented</span>")
output = output.replace("<subjective>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>subjective</span><u>")
output = output.replace("</entity>", "")
output = output.replace("</relation>", "")
output = output.replace("</contradictory>", "")
output = output.replace("</unverifiable>", "</u>")
output = output.replace("</invented>", "")
output = output.replace("</subjective>", "</u>")
output = output.replace("Edited:", "")
return f'<div style="font-weight: normal;">{output}</div>'
def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/zephyr-7b-beta"):
"""
A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
"""
num_docs = int(num_docs)
# Retrieve documents
docs = search(prompt, index=index, num_docs=num_docs)
joined_docs = '\n\n'.join(docs)
messages = [
{
# Please append a newline only when you have finished answering.
"role": "system",
"content": f"You are an advanced medical support assistant, designed to help clinicians by providing quick access to medical information, guidelines, and evidence-based recommendations. Alongside your built-in knowledge, you have access to a curated set of documents retrieved from trustworthy sources such as Wikipedia and PubMed. These documents include up-to-date medical guidelines, research summaries, and clinical practice information. You should use these documents as a primary source of information to ensure your responses are based on the most current and credible evidence available. Your responses should be accurate, concise, and in full compliance with medical ethics. You must always remind users that your guidance does not substitute for professional medical advice, diagnosis, or treatment. Your tone should be professional, supportive, and respectful, recognizing the complexity of healthcare decisions and the importance of personalized patient care. While you can offer information and suggestions based on the documents provided and current medical knowledge, you must emphasize the importance of clinicians' expertise and judgment in making clinical decisions.\n\nRetrieved documents from {index}:\n\n{joined_docs}"
}, {
"role": "user",
"content": prompt
}
]
response = ray.get(generate.remote(model_name, messages))
# analysed_response = ray.get(analyse.remote(joined_docs, response))
analysed_response = ray.get(analyse.remote(joined_docs, prompt))
# Return the generated text and the documents
return analysed_response, response, joined_docs
# Create the Gradio interface
iface = gr.Interface(fn=rag_pipeline,
inputs=[
gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks", "statpearls"], value="pubmed"),
gr.Number(label="Number of Documents", value=3, step=1, minimum=0, maximum=10),
gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
],
outputs=[
gr.HTML(label="Analysed Answer"),
gr.Textbox(label="Generated Answer"),
# gr.Textbox(label="Analysed Answer"),
gr.Textbox(label="Retrieved Documents")
],
description="Retrieval-Augmented Generation Pipeline")
# Launch the interface
iface.launch() |