tamas.kiss commited on
Commit
36319c9
β€’
1 Parent(s): c86f312

Initialize app

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +204 -0
  3. requirements.txt +8 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Kubectl V2
3
  emoji: 🌍
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.4.1
8
  app_file: app.py
 
1
  ---
2
  title: Kubectl V2
3
  emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.4.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ from peft import PeftModel
5
+ import torch
6
+ import pinecone
7
+ from sentence_transformers import SentenceTransformer
8
+ from tqdm import tqdm
9
+ from sentence_transformers.cross_encoder import CrossEncoder
10
+ import numpy as np
11
+ from torch import nn
12
+
13
+ # Set up semantic search
14
+ PINECONE_API_KEY = $PINECONE_API_KEY
15
+
16
+ def get_embedding(text):
17
+ embed_text = sentencetransformer_model.encode(text)
18
+ vector_text = embed_text.tolist()
19
+
20
+ return vector_text
21
+
22
+ def query_from_pinecone(query, top_k=3):
23
+ # get embedding from THE SAME embedder as the documents
24
+ query_embedding = get_embedding(query)
25
+
26
+ return index.query(
27
+ vector=query_embedding,
28
+ top_k=top_k,
29
+ include_metadata=True # gets the metadata (dates, text, etc)
30
+ ).get('matches')
31
+
32
+ def get_results_from_pinecone(query, top_k=3, re_rank=True, verbose=True):
33
+
34
+ results_from_pinecone = query_from_pinecone(query, top_k=top_k)
35
+ if not results_from_pinecone:
36
+ return []
37
+
38
+ if verbose:
39
+ print("Query:", query)
40
+
41
+
42
+ final_results = []
43
+
44
+ if re_rank:
45
+ if verbose:
46
+ print('Document ID (Hash)\t\tRetrieval Score\tCE Score\tText')
47
+
48
+ sentence_combinations = [[query, result_from_pinecone['metadata']['text']] for result_from_pinecone in results_from_pinecone]
49
+
50
+ # Compute the similarity scores for these combinations
51
+ similarity_scores = cross_encoder.predict(sentence_combinations, activation_fct=nn.Sigmoid())
52
+
53
+ # Sort the scores in decreasing order
54
+ sim_scores_argsort = reversed(np.argsort(similarity_scores))
55
+
56
+ # Print the scores
57
+ for idx in sim_scores_argsort:
58
+ result_from_pinecone = results_from_pinecone[idx]
59
+ final_results.append(result_from_pinecone)
60
+ if verbose:
61
+ print(f"{result_from_pinecone['id']}\t{result_from_pinecone['score']:.2f}\t{similarity_scores[idx]:.2f}\t{result_from_pinecone['metadata']['text'][:50]}")
62
+ return final_results
63
+
64
+ if verbose:
65
+ print('Document ID (Hash)\t\tRetrieval Score\tText')
66
+ for result_from_pinecone in results_from_pinecone:
67
+ final_results.append(result_from_pinecone)
68
+ if verbose:
69
+ print(f"{result_from_pinecone['id']}\t{result_from_pinecone['score']:.2f}\t{result_from_pinecone['metadata']['text'][:50]}")
70
+
71
+ return final_results
72
+
73
+ def semantic_search(prompt):
74
+ final_results = get_results_from_pinecone(prompt, top_k=3, re_rank=True, verbose=True)
75
+
76
+ return 'First result:\n' + final_results[0]['metadata']['text'].replace('\n', ' ') + '\n' + 'Second result:\n' + final_results[1]['metadata']['text'].replace('\n', ' ') + '\n' + 'Third result:\n' + final_results[2]['metadata']['text'].replace('\n', ' ')
77
+
78
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
79
+ sentencetransformer_model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-cos-v1')
80
+
81
+ pinecone_key = PINECONE_API_KEY
82
+
83
+ INDEX_NAME = 'k8s-semantic-search'
84
+ NAMESPACE = 'default'
85
+
86
+ pinecone.init(api_key=pinecone_key, environment="gcp-starter")
87
+
88
+ if not INDEX_NAME in pinecone.list_indexes():
89
+ pinecone.create_index(
90
+ INDEX_NAME, # The name of the index
91
+ dimension=768, # The dimensionality of the vectors
92
+ metric='cosine', # The similarity metric to use when searching the index
93
+ pod_type='starter' # The type of Pinecone pod
94
+ )
95
+
96
+ index = pinecone.Index(INDEX_NAME)
97
+
98
+ # Set up mistral model
99
+ base_model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
100
+ lora_model_id = 'ComponentSoft/mistral-kubectl-instruct'
101
+
102
+ tokenizer = AutoTokenizer.from_pretrained(
103
+ lora_model_id,
104
+ padding_side="left",
105
+ add_eos_token=False,
106
+ add_bos_token=True,
107
+ )
108
+ tokenizer.pad_token = tokenizer.eos_token
109
+
110
+ bnb_config = BitsAndBytesConfig(
111
+ load_in_4bit=True,
112
+ bnb_4bit_use_double_quant=True,
113
+ bnb_4bit_quant_type="nf4",
114
+ bnb_4bit_compute_dtype=torch.bfloat16
115
+ )
116
+
117
+ base_model = AutoModelForCausalLM.from_pretrained(
118
+ base_model_id,
119
+ quantization_config=bnb_config,
120
+ use_cache=True,
121
+ trust_remote_code=True,
122
+ )
123
+
124
+ model = PeftModel.from_pretrained(base_model, lora_model_id)
125
+ model.eval()
126
+
127
+ stop_terms=["</s>", "#End"]
128
+ eos_token_ids_custom = [torch.tensor(tokenizer.encode(term, add_special_tokens=False)).to("cuda") for term in stop_terms]
129
+
130
+ category_terms=["</s>", "\n"]
131
+ category_eos_token_ids_custom = [torch.tensor(tokenizer.encode(term, add_special_tokens=False)).to("cuda") for term in stop_terms]
132
+
133
+
134
+ class EvalStopCriterion(StoppingCriteria):
135
+ def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs):
136
+ return any(torch.equal(e, input_ids[0][-len(e):]) for e in eos_token_ids_custom)
137
+
138
+
139
+ class CategoryStopCriterion(StoppingCriteria):
140
+ def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs):
141
+ return any(torch.equal(e, input_ids[0][-len(e):]) for e in category_eos_token_ids_custom)
142
+
143
+ start_template = '### Answer:'
144
+ command_template = '# Command:'
145
+ end_template = '#End'
146
+
147
+ def text_to_text_generation(prompt):
148
+ prompt = prompt.strip()
149
+ ''
150
+
151
+ is_kubectl_prompt = (
152
+ f"[INST] You are a helpful assistant who classifies prompts into three categories. Respond with 0 if it pertains to a 'kubectl' operation. This is an instruction that can be answered with a 'kubectl' action. Look for keywords like 'get', 'list', 'create', 'show', 'view', and other command-like words. This category is an instruction instead of a question. Respond with 1 only if the prompt is a question, and is about a definition related to Kubernetes, or non-action inquiries. Respond with 2 every other scenario, for example if the question is a general question, not related to Kubernetes or 'kubectl'.\n"
153
+ f"So for instance the following:\n"
154
+ f"List all pods in Kubernetes\n"
155
+ f"Would get a response:\n"
156
+ f"0 [/INST]"
157
+ f'text: "{prompt}"'
158
+ f'response (0/1/2): '
159
+ )
160
+
161
+
162
+ model_input = tokenizer(is_kubectl_prompt, return_tensors="pt").to("cuda")
163
+ with torch.no_grad():
164
+ response = tokenizer.decode(model.generate(**model_input, max_new_tokens=8, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.15, stopping_criteria=StoppingCriteriaList([CategoryStopCriterion()]))[0], skip_special_tokens=True)
165
+ response = response[len(is_kubectl_prompt):]
166
+
167
+ print('-----------------------------QUERY START-----------------------------')
168
+ print('Prompt: ' + prompt)
169
+ print('Classified as: ' + response)
170
+ response_num = 2 # Default to generic question
171
+ if '0' in response:
172
+ response_num = 0
173
+ elif '1' in response:
174
+ response_num = 1
175
+
176
+
177
+ # Check if general question
178
+ if response_num == 0:
179
+ prompt = f'[INST] {prompt}\n Lets think step by step. [/INST] {start_template}'
180
+ elif response_num == 1:
181
+ retrieved_results = semantic_search(prompt)
182
+ print('Query:')
183
+ print(f'[INST] You are an assistant who summarizes results retrieved from a book about Kubernetes. This summary should answer the question. If the answer is not in the retrieved results, use your general knowledge. [/INST] Question: {prompt}\nRetrieved results:\n{retrieved_results}\nResponse:')
184
+ prompt = f'[INST] You are an assistant who summarizes results retrieved from a book about Kubernetes. This summary should answer the question. If the answer is not in the retrieved results, use your general knowledge. [/INST] Question: {prompt}\nRetrieved results:\n{retrieved_results}\nResponse:'
185
+ else:
186
+ prompt = f'[INST] {prompt}Β [/INST]'
187
+
188
+ # Generate output
189
+ model_input = tokenizer(prompt, return_tensors="pt").to("cuda")
190
+ with torch.no_grad():
191
+ response = tokenizer.decode(model.generate(**model_input, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.15, stopping_criteria=StoppingCriteriaList([EvalStopCriterion()]))[0], skip_special_tokens=True)
192
+
193
+ # Get the relevalt parts
194
+ start = response.index(start_template) + len(start_template) if start_template in response else len(prompt)
195
+ start = response.index(command_template) + len(command_template) if command_template in response else start
196
+ end = response.index(end_template) if end_template in response else len(response)
197
+ true_response = response[start:end].strip()
198
+ print('Returned: ' + true_response)
199
+ print('------------------------------QUERY END------------------------------')
200
+
201
+ return true_response
202
+
203
+ iface = gr.Interface(fn=semantic_search, inputs="text", outputs="text")
204
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ peft
3
+ bitsandbytes
4
+ torch
5
+ scipy
6
+ pinecone-client
7
+ sentence_transformers
8
+ tqdm