Spaces:
Runtime error
Runtime error
import os | |
import json | |
import gradio as gr | |
import google.generativeai as genai | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# Set up the model | |
generation_config = { | |
"temperature": 0.9, | |
"top_p": 1, | |
"top_k": 1, | |
"max_output_tokens": 2048, | |
} | |
safety_settings = [ | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_MEDIUM_AND_ABOVE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_MEDIUM_AND_ABOVE", | |
}, | |
] | |
model = genai.GenerativeModel( | |
model_name="gemini-pro", | |
generation_config=generation_config, | |
safety_settings=safety_settings, | |
) | |
task_description = " You are an SMS (Short Message Service) reader who reads every message that the short message service centre receives and you need to classify each message among the following categories: {}<div>Let the output be a softmax function output giving the probability of message belonging to each category.</div><div>The sum of the probabilities should be 1</div><div>The output must be in JSON format</div>" | |
def classify_msg(categories, message): | |
prompt_parts = [ | |
task_description.format(categories), | |
f"Message: {message}", | |
"Category: ", | |
] | |
response = model.generate_content(prompt_parts) | |
json_response = json.loads( | |
response.text[response.text.find("{") : response.text.rfind("}") + 1] | |
) | |
return gr.Label(json_response) | |
def clear_inputs_and_outputs(): | |
return [None, None, None] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
<h1 align="center">Multi-language Text Classifier using Gemini Pro</h1> \ | |
This space uses Gemini Pro in order to classify texts.<br> \ | |
Depending on the list of categories that you specify, you can have text classifier, a SPAM detector, a sentiment classifier, ... <br><br> \ | |
<b>For the categories, enter a list of words separated by commas</b><br><br> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
categories = gr.Textbox( | |
label="Categories", | |
placeholder="Input the list of categories as comma separated words", | |
) | |
with gr.Row(): | |
message = gr.Textbox(label="Message", placeholder="Enter Message") | |
with gr.Row(): | |
clr_btn = gr.Button(value="Clear", variant="secondary") | |
csf_btn = gr.Button(value="Classify") | |
with gr.Column(): | |
lbl_output = gr.Label(label="Prediction") | |
clr_btn.click( | |
fn=clear_inputs_and_outputs, | |
inputs=[], | |
outputs=[categories, message, lbl_output], | |
) | |
csf_btn.click( | |
fn=classify_msg, | |
inputs=[categories, message], | |
outputs=[lbl_output], | |
) | |
gr.Examples( | |
examples=[ | |
["Normal, Promotional, Urgent", "Will you be passing by?"], | |
["Spam, Ham", "Plus de 300 % de perte de poids pendant le régime."], | |
["Χαρούμενος, Δυστυχισμένος", "Η εξυπηρέτηση σας ήταν απαίσια"], | |
["مهم، أقل أهمية ", "خبر عاجل"], | |
], | |
inputs=[categories, message], | |
outputs=lbl_output, | |
fn=classify_msg, | |
cache_examples=True, | |
) | |
demo.queue(api_open=False) | |
demo.launch(debug=True, share=True, show_api=False) | |