text-classifier / app.py
fbadine's picture
Update app.py
532ddba
raw
history blame contribute delete
No virus
3.64 kB
import os
import json
import gradio as gr
import google.generativeai as genai
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
genai.configure(api_key=GOOGLE_API_KEY)
# Set up the model
generation_config = {
"temperature": 0.9,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]
model = genai.GenerativeModel(
model_name="gemini-pro",
generation_config=generation_config,
safety_settings=safety_settings,
)
task_description = " You are an SMS (Short Message Service) reader who reads every message that the short message service centre receives and you need to classify each message among the following categories: {}<div>Let the output be a softmax function output giving the probability of message belonging to each category.</div><div>The sum of the probabilities should be 1</div><div>The output must be in JSON format</div>"
def classify_msg(categories, message):
prompt_parts = [
task_description.format(categories),
f"Message: {message}",
"Category: ",
]
response = model.generate_content(prompt_parts)
json_response = json.loads(
response.text[response.text.find("{") : response.text.rfind("}") + 1]
)
return gr.Label(json_response)
def clear_inputs_and_outputs():
return [None, None, None]
with gr.Blocks() as demo:
gr.Markdown(
"""
<h1 align="center">Multi-language Text Classifier using Gemini Pro</h1> \
This space uses Gemini Pro in order to classify texts.<br> \
Depending on the list of categories that you specify, you can have text classifier, a SPAM detector, a sentiment classifier, ... <br><br> \
<b>For the categories, enter a list of words separated by commas</b><br><br>
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
categories = gr.Textbox(
label="Categories",
placeholder="Input the list of categories as comma separated words",
)
with gr.Row():
message = gr.Textbox(label="Message", placeholder="Enter Message")
with gr.Row():
clr_btn = gr.Button(value="Clear", variant="secondary")
csf_btn = gr.Button(value="Classify")
with gr.Column():
lbl_output = gr.Label(label="Prediction")
clr_btn.click(
fn=clear_inputs_and_outputs,
inputs=[],
outputs=[categories, message, lbl_output],
)
csf_btn.click(
fn=classify_msg,
inputs=[categories, message],
outputs=[lbl_output],
)
gr.Examples(
examples=[
["Normal, Promotional, Urgent", "Will you be passing by?"],
["Spam, Ham", "Plus de 300 % de perte de poids pendant le régime."],
["Χαρούμενος, Δυστυχισμένος", "Η εξυπηρέτηση σας ήταν απαίσια"],
["مهم، أقل أهمية ", "خبر عاجل"],
],
inputs=[categories, message],
outputs=lbl_output,
fn=classify_msg,
cache_examples=True,
)
demo.queue(api_open=False)
demo.launch(debug=True, share=True, show_api=False)