emoji-predictor / app.py
vincentclaes's picture
first version
8b891df
raw
history blame
No virus
1.14 kB
import gradio as gr
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
checkpoint = "vincentclaes/emoji-predictor-few-shot"
no_of_emojis = range(20)
emojis_as_images = [Image.open(f"emojis/{i}.png") for i in no_of_emojis]
K = 4
processor = CLIPProcessor.from_pretrained(checkpoint)
model = CLIPModel.from_pretrained(checkpoint)
def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K=4):
inputs = processor(text=text, images=emojis, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
logits_per_text = outputs.logits_per_text
# we take the softmax to get the label probabilities
probs = logits_per_text.softmax(dim=1)
# top K number of options
predictions_suggestions_for_chunk = [torch.topk(prob, K).indices.tolist() for prob in probs][0]
predictions_suggestions_for_chunk
return [f"emojis/{i}.png" for i in predictions_suggestions_for_chunk]
text = gr.inputs.Textbox()
title = "Predicting an Emoji"
gr.Interface(fn=get_emoji, inputs=text, outputs=gr.Gallery(), title=title, enable_queue=True).launch(debug=True)