Spaces:
Runtime error
Runtime error
from platform import python_version | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from tuned_lens.nn import TunedLens | |
from tuned_lens.plotting import plot_lens | |
import gradio as gr | |
LENS_PATH = '<PATH TO LENS>' | |
def plot_lens_outputs(text): | |
model = AutoModelForCausalLM.from_pretrained('gpt2') | |
tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
#lens = TunedLens.load(LENS_PATH) | |
return gr.outputs.Plot(plot_lens(model, tokenizer, text=text)) | |
iface = gr.Interface(fn=plot_lens_outputs, inputs="text", outputs=gr.outputs.Plot(type="auto")) | |
iface.launch() | |