da03 commited on
Commit
f7dc2d2
1 Parent(s): 685026a
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -23,31 +23,57 @@ def predict_product(num1, num2):
23
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
24
 
25
  generated_ids = inputs['input_ids']
26
- outputs = model.generate(generated_ids, max_new_tokens=40, do_sample=False)
27
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- prediction = postprocess(full_output[len(input_text):])
29
 
30
  try:
31
  num1_int = int(num1)
32
  num2_int = int(num2)
33
  correct_product = str(num1_int * num2_int)
34
  except ValueError:
35
- return [], "Invalid input. Could not evaluate correctness."
36
 
37
- # Create the diff for HighlightedText
38
- diff = []
39
- max_len = max(len(prediction), len(correct_product))
40
- for i in range(max_len):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
42
- diff.append((prediction[i], None)) # No highlight for correct digits
43
  elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
44
- diff.append((prediction[i], "+")) # Highlight incorrect digits in red
45
  if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
46
- diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
47
-
48
- result_message = "Correct!" if prediction == correct_product else f"Incorrect! The correct product is {correct_product}."
49
 
50
- return diff, result_message
51
 
52
  demo = gr.Interface(
53
  fn=predict_product,
 
23
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
24
 
25
  generated_ids = inputs['input_ids']
26
+ prediction = ""
27
+ correct_product = ""
28
+ valid_input = True
29
 
30
  try:
31
  num1_int = int(num1)
32
  num2_int = int(num2)
33
  correct_product = str(num1_int * num2_int)
34
  except ValueError:
35
+ valid_input = False
36
 
37
+ eos_token_id = tokenizer.eos_token_id
38
+ for _ in range(100): # Set a maximum limit to prevent infinite loops
39
+ outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False)
40
+ generated_ids = torch.cat((generated_ids, outputs[:, -1:]), dim=-1)
41
+
42
+ if outputs[0, -1].item() == eos_token_id:
43
+ break
44
+
45
+ output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
46
+ prediction = postprocess(output_text[len(input_text):])
47
+
48
+ # Create the diff for HighlightedText
49
+ diff = []
50
+ for i in range(max(len(prediction), len(correct_product))):
51
+ if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
52
+ diff.append((prediction[i], None)) # No highlight for correct digits
53
+ elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
54
+ diff.append((prediction[i], "+")) # Highlight incorrect digits in red
55
+ if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
56
+ diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
57
+
58
+ yield diff, ""
59
+
60
+ if valid_input:
61
+ is_correct = prediction == correct_product
62
+ result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
63
+ else:
64
+ result_message = "Invalid input. Could not evaluate correctness."
65
+
66
+ # Final diff for the complete prediction
67
+ final_diff = []
68
+ for i in range(max(len(prediction), len(correct_product))):
69
  if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
70
+ final_diff.append((prediction[i], None)) # No highlight for correct digits
71
  elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
72
+ final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
73
  if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
74
+ final_diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
 
 
75
 
76
+ yield final_diff, result_message
77
 
78
  demo = gr.Interface(
79
  fn=predict_product,