emoji-predictor / app.py
vincentclaes's picture
update model
08c842e
raw
history blame
1.19 kB
import gradio as gr
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
checkpoint = "vincentclaes/emoji-predictor"
no_of_emojis = range(20)
emojis_as_images = [Image.open(f"emojis/{i}.png") for i in no_of_emojis]
K = 4
processor = CLIPProcessor.from_pretrained(checkpoint)
model = CLIPModel.from_pretrained(checkpoint)
def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K=4):
inputs = processor(text=text, images=emojis, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
# we take the softmax to get the label probabilities
probs = logits_per_text.softmax(dim=1)
# top K number of options
predictions_suggestions_for_chunk = [torch.topk(prob, K).indices.tolist() for prob in probs][0]
predictions_suggestions_for_chunk
return [f"emojis/{i}.png" for i in predictions_suggestions_for_chunk]
text = gr.inputs.Textbox()
title = "Predicting an Emoji"
examples = ["I'm so glad I finally arrived in my holiday resort!"]
gr.Interface(fn=get_emoji, inputs=text, outputs=gr.Gallery(), examples=examples, title=title).launch()