kprsnt commited on
Commit
e6a3b31
1 Parent(s): 50b8ded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -1,12 +1,10 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- # Load the Meta Llma 3.1 Instruct model for the initial argument
5
- model1 = gr.load("models/microsoft/GRIN-MoE")
6
 
7
- # Load a different model for the counter-argument
8
- # We'll use GRIN as an example, but you can replace this with another suitable model
9
- model2 = gr.load("models/microsoft/GRIN-MoE")
10
 
11
  def generate_initial_argument(query):
12
  prompt = f"Provide a logical explanation for the following topic: {query}"
@@ -14,11 +12,20 @@ def generate_initial_argument(query):
14
  return response
15
 
16
  def generate_counter_argument(query, initial_argument):
17
- prompt = f"Given the topic '{query}' and the argument '{initial_argument}', provide a well-reasoned counter-argument:"
18
- response = model2(prompt, max_length=200, num_return_sequences=1, temperature=0.7)[0]['generated_text']
19
- # Extract the counter-argument from the generated text
20
- counter_argument = response.split(prompt)[-1].strip()
21
- return counter_argument
 
 
 
 
 
 
 
 
 
22
 
23
  def debate(query):
24
  initial_argument = generate_initial_argument(query)
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
+ model1 = AutoModelForCausalLM.from_pretrained("microsoft/GRIN-MoE")
7
+ tokenizer1 = AutoTokenizer.from_pretrained("microsoft/GRIN-MoE")
 
8
 
9
  def generate_initial_argument(query):
10
  prompt = f"Provide a logical explanation for the following topic: {query}"
 
12
  return response
13
 
14
  def generate_counter_argument(query, initial_argument):
15
+ try:
16
+ prompt = f"Given the topic '{query}' and the initial argument '{initial_argument}', provide a well-reasoned counter-argument:"
17
+ inputs = tokenizer2(prompt, return_tensors="pt")
18
+ outputs = model2.generate(
19
+ **inputs,
20
+ max_length=200,
21
+ num_return_sequences=1,
22
+ temperature=0.7,
23
+ do_sample=True
24
+ )
25
+ counter_argument = tokenizer2.decode(outputs[0], skip_special_tokens=True)
26
+ return counter_argument.replace(prompt, "").strip()
27
+ except Exception as e:
28
+ return f"An error occurred: {str(e)}"
29
 
30
  def debate(query):
31
  initial_argument = generate_initial_argument(query)