rohand1 commited on
Commit
cefa4a2
1 Parent(s): 90b99b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # imports
2
+ import together
3
+
4
+
5
+ """## RetrievalQA with LLaMA 2-70B on Together API"""
6
+
7
+ import os
8
+ os.environ["TOGETHER_API_KEY"] = "6e132bb99c767328701e4870bad6b3234b94ee701dbf7b995cdbec44fb01687a"
9
+
10
+ # !pip show langchain
11
+
12
+
13
+
14
+ # set your API key
15
+ together.api_key = os.environ["TOGETHER_API_KEY"]
16
+
17
+ # list available models and descriptons
18
+ models = together.Models.list()
19
+
20
+ together.Models.start("togethercomputer/llama-2-70b-chat")
21
+
22
+ import together
23
+
24
+ import logging
25
+ from typing import Any, Dict, List, Mapping, Optional
26
+
27
+ from pydantic import Extra, Field, root_validator
28
+
29
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
30
+ from langchain.llms.base import LLM
31
+ from langchain.llms.utils import enforce_stop_tokens
32
+ from langchain.utils import get_from_dict_or_env
33
+
34
+ class TogetherLLM(LLM):
35
+ """Together large language models."""
36
+
37
+ model: str = "togethercomputer/llama-2-70b-chat"
38
+ """model endpoint to use"""
39
+
40
+ together_api_key: str = os.environ["TOGETHER_API_KEY"]
41
+ """Together API key"""
42
+
43
+ temperature: float = 0.0
44
+ """What sampling temperature to use."""
45
+
46
+ max_tokens: int = 512
47
+ """The maximum number of tokens to generate in the completion."""
48
+
49
+ class Config:
50
+ extra = Extra.forbid
51
+
52
+ # @root_validator()
53
+ # def validate_environment(cls, values: Dict) -> Dict:
54
+ # """Validate that the API key is set."""
55
+ # api_key = get_from_dict_or_env(
56
+ # values, "together_api_key", "TOGETHER_API_KEY"
57
+ # )
58
+ # values["together_api_key"] = api_key
59
+ # return values
60
+
61
+ @property
62
+ def _llm_type(self) -> str:
63
+ """Return type of LLM."""
64
+ return "together"
65
+
66
+ def _call(
67
+ self,
68
+ prompt: str,
69
+ **kwargs: Any,
70
+ ) -> str:
71
+ """Call to Together endpoint."""
72
+ together.api_key = self.together_api_key
73
+ output = together.Complete.create(prompt,
74
+ model=self.model,
75
+ max_tokens=self.max_tokens,
76
+ temperature=self.temperature,
77
+ )
78
+ text = output['output']['choices'][0]['text']
79
+ return text
80
+
81
+ import os
82
+
83
+ """# import"""
84
+
85
+ from langchain.vectorstores import Chroma
86
+ from langchain.chains import RetrievalQA
87
+ from langchain.document_loaders import DirectoryLoader
88
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
89
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
90
+
91
+ # from langchain.document_loaders import TextLoader
92
+ # from langchain.document_loaders import PyPDFLoader
93
+
94
+
95
+
96
+ from InstructorEmbedding import INSTRUCTOR
97
+
98
+
99
+ loader = DirectoryLoader('/content/Data')
100
+
101
+ documents = loader.load()
102
+
103
+ len(documents)
104
+
105
+ #splitting the text into
106
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
107
+ texts = text_splitter.split_documents(documents)
108
+
109
+ # HF Instructor Embeddings
110
+
111
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
112
+
113
+ instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-base",
114
+ # model_kwargs={"device": "cuda"})
115
+ model_kwargs={"device": "cpu"})
116
+
117
+ """## create the DB
118
+
119
+ This will take a bit of time on a T4 GPU
120
+ """
121
+
122
+ persist_directory = 'db'
123
+
124
+ ## Here is the nmew embeddings being used
125
+ embedding = instructor_embeddings
126
+
127
+ vectordb = Chroma.from_documents(documents=texts,
128
+ embedding=embedding,
129
+ persist_directory=persist_directory)
130
+
131
+ """## Make a retriever"""
132
+
133
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
134
+
135
+ """## Make a chain"""
136
+
137
+ llm = TogetherLLM(
138
+ model= "togethercomputer/llama-2-70b-chat",
139
+ temperature = 0.0,
140
+ max_tokens = 1024
141
+ )
142
+
143
+
144
+
145
+ DEFAULT_SYSTEM_PROMPT = """
146
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
147
+
148
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
149
+ """.strip()
150
+
151
+ def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
152
+ return f"""
153
+ [INST] <>
154
+ {system_prompt}
155
+ <>
156
+
157
+ {prompt} [/INST]
158
+ """.strip()
159
+
160
+ # SYSTEM_PROMPT = "Answer from following context, if question is out of context respond you don't know and do not explain the same"
161
+ SYSTEM_PROMPT = "Answer from following context, if question is out of context respond i don't know"
162
+
163
+
164
+ template = generate_prompt(
165
+ """
166
+ {context}
167
+
168
+ Question: {question}
169
+ """,
170
+ system_prompt=SYSTEM_PROMPT,
171
+ )
172
+
173
+ print(template)
174
+
175
+ from langchain import HuggingFacePipeline, PromptTemplate
176
+
177
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
178
+
179
+ print(prompt)
180
+
181
+
182
+
183
+ # create the chain to answer questions
184
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
185
+ chain_type="stuff",
186
+ retriever=retriever,
187
+ return_source_documents=True,
188
+ chain_type_kwargs={"prompt": prompt})
189
+
190
+
191
+
192
+ import gradio
193
+
194
+ def greet(query):
195
+ llm_response = qa_chain(query)
196
+ return llm_response['result']
197
+
198
+
199
+ gradio.Interface(greet, "text", "text").launch()