File size: 1,435 Bytes
490823b d38df3e 490823b d38df3e 490823b d38df3e 490823b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
from pathlib import Path
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from transformers import pipeline
model_name = "databricks/dolly-v2-12b"
local_dir = f"./models/{model_name}"
if not Path(local_dir).exists() or len(os.listdir(local_dir)) == 0:
snapshot_download(model_name, local_dir=local_dir)
generate_text = pipeline(model=local_dir, torch_dtype=torch.bfloat16, trust_remote_code=True,
device_map="auto")
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
with gr.Blocks(theme=theme) as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
history[-1][1] = generate_text(history[-1][0])
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=3000)
# def greet(name):
# return "Hello " + name + "!!"
# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# iface.launch() |