Logeswaransr commited on
Commit
1995a08
·
verified ·
1 Parent(s): db314d7

Create initialize.py

Browse files
Files changed (1) hide show
  1. initialize.py +55 -0
initialize.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from haystack import Document
2
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
3
+ from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
4
+ from haystack.components.builders import PromptBuilder
5
+ from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
6
+ from haystack.pipeline import Pipeline
7
+
8
+ def init_doc_store(path, files):
9
+ docs = []
10
+ for file in files:
11
+ with open(path + '/' + file, 'r') as f:
12
+ content = f.read()
13
+ docs.append(Document(content=content, meta={'name':file}))
14
+
15
+ document_store = InMemoryDocumentStore()
16
+ document_store.write_documents(docs)
17
+ return document_store
18
+
19
+ def define_components(document_store):
20
+ retriever = InMemoryBM25Retriever(document_store, top_k=3)
21
+
22
+ template = """
23
+ Given the following information, answer the question.
24
+
25
+ Context:
26
+ {% for document in documents %}
27
+ {{ document.content }}
28
+ {% endfor %}
29
+
30
+ Question: {{question}}
31
+ Answer:
32
+ """
33
+ prompt_builder = PromptBuilder(template=template)
34
+
35
+ generator = HuggingFaceLocalGenerator(model="gpt2",
36
+ task="text-generation",
37
+ # device='cuda',
38
+ generation_kwargs={
39
+ "max_new_tokens": 100,
40
+ "temperature": 0.9,
41
+ })
42
+ generator.warm_up()
43
+ return retreiver, prompt_builder, generator
44
+
45
+ def define_pipeline(retreiver, prompt_builder, generator):
46
+ basic_rag_pipeline = Pipeline()
47
+
48
+ basic_rag_pipeline.add_component("retriever", retriever)
49
+ basic_rag_pipeline.add_component("prompt_builder", prompt_builder)
50
+ basic_rag_pipeline.add_component("llm", generator)
51
+
52
+ basic_rag_pipeline.connect("retriever", "prompt_builder.documents")
53
+ basic_rag_pipeline.connect("prompt_builder", "llm")
54
+
55
+ return basic_rag_pipeline