Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[2]: | |
import gradio as gr | |
import torch | |
# In[3]: | |
model_ckpt = "langfab/distilbert-base-uncased-finetuned-movie-genre" | |
from transformers import (AutoTokenizer, AutoConfig, | |
AutoModelForSequenceClassification) | |
tokenizer = AutoTokenizer.from_pretrained(model_ckpt) | |
config = AutoConfig.from_pretrained(model_ckpt) | |
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,config=config) | |
# In[4]: | |
id2label = model.config.id2label | |
def predict(plot): | |
encoding = tokenizer(plot, padding=True, truncation=True, return_tensors="pt") | |
encoding = {k: v.to(model.device) for k,v in encoding.items()} | |
outputs = model(**encoding) | |
logits = outputs.logits | |
logits.shape | |
predictions = torch.nn.functional.softmax(logits.squeeze().cpu(), dim=-1) | |
predictions | |
return id2label[int(predictions.argmax())] | |
iface = gr.Interface(title = "Movie Plot Genre Predictor", fn=predict, inputs="text", outputs="text") | |
iface.launch(share=True) | |