Jacaranda's picture
Upload app.py
528d9ae
raw
history blame
1.56 kB
import gradio as gr
import pandas as pd
from tqdm import tqdm
from facility_predict import Preprocess, Facility_Model, obj_Facility_Model, processor
def predict_batch_from_csv(input_file, output_file):
# Load batch data from CSV
batch_data = pd.read_csv(input_file)
# Initialize predictions list
predictions = []
# Iterate over rows with tqdm for progress tracking
for _, row in tqdm(batch_data.iterrows(), total=len(batch_data)):
text = row['facility_name'] # Replace 'facility_name' with the actual column name containing the text data
cleaned_text = processor.clean_text(text)
prepared_data = processor.process_tokenizer(cleaned_text)
prediction = obj_Facility_Model.inference(prepared_data)
predictions.append(prediction)
# Create DataFrame for predictions
output_data = pd.DataFrame({'prediction': predictions})
# Merge with input DataFrame
pred_output_df = pd.concat([batch_data, output_data], axis=1)
# Save predictions to CSV
pred_output_df.to_csv(output_file, index=False)
def predict_batch(input_csv, output_csv):
predict_batch_from_csv(input_csv, output_csv)
return "Prediction completed. Results saved to " + output_csv
iface = gr.Interface(
fn=predict_batch,
inputs=["file", "text"],
outputs="text",
title="Batch Facility Name Prediction",
description="Upload a CSV file with facility names and get the predictions in a CSV file",
#examples=[["input.csv", "output.csv"]],
)
if __name__ == "__main__":
iface.launch()