Tonic commited on
Commit
4e567e9
·
verified ·
1 Parent(s): 947f881

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # Load model and tokenizer
6
+ model_path = "ibm-granite/granite-3.0-1b-a400m-instruct"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
8
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
9
+ model.eval()
10
+
11
+ def generate_response(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
12
+ chat = [
13
+ {"role": "user", "content": prompt},
14
+ ]
15
+ chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
16
+
17
+ input_tokens = tokenizer(chat, return_tensors="pt").to(model.device)
18
+
19
+ output = model.generate(
20
+ **input_tokens,
21
+ max_new_tokens=max_new_tokens,
22
+ temperature=temperature,
23
+ top_p=top_p,
24
+ repetition_penalty=repetition_penalty,
25
+ do_sample=True
26
+ )
27
+
28
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
29
+ return response.split("Human:", 1)[0].strip()
30
+
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("# 🙋🏻‍♂️Welcome to 🌟Tonic's🪨Granite-3.0-1B-A400M-Instruct Demo")
33
+ gr.Markdown("Enter a prompt and adjust generation parameters to interact with the 🪨Granite-3.0-1B-A400M-Instruct model.")
34
+
35
+ with gr.Row():
36
+ with gr.Column():
37
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5)
38
+ generate_button = gr.Button("Generate Response")
39
+ max_new_tokens = gr.Slider(minimum=1, maximum=500, value=100, step=1, label="Max New Tokens")
40
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
41
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P")
42
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
43
+
44
+ with gr.Column() :
45
+ output = gr.Textbox(label="🪨Granite3-1B", lines=10)
46
+
47
+ generate_button.click(
48
+ generate_response,
49
+ inputs=[prompt, max_new_tokens, temperature, top_p, repetition_penalty],
50
+ outputs=output
51
+ )
52
+
53
+ gr.Markdown("## Examples")
54
+ examples = gr.Examples(
55
+ examples=[
56
+ ["Tell me about the history of artificial intelligence.", 200, 0.7, 0.9, 1.1],
57
+ ["Write a short story about a robot learning to paint.", 300, 0.8, 0.95, 1.2],
58
+ ["Explain the concept of quantum computing to a 10-year-old.", 150, 0.6, 0.85, 1.0],
59
+ ],
60
+ inputs=[prompt, max_new_tokens, temperature, top_p, repetition_penalty],
61
+ )
62
+
63
+ demo.launch()