File size: 3,028 Bytes
8f29b1d
3f6f474
cf575f8
 
 
 
ec1d54e
cf575f8
d654474
 
cf575f8
 
 
 
e32f803
cf575f8
 
 
 
 
 
 
613e689
 
 
dd2409d
613e689
 
 
8f29b1d
e691ea0
8f29b1d
 
 
 
 
 
 
 
 
 
ec1d54e
 
 
 
8f29b1d
 
 
 
 
 
 
 
 
 
 
 
 
 
e691ea0
 
 
 
 
 
 
 
 
8f29b1d
d654474
cf575f8
0e97d35
d654474
 
 
 
 
 
3f6f474
d654474
 
cf575f8
dd2409d
 
 
 
 
3f6f474
8920952
cf575f8
3f6f474
613e689
3f6f474
0e97d35
613e689
 
 
cf575f8
 
 
dd2409d
b49f004
a4780a0
 
ec1d54e
cf575f8
ec1d54e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import traceback
from io import StringIO
from typing import Optional

import gradio as gr
import pandas as pd
from loguru import logger

from utils import pipeline
from utils.models import list_models


def read_data(filepath: str) -> Optional[pd.DataFrame]:
    if filepath.endswith('.xlsx'):
        df = pd.read_excel(filepath)
    elif filepath.endswith('.csv'):
        df = pd.read_csv(filepath)
    else:
        raise Exception('File type not supported')
    return df


def process(
        task_name: str,
        model_name: str,
        pooling: str,
        text: str,
        file=None,
) -> (None, pd.DataFrame, str):
    try:
        logger.info(f'Processing {task_name} with {model_name} and {pooling}')
        # load file
        if file:
            df = read_data(file.name)
        elif text:
            string_io = StringIO(text)
            df = pd.read_csv(string_io)
            assert len(df) >= 1, 'No input data'
        else:
            raise Exception('No input data')

        # check
        if len(df) > 10000:
            raise Exception('Data exceeds 10,000 rows')

        # process
        if task_name == 'Originality':
            df = pipeline.p0_originality(df, model_name, pooling)
        elif task_name == 'Flexibility':
            df = pipeline.p1_flexibility(df, model_name, pooling)
        else:
            raise Exception('Task not supported')

        # save
        path = 'output.csv'
        df.to_csv(path, index=False, encoding='utf-8-sig')
        return None, df.iloc[:10], path

    except:
        error = traceback.format_exc()
        logger.warning({
            'error': error,
            'task_name': task_name,
            'model_name': model_name,
            'pooling': pooling,
            'text': text,
            'file': file,
        })
        return {'Info': 'Something wrong', 'Error': traceback.format_exc()}, None, None


# input
task_name_dropdown = gr.components.Dropdown(
    label='Task Name',
    value='Originality',
    choices=['Originality', 'Flexibility']
)
model_name_dropdown = gr.components.Dropdown(
    label='Model Name',
    value=list_models[0],
    choices=list_models
)
pooling_dropdown = gr.components.Dropdown(
    label='Pooling',
    value='mean',
    choices=['mean', 'cls']
)
text_input = gr.components.Textbox(
    value=open('data/example_xlm.csv', 'r').read(),
    lines=10,
)
file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx'])

# output
text_output = gr.components.Textbox(label='Output')
dataframe_output = gr.components.Dataframe(label='DataFrame')
file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx'])

app = gr.Interface(
    fn=process,
    inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
    outputs=[text_output, dataframe_output, file_output],
    description=open('data/description.txt', 'r').read(),
    title='TransDis-CreativityAutoAssessment',
    concurrency_limit=1,
)
app.launch(max_threads=1)