File size: 9,382 Bytes
858ef78
 
482d0d4
4467ed0
 
482d0d4
818a1c7
28875eb
6cbdb81
858ef78
 
8400fbd
818a1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf3bf5
858ef78
 
 
 
6cbdb81
 
 
 
 
 
 
 
 
 
4467ed0
 
 
 
 
 
6cbdb81
 
 
 
858ef78
 
 
c4c0ebe
 
 
 
4467ed0
 
c4c0ebe
4467ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf3bf5
858ef78
 
 
4cf3bf5
 
858ef78
4cf3bf5
be47723
858ef78
be47723
 
529ee77
818a1c7
eaec240
be47723
 
 
 
 
eaec240
 
be47723
25e48d1
8400fbd
6cbdb81
 
 
28875eb
0374d96
 
818a1c7
0374d96
 
3e89f8c
6cbdb81
a9ea6f7
d2e6098
 
847df2e
d2e6098
847df2e
818a1c7
 
 
 
 
 
 
 
 
 
25e48d1
818a1c7
 
 
 
6cbdb81
 
847df2e
6cbdb81
 
858ef78
c4c0ebe
4467ed0
858ef78
4467ed0
858ef78
 
 
 
bbe5900
6cbdb81
 
 
858ef78
 
4467ed0
 
858ef78
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import gradio as gr

import vllm

import torch
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer
from openai import OpenAI

from elasticsearch import Elasticsearch


class MultiTokenEOSCriteria(StoppingCriteria):
    def __init__(self, sequence: str, tokenizer: PreTrainedTokenizer, initial_decoder_input_length: int, batch_size: int = 1) -> None:
        self.initial_decoder_input_length = initial_decoder_input_length
        self.done_tracker = [False] * batch_size
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
        self.sequence_id_len = len(self.sequence_ids) + 2
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
        lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
        lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
        lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
        for i, done in enumerate(self.done_tracker):
            if not done:
                self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
        return False not in self.done_tracker


def search(query, index="pubmed", num_docs=3):
    """
    Search the Elasticsearch index for the most relevant documents.
    """
    
    docs = []
    if num_docs > 0:
        print(f'Running query: {query}')
        es_request_body = {
            "query": {
                "match": {
                    "content": query  # Assuming documents have a 'content' field
                }
            }, "size": num_docs
        }

        # Connect to Elasticsearch
        es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
                           basic_auth=('elastic', os.environ['ES_PASSWORD']),
                           verify_certs=False, ssl_show_warn=False)

        response = es.options(request_timeout=60).search(index=index, body=es_request_body)
        # Extract and return the documents
        docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
        print(f'Received {len(docs)} documents from index {index}')

    return docs

def analyse(reference: str, passage: str) -> str:
    fava_input = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] "
    prompt = [fava_input.format_map({"evidence": reference, "output": passage})]

    model = vllm.LLM(model="fava-uw/fava-model") 
    sampling_params = vllm.SamplingParams(temperature=0, top_p=1.0, max_tokens=500)
    outputs = model.generate(prompt, sampling_params)
    outputs = [it.outputs[0].text for it in outputs]
    output = outputs[0].replace("<mark>", "<span style='color: green; font-weight: bold;'> ")
    output = output.replace("</mark>", " </span>")
    output = output.replace("<delete>", "<span style='color: red; text-decoration: line-through;'>")
    output = output.replace("</delete>", "</span>")
    output = output.replace("<entity>", "<span style='background-color: #E9A2D9; border-bottom: 1px dotted;'>entity</span>")
    output = output.replace("<relation>", "<span style='background-color: #F3B78B; border-bottom: 1px dotted;'>relation</span>")
    output = output.replace("<contradictory>", "<span style='background-color: #FFFF9B; border-bottom: 1px dotted;'>contradictory</span>")
    output = output.replace("<unverifiable>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>unverifiable</span><u>")
    output = output.replace("<invented>", "<span style='background-color: #BFE9B9; border-bottom: 1px dotted;'>invented</span>")
    output = output.replace("<subjective>", "<span style='background-color: #D3D3D3; border-bottom: 1px dotted;'>subjective</span><u>")
    output = output.replace("</entity>", "")
    output = output.replace("</relation>", "")
    output = output.replace("</contradictory>", "")
    output = output.replace("</unverifiable>", "</u>")
    output = output.replace("</invented>", "")
    output = output.replace("</subjective>", "</u>")
    output = output.replace("Edited:", "")
    return f'<div style="font-weight: normal;">{output}</div>'


