File size: 2,124 Bytes
e9a8d2e
 
d70f01c
 
 
 
e9a8d2e
 
 
 
 
 
44602c4
 
d70f01c
 
e9a8d2e
 
 
 
 
d70f01c
 
 
 
 
 
 
 
 
 
e9a8d2e
 
 
 
6807f38
 
 
e9a8d2e
 
 
 
 
 
6807f38
44602c4
6807f38
e9a8d2e
 
 
 
d70f01c
 
44602c4
 
e9a8d2e
 
 
 
6807f38
 
 
d70f01c
6807f38
 
 
 
 
 
 
 
 
d70f01c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os

import torch
import boto3

import gradio as gr
import pandas as pd

from transformers import CLIPProcessor, CLIPModel

checkpoint = "vincentclaes/emoji-predictor"
adjectives = pd.read_table("./adjectives.txt", header=None)[0].to_list()
K = 10
THRESHOLD = 0.05
APP_NAME = "emoji-tagging"
BUCKET = "drift-pilot-ml-model"

processor = CLIPProcessor.from_pretrained(checkpoint)
model = CLIPModel.from_pretrained(checkpoint)


def log_inference():
    if os.environ["CLIENT"]:
        boto3.client("s3").put_object(
            Body=more_binary_data,
            Bucket=BUCKET,
            Key=f"${APP_NAME}/",
        )


def get_tag(emoji, tags="", expected="", model=model, processor=processor, K=K):
    if tags:
        tags = tags.strip().split(",")
    else:
        tags = adjectives
    inputs = processor(
        text=tags, images=emoji, return_tensors="pt", padding=True, truncation=True
    )
    outputs = model(**inputs)

    # we take the softmax to get the label probabilities
    probs = outputs.logits_per_text.softmax(dim=0)
    probs_formatted = torch.tensor([prob[0] for prob in probs])
    values, indices = probs_formatted.topk(K)
    return "Tag (confidence): " + ", ".join(
        [f"{tags[i]} ({round(v.item(), 2)})" for v, i in zip(values, indices) if v >= THRESHOLD]
    )


title = "Tagging an Emoji"
description = """You provide an Emoji and our few-shot fine tuned CLIP model will suggest some tags that are appropriate.\n

- We use the [228 most common adjectives in english](https://grammar.yourdictionary.com/parts-of-speech/adjectives/list-of-adjective-words.html).\n
- We show max 10 tags and only when the confidence is higher than 5% (0.05)

"""

examples = [[f"emojis/{i}.png"] for i in range(32)]

text = gr.inputs.Textbox(
    placeholder="Enter a text and we will try to predict an emoji..."
)
app = gr.Interface(
    fn=get_tag,
    inputs=[
        gr.components.Image(type="pil", label="emoji"),
    ],
    outputs=gr.Textbox(),
    examples=examples,
    examples_per_page=32,
    title=title,
    description=description,
)

if __name__ == "__main__":
    app.launch()