acecalisto3 commited on
Commit
646c35d
·
verified ·
1 Parent(s): 432d2fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -32
app.py CHANGED
@@ -3,12 +3,15 @@ import subprocess
3
  import random
4
  from huggingface_hub import InferenceClient
5
  import gradio as gr
6
- from safe_search import safe_search
7
  from i_search import google
8
  from i_search import i_search as i_s
9
  from datetime import datetime
10
  import logging
11
  import json
 
 
 
12
 
13
  now = datetime.now()
14
  date_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
@@ -112,15 +115,10 @@ def run_gpt(
112
  return resp
113
 
114
  def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.7, max_new_tokens=2048, top_p=0.8, repetition_penalty=1.5, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
115
- content = PREFIX.format(
116
- date_time_str=date_time_str,
117
- purpose=purpose,
118
- safe_search=safe_search,
119
- ) + prompt_template.format(**prompt_kwargs)
120
- if VERBOSE:
121
- logging.info(LOG_PROMPT.format(content)) # Log the prompt
122
-
123
- stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
124
  resp = ""
125
  for response in stream:
126
  resp += response.token.text
@@ -320,6 +318,20 @@ def generate(
320
  temperature = 1e-2
321
  top_p = float(top_p)
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
 
325
  def generate_text_chunked(input_text, model, generation_parameters, max_tokens_to_generate):
@@ -476,38 +488,32 @@ def project_explorer(path):
476
  Displays the file tree of a given path in a Streamlit app.
477
  """
478
  tree = get_file_tree(path)
479
- display_file_tree(tree)
 
480
 
481
  def chat_app_logic(message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model):
482
  # Your existing code here
483
 
484
  try:
485
- # Attempt to join the generator output
486
  response = ''.join(generate(
487
  model=model,
488
- message=message,
489
- stream=True,
490
- temperature=0.7,
491
- max_tokens=1500
 
 
 
 
492
  ))
493
  except TypeError:
494
- # If joining fails, collect the output in a list
 
495
  response_parts = []
496
  for part in generate(
497
  model=model,
498
- message=message,
499
- stream=True,
500
- temperature=0.7,
501
- max_tokens=1500
502
- ):
503
- if isinstance(part, str):
504
- response_parts.append(part)
505
- elif isinstance(part, dict) and 'content' in part:
506
- response_parts.append(part['content']),
507
-
508
- response = ''.join(response_parts,
509
- # Run the model and get the response (convert generator to string)
510
- prompt=message,
511
  history=history,
512
  agent_name=agent_name,
513
  sys_prompt=sys_prompt,
@@ -515,11 +521,17 @@ def chat_app_logic(message, history, purpose, agent_name, sys_prompt, temperatur
515
  max_new_tokens=max_new_tokens,
516
  top_p=top_p,
517
  repetition_penalty=repetition_penalty,
518
- model=model # Pass the model argument here
519
- )
 
 
 
 
 
520
  history.append((message, response))
521
  return history
522
 
 
523
  return history
524
 
525
  def main():
 
3
  import random
4
  from huggingface_hub import InferenceClient
5
  import gradio as gr
6
+ from safe_search import safe_search # Make sure you have this function defined
7
  from i_search import google
8
  from i_search import i_search as i_s
9
  from datetime import datetime
10
  import logging
11
  import json
12
+ import nltk # Import nltk for the generate_text_chunked function
13
+
14
+ nltk.download('punkt') # Download the punkt tokenizer if you haven't already
15
 
16
  now = datetime.now()
17
  date_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
 
115
  return resp
116
 
117
  def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.7, max_new_tokens=2048, top_p=0.8, repetition_penalty=1.5, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
118
+ # Use 'prompt' here instead of 'message'
119
+ formatted_prompt = format_prompt(prompt, history, max_history_turns=5) # Truncated history
120
+ logging.info(f"Formatted Prompt: {formatted_prompt}")
121
+ stream = client.text_generation(formatted_prompt, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, stream=True, details=True, return_full_text=False)
 
 
 
 
 
122
  resp = ""
123
  for response in stream:
124
  resp += response.token.text
 
318
  temperature = 1e-2
319
  top_p = float(top_p)
320
 
321
+ # Add the system prompt to the beginning of the prompt
322
+ formatted_prompt = f"{system_prompt} {prompt}"
323
+
324
+ # Use 'prompt' here instead of 'message'
325
+ formatted_prompt = format_prompt(formatted_prompt, history, max_history_turns=5) # Truncated history
326
+ logging.info(f"Formatted Prompt: {formatted_prompt}")
327
+ stream = client.text_generation(formatted_prompt, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, stream=True, details=True, return_full_text=False)
328
+ resp = ""
329
+ for response in stream:
330
+ resp += response.token.text
331
+
332
+ if VERBOSE:
333
+ logging.info(LOG_RESPONSE.format(resp)) # Log the response
334
+ return resp
335
 
336
 
337
  def generate_text_chunked(input_text, model, generation_parameters, max_tokens_to_generate):
 
488
  Displays the file tree of a given path in a Streamlit app.
489
  """
490
  tree = get_file_tree(path)
491
+ tree_str = json.dumps(tree, indent=4) # Convert the tree to a string for display
492
+ return tree_str
493
 
494
  def chat_app_logic(message, history, purpose, agent_name, sys_prompt, temperature, max_new_tokens, top_p, repetition_penalty, model):
495
  # Your existing code here
496
 
497
  try:
498
+ # Pass 'message' as 'prompt'
499
  response = ''.join(generate(
500
  model=model,
501
+ prompt=message, # Use 'prompt' here
502
+ history=history,
503
+ agent_name=agent_name,
504
+ sys_prompt=sys_prompt,
505
+ temperature=temperature,
506
+ max_new_tokens=max_new_tokens,
507
+ top_p=top_p,
508
+ repetition_penalty=repetition_penalty,
509
  ))
510
  except TypeError:
511
+ # ... (rest of the exception handling)
512
+
513
  response_parts = []
514
  for part in generate(
515
  model=model,
516
+ prompt=message, # Use 'prompt' here
 
 
 
 
 
 
 
 
 
 
 
 
517
  history=history,
518
  agent_name=agent_name,
519
  sys_prompt=sys_prompt,
 
521
  max_new_tokens=max_new_tokens,
522
  top_p=top_p,
523
  repetition_penalty=repetition_penalty,
524
+ ):
525
+ if isinstance(part, str):
526
+ response_parts.append(part)
527
+ elif isinstance(part, dict) and 'content' in part:
528
+ response_parts.append(part['content'])
529
+
530
+ response = ''.join(response_parts)
531
  history.append((message, response))
532
  return history
533
 
534
+ history.append((message, response))
535
  return history
536
 
537
  def main():