pminervini commited on
Commit
858ef78
1 Parent(s): 478e69d
Files changed (2) hide show
  1. app.py +69 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from transformers import pipeline
4
+ from elasticsearch import Elasticsearch
5
+
6
+ # Connect to Elasticsearch
7
+ es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
8
+ basic_auth=('elastic', os.environ['ES_PASSWORD']),
9
+ verify_certs=False, ssl_show_warn=False)
10
+
11
+ # Load your language model from HuggingFace Transformers
12
+ generator = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2")
13
+
14
+ def search_es(query, index="pubmed", num_results=3):
15
+ """
16
+ Search the Elasticsearch index for the most relevant documents.
17
+ """
18
+
19
+ print(f'Running query: {query}')
20
+
21
+ response = es.search(
22
+ index=index,
23
+ body={
24
+ "query": {
25
+ "match": {
26
+ "content": query # Assuming documents have a 'content' field
27
+ }
28
+ },
29
+ "size": num_results
30
+ }
31
+ )
32
+
33
+ # Extract and return the documents
34
+ docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
35
+
36
+ print(f'Received {len(docs)} documents')
37
+
38
+ return docs
39
+
40
+ def rag_pipeline(prompt, index="pubmed"):
41
+ """
42
+ A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
43
+ """
44
+ # Retrieve documents
45
+ docs = search_es(prompt, index=index)
46
+
47
+ # Combine prompt with retrieved documents
48
+ enriched_prompt = f"{prompt}\n\n{' '.join(docs)}"
49
+
50
+ # Generate response using the LLM
51
+ response = generator(enriched_prompt, max_new_tokens=256, return_full_text=False)
52
+
53
+ # Return the generated text and the documents
54
+ return response[0]['generated_text'], "\n\n".join(docs)
55
+
56
+ # Create the Gradio interface
57
+ iface = gr.Interface(fn=rag_pipeline,
58
+ inputs=[
59
+ gr.Textbox(label="Input Prompt"),
60
+ gr.Textbox(label="Elasticsearch Index", value="pubmed") # Corrected here
61
+ ],
62
+ outputs=[
63
+ gr.Textbox(label="Generated Text"),
64
+ gr.Textbox(label="Retrieved Documents")
65
+ ],
66
+ description="Retrieval-Augmented Generation Pipeline")
67
+
68
+ # Launch the interface
69
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ elasticsearch