|
import os |
|
import streamlit as st |
|
import openai |
|
import pandas as pd |
|
from typing import List, Tuple |
|
from uuid import uuid4 |
|
import time |
|
|
|
|
|
Site_Name = 'π Self-Taught Reasoner (STaR) App' |
|
title = "π€π STaR: Self-Taught Reasoner - Bootstrapping Reasoning With Reasoning" |
|
helpURL = 'https://arxiv.org/abs/2203.14465' |
|
bugURL = 'https://arxiv.org/pdf/2203.14465' |
|
icons = 'ππ€' |
|
|
|
useConfig = True |
|
if useConfig: |
|
st.set_page_config( |
|
page_title=title, |
|
page_icon=icons, |
|
layout="wide", |
|
initial_sidebar_state="auto", |
|
menu_items={ |
|
'Get Help': helpURL, |
|
'Report a bug': bugURL, |
|
'About': title |
|
} |
|
) |
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
def get_session_id(): |
|
if 'session_id' not in st.session_state: |
|
st.session_state.session_id = str(uuid4()) |
|
return st.session_state.session_id |
|
|
|
|
|
class SelfTaughtReasoner: |
|
def __init__(self, model_engine="gpt-3.5-turbo"): |
|
self.model_engine = model_engine |
|
self.prompt_examples = [] |
|
self.iterations = 0 |
|
self.generated_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct']) |
|
self.rationalized_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct']) |
|
self.fine_tuned_model = None |
|
|
|
def add_prompt_example(self, problem: str, rationale: str, answer: str): |
|
""" |
|
β Adds a prompt example to the few-shot examples. |
|
""" |
|
self.prompt_examples.append({ |
|
'Problem': problem, |
|
'Rationale': rationale, |
|
'Answer': answer |
|
}) |
|
|
|
def construct_prompt(self, problem: str, include_answer: bool = False, answer: str = "") -> List[dict]: |
|
""" |
|
π Constructs the prompt for the OpenAI API call. |
|
Converts examples into the new chat format, where each example is a user message. |
|
""" |
|
messages = [] |
|
for example in self.prompt_examples: |
|
messages.append({"role": "system", "content": f"Problem: {example['Problem']}\nRationale: {example['Rationale']}\nAnswer: {example['Answer']}\n"}) |
|
|
|
messages.append({"role": "user", "content": f"Problem: {problem}\nRationale:"}) |
|
|
|
if include_answer: |
|
messages.append({"role": "system", "content": f"Answer: {answer}"}) |
|
|
|
return messages |
|
|
|
def generate_rationale_and_answer(self, problem: str) -> Tuple[str, str]: |
|
""" |
|
π€ Generates a rationale and answer for a given problem using openai.ChatCompletion.create. |
|
""" |
|
messages = self.construct_prompt(problem) |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model=self.model_engine, |
|
messages=messages, |
|
max_tokens=150, |
|
temperature=0.7 |
|
) |
|
rationale = response.choices[0].message['content'].strip() |
|
|
|
|
|
messages.append({"role": "system", "content": f"Rationale: {rationale}\nAnswer:"}) |
|
answer_response = openai.ChatCompletion.create( |
|
model=self.model_engine, |
|
messages=messages, |
|
max_tokens=10, |
|
temperature=0 |
|
) |
|
answer = answer_response.choices[0].message['content'].strip() |
|
return rationale, answer |
|
except Exception as e: |
|
st.error(f"β Error generating rationale and answer: {e}") |
|
return "", "" |
|
|
|
def fine_tune_model(self): |
|
""" |
|
π οΈ Fine-tunes the model on the generated rationales. |
|
""" |
|
time.sleep(1) |
|
self.fine_tuned_model = f"{self.model_engine}-fine-tuned-{get_session_id()}" |
|
st.success(f"β
Model fine-tuned: {self.fine_tuned_model}") |
|
|
|
def run_iteration(self, dataset: pd.DataFrame): |
|
""" |
|
π Runs one iteration of the STaR process. |
|
""" |
|
st.write(f"### Iteration {self.iterations + 1}") |
|
progress_bar = st.progress(0) |
|
total = len(dataset) |
|
for idx, row in dataset.iterrows(): |
|
problem = row['Problem'] |
|
correct_answer = row['Answer'] |
|
|
|
rationale, answer = self.generate_rationale_and_answer(problem) |
|
is_correct = (answer.lower() == correct_answer.lower()) |
|
|
|
self.generated_data = self.generated_data.append({ |
|
'Problem': problem, |
|
'Rationale': rationale, |
|
'Answer': answer, |
|
'Is_Correct': is_correct |
|
}, ignore_index=True) |
|
|
|
if not is_correct: |
|
rationale, answer = self.rationalize(problem, correct_answer) |
|
is_correct = (answer.lower() == correct_answer.lower()) |
|
if is_correct: |
|
self.rationalized_data = self.rationalized_data.append({ |
|
'Problem': problem, |
|
'Rationale': rationale, |
|
'Answer': answer, |
|
'Is_Correct': is_correct |
|
}, ignore_index=True) |
|
progress_bar.progress((idx + 1) / total) |
|
|
|
st.write("π Fine-tuning the model on correct rationales...") |
|
self.fine_tune_model() |
|
self.iterations += 1 |
|
|
|
|
|
EXAMPLE_PROBLEM_ANSWERS = [ |
|
{"Problem": "What is deductive reasoning?", "Answer": "It is a logical process that draws specific conclusions from general principles."}, |
|
{"Problem": "What is inductive reasoning?", "Answer": "It is reasoning that forms general principles from specific examples."}, |
|
{"Problem": "Explain abductive reasoning.", "Answer": "It involves finding the best explanation for incomplete observations."}, |
|
{"Problem": "What is the capital of France?", "Answer": "Paris."}, |
|
{"Problem": "Who wrote Hamlet?", "Answer": "William Shakespeare."} |
|
] |
|
|
|
|
|
TEST_PROBLEM_SET = [ |
|
"What is the Pythagorean theorem?", |
|
"Who developed the theory of relativity?", |
|
"What is the main ingredient in bread?", |
|
"Who is the author of 1984?", |
|
"What is the boiling point of water?" |
|
] |
|
|
|
|
|
def format_examples_for_text_area(examples): |
|
return '\n'.join([f"{example['Problem']} | {example['Answer']}" for example in examples]) |
|
|
|
|
|
def main(): |
|
st.title("π€ Self-Taught Reasoners (STaR)") |
|
st.markdown(''' |
|
# π Papers: |
|
1. π€«π Quiet-STaR: Language Models Can Teach Themselves to Think π€ Before Speaking π£οΈ |
|
- π https://arxiv.org/abs/2403.09629 - π https://arxiv.org/pdf/2403.09629 |
|
2. ππ€ STaR: Self-Taught Reasoner - Bootstrapping Reasoning With Reasoning |
|
- π https://arxiv.org/abs/2203.14465 - π https://arxiv.org/pdf/2203.14465 |
|
''') |
|
|
|
|
|
|
|
if 'star' not in st.session_state: |
|
st.session_state.star = SelfTaughtReasoner() |
|
|
|
star = st.session_state.star |
|
|
|
|
|
st.header("Step 1: Add Few-Shot Prompt Examples") |
|
st.write("Choose an example from the dropdown or input your own.") |
|
|
|
selected_example = st.selectbox( |
|
"Select a predefined example", |
|
[f"Example {i + 1}: {ex['Problem']}" for i, ex in enumerate(EXAMPLE_PROBLEM_ANSWERS)] |
|
) |
|
|
|
|
|
example_idx = int(selected_example.split(" ")[1].replace(":", "")) - 1 |
|
example_problem = EXAMPLE_PROBLEM_ANSWERS[example_idx]['Problem'] |
|
example_answer = EXAMPLE_PROBLEM_ANSWERS[example_idx]['Answer'] |
|
|
|
st.text_area("Problem", value=example_problem, height=50, key="example_problem") |
|
st.text_input("Answer", value=example_answer, key="example_answer") |
|
|
|
if st.button("Add Example"): |
|
star.add_prompt_example(st.session_state.example_problem, "Rationale placeholder", st.session_state.example_answer) |
|
st.success("Example added successfully!") |
|
|
|
|
|
st.header("Step 2: Input Dataset") |
|
|
|
|
|
prefilled_data = format_examples_for_text_area(EXAMPLE_PROBLEM_ANSWERS) |
|
dataset_problems = st.text_area( |
|
"Enter problems and answers in the format 'Problem | Answer', one per line.", |
|
value=prefilled_data, |
|
height=200 |
|
) |
|
|
|
if st.button("Submit Dataset"): |
|
dataset = [] |
|
lines = dataset_problems.strip().split('\n') |
|
for line in lines: |
|
if '|' in line: |
|
problem, answer = line.split('|', 1) |
|
dataset.append({'Problem': problem.strip(), 'Answer': answer.strip()}) |
|
st.session_state.dataset = pd.DataFrame(dataset) |
|
st.success("Dataset loaded.") |
|
|
|
if 'dataset' in st.session_state: |
|
st.subheader("Current Dataset:") |
|
st.dataframe(st.session_state.dataset.head()) |
|
|
|
|
|
st.header("Step 3: Test the Fine-Tuned Model") |
|
|
|
|
|
test_problem = st.selectbox( |
|
"Select a problem to test the fine-tuned model", |
|
TEST_PROBLEM_SET |
|
) |
|
|
|
if st.button("Solve Problem"): |
|
if not test_problem: |
|
st.warning("Please enter or select a problem to solve.") |
|
else: |
|
rationale, answer = star.generate_rationale_and_answer(test_problem) |
|
st.subheader("Rationale:") |
|
st.write(rationale) |
|
st.subheader("Answer:") |
|
st.write(answer) |
|
|
|
|
|
st.write("---") |
|
st.write("Developed as a demonstration of the STaR method.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|