da03 commited on
Commit
6cc23f5
1 Parent(s): eaa0586
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -2,7 +2,6 @@ import spaces
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import time
6
 
7
  model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -41,13 +40,15 @@ def predict_product(num1, num2):
41
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
42
  prediction = postprocess(output_text)
43
 
44
- result_html = "<div style='margin-bottom: 10px;'>Correct Result: " + " ".join(correct_product) + "</div><div>"
45
- for i, pred_digit in enumerate(prediction):
46
- color = "green" if i < len(correct_product) and pred_digit == correct_product[i] else "red"
47
- result_html += f"<span style='color: {color};'>{pred_digit}</span>"
48
- result_html += "</div>"
 
 
49
 
50
- yield result_html, ""
51
 
52
  if valid_input:
53
  is_correct = prediction == correct_product
@@ -55,7 +56,15 @@ def predict_product(num1, num2):
55
  else:
56
  result_message = "Invalid input. Could not evaluate correctness."
57
 
58
- yield result_html, result_message
 
 
 
 
 
 
 
 
59
 
60
  demo = gr.Interface(
61
  fn=predict_product,
@@ -64,7 +73,7 @@ demo = gr.Interface(
64
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
65
  ],
66
  outputs=[
67
- gr.HTML(label='Predicted Product with Matching Digits Highlighted'),
68
  gr.HTML(label='Result Message')
69
  ],
70
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
@@ -76,7 +85,7 @@ demo = gr.Interface(
76
  """,
77
  clear_btn=None,
78
  submit_btn="Multiply!",
79
- live=True
80
  )
81
 
82
  demo.launch()
 
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
  model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
40
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
41
  prediction = postprocess(output_text)
42
 
43
+ # Manually create the diff for HighlightedText
44
+ diff = []
45
+ for i in range(len(prediction)):
46
+ if i < len(correct_product) and prediction[i] == correct_product[i]:
47
+ diff.append((prediction[i], None)) # No highlight for correct digits
48
+ else:
49
+ diff.append((prediction[i], "+")) # Highlight incorrect digits in red
50
 
51
+ yield diff, ""
52
 
53
  if valid_input:
54
  is_correct = prediction == correct_product
 
56
  else:
57
  result_message = "Invalid input. Could not evaluate correctness."
58
 
59
+ # Final diff for the complete prediction
60
+ final_diff = []
61
+ for i in range(len(prediction)):
62
+ if i < len(correct_product) and prediction[i] == correct_product[i]:
63
+ final_diff.append((prediction[i], None)) # No highlight for correct digits
64
+ else:
65
+ final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
66
+
67
+ yield final_diff, result_message
68
 
69
  demo = gr.Interface(
70
  fn=predict_product,
 
73
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
74
  ],
75
  outputs=[
76
+ gr.HighlightedText(label='Predicted Product with Matching Digits Highlighted', combine_adjacent=True, show_legend=True, color_map={"+": "red"}),
77
  gr.HTML(label='Result Message')
78
  ],
79
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
 
85
  """,
86
  clear_btn=None,
87
  submit_btn="Multiply!",
88
+ live=False
89
  )
90
 
91
  demo.launch()