da03 commited on
Commit
bf65d9e
1 Parent(s): 0ad2aca
Files changed (1) hide show
  1. app.py +54 -31
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
9
 
10
  def preprocess(num):
11
  num = str(num).strip().replace(' ', '')
@@ -21,8 +22,10 @@ def predict_product(num1, num2):
21
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
22
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
23
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
 
24
 
25
- generated_ids = inputs['input_ids']
 
26
  prediction = ""
27
  correct_product = ""
28
  valid_input = True
@@ -34,55 +37,74 @@ def predict_product(num1, num2):
34
  except ValueError:
35
  valid_input = False
36
 
37
- eos_token_id = tokenizer.eos_token_id
38
  past_key_values = None
39
- for _ in range(100): # Set a maximum limit to prevent infinite loops
40
  outputs = model(
41
  input_ids=generated_ids,
42
  past_key_values=past_key_values,
43
  use_cache=True
44
  )
45
  logits = outputs.logits
46
- past_key_values = outputs.past_key_values
47
 
48
  next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
49
- generated_ids = torch.cat((generated_ids, next_token_id.unsqueeze(-1)), dim=-1)
50
 
51
  if next_token_id.item() == eos_token_id:
52
  break
 
53
 
54
- output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
55
- prediction = postprocess(output_text[len(input_text):])
 
 
56
 
57
  # Create the diff for HighlightedText
58
  diff = []
59
- for i in range(max(len(prediction), len(correct_product))):
60
- if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
61
- diff.append((prediction[i], None)) # No highlight for correct digits
62
- elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
63
- diff.append((prediction[i], "+")) # Highlight incorrect digits in red
64
- if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
65
- diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- yield diff, ""
68
 
69
- if valid_input:
70
- is_correct = prediction == correct_product
71
- result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
72
- else:
73
- result_message = "Invalid input. Could not evaluate correctness."
74
 
75
- # Final diff for the complete prediction
76
- final_diff = []
77
- for i in range(max(len(prediction), len(correct_product))):
78
- if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
79
- final_diff.append((prediction[i], None)) # No highlight for correct digits
80
- elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
81
- final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
82
- if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
83
- final_diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
84
 
85
- yield final_diff, result_message
86
 
87
  demo = gr.Interface(
88
  fn=predict_product,
@@ -91,7 +113,8 @@ demo = gr.Interface(
91
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
92
  ],
93
  outputs=[
94
- gr.HighlightedText(label='Predicted Product with Matching and Unmatching Digits Highlighted', combine_adjacent=True, show_legend=True, color_map={"-": "green", "+": "red"}),
 
95
  gr.HTML(label='Result Message')
96
  ],
97
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
 
6
  model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
+ MAX_PRODUCT_DIGITS = 100
10
 
11
  def preprocess(num):
12
  num = str(num).strip().replace(' ', '')
 
22
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
23
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
24
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
25
+ eos_token_id = tokenizer.eos_token_id
26
 
27
+ input_ids = inputs['input_ids']
28
+ input_len = input_ids.shape[-1]
29
  prediction = ""
30
  correct_product = ""
31
  valid_input = True
 
37
  except ValueError:
38
  valid_input = False
39
 
40
+ generated_ids = inputs['input_ids']
41
  past_key_values = None
42
+ for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
43
  outputs = model(
44
  input_ids=generated_ids,
45
  past_key_values=past_key_values,
46
  use_cache=True
47
  )
48
  logits = outputs.logits
 
49
 
50
  next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
51
+ generated_ids = torch.cat((generated_ids, next_token_id.view(1,-1)), dim=-1)
52
 
53
  if next_token_id.item() == eos_token_id:
54
  break
55
+ past_key_values = outputs.past_key_values
56
 
57
+ output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
58
+ #prediction = postprocess(output_text)
59
+ predicted_digits_reversed = output_text.strip().split(' ')
60
+ correct_digits_reversed = ' '.join(correct_product)[::-1]
61
 
62
  # Create the diff for HighlightedText
63
  diff = []
64
+ correct_digits = []
65
+ is_correct_sofar = True
66
+ for i in range(len(predicted_digits_reversed)):
67
+ predicted_digit = predicted_digits_reversed[i]
68
+ correct_digit = correct_digits_reversed[i]
69
+ correct_digits.append((correct_digit, None))
70
+ if i >= len(correct_digits_reversed):
71
+ if predicted_digit == '0' and is_correct_sofar:
72
+ is_correct_digit = True
73
+ else:
74
+ is_correct_digit = True
75
+ else:
76
+ if predicted_digit == correct_digit:
77
+ is_correct_digit = True
78
+ else:
79
+ is_correct_digit = False
80
+ if not is_correct_digit:
81
+ is_correct_sofar = False
82
+ if is_correct_digit:
83
+ diff.append((correct_product[i], "-"))
84
+ else:
85
+ diff.append((predicted_digit, "+"))
86
+ diff = diff[::-1]
87
+ correct_digits = correct_digits[::-1]
88
 
89
+ yield correct_digits, diff, ""
90
 
91
+ #if valid_input:
92
+ # is_correct = prediction == correct_product
93
+ # result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
94
+ #else:
95
+ # result_message = "Invalid input. Could not evaluate correctness."
96
 
97
+ ## Final diff for the complete prediction
98
+ #final_diff = []
99
+ #for i in range(max(len(prediction), len(correct_product))):
100
+ # if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
101
+ # final_diff.append((prediction[i], None)) # No highlight for correct digits
102
+ # elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
103
+ # final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
104
+ # if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
105
+ # final_diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
106
 
107
+ #yield final_diff, result_message
108
 
109
  demo = gr.Interface(
110
  fn=predict_product,
 
113
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
114
  ],
115
  outputs=[
116
+ gr.Textbox(label='Ground Truth Product'),
117
+ gr.HighlightedText(label='Predicted Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}),
118
  gr.HTML(label='Result Message')
119
  ],
120
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',