Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# imports
|
2 |
+
import together
|
3 |
+
|
4 |
+
|
5 |
+
"""## RetrievalQA with LLaMA 2-70B on Together API"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
os.environ["TOGETHER_API_KEY"] = "6e132bb99c767328701e4870bad6b3234b94ee701dbf7b995cdbec44fb01687a"
|
9 |
+
|
10 |
+
# !pip show langchain
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
# set your API key
|
15 |
+
together.api_key = os.environ["TOGETHER_API_KEY"]
|
16 |
+
|
17 |
+
# list available models and descriptons
|
18 |
+
models = together.Models.list()
|
19 |
+
|
20 |
+
together.Models.start("togethercomputer/llama-2-70b-chat")
|
21 |
+
|
22 |
+
import together
|
23 |
+
|
24 |
+
import logging
|
25 |
+
from typing import Any, Dict, List, Mapping, Optional
|
26 |
+
|
27 |
+
from pydantic import Extra, Field, root_validator
|
28 |
+
|
29 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
30 |
+
from langchain.llms.base import LLM
|
31 |
+
from langchain.llms.utils import enforce_stop_tokens
|
32 |
+
from langchain.utils import get_from_dict_or_env
|
33 |
+
|
34 |
+
class TogetherLLM(LLM):
|
35 |
+
"""Together large language models."""
|
36 |
+
|
37 |
+
model: str = "togethercomputer/llama-2-70b-chat"
|
38 |
+
"""model endpoint to use"""
|
39 |
+
|
40 |
+
together_api_key: str = os.environ["TOGETHER_API_KEY"]
|
41 |
+
"""Together API key"""
|
42 |
+
|
43 |
+
temperature: float = 0.0
|
44 |
+
"""What sampling temperature to use."""
|
45 |
+
|
46 |
+
max_tokens: int = 512
|
47 |
+
"""The maximum number of tokens to generate in the completion."""
|
48 |
+
|
49 |
+
class Config:
|
50 |
+
extra = Extra.forbid
|
51 |
+
|
52 |
+
# @root_validator()
|
53 |
+
# def validate_environment(cls, values: Dict) -> Dict:
|
54 |
+
# """Validate that the API key is set."""
|
55 |
+
# api_key = get_from_dict_or_env(
|
56 |
+
# values, "together_api_key", "TOGETHER_API_KEY"
|
57 |
+
# )
|
58 |
+
# values["together_api_key"] = api_key
|
59 |
+
# return values
|
60 |
+
|
61 |
+
@property
|
62 |
+
def _llm_type(self) -> str:
|
63 |
+
"""Return type of LLM."""
|
64 |
+
return "together"
|
65 |
+
|
66 |
+
def _call(
|
67 |
+
self,
|
68 |
+
prompt: str,
|
69 |
+
**kwargs: Any,
|
70 |
+
) -> str:
|
71 |
+
"""Call to Together endpoint."""
|
72 |
+
together.api_key = self.together_api_key
|
73 |
+
output = together.Complete.create(prompt,
|
74 |
+
model=self.model,
|
75 |
+
max_tokens=self.max_tokens,
|
76 |
+
temperature=self.temperature,
|
77 |
+
)
|
78 |
+
text = output['output']['choices'][0]['text']
|
79 |
+
return text
|
80 |
+
|
81 |
+
import os
|
82 |
+
|
83 |
+
"""# import"""
|
84 |
+
|
85 |
+
from langchain.vectorstores import Chroma
|
86 |
+
from langchain.chains import RetrievalQA
|
87 |
+
from langchain.document_loaders import DirectoryLoader
|
88 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
89 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
90 |
+
|
91 |
+
# from langchain.document_loaders import TextLoader
|
92 |
+
# from langchain.document_loaders import PyPDFLoader
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
from InstructorEmbedding import INSTRUCTOR
|
97 |
+
|
98 |
+
|
99 |
+
loader = DirectoryLoader('/content/Data')
|
100 |
+
|
101 |
+
documents = loader.load()
|
102 |
+
|
103 |
+
len(documents)
|
104 |
+
|
105 |
+
#splitting the text into
|
106 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
|
107 |
+
texts = text_splitter.split_documents(documents)
|
108 |
+
|
109 |
+
# HF Instructor Embeddings
|
110 |
+
|
111 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
112 |
+
|
113 |
+
instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-base",
|
114 |
+
# model_kwargs={"device": "cuda"})
|
115 |
+
model_kwargs={"device": "cpu"})
|
116 |
+
|
117 |
+
"""## create the DB
|
118 |
+
|
119 |
+
This will take a bit of time on a T4 GPU
|
120 |
+
"""
|
121 |
+
|
122 |
+
persist_directory = 'db'
|
123 |
+
|
124 |
+
## Here is the nmew embeddings being used
|
125 |
+
embedding = instructor_embeddings
|
126 |
+
|
127 |
+
vectordb = Chroma.from_documents(documents=texts,
|
128 |
+
embedding=embedding,
|
129 |
+
persist_directory=persist_directory)
|
130 |
+
|
131 |
+
"""## Make a retriever"""
|
132 |
+
|
133 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
|
134 |
+
|
135 |
+
"""## Make a chain"""
|
136 |
+
|
137 |
+
llm = TogetherLLM(
|
138 |
+
model= "togethercomputer/llama-2-70b-chat",
|
139 |
+
temperature = 0.0,
|
140 |
+
max_tokens = 1024
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
DEFAULT_SYSTEM_PROMPT = """
|
146 |
+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
147 |
+
|
148 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
149 |
+
""".strip()
|
150 |
+
|
151 |
+
def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
|
152 |
+
return f"""
|
153 |
+
[INST] <>
|
154 |
+
{system_prompt}
|
155 |
+
<>
|
156 |
+
|
157 |
+
{prompt} [/INST]
|
158 |
+
""".strip()
|
159 |
+
|
160 |
+
# SYSTEM_PROMPT = "Answer from following context, if question is out of context respond you don't know and do not explain the same"
|
161 |
+
SYSTEM_PROMPT = "Answer from following context, if question is out of context respond i don't know"
|
162 |
+
|
163 |
+
|
164 |
+
template = generate_prompt(
|
165 |
+
"""
|
166 |
+
{context}
|
167 |
+
|
168 |
+
Question: {question}
|
169 |
+
""",
|
170 |
+
system_prompt=SYSTEM_PROMPT,
|
171 |
+
)
|
172 |
+
|
173 |
+
print(template)
|
174 |
+
|
175 |
+
from langchain import HuggingFacePipeline, PromptTemplate
|
176 |
+
|
177 |
+
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
178 |
+
|
179 |
+
print(prompt)
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
# create the chain to answer questions
|
184 |
+
qa_chain = RetrievalQA.from_chain_type(llm=llm,
|
185 |
+
chain_type="stuff",
|
186 |
+
retriever=retriever,
|
187 |
+
return_source_documents=True,
|
188 |
+
chain_type_kwargs={"prompt": prompt})
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
import gradio
|
193 |
+
|
194 |
+
def greet(query):
|
195 |
+
llm_response = qa_chain(query)
|
196 |
+
return llm_response['result']
|
197 |
+
|
198 |
+
|
199 |
+
gradio.Interface(greet, "text", "text").launch()
|