acecalisto3 commited on
Commit
461c0f4
·
verified ·
1 Parent(s): f5516d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -9,6 +9,7 @@ import gradio as gr
9
  from huggingface_hub import InferenceClient
10
  from safe_search import safe_search
11
  from i_search import google, i_search as i_s
 
12
 
13
  # --- Configuration ---
14
  VERBOSE = True
@@ -141,6 +142,31 @@ def generate(
141
  logging.info(LOG_RESPONSE.format(resp=response))
142
  return response.text
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def main():
145
  """Main function to launch the Gradio interface."""
146
  with gr.Blocks() as demo:
@@ -242,7 +268,8 @@ def main():
242
  ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
243
  """Handles the chat interaction, generating responses and updating history."""
244
  prompt = format_prompt(message, history)
245
- response = generate(
 
246
  prompt,
247
  history,
248
  agent_name,
 
9
  from huggingface_hub import InferenceClient
10
  from safe_search import safe_search
11
  from i_search import google, i_search as i_s
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
 
14
  # --- Configuration ---
15
  VERBOSE = True
 
142
  logging.info(LOG_RESPONSE.format(resp=response))
143
  return response.text
144
 
145
+ # --- Mixtral Integration ---
146
+ def mixtral_generate(
147
+ prompt: str,
148
+ history: List[Tuple[str, str]],
149
+ agent_name: str = agents[0],
150
+ sys_prompt: str = "",
151
+ temperature: float = TEMPERATURE,
152
+ max_new_tokens: int = MAX_TOKENS,
153
+ top_p: float = TOP_P,
154
+ repetition_penalty: float = REPETITION_PENALTY,
155
+ ) -> str:
156
+ """Generates a response using the Mixtral model."""
157
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
158
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
159
+
160
+ content = PREFIX.format(
161
+ date_time_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
162
+ purpose=f"Generating response as {agent_name}",
163
+ safe_search=safe_search,
164
+ ) + sys_prompt + "\n" + prompt
165
+
166
+ inputs = tokenizer(content, return_tensors="pt")
167
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
168
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
169
+
170
  def main():
171
  """Main function to launch the Gradio interface."""
172
  with gr.Blocks() as demo:
 
268
  ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
269
  """Handles the chat interaction, generating responses and updating history."""
270
  prompt = format_prompt(message, history)
271
+ # Use Mixtral for generation
272
+ response = mixtral_generate(
273
  prompt,
274
  history,
275
  agent_name,