divy131 commited on
Commit
d8fcee7
·
verified ·
1 Parent(s): db3fd71

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +103 -0
utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing Dependencies
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+
4
+ from langchain import PromptTemplate, HuggingFacePipeline
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.chains import RetrievalQA
8
+
9
+ # Faiss Index Path
10
+ FAISS_INDEX = "vectorstore/"
11
+
12
+ # Custom prompt template
13
+ custom_prompt_template = """[INST] <<SYS>>
14
+ You are a trained bot to guide people about Indian Law. You will answer user's query with your knowledge and the context provided.
15
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
16
+ Do not say thank you and tell you are an AI Assistant and be open about everything.
17
+ <</SYS>>
18
+ Use the following pieces of context to answer the users question.
19
+ Context : {context}
20
+ Question : {question}
21
+ Answer : [/INST]
22
+ """
23
+
24
+ # Return the custom prompt template
25
+ def set_custom_prompt_template():
26
+ """
27
+ Set the custom prompt template for the LLMChain
28
+ """
29
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
30
+
31
+ return prompt
32
+
33
+ # Return the LLM
34
+ def load_llm():
35
+ """
36
+ Load the LLM
37
+ """
38
+ # Model ID
39
+ repo_id = 'meta-llama/Llama-2-7b-chat-hf'
40
+
41
+ # Load the model
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ repo_id,
44
+ device_map='auto',
45
+ load_in_4bit=True
46
+ )
47
+
48
+ # Load the tokenizer
49
+ tokenizer = AutoTokenizer.from_pretrained(
50
+ repo_id,
51
+ use_fast=True
52
+ )
53
+
54
+ # Create pipeline
55
+ pipe = pipeline(
56
+ 'text-generation',
57
+ model=model,
58
+ tokenizer=tokenizer,
59
+ max_length=512
60
+ )
61
+
62
+ # Load the LLM
63
+ llm = HuggingFacePipeline(pipeline=pipe)
64
+
65
+ return llm
66
+
67
+ # Return the chain
68
+ def retrieval_qa_chain(llm, prompt, db):
69
+ """
70
+ Create the Retrieval QA chain
71
+ """
72
+ # Create the chain
73
+ qa_chain = RetrievalQA.from_chain_type(
74
+ llm=llm,
75
+ chain_type='stuff',
76
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
77
+ return_source_documents=True,
78
+ chain_type_kwargs={'prompt': prompt}
79
+ )
80
+
81
+ return qa_chain
82
+
83
+ # Return the chain
84
+ def qa_pipeline():
85
+ """
86
+ Create the QA pipeline
87
+ """
88
+ # Load the HuggingFace embeddings
89
+ embeddings = HuggingFaceEmbeddings()
90
+
91
+ # Load the index
92
+ db = FAISS.load_local("vectorstore/", embeddings)
93
+
94
+ # Load the LLM
95
+ llm = load_llm()
96
+
97
+ # Set the custom prompt template
98
+ qa_prompt = set_custom_prompt_template()
99
+
100
+ # Create the retrieval QA chain
101
+ chain = retrieval_qa_chain(llm, qa_prompt, db)
102
+
103
+ return chain