Normandy_QA_2 / app.py
Rajut's picture
Create app.py
bc9b495 verified
raw
history blame
2.63 kB
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import torch
import os
import gradio as gr
# Load pre-trained GPT-2 model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Load your preprocessed data
with open("normans_wikipedia.txt", "r", encoding="utf-8") as file:
data = file.read()
# Specify the output directory for fine-tuned model
output_dir = "./normans_fine-tuned"
os.makedirs(output_dir, exist_ok=True)
# Tokenize and encode the data
input_ids = tokenizer.encode(data, return_tensors="pt")
# Create a dataset and data collator
dataset = TextDataset(
tokenizer=tokenizer,
file_path="normans_wikipedia.txt",
block_size=512, # Adjust this value based on your requirements
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Fine-tune the model
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=2,
save_steps=10_000,
save_total_limit=2,
logging_dir=output_dir, # Add this line for logging
logging_steps=100, # Adjust this value based on your requirements
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
# Training loop
try:
trainer.train()
except KeyboardInterrupt:
print("Training interrupted by user.")
# Save the fine-tuned model
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# Load the fine-tuned model
fine_tuned_model = GPT2LMHeadModel.from_pretrained(output_dir)
# Function to generate responses from the fine-tuned model
def generate_response(user_input):
# Tokenize and encode user input
user_input_ids = tokenizer.encode(user_input, return_tensors="pt")
# Generate response from the fine-tuned model
generated_output = fine_tuned_model.generate(
user_input_ids,
max_length=100,
num_beams=5,
no_repeat_ngram_size=2,
top_k=50,
top_p=0.90,
temperature=0.9
)
# Decode and return the generated response
chatbot_response = tokenizer.decode(
generated_output[0], skip_special_tokens=True)
return "Chatbot: " + chatbot_response
# Create a Gradio interface
iface = gr.Interface(
fn=generate_response,
inputs="text",
outputs="text",
live=True
)
# Launch the Gradio interface
iface.launch()