Spaces:
Runtime error
Runtime error
File size: 1,559 Bytes
719f5b1 528d9ae 719f5b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import gradio as gr
import pandas as pd
from tqdm import tqdm
from facility_predict import Preprocess, Facility_Model, obj_Facility_Model, processor
def predict_batch_from_csv(input_file, output_file):
# Load batch data from CSV
batch_data = pd.read_csv(input_file)
# Initialize predictions list
predictions = []
# Iterate over rows with tqdm for progress tracking
for _, row in tqdm(batch_data.iterrows(), total=len(batch_data)):
text = row['facility_name'] # Replace 'facility_name' with the actual column name containing the text data
cleaned_text = processor.clean_text(text)
prepared_data = processor.process_tokenizer(cleaned_text)
prediction = obj_Facility_Model.inference(prepared_data)
predictions.append(prediction)
# Create DataFrame for predictions
output_data = pd.DataFrame({'prediction': predictions})
# Merge with input DataFrame
pred_output_df = pd.concat([batch_data, output_data], axis=1)
# Save predictions to CSV
pred_output_df.to_csv(output_file, index=False)
def predict_batch(input_csv, output_csv):
predict_batch_from_csv(input_csv, output_csv)
return "Prediction completed. Results saved to " + output_csv
iface = gr.Interface(
fn=predict_batch,
inputs=["file", "text"],
outputs="text",
title="Batch Facility Name Prediction",
description="Upload a CSV file with facility names and get the predictions in a CSV file",
#examples=[["input.csv", "output.csv"]],
)
if __name__ == "__main__":
iface.launch() |