swahili_llm / app.py
art-manuh's picture
Update app.py
c10dbbf verified
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import gradio as gr
from transformers import pipeline
import logging
# Enable detailed logging
logging.basicConfig(level=logging.INFO)
# Load dataset
dataset = load_dataset("mwitiderrick/swahili")
# Print dataset columns for verification
print(f"Dataset columns: {dataset['train'].column_names}")
# Select a subset of the dataset (e.g., first 100,000 rows)
subset_size = 50000 # Adjust the size as needed
subset_dataset = dataset["train"].select(range(min(subset_size, len(dataset["train"]))))
print(f"Using a subset of {len(subset_dataset)} rows for training.")
# Initialize the tokenizer and model
model_name = "gpt2" # Use GPT-2 for text generation
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Add a padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token # Use eos_token as pad_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
# Preprocess the dataset
def preprocess_function(examples):
# Tokenize and format the dataset
encodings = tokenizer(
examples['text'], # Use 'text' column from your dataset
truncation=True,
padding='max_length', # Ensure consistent length
max_length=512
)
encodings['labels'] = encodings['input_ids'] # Use input_ids directly as labels
return encodings
# Tokenize the dataset
try:
tokenized_datasets = subset_dataset.map(
preprocess_function,
batched=True
)
except Exception as e:
print(f"Error during tokenization: {e}")
# Define training arguments
training_args = TrainingArguments(
output_dir='./results',
per_device_train_batch_size=4,
num_train_epochs=1,
logging_dir='./logs',
logging_steps=500, # Log every 500 steps
evaluation_strategy="steps", # Use evaluation strategy
save_steps=10_000, # Save checkpoint every 10,000 steps
save_total_limit=2, # Keep only the last 2 checkpoints
)
# Define Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
tokenizer=tokenizer,
)
# Start training
try:
trainer.train()
except Exception as e:
print(f"Error during training: {e}")
# Define the Gradio interface function
nlp = pipeline("text-generation", model=model, tokenizer=tokenizer)
def generate_text(prompt):
try:
return nlp(prompt, max_length=50)[0]['generated_text']
except Exception as e:
return f"Error during text generation: {e}"
# Create and launch the Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Swahili Language Model",
description="Generate text in Swahili using a pre-trained language model."
)
iface.launch()