from io import StringIO from typing import Optional import gradio as gr import pandas as pd from utils import pipeline from utils.models import list_models def read_data(filepath: str) -> Optional[pd.DataFrame]: if filepath.endswith('.xlsx'): df = pd.read_excel(filepath) elif filepath.endswith('.csv'): df = pd.read_csv(filepath) else: raise Exception('File type not supported') return df def process( task_name: str, model_name: str, pooling: str, text: str, file=None, ) -> (None, pd.DataFrame, str): # try: # load file if file: df = read_data(file.name) elif text: string_io = StringIO(text) df = pd.read_csv(string_io) assert len(df) >= 1, 'No input data' else: raise Exception('No input data') # process if task_name == 'Originality': df = pipeline.p0_originality(df, model_name, pooling) elif task_name == 'Flexibility': df = pipeline.p1_flexibility(df, model_name, pooling) else: raise Exception('Task not supported') # save path = 'output.csv' df.to_csv(path, index=False, encoding='utf-8-sig') return None, df.iloc[:10], path # except Exception as e: # return {'Error': e}, None, None # input task_name_dropdown = gr.components.Dropdown( label='Task Name', value='Originality', choices=['Originality', 'Flexibility'] ) model_name_dropdown = gr.components.Dropdown( label='Model Name', value=list_models[0], choices=list_models ) pooling_dropdown = gr.components.Dropdown( label='Pooling', value='mean', choices=['mean', 'cls'] ) text_input = gr.components.Textbox( value=open('data/example_xlm.csv', 'r').read(), lines=10, ) file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx']) # output text_output = gr.components.Textbox(label='Output') dataframe_output = gr.components.Dataframe(label='DataFrame') file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx']) app = gr.Interface( fn=process, inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input], outputs=[text_output, dataframe_output, file_output], description=open('data/description.txt', 'r').read(), title='TransDis-CreativityAutoAssessment', ) app.launch()