da03 commited on
Commit
d8750b1
1 Parent(s): ad4fc9e
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -51,14 +51,11 @@ def predict_product(num1, num2):
51
  generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
52
  finished_per_model = {model_name: False for model_name in models}
53
  past_key_values_per_model = {model_name: None for model_name in models}
54
- predicted_results_per_model = {}
55
  for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
56
  # Ground Truth
57
- ground_truth_results = []
58
- for i in range(step+1):
59
- ground_truth_digit = ground_truth_digits_reversed[i]
60
- ground_truth_results.append((ground_truth_digit, None))
61
- ground_truth_results = ground_truth_results[::-1]
62
  # Predicted
63
  for model_name in models:
64
  model = models[model_name]
@@ -91,7 +88,7 @@ def predict_product(num1, num2):
91
  output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
92
  predicted_digits_reversed = output_text.strip().split(' ')
93
 
94
- predicted_results = []
95
  is_correct_sofar = True
96
  for i in range(len(predicted_digits_reversed)):
97
  predicted_digit = predicted_digits_reversed[i]
@@ -109,17 +106,17 @@ def predict_product(num1, num2):
109
  if not is_correct_digit:
110
  is_correct_sofar = False
111
  if is_correct_digit:
112
- predicted_results.append((predicted_digit, "correct"))
113
  else:
114
- predicted_results.append((predicted_digit, "wrong"))
115
- predicted_results = predicted_results[::-1]
116
- predicted_results_per_model[model_name] = predicted_results
117
 
118
- predicted_results_implicit_cot = predicted_results_per_model['implicit']
119
- predicted_results_nocot = predicted_results_per_model['no']
120
- predicted_results_explicit_cot = predicted_results_per_model['explicit']
121
 
122
- yield ground_truth_results, predicted_results_implicit_cot, predicted_results_nocot, predicted_results_explicit_cot
123
 
124
  color_map = {"correct": "green", "wrong": "red"}
125
 
 
51
  generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
52
  finished_per_model = {model_name: False for model_name in models}
53
  past_key_values_per_model = {model_name: None for model_name in models}
54
+ predicted_annotations_per_model = {}
55
  for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
56
  # Ground Truth
57
+ ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
58
+ ground_truth_annotations = ground_truth_annotations[::-1]
 
 
 
59
  # Predicted
60
  for model_name in models:
61
  model = models[model_name]
 
88
  output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
89
  predicted_digits_reversed = output_text.strip().split(' ')
90
 
91
+ predicted_annotations = []
92
  is_correct_sofar = True
93
  for i in range(len(predicted_digits_reversed)):
94
  predicted_digit = predicted_digits_reversed[i]
 
106
  if not is_correct_digit:
107
  is_correct_sofar = False
108
  if is_correct_digit:
109
+ predicted_annotations.append((predicted_digit, "correct"))
110
  else:
111
+ predicted_annotations.append((predicted_digit, "wrong"))
112
+ predicted_annotations = predicted_annotations[::-1]
113
+ predicted_annotations_per_model[model_name] = predicted_annotations
114
 
115
+ predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
116
+ predicted_annotations_nocot = predicted_annotations_per_model['no']
117
+ predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']
118
 
119
+ yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot
120
 
121
  color_map = {"correct": "green", "wrong": "red"}
122