import streamlit as st from datasets import load_dataset from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch ds = load_dataset("higgsfield/school-math-questions") qa_pairs = [(item['prompt'], item['completion']) for item in ds['train']] class MathDataset(torch.utils.data.Dataset): def __init__(self, qa_pairs, tokenizer, max_length=128): self.qa_pairs = qa_pairs self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.qa_pairs) def __getitem__(self, idx): question, answer = self.qa_pairs[idx] input_text = f"Q: {question} A:" # Tokenize and pad input and target sequences input_ids = self.tokenizer.encode(input_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt").squeeze(0) target_ids = self.tokenizer.encode(answer.strip(), truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt").squeeze(0) # Set the labels to -100 where input_ids are padding tokens target_ids[target_ids == self.tokenizer.pad_token_id] = -100 return { "input_ids": input_ids, "labels": target_ids, } model_name = "gpt2" tokenizer = GPT2Tokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token model = GPT2LMHeadModel.from_pretrained(model_name) math_dataset = MathDataset(qa_pairs, tokenizer) from transformers import Trainer, TrainingArguments # Set training arguments training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=2, save_steps=10, save_total_limit=2, ) # Create a Trainer trainer = Trainer( model=model, args=training_args, train_dataset=math_dataset, ) # Fine-tune the model trainer.train() class MathChatBot: def __init__(self, model_name="gpt2"): self.tokenizer = GPT2Tokenizer.from_pretrained(model_name) self.model = GPT2LMHeadModel.from_pretrained(model_name) def get_response(self, question): input_text = f"Q: {question} A:" input_ids = self.tokenizer.encode(input_text, return_tensors="pt") output = self.model.generate(input_ids, max_length=50, num_return_sequences=1) answer = self.tokenizer.decode(output[0], skip_special_tokens=True) return answer.split("A:")[-1].strip() # Usage if __name__ == "__main__": bot = MathChatBot() user_input = st.text_area("Enter your question:") response = bot.get_response(user_input) st.write(f"Bot: {response}")