tuned-lens / app.py
Lev McKinney
Attempting to render plot
4c39b84
raw
history blame
584 Bytes
from platform import python_version
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn import TunedLens
from tuned_lens.plotting import plot_lens
import gradio as gr
LENS_PATH = '<PATH TO LENS>'
def plot_lens_outputs(text):
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
#lens = TunedLens.load(LENS_PATH)
return gr.outputs.Plot(plot_lens(model, tokenizer, text=text))
iface = gr.Interface(fn=plot_lens_outputs, inputs="text", outputs=gr.outputs.Plot(type="auto"))
iface.launch()