Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,506 Bytes
02df9f8 5c3cea5 02df9f8 9bfa66b 02df9f8 bf65d9e 02df9f8 9bfa66b 02df9f8 9428a07 02df9f8 bf65d9e 39a2dae bf65d9e f7dc2d2 8fa0ae4 70487ef eaa0586 70487ef f7dc2d2 eaa0586 bf65d9e 3e9942b bf65d9e 0ad2aca f7dc2d2 0ad2aca bf65d9e 0ad2aca f7dc2d2 bf65d9e f7dc2d2 bf65d9e f7dc2d2 bf65d9e f7dc2d2 bf65d9e f7dc2d2 bf65d9e f7dc2d2 bf65d9e eaa0586 bf65d9e 02df9f8 3f861c3 513d0fe 3f861c3 9bfa66b bf65d9e 8fa0ae4 9bfa66b 1efd23b 9428a07 8fa0ae4 486c21f 6cc23f5 02df9f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
MAX_PRODUCT_DIGITS = 100
def preprocess(num):
num = str(num).strip().replace(' ', '')
reversed_num = ' '.join(num[::-1])
return reversed_num
def postprocess(raw_output):
prediction = raw_output.replace(' ', '')[::-1]
return prediction
@spaces.GPU
def predict_product(num1, num2):
input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
model.to('cuda' if torch.cuda.is_available() else 'cpu')
eos_token_id = tokenizer.eos_token_id
input_ids = inputs['input_ids']
input_len = input_ids.shape[-1]
prediction = ""
correct_product = ""
valid_input = True
try:
num1_int = int(num1)
num2_int = int(num2)
correct_product = str(num1_int * num2_int)
except ValueError:
valid_input = False
generated_ids = inputs['input_ids']
past_key_values = None
for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
outputs = model(
input_ids=generated_ids,
past_key_values=past_key_values,
use_cache=True
)
logits = outputs.logits
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
generated_ids = torch.cat((generated_ids, next_token_id.view(1,-1)), dim=-1)
if next_token_id.item() == eos_token_id:
break
past_key_values = outputs.past_key_values
output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
#prediction = postprocess(output_text)
predicted_digits_reversed = output_text.strip().split(' ')
correct_digits_reversed = ' '.join(correct_product)[::-1]
# Create the diff for HighlightedText
diff = []
correct_digits = []
is_correct_sofar = True
for i in range(len(predicted_digits_reversed)):
predicted_digit = predicted_digits_reversed[i]
correct_digit = correct_digits_reversed[i]
correct_digits.append((correct_digit, None))
if i >= len(correct_digits_reversed):
if predicted_digit == '0' and is_correct_sofar:
is_correct_digit = True
else:
is_correct_digit = True
else:
if predicted_digit == correct_digit:
is_correct_digit = True
else:
is_correct_digit = False
if not is_correct_digit:
is_correct_sofar = False
if is_correct_digit:
diff.append((correct_product[i], "-"))
else:
diff.append((predicted_digit, "+"))
diff = diff[::-1]
correct_digits = correct_digits[::-1]
yield correct_digits, diff, ""
#if valid_input:
# is_correct = prediction == correct_product
# result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
#else:
# result_message = "Invalid input. Could not evaluate correctness."
## Final diff for the complete prediction
#final_diff = []
#for i in range(max(len(prediction), len(correct_product))):
# if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
# final_diff.append((prediction[i], None)) # No highlight for correct digits
# elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
# final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
# if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
# final_diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
#yield final_diff, result_message
demo = gr.Interface(
fn=predict_product,
inputs=[
gr.Textbox(label='First Number (up to 12 digits)', value='12345'),
gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
],
outputs=[
gr.Textbox(label='Ground Truth Product'),
gr.HighlightedText(label='Predicted Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}),
gr.HTML(label='Result Message')
],
title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
description='This demo uses GPT2 to directly predict the product of two numbers without using any intermediate reasoning steps. The GPT2 model has been fine-tuned to internalize chain-of-thought reasoning within its hidden states, following our stepwise internalization approach detailed in the paper linked at the bottom of this page.',
article="""
- [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
""",
clear_btn=None,
submit_btn="Multiply!",
live=False
)
demo.launch()
|