Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import ray | |
import vllm | |
import torch | |
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, BitsAndBytesConfig | |
from openai import OpenAI | |
from elasticsearch import Elasticsearch | |
class MultiTokenEOSCriteria(StoppingCriteria): | |
def __init__(self, sequence: str, tokenizer: PreTrainedTokenizer, initial_decoder_input_length: int, batch_size: int = 1) -> None: | |
self.initial_decoder_input_length = initial_decoder_input_length | |
self.done_tracker = [False] * batch_size | |
self.sequence = sequence | |
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) | |
self.sequence_id_len = len(self.sequence_ids) + 2 | |
self.tokenizer = tokenizer | |
def __call__(self, input_ids, scores, **kwargs) -> bool: | |
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence | |
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :] | |
lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :] | |
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) | |
for i, done in enumerate(self.done_tracker): | |
if not done: | |
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] | |
return False not in self.done_tracker | |
def search(query, index="pubmed", num_docs=3): | |
""" | |
Search the Elasticsearch index for the most relevant documents. | |
""" | |
docs = [] | |
if num_docs > 0: | |
print(f'Running query: {query}') | |
es_request_body = { | |
"query": { | |
"match": { | |
"content": query # Assuming documents have a 'content' field | |
} | |
}, "size": num_docs | |
} | |
# Connect to Elasticsearch | |
es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"], | |
basic_auth=('elastic', os.environ['ES_PASSWORD']), | |
verify_certs=False, ssl_show_warn=False) | |
response = es.options(request_timeout=60).search(index=index, body=es_request_body) | |
# Extract and return the documents | |
docs = [hit["_source"]["content"] for hit in response['hits']['hits']] | |
print(f'Received {len(docs)} documents from index {index}') | |
return docs | |
def generate(model_name: str, messages): | |
max_new_tokens = 1024 | |
if model_name.startswith('openai/'): | |
openai_model_name = model_name.split('/')[1] | |
client = OpenAI() | |
openai_res = client.chat.completions.create(model=openai_model_name, | |
messages=messages, | |
max_tokens=max_new_tokens, | |
temperature=0) | |
print('OAI_RESPONSE', openai_res) | |
response = openai_res.choices[0].message.content.strip() | |
else: | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", low_cpu_mem_usage=True, quantization_config=quantization_config) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load your language model from HuggingFace Transformers | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
tokenized_prompt = tokenizer.apply_chat_template(messages, tokenize=True) | |
# Define the stopping criteria using MaxTimeCriteria | |
stopping_criteria = StoppingCriteriaList([ | |
# MaxTimeCriteria(32), | |
MultiTokenEOSCriteria("\n", tokenizer, len(tokenized_prompt)) | |
]) | |
# Define the generation_kwargs with stopping criteria | |
generation_kwargs = { | |
"max_new_tokens": max_new_tokens, | |
# "stopping_criteria": stopping_criteria, | |
"return_full_text": False | |
} | |
# Generate response using the HF LLM | |
hf_response = generator(messages, **generation_kwargs) | |
print('HF_RESPONSE', hf_response) | |
response = hf_response[0]['generated_text'].strip() | |
return response | |
def analyse(reference: str, passage: str) -> str: | |
fava_input = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] " | |
prompt = [fava_input.format_map({"evidence": reference, "output": passage})] | |
model = vllm.LLM(model="fava-uw/fava-model") | |
sampling_params = vllm.SamplingParams(temperature=0, top_p=1.0, max_tokens=500) | |
outputs = model.generate(prompt, sampling_params) | |
outputs = [it.outputs[0].text for it in outputs] | |
output = outputs[0].replace("<mark>", "<span style='color: green; font-weight: bold;'> ") | |
output = output.replace("</mark>", " </span>") | |
output = output.replace("<delete>", "<span style='color: red; text-decoration: line-through;'>") | |
output = output.replace("</delete>", "</span>") | |
output = output.replace("<entity>", "<span style='background-color: #E9A2D9; border-bottom: 1px dotted;'>entity</span>") | |
output = output.replace("<relation>", "<span style='background-color: #F3B78B; border-bottom: 1px dotted;'>relation</span>") | |
output = output.replace("<contradictory>", "<span style='background-color: #FFFF9B; border-bottom: 1px dotted;'>contradictory</span>") | |
output = output.replace("<unverifiable>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>unverifiable</span><u>") | |
output = output.replace("<invented>", "<span style='background-color: #BFE9B9; border-bottom: 1px dotted;'>invented</span>") | |
output = output.replace("<subjective>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>subjective</span><u>") | |
output = output.replace("</entity>", "") | |
output = output.replace("</relation>", "") | |
output = output.replace("</contradictory>", "") | |
output = output.replace("</unverifiable>", "</u>") | |
output = output.replace("</invented>", "") | |
output = output.replace("</subjective>", "</u>") | |
output = output.replace("Edited:", "") | |
return f'<div style="font-weight: normal;">{output}</div>' | |
def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/zephyr-7b-beta"): | |
""" | |
A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM. | |
""" | |
num_docs = int(num_docs) | |
# Retrieve documents | |
docs = search(prompt, index=index, num_docs=num_docs) | |
joined_docs = '\n\n'.join(docs) | |
messages = [ | |
{ | |
# Please append a newline only when you have finished answering. | |
"role": "system", | |
"content": f"You are an advanced medical support assistant, designed to help clinicians by providing quick access to medical information, guidelines, and evidence-based recommendations. Alongside your built-in knowledge, you have access to a curated set of documents retrieved from trustworthy sources such as Wikipedia and PubMed. These documents include up-to-date medical guidelines, research summaries, and clinical practice information. You should use these documents as a primary source of information to ensure your responses are based on the most current and credible evidence available. Your responses should be accurate, concise, and in full compliance with medical ethics. You must always remind users that your guidance does not substitute for professional medical advice, diagnosis, or treatment. Your tone should be professional, supportive, and respectful, recognizing the complexity of healthcare decisions and the importance of personalized patient care. While you can offer information and suggestions based on the documents provided and current medical knowledge, you must emphasize the importance of clinicians' expertise and judgment in making clinical decisions.\n\nRetrieved documents from {index}:\n\n{joined_docs}" | |
}, { | |
"role": "user", | |
"content": prompt | |
} | |
] | |
response = ray.get(generate.remote(model_name, messages)) | |
# analysed_response = ray.get(analyse.remote(joined_docs, response)) | |
analysed_response = ray.get(analyse.remote(joined_docs, prompt)) | |
# Return the generated text and the documents | |
return analysed_response, response, joined_docs | |
# Create the Gradio interface | |
iface = gr.Interface(fn=rag_pipeline, | |
inputs=[ | |
gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"), | |
gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks", "statpearls"], value="pubmed"), | |
gr.Number(label="Number of Documents", value=3, step=1, minimum=0, maximum=10), | |
gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta") | |
], | |
outputs=[ | |
gr.HTML(label="Analysed Answer"), | |
gr.Textbox(label="Generated Answer"), | |
# gr.Textbox(label="Analysed Answer"), | |
gr.Textbox(label="Retrieved Documents") | |
], | |
description="Retrieval-Augmented Generation Pipeline") | |
# Launch the interface | |
iface.launch() |