acecalisto3 commited on
Commit
350b121
·
verified ·
1 Parent(s): 3ae40d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -60
app.py CHANGED
@@ -13,9 +13,9 @@ import json
13
  now = datetime.now()
14
  date_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
15
 
16
- # Define the model globally (or pass it as an argument to main)
17
- model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
18
- client = InferenceClient(model)
19
 
20
  # --- Set up logging ---
21
  logging.basicConfig(
@@ -24,7 +24,7 @@ logging.basicConfig(
24
  format="%(asctime)s - %(levelname)s - %(message)s",
25
  )
26
 
27
- agents =[
28
  "WEB_DEV",
29
  "AI_SYSTEM_PROMPT",
30
  "PYTHON_CODE_DEV"
@@ -33,7 +33,7 @@ agents =[
33
 
34
  VERBOSE = True
35
  MAX_HISTORY = 5
36
- #MODEL = "gpt-3.5-turbo" # "gpt-4"
37
 
38
  PREFIX = """
39
  {date_time_str}
@@ -111,7 +111,200 @@ def run_gpt(
111
  logging.info(LOG_RESPONSE.format(resp)) # Log the response
112
  return resp
113
 
114
- def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.7, max_new_tokens=2048, top_p=0.8, repetition_penalty=1.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  seed = random.randint(1,1111111111111111)
116
 
117
  # Correct the line:
@@ -127,21 +320,31 @@ def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0
127
  temperature = 1e-2
128
  top_p = float(top_p)
129
 
130
- generate_kwargs = dict(
131
- temperature=temperature,
132
- max_new_tokens=max_new_tokens,
133
- top_p=top_p,
134
- repetition_penalty=repetition_penalty,
135
- do_sample=True,
136
- seed=seed,
137
- )
138
 
139
- formatted_prompt = format_prompt(prompt, history, max_history_turns=5) # Truncated history
140
- logging.info(f"Formatted Prompt: {formatted_prompt}")
141
 
142
- messages = [{"role": "user", "content": formatted_prompt}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- stream = client.text_generation(messages, **generate_kwargs, stream=True, details=True, return_full_text=False)
 
 
145
  output = ""
146
 
147
  for response in stream:
@@ -275,7 +478,8 @@ def project_explorer(path):
275
  tree = get_file_tree(path)
276
  display_file_tree(tree)
277
 
278
- def chat_app_logic(message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model): # Add 'model' as an argument
 
279
 
280
  try:
281
  # Attempt to join the generator output
@@ -311,41 +515,7 @@ def chat_app_logic(message, history, purpose, agent_name, sys_prompt, temperatur
311
  max_new_tokens=max_new_tokens,
312
  top_p=top_p,
313
  repetition_penalty=repetition_penalty,
314
- )
315
- try:
316
- # Attempt to join the generator output
317
- response = ''.join(generate(
318
- model=model, # Now you can use 'model' here
319
- messages=messages,
320
- stream=True,
321
- temperature=0.7,
322
- max_tokens=1500
323
- ))
324
- except TypeError:
325
- # If joining fails, collect the output in a list
326
- response_parts = []
327
- for part in generate(
328
- model=model, # Now you can use 'model' here
329
- messages=messages,
330
- stream=True,
331
- temperature=0.7,
332
- max_tokens=1500
333
- ):
334
- if isinstance(part, str):
335
- response_parts.append(part)
336
- elif isinstance(part, dict) and 'content' in part:
337
- response_parts.append(part['content']),
338
-
339
- response = ''.join(response_parts,
340
- # Run the model and get the response (convert generator to string)
341
- prompt=message,
342
- history=history,
343
- agent_name=agent_name,
344
- sys_prompt=sys_prompt,
345
- temperature=temperature,
346
- max_new_tokens=max_new_tokens,
347
- top_p=top_p,
348
- repetition_penalty=repetition_penalty,
349
  )
350
  history.append((message, response))
351
  return history
@@ -353,7 +523,6 @@ def chat_app_logic(message, history, purpose, agent_name, sys_prompt, temperatur
353
  return history
354
 
355
  def main():
356
-
357
  with gr.Blocks() as demo:
358
  gr.Markdown("## FragMixt")
359
  gr.Markdown("### Agents w/ Agents")
@@ -371,6 +540,7 @@ def main():
371
  max_new_tokens = gr.Slider(label="Max new tokens", value=1048*10, minimum=0, maximum=1048*10, step=64, interactive=True, info="The maximum numbers of new tokens")
372
  top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens")
373
  repetition_penalty = gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
 
374
 
375
  # Button to submit the message
376
  submit_button = gr.Button(value="Send")
@@ -381,16 +551,15 @@ def main():
381
  explore_button = gr.Button(value="Explore")
382
  project_output = gr.Textbox(label="File Tree", lines=20)
383
 
384
- # Chat App Logic Tab
385
  with gr.Tab("Chat App"):
386
  history = gr.State([])
387
  for example in examples:
388
- gr.Button(value=example[0]).click(lambda: chat_app_logic(example[0], history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model), outputs=chatbot)
389
 
390
  # Connect components to the chat app logic
391
- submit_button.click(chat_app_logic, inputs=[message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model], outputs=chatbot) # Pass 'model'
392
- message.submit(chat_app_logic, inputs=[message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model], outputs=chatbot) # Pass 'model'
393
-
394
 
395
  # Connect components to the project explorer
396
  explore_button.click(project_explorer, inputs=project_path, outputs=project_output)
 
13
  now = datetime.now()
14
  date_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
15
 
16
+ client = InferenceClient(
17
+ "mistralai/Mixtral-8x7B-Instruct-v0.1"
18
+ )
19
 
20
  # --- Set up logging ---
21
  logging.basicConfig(
 
24
  format="%(asctime)s - %(levelname)s - %(message)s",
25
  )
26
 
27
+ agents = [
28
  "WEB_DEV",
29
  "AI_SYSTEM_PROMPT",
30
  "PYTHON_CODE_DEV"
 
33
 
34
  VERBOSE = True
35
  MAX_HISTORY = 5
36
+ # MODEL = "gpt-3.5-turbo" # "gpt-4"
37
 
38
  PREFIX = """
39
  {date_time_str}
 
111
  logging.info(LOG_RESPONSE.format(resp)) # Log the response
112
  return resp
113
 
114
+ def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.7, max_new_tokens=2048, top_p=0.8, repetition_penalty=1.5, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
115
+ content = PREFIX.format(
116
+ date_time_str=date_time_str,
117
+ purpose=purpose,
118
+ safe_search=safe_search,
119
+ ) + prompt_template.format(**prompt_kwargs)
120
+ if VERBOSE:
121
+ logging.info(LOG_PROMPT.format(content)) # Log the prompt
122
+
123
+ stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
124
+ resp = ""
125
+ for response in stream:
126
+ resp += response.token.text
127
+
128
+ if VERBOSE:
129
+ logging.info(LOG_RESPONSE.format(resp)) # Log the response
130
+ return resp
131
+
132
+
133
+ def compress_history(purpose, task, history, directory):
134
+ resp = run_gpt(
135
+ COMPRESS_HISTORY_PROMPT,
136
+ stop_tokens=["observation:", "task:", "action:", "thought:"],
137
+ max_tokens=512,
138
+ purpose=purpose,
139
+ task=task,
140
+ history=history,
141
+ )
142
+ history = "observation: {}\n".format(resp)
143
+ return history
144
+
145
+ def call_search(purpose, task, history, directory, action_input):
146
+ logging.info(f"CALLING SEARCH: {action_input}")
147
+ try:
148
+
149
+ if "http" in action_input:
150
+ if "<" in action_input:
151
+ action_input = action_input.strip("<")
152
+ if ">" in action_input:
153
+ action_input = action_input.strip(">")
154
+
155
+ response = i_s(action_input)
156
+ #response = google(search_return)
157
+ logging.info(f"Search Result: {response}")
158
+ history += "observation: search result is: {}\n".format(response)
159
+ else:
160
+ history += "observation: I need to provide a valid URL to 'action: SEARCH action_input=https://URL'\n"
161
+ except Exception as e:
162
+ history += "observation: {}'\n".format(e)
163
+ return "MAIN", None, history, task
164
+
165
+ def call_main(purpose, task, history, directory, action_input):
166
+ logging.info(f"CALLING MAIN: {action_input}")
167
+ resp = run_gpt(
168
+ ACTION_PROMPT,
169
+ stop_tokens=["observation:", "task:", "action:","thought:"],
170
+ max_tokens=32000,
171
+ purpose=purpose,
172
+ task=task,
173
+ history=history,
174
+ )
175
+ lines = resp.strip().strip("\n").split("\n")
176
+ for line in lines:
177
+ if line == "":
178
+ continue
179
+ if line.startswith("thought: "):
180
+ history += "{}\n".format(line)
181
+ logging.info(f"Thought: {line}")
182
+ elif line.startswith("action: "):
183
+
184
+ action_name, action_input = parse_action(line)
185
+ logging.info(f"Action: {action_name} - {action_input}")
186
+ history += "{}\n".format(line)
187
+ if "COMPLETE" in action_name or "COMPLETE" in action_input:
188
+ task = "END"
189
+ return action_name, action_input, history, task
190
+ else:
191
+ return action_name, action_input, history, task
192
+ else:
193
+ history += "{}\n".format(line)
194
+ logging.info(f"Other Output: {line}")
195
+ #history += "observation: the following command did not produce any useful output: '{}', I need to check the commands syntax, or use a different command\n".format(line)
196
+
197
+ #return action_name, action_input, history, task
198
+ #assert False, "unknown action: {}".format(line)
199
+ return "MAIN", None, history, task
200
+
201
+
202
+ def call_set_task(purpose, task, history, directory, action_input):
203
+ logging.info(f"CALLING SET_TASK: {action_input}")
204
+ task = run_gpt(
205
+ TASK_PROMPT,
206
+ stop_tokens=[],
207
+ max_tokens=64,
208
+ purpose=purpose,
209
+ task=task,
210
+ history=history,
211
+ ).strip("\n")
212
+ history += "observation: task has been updated to: {}\n".format(task)
213
+ return "MAIN", None, history, task
214
+
215
+ def end_fn(purpose, task, history, directory, action_input):
216
+ logging.info(f"CALLING END_FN: {action_input}")
217
+ task = "END"
218
+ return "COMPLETE", "COMPLETE", history, task
219
+
220
+ NAME_TO_FUNC = {
221
+ "MAIN": call_main,
222
+ "UPDATE-TASK": call_set_task,
223
+ "SEARCH": call_search,
224
+ "COMPLETE": end_fn,
225
+
226
+ }
227
+
228
+ def run_action(purpose, task, history, directory, action_name, action_input):
229
+ logging.info(f"RUNNING ACTION: {action_name} - {action_input}")
230
+ try:
231
+ if "RESPONSE" in action_name or "COMPLETE" in action_name:
232
+ action_name="COMPLETE"
233
+ task="END"
234
+ return action_name, "COMPLETE", history, task
235
+
236
+ # compress the history when it is long
237
+ if len(history.split("\n")) > MAX_HISTORY:
238
+ logging.info("COMPRESSING HISTORY")
239
+ history = compress_history(purpose, task, history, directory)
240
+ if not action_name in NAME_TO_FUNC:
241
+ action_name="MAIN"
242
+ if action_name == "" or action_name == None:
243
+ action_name="MAIN"
244
+ assert action_name in NAME_TO_FUNC
245
+
246
+ logging.info(f"RUN: {action_name} - {action_input}")
247
+ return NAME_TO_FUNC[action_name](purpose, task, history, directory, action_input)
248
+ except Exception as e:
249
+ history += "observation: the previous command did not produce any useful output, I need to check the commands syntax, or use a different command\n"
250
+ logging.error(f"Error in run_action: {e}")
251
+ return "MAIN", None, history, task
252
+
253
+ def run(purpose,history):
254
+
255
+ #print(purpose)
256
+ #print(hist)
257
+ task=None
258
+ directory="./"
259
+ if history:
260
+ history=str(history).strip("[]")
261
+ if not history:
262
+ history = ""
263
+
264
+ action_name = "UPDATE-TASK" if task is None else "MAIN"
265
+ action_input = None
266
+ while True:
267
+ logging.info(f"---")
268
+ logging.info(f"Purpose: {purpose}")
269
+ logging.info(f"Task: {task}")
270
+ logging.info(f"---")
271
+ logging.info(f"History: {history}")
272
+ logging.info(f"---")
273
+
274
+ action_name, action_input, history, task = run_action(
275
+ purpose,
276
+ task,
277
+ history,
278
+ directory,
279
+ action_name,
280
+ action_input,
281
+ )
282
+ yield (history)
283
+ #yield ("",[(purpose,history)])
284
+ if task == "END":
285
+ return (history)
286
+ #return ("", [(purpose,history)])
287
+
288
+
289
+
290
+ ################################################
291
+
292
+ def format_prompt(message, history, max_history_turns=5):
293
+ prompt = "<s>"
294
+ # Keep only the last 'max_history_turns' turns
295
+ for user_prompt, bot_response in history[-max_history_turns:]:
296
+ prompt += f"[INST] {user_prompt} [/INST]"
297
+ prompt += f" {bot_response}</s> "
298
+ prompt += f"[INST] {message} [/INST]"
299
+ return prompt
300
+ agents =[
301
+ "WEB_DEV",
302
+ "AI_SYSTEM_PROMPT",
303
+ "PYTHON_CODE_DEV"
304
+ ]
305
+ def generate(
306
+ prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0, model="mistralai/Mixtral-8x7B-Instruct-v0.1"
307
+ ):
308
  seed = random.randint(1,1111111111111111)
309
 
310
  # Correct the line:
 
320
  temperature = 1e-2
321
  top_p = float(top_p)
322
 
 
 
 
 
 
 
 
 
323
 
 
 
324
 
325
+ def generate_text_chunked(input_text, model, generation_parameters, max_tokens_to_generate):
326
+ """Generates text in chunks to avoid token limit errors."""
327
+ sentences = nltk.sent_tokenize(input_text)
328
+ generated_text = []
329
+ generator = pipeline('text-generation', model=model)
330
+
331
+ for sentence in sentences:
332
+ # Tokenize the sentence and check if it's within the limit
333
+ tokens = generator.tokenizer(sentence).input_ids
334
+ if len(tokens) + max_tokens_to_generate <= 32768:
335
+ # Generate text for this chunk
336
+ response = generator(sentence, max_length=max_tokens_to_generate, **generation_parameters)
337
+ generated_text.append(response[0]['generated_text'])
338
+ else:
339
+ # Handle cases where the sentence is too long
340
+ # You could split the sentence further or skip it
341
+ print(f"Sentence too long: {sentence}")
342
+
343
+ return ''.join(generated_text)
344
 
345
+ formatted_prompt = format_prompt(prompt, history, max_history_turns=5) # Truncated history
346
+ logging.info(f"Formatted Prompt: {formatted_prompt}")
347
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
348
  output = ""
349
 
350
  for response in stream:
 
478
  tree = get_file_tree(path)
479
  display_file_tree(tree)
480
 
481
+ def chat_app_logic(message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model):
482
+ # Your existing code here
483
 
484
  try:
485
  # Attempt to join the generator output
 
515
  max_new_tokens=max_new_tokens,
516
  top_p=top_p,
517
  repetition_penalty=repetition_penalty,
518
+ model=model # Pass the model argument here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  )
520
  history.append((message, response))
521
  return history
 
523
  return history
524
 
525
  def main():
 
526
  with gr.Blocks() as demo:
527
  gr.Markdown("## FragMixt")
528
  gr.Markdown("### Agents w/ Agents")
 
540
  max_new_tokens = gr.Slider(label="Max new tokens", value=1048*10, minimum=0, maximum=1048*10, step=64, interactive=True, info="The maximum numbers of new tokens")
541
  top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens")
542
  repetition_penalty = gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
543
+ model_input = gr.Textbox(label="Model", value="mistralai/Mixtral-8x7B-Instruct-v0.1", visible=False)
544
 
545
  # Button to submit the message
546
  submit_button = gr.Button(value="Send")
 
551
  explore_button = gr.Button(value="Explore")
552
  project_output = gr.Textbox(label="File Tree", lines=20)
553
 
554
+ # Chat App Logic Tab
555
  with gr.Tab("Chat App"):
556
  history = gr.State([])
557
  for example in examples:
558
+ gr.Button(value=example[0]).click(lambda: chat_app_logic(example[0], history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model=model_input), outputs=chatbot)
559
 
560
  # Connect components to the chat app logic
561
+ submit_button.click(chat_app_logic, inputs=[message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model_input], outputs=chatbot)
562
+ message.submit(chat_app_logic, inputs=[message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model_input], outputs=chatbot)
 
563
 
564
  # Connect components to the project explorer
565
  explore_button.click(project_explorer, inputs=project_path, outputs=project_output)