Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
import pandas as pd | |
import streamlit as st | |
from datetime import datetime | |
from transformers import ( | |
T5ForConditionalGeneration, | |
T5Tokenizer, | |
Trainer, | |
TrainingArguments, | |
DataCollatorForSeq2Seq | |
) | |
from torch.utils.data import Dataset | |
import random | |
# Ensure reproducibility | |
torch.manual_seed(42) | |
random.seed(42) | |
# Environment setup | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
class TravelDataset(Dataset): | |
def __init__(self, data, tokenizer, max_length=512): | |
self.tokenizer = tokenizer | |
self.data = data | |
self.max_length = max_length | |
print(f"Dataset loaded with {len(data)} samples") | |
print("Columns:", list(data.columns)) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
row = self.data.iloc[idx] | |
# Input: query | |
input_text = row['query'] | |
# Target: reference_information | |
target_text = row['reference_information'] | |
# Tokenize inputs | |
input_encodings = self.tokenizer( | |
input_text, | |
max_length=self.max_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
# Tokenize targets | |
target_encodings = self.tokenizer( | |
target_text, | |
max_length=self.max_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
return { | |
'input_ids': input_encodings['input_ids'].squeeze(), | |
'attention_mask': input_encodings['attention_mask'].squeeze(), | |
'labels': target_encodings['input_ids'].squeeze() | |
} | |
def load_dataset(): | |
""" | |
Load the travel planning dataset from CSV. | |
""" | |
try: | |
data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv") | |
required_columns = ['query', 'reference_information'] | |
for col in required_columns: | |
if col not in data.columns: | |
raise ValueError(f"Missing required column: {col}") | |
print(f"Dataset loaded successfully with {len(data)} rows.") | |
return data | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
sys.exit(1) | |
def train_model(): | |
try: | |
# Load dataset | |
data = load_dataset() | |
# Initialize model and tokenizer | |
print("Initializing T5 model and tokenizer...") | |
tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False) | |
model = T5ForConditionalGeneration.from_pretrained('t5-base') | |
# Split data | |
train_size = int(0.8 * len(data)) | |
train_data = data[:train_size] | |
val_data = data[train_size:] | |
train_dataset = TravelDataset(train_data, tokenizer) | |
val_dataset = TravelDataset(val_data, tokenizer) | |
training_args = TrainingArguments( | |
output_dir="./trained_travel_planner", | |
num_train_epochs=3, | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
evaluation_strategy="steps", | |
eval_steps=50, | |
save_steps=100, | |
weight_decay=0.01, | |
logging_dir="./logs", | |
logging_steps=10, | |
load_best_model_at_end=True, | |
) | |
data_collator = DataCollatorForSeq2Seq( | |
tokenizer=tokenizer, | |
model=model, | |
padding=True | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
data_collator=data_collator | |
) | |
print("Training model...") | |
trainer.train() | |
model.save_pretrained("./trained_travel_planner") | |
tokenizer.save_pretrained("./trained_travel_planner") | |
print("Model training complete!") | |
return model, tokenizer | |
except Exception as e: | |
print(f"Training error: {e}") | |
return None, None | |
def generate_travel_plan(query, model, tokenizer): | |
""" | |
Generate a travel plan using the trained model. | |
""" | |
try: | |
inputs = tokenizer( | |
query, | |
return_tensors="pt", | |
max_length=512, | |
padding="max_length", | |
truncation=True | |
) | |
if torch.cuda.is_available(): | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
model = model.cuda() | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
num_beams=4, | |
no_repeat_ngram_size=3, | |
num_return_sequences=1 | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
return f"Error generating travel plan: {e}" | |
def main(): | |
st.set_page_config( | |
page_title="AI Travel Planner", | |
page_icon="✈️", | |
layout="wide" | |
) | |
st.title("✈️ AI Travel Planner") | |
# Sidebar to train model | |
with st.sidebar: | |
st.header("Model Management") | |
if st.button("Retrain Model"): | |
with st.spinner("Training the model..."): | |
model, tokenizer = train_model() | |
if model: | |
st.session_state['model'] = model | |
st.session_state['tokenizer'] = tokenizer | |
st.success("Model retrained successfully!") | |
else: | |
st.error("Model retraining failed.") | |
# Load model if not already loaded | |
if 'model' not in st.session_state: | |
with st.spinner("Loading model..."): | |
model, tokenizer = train_model() | |
st.session_state['model'] = model | |
st.session_state['tokenizer'] = tokenizer | |
# Input query | |
st.subheader("Plan Your Trip") | |
query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')") | |
if st.button("Generate Plan"): | |
if not query: | |
st.error("Please enter a query.") | |
else: | |
with st.spinner("Generating your travel plan..."): | |
travel_plan = generate_travel_plan( | |
query, | |
st.session_state['model'], | |
st.session_state['tokenizer'] | |
) | |
st.subheader("Your Travel Plan") | |
st.write(travel_plan) | |
if __name__ == "__main__": | |
main() | |