not-lain's picture
skip original prompt
c8fdb3b
raw
history blame
527 Bytes
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
token = os.environ["HF_TOKEN"]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b",token=token)
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b",token=token)
streamer = TextStreamer(tokenizer,skip_prompt=True)
def generate(inputs):
inputs = tokenizer([inputs], return_tensors="pt")
yield model.generate(**inputs, streamer=streamer)
app = gr.ChatInterface(generate)
app.launch(debug=True)