trip_planner / app.py
Abdulla Fahem
Add application file
a86a6db
import os
import sys
import torch
import pandas as pd
import streamlit as st
from datetime import datetime
from transformers import (
T5ForConditionalGeneration,
T5Tokenizer,
Trainer,
TrainingArguments,
DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
import random
# Ensure reproducibility
torch.manual_seed(42)
random.seed(42)
# Environment setup
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
class TravelDataset(Dataset):
def __init__(self, data, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.data = data
self.max_length = max_length
print(f"Dataset loaded with {len(data)} samples")
print("Columns:", list(data.columns))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
# Input: query
input_text = row['query']
# Target: reference_information
target_text = row['reference_information']
# Tokenize inputs
input_encodings = self.tokenizer(
input_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Tokenize targets
target_encodings = self.tokenizer(
target_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': input_encodings['input_ids'].squeeze(),
'attention_mask': input_encodings['attention_mask'].squeeze(),
'labels': target_encodings['input_ids'].squeeze()
}
def load_dataset():
"""
Load the travel planning dataset from CSV.
"""
try:
data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv")
required_columns = ['query', 'reference_information']
for col in required_columns:
if col not in data.columns:
raise ValueError(f"Missing required column: {col}")
print(f"Dataset loaded successfully with {len(data)} rows.")
return data
except Exception as e:
print(f"Error loading dataset: {e}")
sys.exit(1)
def train_model():
try:
# Load dataset
data = load_dataset()
# Initialize model and tokenizer
print("Initializing T5 model and tokenizer...")
tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False)
model = T5ForConditionalGeneration.from_pretrained('t5-base')
# Split data
train_size = int(0.8 * len(data))
train_data = data[:train_size]
val_data = data[train_size:]
train_dataset = TravelDataset(train_data, tokenizer)
val_dataset = TravelDataset(val_data, tokenizer)
training_args = TrainingArguments(
output_dir="./trained_travel_planner",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
evaluation_strategy="steps",
eval_steps=50,
save_steps=100,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
load_best_model_at_end=True,
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
padding=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator
)
print("Training model...")
trainer.train()
model.save_pretrained("./trained_travel_planner")
tokenizer.save_pretrained("./trained_travel_planner")
print("Model training complete!")
return model, tokenizer
except Exception as e:
print(f"Training error: {e}")
return None, None
def generate_travel_plan(query, model, tokenizer):
"""
Generate a travel plan using the trained model.
"""
try:
inputs = tokenizer(
query,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
model = model.cuda()
outputs = model.generate(
**inputs,
max_length=512,
num_beams=4,
no_repeat_ngram_size=3,
num_return_sequences=1
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error generating travel plan: {e}"
def main():
st.set_page_config(
page_title="AI Travel Planner",
page_icon="✈️",
layout="wide"
)
st.title("✈️ AI Travel Planner")
# Sidebar to train model
with st.sidebar:
st.header("Model Management")
if st.button("Retrain Model"):
with st.spinner("Training the model..."):
model, tokenizer = train_model()
if model:
st.session_state['model'] = model
st.session_state['tokenizer'] = tokenizer
st.success("Model retrained successfully!")
else:
st.error("Model retraining failed.")
# Load model if not already loaded
if 'model' not in st.session_state:
with st.spinner("Loading model..."):
model, tokenizer = train_model()
st.session_state['model'] = model
st.session_state['tokenizer'] = tokenizer
# Input query
st.subheader("Plan Your Trip")
query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')")
if st.button("Generate Plan"):
if not query:
st.error("Please enter a query.")
else:
with st.spinner("Generating your travel plan..."):
travel_plan = generate_travel_plan(
query,
st.session_state['model'],
st.session_state['tokenizer']
)
st.subheader("Your Travel Plan")
st.write(travel_plan)
if __name__ == "__main__":
main()