File size: 9,810 Bytes
858ef78
 
482d0d4
5f90f73
 
 
482d0d4
238b842
28875eb
6cbdb81
858ef78
 
8400fbd
818a1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf3bf5
858ef78
 
 
 
6cbdb81
 
 
 
 
 
 
 
 
 
4467ed0
 
 
 
 
 
6cbdb81
 
 
 
858ef78
 
 
5f90f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b0064
5f90f73
 
 
 
 
 
 
eb1a71a
5f90f73
 
 
c4c0ebe
 
 
 
4467ed0
 
c4c0ebe
4467ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf3bf5
858ef78
 
 
4cf3bf5
 
858ef78
4cf3bf5
be47723
858ef78
be47723
 
78b0064
529ee77
78b0064
eaec240
be47723
 
 
 
 
a744ad5
eb1a71a
 
4467ed0
858ef78
eb1a71a
858ef78
 
 
 
bbe5900
7aff13a
6cbdb81
 
858ef78
 
eb1a71a
4467ed0
f064e36
858ef78
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import gradio as gr

import ray
import vllm

import torch
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, BitsAndBytesConfig
from openai import OpenAI

from elasticsearch import Elasticsearch


class MultiTokenEOSCriteria(StoppingCriteria):
    def __init__(self, sequence: str, tokenizer: PreTrainedTokenizer, initial_decoder_input_length: int, batch_size: int = 1) -> None:
        self.initial_decoder_input_length = initial_decoder_input_length
        self.done_tracker = [False] * batch_size
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
        self.sequence_id_len = len(self.sequence_ids) + 2
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
        lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
        lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
        lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
        for i, done in enumerate(self.done_tracker):
            if not done:
                self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
        return False not in self.done_tracker


def search(query, index="pubmed", num_docs=3):
    """
    Search the Elasticsearch index for the most relevant documents.
    """
    
    docs = []
    if num_docs > 0:
        print(f'Running query: {query}')
        es_request_body = {
            "query": {
                "match": {
                    "content": query  # Assuming documents have a 'content' field
                }
            }, "size": num_docs
        }

        # Connect to Elasticsearch
        es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
                           basic_auth=('elastic', os.environ['ES_PASSWORD']),
                           verify_certs=False, ssl_show_warn=False)

        response = es.options(request_timeout=60).search(index=index, body=es_request_body)
        # Extract and return the documents
        docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
        print(f'Received {len(docs)} documents from index {index}')

    return docs

@ray.remote(num_gpus=1, max_calls=1)
def generate(model_name: str, messages):
    max_new_tokens = 1024

    if model_name.startswith('openai/'):
        openai_model_name = model_name.split('/')[1]

        client = OpenAI()
        openai_res = client.chat.completions.create(model=openai_model_name,
                                                    messages=messages,
                                                    max_tokens=max_new_tokens,
                                                    temperature=0)
        print('OAI_RESPONSE', openai_res)
        response = openai_res.choices[0].message.content.strip()
    else:
        quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", low_cpu_mem_usage=True, quantization_config=quantization_config)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Load your language model from HuggingFace Transformers
        generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

        tokenized_prompt = tokenizer.apply_chat_template(messages, tokenize=True)

        # Define the stopping criteria using MaxTimeCriteria
        stopping_criteria = StoppingCriteriaList([
            # MaxTimeCriteria(32),
            MultiTokenEOSCriteria("\n", tokenizer, len(tokenized_prompt))
        ])

        # Define the generation_kwargs with stopping criteria
        generation_kwargs = {
            "max_new_tokens": max_new_tokens,
            # "stopping_criteria": stopping_criteria,
            "return_full_text": False
        }

        # Generate response using the HF LLM
        hf_response = generator(messages, **generation_kwargs)

        print('HF_RESPONSE', hf_response)
        response = hf_response[0]['generated_text'].strip()
    return response

