praveenpankaj commited on
Commit
02b1187
·
verified ·
1 Parent(s): c7b6715

Create rag_output.py

Browse files
Files changed (1) hide show
  1. rag_output.py +138 -0
rag_output.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import torch
3
+ from transformers import BitsAndBytesConfig, AutoModelForCausalLM
4
+
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain.vectorstores import FAISS
7
+ from langchain_community.document_loaders.csv_loader import CSVLoader
8
+ from langchain_community.document_loaders import HuggingFaceDatasetLoader
9
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
10
+
11
+ import transformers
12
+ from langchain.llms import HuggingFacePipeline
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
15
+ from langchain.chains import LLMChain
16
+
17
+ import profanity_check
18
+ from langdetect import detect
19
+ import langid
20
+ import os
21
+
22
+ auth_token = os.environ.get("HF_TOKEN") or True
23
+
24
+ base_model = "praveenpankaj/aksara_1_unsloth_q4" #our finetuned model
25
+
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ base_model,
29
+ padding_side = "left",
30
+ add_eos_token = True,
31
+
32
+ )
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+ tokenizer.add_bos_token, tokenizer.add_eos_token
35
+
36
+ bnb_config = BitsAndBytesConfig(
37
+ load_in_4bit= True,
38
+ bnb_4bit_quant_type= "nf4",
39
+ bnb_4bit_compute_dtype= torch.bfloat16,
40
+ bnb_4bit_use_double_quant= False,
41
+ )
42
+
43
+ # model = AutoModelForCausalLM.from_pretrained(
44
+ # base_model,
45
+ # use_auth_token=auth_token,
46
+ # quantization_config=bnb_config,
47
+ # torch_dtype=torch.bfloat16,
48
+ # device_map="auto",
49
+ # trust_remote_code=True,
50
+ # )
51
+
52
+
53
+ # loader = CSVLoader(file_path='vsdb.csv')
54
+ # data = loader.load()
55
+ loader = HuggingFaceDatasetLoader('cropinailab/context_pop', 'pop')
56
+ data = loader.load()
57
+
58
+ db = FAISS.from_documents(data,
59
+ HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2'))
60
+
61
+
62
+ # Connect query to FAISS index using a retriever
63
+ retriever = db.as_retriever(
64
+ search_type="similarity_score_threshold",
65
+ search_kwargs={"score_threshold": 0.25, "k": 2}
66
+ )
67
+
68
+ def fetch(query):
69
+ res = retriever.get_relevant_documents(query)
70
+ docs = []
71
+ for i in res:
72
+ docs.append(i.page_content[5:])
73
+ return docs
74
+
75
+
76
+ text_generation_pipeline = transformers.pipeline(
77
+ model=model,
78
+ tokenizer=tokenizer,
79
+ task="text-generation",
80
+ temperature=0.000001,
81
+ repetition_penalty=1.2,
82
+ top_k=50,
83
+ top_p=0.95,
84
+ return_full_text=True,
85
+ max_new_tokens=512,
86
+ num_return_sequences=1,
87
+ do_sample=True
88
+ )
89
+
90
+ # Do not answer if you are not sure, just say I don't know
91
+
92
+ prompt_template = """
93
+ ### [INST]
94
+ Instruction: You are an expert Agronomist have a fruitful conversation with the user. Answer the question based on your knowledge. Just say I don't know if you are not sure of the answer. First check if question belongs to agriculture domain, if not then say "I don't know". Here is some context to enhance your response:
95
+ NOTE: Don't use the context if it is not factually related to the question. Don't mention you are answering based on the documents or context, rather you can say based on your training knowledge. Always provide disclaimer whenever you mention about any kind of chemicals.
96
+ {context}
97
+ ### USER
98
+ {question}
99
+ [/INST]
100
+ """
101
+ llm = Llama(
102
+ model_path=hf_hub_download(
103
+ repo_id="praveenpankaj/aksara_1_unsloth_q4",
104
+ filename="aksara_-unsloth.Q4_K_M.gguf",
105
+ ),
106
+ n_ctx=1024,
107
+ )
108
+ mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
109
+
110
+ # Create prompt from prompt template
111
+ prompt = PromptTemplate(
112
+ input_variables=["context", "question"],
113
+ template=prompt_template,
114
+ )
115
+
116
+ # Create llm chain
117
+ llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)
118
+
119
+ from langchain.schema.runnable import RunnablePassthrough
120
+
121
+ rag_chain = (
122
+ {"context": fetch, "question": RunnablePassthrough()}
123
+ | llm_chain
124
+ )
125
+
126
+
127
+ #check profanity
128
+ def check_if_profane(inp):
129
+ return profanity_check.predict([inp])
130
+
131
+ def rag_response(query):
132
+ if langid.classify(query)[0] != 'en':
133
+ return "Please provide a question in English language, I will be happy to help you."
134
+ elif check_if_profane(query):
135
+ return "Profanity detected in the query, I cannot provide the answer"
136
+ else:
137
+ res = rag_chain.invoke(query)
138
+ return res['text']