File size: 3,636 Bytes
197c6e8
 
 
 
 
d2763bd
197c6e8
 
 
 
d2763bd
 
 
 
197c6e8
 
 
d2763bd
 
 
 
 
 
 
 
 
 
197c6e8
 
 
 
 
d2763bd
197c6e8
 
 
 
d2763bd
197c6e8
 
 
 
 
 
 
 
 
 
d2763bd
197c6e8
 
 
 
d2763bd
197c6e8
 
 
d2763bd
197c6e8
 
 
3a6d513
 
 
 
197c6e8
 
 
 
 
d2763bd
 
 
 
197c6e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2763bd
 
 
 
197c6e8
 
 
 
 
 
 
532ddba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import json
import gradio as gr
import google.generativeai as genai

GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
genai.configure(api_key=GOOGLE_API_KEY)

# Set up the model
generation_config = {
    "temperature": 0.9,
    "top_p": 1,
    "top_k": 1,
    "max_output_tokens": 2048,
}

safety_settings = [
    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE",
    },
]

model = genai.GenerativeModel(
    model_name="gemini-pro",
    generation_config=generation_config,
    safety_settings=safety_settings,
)

task_description = " You are an SMS (Short Message Service) reader who reads every message that the short message service centre receives and you need to classify each message among the following categories: {}<div>Let the output be a softmax function output giving the probability of message belonging to each category.</div><div>The sum of the probabilities should be 1</div><div>The output must be in JSON format</div>"


def classify_msg(categories, message):
    prompt_parts = [
        task_description.format(categories),
        f"Message: {message}",
        "Category: ",
    ]

    response = model.generate_content(prompt_parts)

    json_response = json.loads(
        response.text[response.text.find("{") : response.text.rfind("}") + 1]
    )

    return gr.Label(json_response)


def clear_inputs_and_outputs():
    return [None, None, None]


with gr.Blocks() as demo:
    gr.Markdown(
        """
        <h1 align="center">Multi-language Text Classifier using Gemini Pro</h1> \
        This space uses Gemini Pro in order to classify texts.<br> \
        Depending on the list of categories that you specify, you can have text classifier, a SPAM detector, a sentiment classifier, ... <br><br> \
        <b>For the categories, enter a list of words separated by commas</b><br><br>
        """
    )
    with gr.Row():
        with gr.Column():
            with gr.Row():
                categories = gr.Textbox(
                    label="Categories",
                    placeholder="Input the list of categories as comma separated words",
                )
            with gr.Row():
                message = gr.Textbox(label="Message", placeholder="Enter Message")
            with gr.Row():
                clr_btn = gr.Button(value="Clear", variant="secondary")
                csf_btn = gr.Button(value="Classify")
        with gr.Column():
            lbl_output = gr.Label(label="Prediction")

    clr_btn.click(
        fn=clear_inputs_and_outputs,
        inputs=[],
        outputs=[categories, message, lbl_output],
    )
    csf_btn.click(
        fn=classify_msg,
        inputs=[categories, message],
        outputs=[lbl_output],
    )

    gr.Examples(
        examples=[
            ["Normal, Promotional, Urgent", "Will you be passing by?"],
            ["Spam, Ham", "Plus de 300 % de perte de poids pendant le régime."],
            ["Χαρούμενος, Δυστυχισμένος", "Η εξυπηρέτηση σας ήταν απαίσια"],
            ["مهم، أقل أهمية ", "خبر عاجل"],
        ],
        inputs=[categories, message],
        outputs=lbl_output,
        fn=classify_msg,
        cache_examples=True,
    )

demo.queue(api_open=False)
demo.launch(debug=True, share=True, show_api=False)