|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
model_ckpt = "langfab/distilbert-base-uncased-finetuned-movie-genre" |
|
|
|
from transformers import (AutoTokenizer, AutoConfig, |
|
AutoModelForSequenceClassification) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_ckpt) |
|
config = AutoConfig.from_pretrained(model_ckpt) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,config=config) |
|
|
|
|
|
|
|
|
|
|
|
id2label = model.config.id2label |
|
|
|
def predict(plot): |
|
encoding = tokenizer(plot, padding=True, truncation=True, return_tensors="pt") |
|
encoding = {k: v.to(model.device) for k,v in encoding.items()} |
|
|
|
outputs = model(**encoding) |
|
|
|
logits = outputs.logits |
|
logits.shape |
|
|
|
predictions = torch.nn.functional.softmax(logits.squeeze().cpu(), dim=-1) |
|
predictions |
|
|
|
return id2label[int(predictions.argmax())] |
|
|
|
iface = gr.Interface(title = "Movie Plot Genre Predictor", fn=predict, inputs="text", outputs="text") |
|
iface.launch(share=True) |
|
|
|
|