Bloom_test / app.py
andrew3279's picture
Update app.py
086b4de
import transformers
from transformers import BloomForCausalLM
from transformers import BloomTokenizerFast
import torch
import gradio as gr
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model_name = "bigscience/bloom-1b1"
model = BloomForCausalLM.from_pretrained(model_name)
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
# Define the pipeline for Gradio purpose
def beam_gradio_pipeline(prompt,length=100):
result_length = length
inputs = tokenizer(prompt, return_tensors="pt").to(device)
return tokenizer.decode(model.generate(inputs["input_ids"],
max_length=result_length,
num_beams=2,
no_repeat_ngram_size=2,
early_stopping=True
)[0])
with gr.Blocks() as web:
gr.Markdown("<h1><center>Andrew Lim Bloom Test </center></h1>")
gr.Markdown("""<h2><center>Generate your story with a sentence or ask a question:<br><br>
<img src=https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1634806038075-5df7e9e5da6d0311fd3d53f9.png?w=200&h=200&f=face width=200px></center></h2>""")
gr.Markdown("""<center>******</center>""")
input_text = gr.Textbox(label="Prompt", lines=6)
buton = gr.Button("Submit ")
output_text = gr.Textbox(lines=6, label="The story start with :")
buton.click(beam_gradio_pipeline, inputs=[input_text], outputs=output_text)
web.launch()