def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/zephyr-7b-beta"):
    """
    A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
    """
    num_docs = int(num_docs)

    # Retrieve documents
    docs = search(prompt, index=index, num_docs=num_docs)
    joined_docs = '\n\n'.join(docs)
    
    messages = [
        {
            "role": "system",
            "content": f"You are an advanced medical support assistant, designed to help clinicians by providing quick access to medical information, guidelines, and evidence-based recommendations. Alongside your built-in knowledge, you have access to a curated set of documents retrieved from trustworthy sources such as Wikipedia and PubMed. These documents include up-to-date medical guidelines, research summaries, and clinical practice information. You should use these documents as a primary source of information to ensure your responses are based on the most current and credible evidence available. Your responses should be accurate, concise, and in full compliance with medical ethics. You must always remind users that your guidance does not substitute for professional medical advice, diagnosis, or treatment. Your tone should be professional, supportive, and respectful, recognizing the complexity of healthcare decisions and the importance of personalized patient care. While you can offer information and suggestions based on the documents provided and current medical knowledge, you must emphasize the importance of clinicians' expertise and judgment in making clinical decisions. Please append a newline only when you have finished answering.\n\nRetrieved documents from {index}:\n\n{joined_docs}"
        }, {
            "role": "user",
            "content": prompt
        }
    ]

    for message in messages:
        print('MSG', message)

    max_new_tokens = 1024

    if model_name.startswith('openai/'):
        openai_model_name = model_name.split('/')[1]

        client = OpenAI()
        openai_res = client.chat.completions.create(model=openai_model_name,
                                                    messages=messages,
                                                    max_tokens=max_new_tokens,
                                                    temperature=0)
        print('OAI_RESPONSE', openai_res)
        response = openai_res.choices[0].message.content.strip()
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, load_in_4bit=True)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Load your language model from HuggingFace Transformers
        generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

        tokenized_prompt = tokenizer.apply_chat_template(messages, tokenize=True)

        # Define the stopping criteria using MaxTimeCriteria
        stopping_criteria = StoppingCriteriaList([
            # MaxTimeCriteria(32),
            MultiTokenEOSCriteria("\n", tokenizer, len(tokenized_prompt))
        ])

        # Define the generation_kwargs with stopping criteria
        generation_kwargs = {
            "max_new_tokens": max_new_tokens,
            "generation_kwargs": {"stopping_criteria": stopping_criteria},
            "return_full_text": False
        }

        # Generate response using the HF LLM
        hf_response = generator(messages, **generation_kwargs)

        print('HF_RESPONSE', hf_response)
        response = hf_response[0]['generated_text']
    
    analysed_response = analyse(joined_docs, response)

    # Return the generated text and the documents
    return response, analysed_response, joined_docs

# Create the Gradio interface
iface = gr.Interface(fn=rag_pipeline,
                     inputs=[
                         gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
                         gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks"], value="pubmed"),
                         gr.Number(label="Number of Documents", value=3, step=1, minimum=0, maximum=10),
                         gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
                     ],
                     outputs=[
                         gr.Textbox(label="Generated Answer"),
                         gr.Textbox(label="Analysed Answer"),
                         gr.Textbox(label="Retrieved Documents")
                     ],
                     description="Retrieval-Augmented Generation Pipeline")

# Launch the interface
iface.launch()