import os from pathlib import Path import gradio as gr import torch from huggingface_hub import snapshot_download from transformers import pipeline model_name = "databricks/dolly-v2-12b" local_dir = f"./models/{model_name}" if not Path(local_dir).exists() or len(os.listdir(local_dir)) == 0: snapshot_download(model_name, local_dir=local_dir) generate_text = pipeline(model=local_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], ) with gr.Blocks(theme=theme) as demo: chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") def user(user_message, history): return "", history + [[user_message, None]] def bot(history): history[-1][1] = generate_text(history[-1][0]) return history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=3000) # def greet(name): # return "Hello " + name + "!!" # iface = gr.Interface(fn=greet, inputs="text", outputs="text") # iface.launch()