@ray.remote(num_gpus=1, max_calls=1)
def analyse(reference: str, passage: str) -> str:
    fava_input = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] "
    prompt = [fava_input.format_map({"evidence": reference, "output": passage})]

    model = vllm.LLM(model="fava-uw/fava-model") 
    sampling_params = vllm.SamplingParams(temperature=0, top_p=1.0, max_tokens=500)
    outputs = model.generate(prompt, sampling_params)
    outputs = [it.outputs[0].text for it in outputs]
    output = outputs[0].replace("<mark>", "<span style='color: green; font-weight: bold;'> ")
    output = output.replace("</mark>", " </span>")
    output = output.replace("<delete>", "<span style='color: red; text-decoration: line-through;'>")
    output = output.replace("</delete>", "</span>")
    output = output.replace("<entity>", "<span style='background-color: #E9A2D9; border-bottom: 1px dotted;'>entity</span>")
    output = output.replace("<relation>", "<span style='background-color: #F3B78B; border-bottom: 1px dotted;'>relation</span>")
    output = output.replace("<contradictory>", "<span style='background-color: #FFFF9B; border-bottom: 1px dotted;'>contradictory</span>")
    output = output.replace("<unverifiable>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>unverifiable</span><u>")
    output = output.replace("<invented>", "<span style='background-color: #BFE9B9; border-bottom: 1px dotted;'>invented</span>")
    output = output.replace("<subjective>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>subjective</span><u>")
    output = output.replace("</entity>", "")
    output = output.replace("</relation>", "")
    output = output.replace("</contradictory>", "")
    output = output.replace("</unverifiable>", "</u>")
    output = output.replace("</invented>", "")
    output = output.replace("</subjective>", "</u>")
    output = output.replace("Edited:", "")
    return f'<div style="font-weight: normal;">{output}</div>'


def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/zephyr-7b-beta"):
    """
    A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
    """
    num_docs = int(num_docs)

    # Retrieve documents
    docs = search(prompt, index=index, num_docs=num_docs)
    joined_docs = '\n\n'.join(docs)
    
    messages = [
        {
            # Please append a newline only when you have finished answering.
            "role": "system",
            "content": f"You are an advanced medical support assistant, designed to help clinicians by providing quick access to medical information, guidelines, and evidence-based recommendations. Alongside your built-in knowledge, you have access to a curated set of documents retrieved from trustworthy sources such as Wikipedia and PubMed. These documents include up-to-date medical guidelines, research summaries, and clinical practice information. You should use these documents as a primary source of information to ensure your responses are based on the most current and credible evidence available. Your responses should be accurate, concise, and in full compliance with medical ethics. You must always remind users that your guidance does not substitute for professional medical advice, diagnosis, or treatment. Your tone should be professional, supportive, and respectful, recognizing the complexity of healthcare decisions and the importance of personalized patient care. While you can offer information and suggestions based on the documents provided and current medical knowledge, you must emphasize the importance of clinicians' expertise and judgment in making clinical decisions.\n\nRetrieved documents from {index}:\n\n{joined_docs}"
        }, {
            "role": "user",
            "content": prompt
        }
    ]

    response = ray.get(generate.remote(model_name, messages))
    # analysed_response = ray.get(analyse.remote(joined_docs, response))
    analysed_response = ray.get(analyse.remote(joined_docs, prompt))

    # Return the generated text and the documents
    return analysed_response, response, joined_docs

# Create the Gradio interface
iface = gr.Interface(fn=rag_pipeline,
                     inputs=[
                         gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
                         gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks", "statpearls"], value="pubmed"),
                         gr.Number(label="Number of Documents", value=3, step=1, minimum=0, maximum=10),
                         gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
                     ],
                     outputs=[
                         gr.HTML(label="Analysed Answer"),
                         gr.Textbox(label="Generated Answer"),
                         # gr.Textbox(label="Analysed Answer"),
                         gr.Textbox(label="Retrieved Documents")
                     ],
                     description="Retrieval-Augmented Generation Pipeline")

# Launch the interface
iface.launch()