#!/usr/bin/env python # coding: utf-8 # In[2]: import gradio as gr import torch # In[3]: model_ckpt = "langfab/distilbert-base-uncased-finetuned-movie-genre" from transformers import (AutoTokenizer, AutoConfig, AutoModelForSequenceClassification) tokenizer = AutoTokenizer.from_pretrained(model_ckpt) config = AutoConfig.from_pretrained(model_ckpt) model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,config=config) # In[4]: id2label = model.config.id2label def predict(plot): encoding = tokenizer(plot, padding=True, truncation=True, return_tensors="pt") encoding = {k: v.to(model.device) for k,v in encoding.items()} outputs = model(**encoding) logits = outputs.logits logits.shape predictions = torch.nn.functional.softmax(logits.squeeze().cpu(), dim=-1) predictions return id2label[int(predictions.argmax())] iface = gr.Interface(title = "Movie Plot Genre Predictor", fn=predict, inputs="text", outputs="text") iface.launch(share=True)