File size: 1,041 Bytes
f2aeda3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python
# coding: utf-8

# In[2]:


import gradio as gr
import torch


# In[3]:


model_ckpt = "langfab/distilbert-base-uncased-finetuned-movie-genre"

from transformers import (AutoTokenizer, AutoConfig,
                          AutoModelForSequenceClassification)

tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
config = AutoConfig.from_pretrained(model_ckpt)
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,config=config)


# In[4]:


id2label = model.config.id2label

def predict(plot):
    encoding = tokenizer(plot,  padding=True, truncation=True, return_tensors="pt")
    encoding = {k: v.to(model.device) for k,v in encoding.items()}

    outputs = model(**encoding)

    logits = outputs.logits
    logits.shape

    predictions = torch.nn.functional.softmax(logits.squeeze().cpu(), dim=-1)
    predictions

    return id2label[int(predictions.argmax())]
    
iface = gr.Interface(title = "Movie Plot Genre Predictor", fn=predict, inputs="text", outputs="text")
iface.launch(share=True)