Spaces:
Sleeping
Sleeping
File size: 10,748 Bytes
dcdf02a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import os
import streamlit as st
import openai
import pandas as pd
from uuid import uuid4
import time
# π Set the OpenAI API key from an environment variable
openai.api_key = os.getenv("OPENAI_API_KEY")
# π Function to generate a unique session ID for caching
def get_session_id():
if 'session_id' not in st.session_state:
st.session_state.session_id = str(uuid4())
return st.session_state.session_id
# π Predefined examples loaded from Python dictionaries
EXAMPLES = [
{
'Problem': 'What is deductive reasoning?',
'Rationale': 'Deductive reasoning starts from general premises to arrive at a specific conclusion.',
'Answer': 'It involves deriving specific conclusions from general premises.'
},
{
'Problem': 'What is inductive reasoning?',
'Rationale': 'Inductive reasoning involves drawing generalizations based on specific observations.',
'Answer': 'It involves forming general rules from specific examples.'
},
{
'Problem': 'Explain abductive reasoning.',
'Rationale': 'Abductive reasoning finds the most likely explanation for incomplete observations.',
'Answer': 'It involves finding the best possible explanation.'
}
]
# π§ STaR Algorithm Implementation
class SelfTaughtReasoner:
def __init__(self, model_engine="text-davinci-003"):
self.model_engine = model_engine
self.prompt_examples = EXAMPLES # Initialize with predefined examples
self.iterations = 0
self.generated_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
self.rationalized_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
self.fine_tuned_model = None # ποΈ Placeholder for fine-tuned model
def add_prompt_example(self, problem: str, rationale: str, answer: str):
"""
β Adds a prompt example to the few-shot examples.
"""
self.prompt_examples.append({
'Problem': problem,
'Rationale': rationale,
'Answer': answer
})
def construct_prompt(self, problem: str, include_answer: bool = False, answer: str = "") -> str:
"""
π Constructs the prompt for the OpenAI API call.
"""
prompt = ""
for example in self.prompt_examples:
prompt += f"Problem: {example['Problem']}\n"
prompt += f"Rationale: {example['Rationale']}\n"
prompt += f"Answer: {example['Answer']}\n\n"
prompt += f"Problem: {problem}\n"
if include_answer:
prompt += f"Answer (as hint): {answer}\n"
prompt += "Rationale:"
return prompt
def generate_rationale_and_answer(self, problem: str) -> Tuple[str, str]:
"""
π€ Generates a rationale and answer for a given problem.
"""
prompt = self.construct_prompt(problem)
try:
response = openai.Completion.create(
engine=self.model_engine,
prompt=prompt,
max_tokens=150,
temperature=0.7,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=["\n\n", "Problem:", "Answer:"]
)
rationale = response.choices[0].text.strip()
# π Now generate the answer using the rationale
prompt += f" {rationale}\nAnswer:"
answer_response = openai.Completion.create(
engine=self.model_engine,
prompt=prompt,
max_tokens=10,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=["\n", "\n\n", "Problem:"]
)
answer = answer_response.choices[0].text.strip()
return rationale, answer
except Exception as e:
st.error(f"β Error generating rationale and answer: {e}")
return "", ""
def fine_tune_model(self):
"""
π οΈ Fine-tunes the model on the generated rationales.
"""
time.sleep(1) # β³ Simulate time taken for fine-tuning
self.fine_tuned_model = f"{self.model_engine}-fine-tuned-{get_session_id()}"
st.success(f"β
Model fine-tuned: {self.fine_tuned_model}")
def run_iteration(self, dataset: pd.DataFrame):
"""
π Runs one iteration of the STaR process.
"""
st.write(f"### Iteration {self.iterations + 1}")
progress_bar = st.progress(0)
total = len(dataset)
for idx, row in dataset.iterrows():
problem = row['Problem']
correct_answer = row['Answer']
# π€ Generate rationale and answer
rationale, answer = self.generate_rationale_and_answer(problem)
is_correct = (answer.lower() == correct_answer.lower())
# π Record the generated data
self.generated_data = self.generated_data.append({
'Problem': problem,
'Rationale': rationale,
'Answer': answer,
'Is_Correct': is_correct
}, ignore_index=True)
# β If incorrect, perform rationalization
if not is_correct:
rationale, answer = self.rationalize(problem, correct_answer)
is_correct = (answer.lower() == correct_answer.lower())
if is_correct:
self.rationalized_data = self.rationalized_data.append({
'Problem': problem,
'Rationale': rationale,
'Answer': answer,
'Is_Correct': is_correct
}, ignore_index=True)
progress_bar.progress((idx + 1) / total)
# π§ Fine-tune the model on correct rationales
st.write("π Fine-tuning the model on correct rationales...")
self.fine_tune_model()
self.iterations += 1
# π₯οΈ Streamlit App
def main():
st.title("π€ Self-Taught Reasoner (STaR) Demonstration")
# π§© Initialize the Self-Taught Reasoner
if 'star' not in st.session_state:
st.session_state.star = SelfTaughtReasoner()
star = st.session_state.star
# π Wide format layout
col1, col2 = st.columns([1, 2]) # Column widths: col1 for input, col2 for display
# Step 1: Few-Shot Prompt Examples
with col1:
st.header("Step 1: Add Few-Shot Prompt Examples")
st.write("Choose an example from the dropdown or input your own.")
selected_example = st.selectbox(
"Select a predefined example",
[f"Example {i + 1}: {ex['Problem']}" for i, ex in enumerate(EXAMPLES)]
)
# Prefill with selected example
example_idx = int(selected_example.split(" ")[1]) - 1
example_problem = EXAMPLES[example_idx]['Problem']
example_rationale = EXAMPLES[example_idx]['Rationale']
example_answer = EXAMPLES[example_idx]['Answer']
st.text_area("Problem", value=example_problem, height=50, key="example_problem")
st.text_area("Rationale", value=example_rationale, height=100, key="example_rationale")
st.text_input("Answer", value=example_answer, key="example_answer")
if st.button("Add Example"):
star.add_prompt_example(st.session_state.example_problem, st.session_state.example_rationale, st.session_state.example_answer)
st.success("Example added successfully!")
with col2:
# Display current prompt examples
if star.prompt_examples:
st.subheader("Current Prompt Examples:")
for idx, example in enumerate(star.prompt_examples):
st.write(f"**Example {idx + 1}:**")
st.write(f"Problem: {example['Problem']}")
st.write(f"Rationale: {example['Rationale']}")
st.write(f"Answer: {example['Answer']}")
# Step 2: Input Dataset
st.header("Step 2: Input Dataset")
dataset_input_method = st.radio("How would you like to input the dataset?", ("Manual Entry", "Upload CSV"))
if dataset_input_method == "Manual Entry":
dataset_problems = st.text_area("Enter problems and answers in the format 'Problem | Answer', one per line.", height=200)
if st.button("Submit Dataset"):
dataset = []
lines = dataset_problems.strip().split('\n')
for line in lines:
if '|' in line:
problem, answer = line.split('|', 1)
dataset.append({'Problem': problem.strip(), 'Answer': answer.strip()})
st.session_state.dataset = pd.DataFrame(dataset)
st.success("Dataset loaded.")
else:
uploaded_file = st.file_uploader("Upload a CSV file with 'Problem' and 'Answer' columns.", type=['csv'])
if uploaded_file:
st.session_state.dataset = pd.read_csv(uploaded_file)
st.success("Dataset loaded.")
if 'dataset' in st.session_state:
st.subheader("Current Dataset:")
st.dataframe(st.session_state.dataset.head())
# Step 3: Run STaR Process
st.header("Step 3: Run STaR Process")
num_iterations = st.number_input("Number of Iterations to Run:", min_value=1, max_value=10, value=1)
if st.button("Run STaR"):
for _ in range(num_iterations):
star.run_iteration(st.session_state.dataset)
st.header("Results")
st.subheader("Generated Data")
st.dataframe(star.generated_data)
st.subheader("Rationalized Data")
st.dataframe(star.rationalized_data)
st.write("The model has been fine-tuned iteratively.")
# Step 4: Test the Fine-Tuned Model
st.header("Step 4: Test the Fine-Tuned Model")
test_problem = st.text_area("Enter a new problem to solve:", height=100)
if st.button("Solve Problem"):
if not test_problem:
st.warning("Please enter a problem to solve.")
else:
rationale, answer = star.generate_rationale_and_answer(test_problem)
st.subheader("Rationale:")
st.write(rationale)
st.subheader("Answer:")
st.write(answer)
# Footer with custom HTML/JS component
st.markdown("---")
st.write("Developed as a demonstration of the STaR method with enhanced Streamlit capabilities.")
st.components.v1.html("""
<div style="text-align: center; margin-top: 20px;">
<h3>π Boost Your AI Reasoning with STaR! π</h3>
</div>
""")
if __name__ == "__main__":
main()
|