arjunanand13 commited on
Commit
dd2e81b
1 Parent(s): cd1ab97

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import accelerate
5
+ import einops
6
+ import langchain
7
+ import xformers
8
+ import os
9
+ import bitsandbytes
10
+ import sentence_transformers
11
+ import huggingface_hub
12
+ import torch
13
+ from torch import cuda, bfloat16
14
+ from transformers import StoppingCriteria, StoppingCriteriaList
15
+ from langchain.llms import HuggingFacePipeline
16
+ from langchain.document_loaders import TextLoader, DirectoryLoader
17
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
18
+ from langchain.embeddings import HuggingFaceEmbeddings
19
+ from langchain.vectorstores import FAISS
20
+ from langchain.chains import ConversationalRetrievalChain
21
+ from huggingface_hub import InferenceClient
22
+
23
+ # Login to Hugging Face using a token
24
+ # huggingface_hub.login(HF_TOKEN)
25
+
26
+ """
27
+ Loading of the LLama3 model
28
+ """
29
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
30
+ model_id = 'meta-llama/Meta-Llama-3-8B'
31
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
32
+
33
+
34
+ """set quantization configuration to load large model with less GPU memory
35
+ this requires the `bitsandbytes` library"""
36
+ bnb_config = transformers.BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_quant_type='nf4',
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_compute_dtype=bfloat16
41
+ )
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct",token=HF_TOKEN)
44
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto",token=HF_TOKEN,quantization_config=bnb_config) # to("cuda:0")
45
+ terminators = [
46
+ tokenizer.eos_token_id,
47
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
48
+ ]
49
+
50
+ """CPU"""
51
+
52
+ # model_config = transformers.AutoConfig.from_pretrained(
53
+ # model_id,
54
+ # token=HF_TOKEN,
55
+ # # use_auth_token=hf_auth
56
+ # )
57
+ # model = transformers.AutoModelForCausalLM.from_pretrained(
58
+ # model_id,
59
+ # trust_remote_code=True,
60
+ # config=model_config,
61
+ # # quantization_config=bnb_config,
62
+ # token=HF_TOKEN,
63
+ # # use_auth_token=hf_auth
64
+ # )
65
+ # model.eval()
66
+ # tokenizer = transformers.AutoTokenizer.from_pretrained(
67
+ # model_id,
68
+ # token=HF_TOKEN,
69
+ # # use_auth_token=hf_auth
70
+ # )
71
+ # generate_text = transformers.pipeline(
72
+ # model=self.model, tokenizer=self.tokenizer,
73
+ # return_full_text=True,
74
+ # task='text-generation',
75
+ # temperature=0.01,
76
+ # max_new_tokens=512
77
+ # )
78
+
79
+ """
80
+ Setting up the stop list to define stopping criteria.
81
+ """
82
+
83
+ stop_list = ['\nHuman:', '\n```\n']
84
+
85
+ stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
86
+ stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
87
+
88
+
89
+ # define custom stopping criteria object
90
+ class StopOnTokens(StoppingCriteria):
91
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
92
+ for stop_ids in stop_token_ids:
93
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
94
+ return True
95
+ return False
96
+
97
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
98
+
99
+
100
+ generate_text = transformers.pipeline(
101
+ model=model,
102
+ tokenizer=tokenizer,
103
+ return_full_text=True, # langchain expects the full text
104
+ task='text-generation',
105
+ # we pass model parameters here too
106
+ stopping_criteria=stopping_criteria, # without this model rambles during chat
107
+ temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
108
+ max_new_tokens=512, # max number of tokens to generate in the output
109
+ repetition_penalty=1.1 # without this output begins repeating
110
+ )
111
+
112
+ llm = HuggingFacePipeline(pipeline=generate_text)
113
+
114
+ loader = DirectoryLoader('data/text/', loader_cls=TextLoader)
115
+ documents = loader.load()
116
+ print('len of documents are',len(documents))
117
+
118
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
119
+ all_splits = text_splitter.split_documents(documents)
120
+
121
+ model_name = "sentence-transformers/all-mpnet-base-v2"
122
+ model_kwargs = {"device": "cuda"}
123
+
124
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
125
+
126
+ # storing embeddings in the vector store
127
+ vectorstore = FAISS.from_documents(all_splits, embeddings)
128
+
129
+ chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
130
+
131
+ chat_history = []
132
+
133
+ def format_prompt(query):
134
+ # This prompt is designed to guide the LLM to respond in a specified format
135
+ prompt = f"""
136
+ You are a helpful assistant with access to a specialized knowledge base. I have a question that I need answered. After providing a brief general response to the question, please return a JSON-formatted output with the following keys:
137
+ - "question": The original question.
138
+ - "answer": The detailed answer.
139
+
140
+ related questions from database
141
+ - "question " : related question
142
+ - "answer" : its answer
143
+
144
+ Please answer the following question:
145
+ {query}
146
+ """
147
+ return prompt
148
+
149
+ def qa_infer(query):
150
+ formatted_prompt = format_prompt(query)
151
+ result = chain({"question": formatted_prompt, "chat_history": chat_history})
152
+ return result['answer']
153
+
154
+ # query = "What` is the best TS pin configuration for BQ24040 in normal battery charge mode"
155
+ # qa_infer(query)
156
+
157
+ EXAMPLES = ["What is the best TS pin configuration for BQ24040 in normal battery charge mode",
158
+ "Can BQ25896 support I2C interface?",
159
+ "Can you please provide me with Gerber/CAD file for UCC2897A"]
160
+
161
+ demo = gr.Interface(fn=qa_infer, inputs="text",allow_flagging='never', examples=EXAMPLES,
162
+ cache_examples=False,outputs="text")
163
+
164
+ # launch the app!
165
+ #demo.launch(enable_queue = True,share=True)
166
+ #demo.queue(default_enabled=True).launch(debug=True,share=True)
167
+ demo.launch()