File size: 5,469 Bytes
858ef78
 
482d0d4
 
8400fbd
28875eb
6cbdb81
858ef78
 
8400fbd
858ef78
 
 
 
 
4cf3bf5
858ef78
 
 
 
6cbdb81
 
 
 
 
 
 
 
 
 
 
 
 
 
858ef78
 
 
4cf3bf5
858ef78
 
 
4cf3bf5
 
858ef78
4cf3bf5
be47723
858ef78
be47723
 
529ee77
be47723
eaec240
be47723
 
 
 
 
eaec240
 
be47723
8400fbd
 
 
 
 
 
6cbdb81
8400fbd
 
 
6cbdb81
 
 
 
28875eb
 
6cbdb81
 
 
 
 
 
 
847df2e
 
 
6cbdb81
 
847df2e
6cbdb81
 
858ef78
 
6cbdb81
858ef78
 
 
 
bbe5900
6cbdb81
 
 
858ef78
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import gradio as gr

import torch
from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
from openai import OpenAI

from elasticsearch import Elasticsearch


# Connect to Elasticsearch
es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
                   basic_auth=('elastic', os.environ['ES_PASSWORD']),
                   verify_certs=False, ssl_show_warn=False)

def search(query, index="pubmed", num_docs=3):
    """
    Search the Elasticsearch index for the most relevant documents.
    """
    
    docs = []
    if num_docs > 0:
        print(f'Running query: {query}')
        es_request_body = {
            "query": {
                "match": {
                    "content": query  # Assuming documents have a 'content' field
                }
            }, "size": num_docs
        }
        response = es.options(request_timeout=60).search(index=index, body=es_request_body)
        # Extract and return the documents
        docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
        print(f'Received {len(docs)} documents from index {index}')

    return docs

def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/zephyr-7b-beta"):
    """
    A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
    """
    num_docs = int(num_docs)

    # Retrieve documents
    docs = search(prompt, index=index, num_docs=num_docs)
    joined_docs = '\n\n'.join(docs)
    
    messages = [
        {
            "role": "system",
            "content": f"You are an advanced medical support assistant, designed to help clinicians by providing quick access to medical information, guidelines, and evidence-based recommendations. Alongside your built-in knowledge, you have access to a curated set of documents retrieved from trustworthy sources such as Wikipedia and PubMed. These documents include up-to-date medical guidelines, research summaries, and clinical practice information. You should use these documents as a primary source of information to ensure your responses are based on the most current and credible evidence available. Your responses should be accurate, concise, and in full compliance with medical ethics. You must always remind users that your guidance does not substitute for professional medical advice, diagnosis, or treatment. Your tone should be professional, supportive, and respectful, recognizing the complexity of healthcare decisions and the importance of personalized patient care. While you can offer information and suggestions based on the documents provided and current medical knowledge, you must emphasize the importance of clinicians' expertise and judgment in making clinical decisions.\n\nRetrieved documents from {index}:\n\n{joined_docs}"
        }, {
            "role": "user",
            "content": prompt
        }
    ]

    for message in messages:
        print('MSG', message)

    # Define the stopping criteria using MaxTimeCriteria
    stopping_criteria = StoppingCriteriaList([MaxTimeCriteria(32)])

    # Define the generation_kwargs with stopping criteria
    generation_kwargs = {
        "max_new_tokens": 128,
        # "generation_kwargs": {"stopping_criteria": stopping_criteria},
        "return_full_text": False
    }

    if model_name.startswith('openai/'):
        openai_model_name = model_name.split('/')[1]
        openai_prompt = '\n\n'.join([m['content'] for m in messages])

        client = OpenAI()
        openai_res = client.completions.create(model=openai_model_name,
                                       prompt=openai_prompt,
                                       max_tokens=generation_kwargs["max_new_tokens"],
                                       n=1,
                                       stop=None,
                                       temperature=0)
        response = openai_res.choices[0].text.strip()
    else:
        # Load your language model from HuggingFace Transformers
        generator = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16)

        # Generate response using the HF LLM
        hf_response = generator(messages, **generation_kwargs)

        print('HF_RESPONSE', hf_response)
        response = hf_response[0]['generated_text']
    
    # Return the generated text and the documents
    return response, joined_docs

# Create the Gradio interface
iface = gr.Interface(fn=rag_pipeline,
                     inputs=[
                         gr.Textbox(label="Input Prompt", value="Are group 2 innate lymphoid cells (ILC2s) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"),
                         gr.Dropdown(label="Index", choices=["pubmed", "wikipedia", "textbooks"], value="pubmed"),
                         gr.Number(label="Number of Documents", value=3, step=1, minimum=0, maximum=10),
                         gr.Dropdown(label="Model", choices=["HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "openai/gpt-3.5-turbo"], value="HuggingFaceH4/zephyr-7b-beta")
                     ],
                     outputs=[
                         gr.Textbox(label="Generated Text"),
                         gr.Textbox(label="Retrieved Documents")
                     ],
                     description="Retrieval-Augmented Generation Pipeline")

# Launch the interface
iface.launch()